mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 05:39:12 +00:00
Reuse a single func to format conversation for Gemini
This deduplicates code and prevents logic from deviating across gemini chat actors
This commit is contained in:
@@ -9,6 +9,7 @@ from langchain.schema import ChatMessage
|
|||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.google.utils import (
|
from khoj.processor.conversation.google.utils import (
|
||||||
|
format_messages_for_gemini,
|
||||||
gemini_chat_completion_with_backoff,
|
gemini_chat_completion_with_backoff,
|
||||||
gemini_completion_with_backoff,
|
gemini_completion_with_backoff,
|
||||||
)
|
)
|
||||||
@@ -105,15 +106,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
|
|||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
system_prompt = None
|
messages, system_prompt = format_messages_for_gemini(messages)
|
||||||
if len(messages) == 1:
|
|
||||||
messages[0].role = "user"
|
|
||||||
else:
|
|
||||||
system_prompt = ""
|
|
||||||
for message in messages.copy():
|
|
||||||
if message.role == "system":
|
|
||||||
system_prompt += message.content
|
|
||||||
messages.remove(message)
|
|
||||||
|
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
if response_type == "json_object":
|
if response_type == "json_object":
|
||||||
@@ -195,14 +188,7 @@ def converse_gemini(
|
|||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
for message in messages:
|
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||||
if message.role == "assistant":
|
|
||||||
message.role = "model"
|
|
||||||
|
|
||||||
for message in messages.copy():
|
|
||||||
if message.role == "system":
|
|
||||||
system_prompt += message.content
|
|
||||||
messages.remove(message)
|
|
||||||
|
|
||||||
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
||||||
logger.debug(f"Conversation Context for Gemini: {truncated_messages}")
|
logger.debug(f"Conversation Context for Gemini: {truncated_messages}")
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from google.generativeai.types.safety_types import (
|
|||||||
HarmCategory,
|
HarmCategory,
|
||||||
HarmProbability,
|
HarmProbability,
|
||||||
)
|
)
|
||||||
|
from langchain.schema import ChatMessage
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
retry,
|
retry,
|
||||||
@@ -19,6 +20,7 @@ from tenacity import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||||
|
from khoj.utils.helpers import is_none_or_empty
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -182,3 +184,23 @@ def generate_safety_response(safety_ratings):
|
|||||||
return safety_response_choice.format(
|
return safety_response_choice.format(
|
||||||
category=max_safety_category, probability=max_safety_rating.probability.name, discomfort_level=discomfort_level
|
category=max_safety_category, probability=max_safety_rating.probability.name, discomfort_level=discomfort_level
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
||||||
|
if len(messages) == 1:
|
||||||
|
messages[0].role = "user"
|
||||||
|
return messages, system_prompt
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if message.role == "assistant":
|
||||||
|
message.role = "model"
|
||||||
|
|
||||||
|
# Extract system message
|
||||||
|
system_prompt = system_prompt or ""
|
||||||
|
for message in messages.copy():
|
||||||
|
if message.role == "system":
|
||||||
|
system_prompt += message.content
|
||||||
|
messages.remove(message)
|
||||||
|
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||||
|
|
||||||
|
return messages, system_prompt
|
||||||
|
|||||||
Reference in New Issue
Block a user