From e3f6d241dd02bdd3a1a95a86e280c476670d176d Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 25 Mar 2025 13:32:07 +0530 Subject: [PATCH] Normalize chat messages sent to gemini funcs to work with prompt tracer Previously messages passed to gemini (chat) completion functions got a little of Gemini specific formatting mixed in. These functions expect a message of type list[ChatMessage] to work with prompt tracer etc. Move the code to format messages of type list[ChatMessage] into gemini specific format down to the gemini (chat) completion functions. This allows the rest of the functionality like prompt tracing to work with normalize list[ChatMesssage] type of chat messages across providers --- src/khoj/processor/conversation/google/gemini_chat.py | 9 +++------ src/khoj/processor/conversation/google/utils.py | 10 ++++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 6f518c04..3c630dec 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -134,8 +134,6 @@ def gemini_send_message_to_model( """ Send message to model """ - messages_for_gemini, system_prompt = format_messages_for_gemini(messages) - model_kwargs = {} # This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series. @@ -145,8 +143,8 @@ def gemini_send_message_to_model( # Get Response from Gemini return gemini_completion_with_backoff( - messages=messages_for_gemini, - system_prompt=system_prompt, + messages=messages, + system_prompt="", model_name=model, api_key=api_key, api_base_url=api_base_url, @@ -244,12 +242,11 @@ def converse_gemini( program_execution_context=program_execution_context, ) - messages_for_gemini, system_prompt = format_messages_for_gemini(messages, system_prompt) logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}") # Get Response from Google AI return gemini_chat_completion_with_backoff( - messages=messages_for_gemini, + messages=messages, compiled_references=references, online_results=online_results, model_name=model, diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index d63010b2..b224e2a0 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -80,6 +80,8 @@ def gemini_completion_with_backoff( client = get_gemini_client(api_key, api_base_url) gemini_clients[api_key] = client + formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt) + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( system_instruction=system_prompt, @@ -91,8 +93,6 @@ def gemini_completion_with_backoff( seed=seed, ) - formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages] - try: # Generate the response response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages) @@ -165,6 +165,8 @@ def gemini_llm_thread( client = get_gemini_client(api_key, api_base_url) gemini_clients[api_key] = client + formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt) + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( system_instruction=system_prompt, @@ -176,7 +178,6 @@ def gemini_llm_thread( ) aggregated_response = "" - formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages] for chunk in client.models.generate_content_stream( model=model_name, config=config, contents=formatted_messages @@ -300,4 +301,5 @@ def format_messages_for_gemini( if len(messages) == 1: messages[0].role = "user" - return messages, system_prompt + formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages] + return formatted_messages, system_prompt