mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Reuse a single func to format conversation for Gemini
This deduplicates code and prevents logic from deviating across gemini chat actors
This commit is contained in:
@@ -9,6 +9,7 @@ from langchain.schema import ChatMessage
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.google.utils import (
|
||||
format_messages_for_gemini,
|
||||
gemini_chat_completion_with_backoff,
|
||||
gemini_completion_with_backoff,
|
||||
)
|
||||
@@ -105,15 +106,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
system_prompt = None
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
else:
|
||||
system_prompt = ""
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
messages, system_prompt = format_messages_for_gemini(messages)
|
||||
|
||||
model_kwargs = {}
|
||||
if response_type == "json_object":
|
||||
@@ -195,14 +188,7 @@ def converse_gemini(
|
||||
tokenizer_name=tokenizer_name,
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
|
||||
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
||||
logger.debug(f"Conversation Context for Gemini: {truncated_messages}")
|
||||
|
||||
@@ -10,6 +10,7 @@ from google.generativeai.types.safety_types import (
|
||||
HarmCategory,
|
||||
HarmProbability,
|
||||
)
|
||||
from langchain.schema import ChatMessage
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
@@ -19,6 +20,7 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -182,3 +184,23 @@ def generate_safety_response(safety_ratings):
|
||||
return safety_response_choice.format(
|
||||
category=max_safety_category, probability=max_safety_rating.probability.name, discomfort_level=discomfort_level
|
||||
)
|
||||
|
||||
|
||||
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
return messages, system_prompt
|
||||
|
||||
for message in messages:
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
# Extract system message
|
||||
system_prompt = system_prompt or ""
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||
|
||||
return messages, system_prompt
|
||||
|
||||
Reference in New Issue
Block a user