diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 44ba4ec2..036e798b 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -205,7 +205,7 @@ def configure_conversation_processor(conversation_processor_config): else: # Initialize Conversation Logs conversation_processor.meta_log = {} - conversation_processor.chat_session = "" + conversation_processor.chat_session = [] return conversation_processor @@ -225,9 +225,9 @@ def save_chat_session(): chat_session = state.processor_config.conversation.chat_session openai_api_key = state.processor_config.conversation.openai_api_key conversation_log = state.processor_config.conversation.meta_log - model = state.processor_config.conversation.model + chat_model = state.processor_config.conversation.chat_model session = { - "summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key), + "summary": summarize(chat_session, model=chat_model, api_key=openai_api_key), "session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"], "session-end": len(conversation_log["chat"]), } @@ -242,7 +242,7 @@ def save_chat_session(): with open(conversation_logfile, "w+", encoding="utf-8") as logfile: json.dump(conversation_log, logfile, indent=2) - state.processor_config.conversation.chat_session = None + state.processor_config.conversation.chat_session = [] logger.info("📩 Saved current chat session to conversation logs") diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index fbcd4da5..226af3fb 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -3,6 +3,9 @@ import logging from datetime import datetime from typing import Optional +# External Packages +from langchain.schema import ChatMessage + # Internal Packages from khoj.utils.constants import empty_escape_sequences from khoj.processor.conversation import prompts @@ -16,22 +19,16 @@ from khoj.processor.conversation.utils import ( logger = logging.getLogger(__name__) -def summarize(text, summary_type, model, user_query=None, api_key=None, temperature=0.5, max_tokens=200): +def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200): """ - Summarize user input using OpenAI's GPT + Summarize conversation session using the specified OpenAI chat model """ - # Setup Prompt based on Summary Type - if summary_type == "chat": - prompt = prompts.summarize_chat.format(text=text) - elif summary_type == "notes": - prompt = prompts.summarize_notes.format(text=text, user_query=user_query) - else: - raise ValueError(f"Invalid summary type: {summary_type}") + messages = [ChatMessage(content=prompts.summarize_chat.format(), role="system")] + session # Get Response from GPT - logger.debug(f"Prompt for GPT: {prompt}") + logger.debug(f"Prompt for GPT: {messages}") response = completion_with_backoff( - prompt=prompt, + messages=messages, model_name=model, temperature=temperature, max_tokens=max_tokens, @@ -41,11 +38,11 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat ) # Extract, Clean Message from GPT's Response - return str(response).replace("\n\n", "") + return str(response.content).replace("\n\n", "") def extract_questions( - text, model: Optional[str] = "text-davinci-003", conversation_log={}, api_key=None, temperature=0, max_tokens=100 + text, model: Optional[str] = "gpt-4", conversation_log={}, api_key=None, temperature=0, max_tokens=100 ): """ Infer search queries to retrieve relevant notes to answer user query @@ -74,10 +71,11 @@ def extract_questions( chat_history=chat_history, text=text, ) + messages = [ChatMessage(content=prompt, role="assistant")] # Get Response from GPT response = completion_with_backoff( - prompt=prompt, + messages=messages, model_name=model, temperature=temperature, max_tokens=max_tokens, @@ -88,7 +86,7 @@ def extract_questions( # Extract, Clean Message from GPT's Response try: questions = ( - response.strip(empty_escape_sequences) + response.content.strip(empty_escape_sequences) .replace("['", '["') .replace("']", '"]') .replace("', '", '", "') diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 69443eaa..c04e9042 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -37,12 +37,7 @@ Question: {query} ## Summarize Chat ## -- summarize_chat = PromptTemplate.from_template( - """ -You are an AI. Summarize the conversation below from your perspective: - -{text} - -Summarize the conversation from the AI's first-person perspective:""" + f"{personality.format()} Summarize the conversation from your first person perspective" ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 99084bf0..e77b7899 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -9,7 +9,6 @@ import json # External Packages from langchain.chat_models import ChatOpenAI -from langchain.llms import OpenAI from langchain.schema import ChatMessage from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.base import BaseCallbackManager @@ -89,13 +88,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler): reraise=True, ) def completion_with_backoff(**kwargs): - prompt = kwargs.pop("prompt") - if "api_key" in kwargs: - kwargs["openai_api_key"] = kwargs.get("api_key") - else: + messages = kwargs.pop("messages") + if not "openai_api_key" in kwargs: kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") - llm = OpenAI(**kwargs, request_timeout=20, max_retries=1) - return llm(prompt) + llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1) + return llm(messages=messages) @retry( @@ -126,11 +123,12 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None): streaming=True, verbose=True, callback_manager=BaseCallbackManager([callback_handler]), - model_name=model_name, + model_name=model_name, # type: ignore temperature=temperature, openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), request_timeout=20, max_retries=1, + client=None, ) chat(messages=messages) @@ -196,15 +194,6 @@ def reciprocal_conversation_to_chatml(message_pair): return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])] -def message_to_prompt( - user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:" -): - """Create prompt for GPT from messages and conversation history""" - gpt_message = f" {gpt_message}" if gpt_message else "" - - return f"{conversation_history}{restart_sequence} {user_message}{start_sequence}{gpt_message}" - - def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]): """Create json logs from messages, metadata for conversation log""" default_khoj_message_metadata = { diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index e87a80d0..7ccdf5cd 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -15,7 +15,7 @@ from sentence_transformers import util # Internal Packages from khoj.configure import configure_processor, configure_search from khoj.processor.conversation.gpt import converse, extract_questions -from khoj.processor.conversation.utils import message_to_log, message_to_prompt +from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml from khoj.search_type import image_search, text_search from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter @@ -448,10 +448,9 @@ async def chat( user_message_time: str, compiled_references: List[str], inferred_queries: List[str], - chat_session: str, meta_log, ): - state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) + state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response]) state.processor_config.conversation.meta_log["chat"] = message_to_log( q, gpt_response, @@ -470,7 +469,6 @@ async def chat( ) # Load Conversation History - chat_session = state.processor_config.conversation.chat_session meta_log = state.processor_config.conversation.meta_log # If user query is empty, return nothing @@ -479,7 +477,6 @@ async def chat( # Initialize Variables api_key = state.processor_config.conversation.openai_api_key - model = state.processor_config.conversation.model chat_model = state.processor_config.conversation.chat_model user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") conversation_type = "general" if q.startswith("@general") else "notes" @@ -489,7 +486,7 @@ async def chat( if conversation_type == "notes": # Infer search queries from user message with timer("Extracting search queries took", logger): - inferred_queries = extract_questions(q, model=model, api_key=api_key, conversation_log=meta_log) + inferred_queries = extract_questions(q, api_key=api_key, conversation_log=meta_log) # Collate search results as context for GPT with timer("Searching knowledge base took", logger): @@ -525,7 +522,6 @@ async def chat( user_message_time=user_message_time, compiled_references=compiled_references, inferred_queries=inferred_queries, - chat_session=chat_session, meta_log=meta_log, ) diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 3adc6e9d..155cdcc6 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -72,7 +72,7 @@ class ConversationProcessorConfigModel: self.model = processor_config.model self.chat_model = processor_config.chat_model self.conversation_logfile = Path(processor_config.conversation_logfile) - self.chat_session = "" + self.chat_session: List[str] = [] self.meta_log: dict = {}