From 384f394336c18dccd2ccc66d6bebe669a867c3f5 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 23 Oct 2024 20:01:06 -0700 Subject: [PATCH] Allow OpenAI API calling functions to save conversation traces --- src/khoj/processor/conversation/openai/gpt.py | 16 +++++- .../processor/conversation/openai/utils.py | 52 +++++++++++++++---- 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 4a656fac..0d513268 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -33,6 +33,7 @@ def extract_questions( query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, + tracer: dict = {}, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -82,7 +83,13 @@ def extract_questions( messages = [ChatMessage(content=prompt, role="user")] response = send_message_to_model( - messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature + messages, + api_key, + model, + response_type="json_object", + api_base_url=api_base_url, + temperature=temperature, + tracer=tracer, ) # Extract, Clean Message from GPT's Response @@ -103,7 +110,9 @@ def extract_questions( return questions -def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0): +def send_message_to_model( + messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {} +): """ Send message to model """ @@ -116,6 +125,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba temperature=temperature, api_base_url=api_base_url, model_kwargs={"response_format": {"type": response_type}}, + tracer=tracer, ) @@ -137,6 +147,7 @@ def converse( agent: Agent = None, query_images: Optional[list[str]] = None, vision_available: bool = False, + tracer: dict = {}, ): """ Converse with user using OpenAI's ChatGPT @@ -209,4 +220,5 @@ def converse( api_base_url=api_base_url, completion_func=completion_func, model_kwargs={"stop": ["Notes:\n["]}, + tracer=tracer, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 878dbb9c..6e519f5a 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -12,7 +12,12 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator +from khoj.processor.conversation.utils import ( + ThreadedGenerator, + commit_conversation_trace, +) +from khoj.utils import state +from khoj.utils.helpers import in_debug_mode logger = logging.getLogger(__name__) @@ -33,7 +38,7 @@ openai_clients: Dict[str, openai.OpenAI] = {} reraise=True, ) def completion_with_backoff( - messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None + messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {} ) -> str: client_key = f"{openai_api_key}--{api_base_url}" client: openai.OpenAI | None = openai_clients.get(client_key) @@ -77,6 +82,12 @@ def completion_with_backoff( elif delta_chunk.content: aggregated_response += delta_chunk.content + # Save conversation trace + tracer["chat_model"] = model + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) + return aggregated_response @@ -103,26 +114,37 @@ def chat_completion_with_backoff( api_base_url=None, completion_func=None, model_kwargs=None, + tracer: dict = {}, ): g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( - target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs) + target=llm_thread, + args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer), ) t.start() return g -def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None): +def llm_thread( + g, + messages, + model_name, + temperature, + openai_api_key=None, + api_base_url=None, + model_kwargs=None, + tracer: dict = {}, +): try: client_key = f"{openai_api_key}--{api_base_url}" if client_key not in openai_clients: - client: openai.OpenAI = openai.OpenAI( + client = openai.OpenAI( api_key=openai_api_key, base_url=api_base_url, ) openai_clients[client_key] = client else: - client: openai.OpenAI = openai_clients[client_key] + client = openai_clients[client_key] formatted_messages = [{"role": message.role, "content": message.content} for message in messages] stream = True @@ -144,17 +166,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba **(model_kwargs or dict()), ) + aggregated_response = "" if not stream: - g.send(chat.choices[0].message.content) + aggregated_response = chat.choices[0].message.content + g.send(aggregated_response) else: for chunk in chat: if len(chunk.choices) == 0: continue delta_chunk = chunk.choices[0].delta + text_chunk = "" if isinstance(delta_chunk, str): - g.send(delta_chunk) + text_chunk = delta_chunk elif delta_chunk.content: - g.send(delta_chunk.content) + text_chunk = delta_chunk.content + if text_chunk: + aggregated_response += text_chunk + g.send(text_chunk) + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: logger.error(f"Error in llm_thread: {e}", exc_info=True) finally: