From 2b8f7f3efb7dcd715dbcb614747ab1c4babb424e Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 6 Oct 2024 15:42:42 -0700 Subject: [PATCH] Reuse a single func to format conversation for Gemini This deduplicates code and prevents logic from deviating across gemini chat actors --- .../conversation/google/gemini_chat.py | 20 +++-------------- .../processor/conversation/google/utils.py | 22 +++++++++++++++++++ 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 84d16dbd..9936c398 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -9,6 +9,7 @@ from langchain.schema import ChatMessage from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.google.utils import ( + format_messages_for_gemini, gemini_chat_completion_with_backoff, gemini_completion_with_backoff, ) @@ -105,15 +106,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text") """ Send message to model """ - system_prompt = None - if len(messages) == 1: - messages[0].role = "user" - else: - system_prompt = "" - for message in messages.copy(): - if message.role == "system": - system_prompt += message.content - messages.remove(message) + messages, system_prompt = format_messages_for_gemini(messages) model_kwargs = {} if response_type == "json_object": @@ -195,14 +188,7 @@ def converse_gemini( tokenizer_name=tokenizer_name, ) - for message in messages: - if message.role == "assistant": - message.role = "model" - - for message in messages.copy(): - if message.role == "system": - system_prompt += message.content - messages.remove(message) + messages, system_prompt = format_messages_for_gemini(messages, system_prompt) truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages}) logger.debug(f"Conversation Context for Gemini: {truncated_messages}") diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 4af724f6..7ef99a92 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -10,6 +10,7 @@ from google.generativeai.types.safety_types import ( HarmCategory, HarmProbability, ) +from langchain.schema import ChatMessage from tenacity import ( before_sleep_log, retry, @@ -19,6 +20,7 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ThreadedGenerator +from khoj.utils.helpers import is_none_or_empty logger = logging.getLogger(__name__) @@ -182,3 +184,23 @@ def generate_safety_response(safety_ratings): return safety_response_choice.format( category=max_safety_category, probability=max_safety_rating.probability.name, discomfort_level=discomfort_level ) + + +def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]: + if len(messages) == 1: + messages[0].role = "user" + return messages, system_prompt + + for message in messages: + if message.role == "assistant": + message.role = "model" + + # Extract system message + system_prompt = system_prompt or "" + for message in messages.copy(): + if message.role == "system": + system_prompt += message.content + messages.remove(message) + system_prompt = None if is_none_or_empty(system_prompt) else system_prompt + + return messages, system_prompt