Normalize chat messages sent to gemini funcs to work with prompt tracer

Previously messages passed to gemini (chat) completion functions
got a little of Gemini specific formatting mixed in.

These functions expect a message of type list[ChatMessage] to work
with prompt tracer etc.

Move the code to format messages of type list[ChatMessage] into gemini
specific format down to the gemini (chat) completion functions.

This allows the rest of the functionality like prompt tracing to work
with normalize list[ChatMesssage] type of chat messages across
providers
This commit is contained in:
Debanjum
2025-03-25 13:32:07 +05:30
parent 7976aa30f8
commit e3f6d241dd
2 changed files with 9 additions and 10 deletions

View File

@@ -134,8 +134,6 @@ def gemini_send_message_to_model(
"""
Send message to model
"""
messages_for_gemini, system_prompt = format_messages_for_gemini(messages)
model_kwargs = {}
# This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series.
@@ -145,8 +143,8 @@ def gemini_send_message_to_model(
# Get Response from Gemini
return gemini_completion_with_backoff(
messages=messages_for_gemini,
system_prompt=system_prompt,
messages=messages,
system_prompt="",
model_name=model,
api_key=api_key,
api_base_url=api_base_url,
@@ -244,12 +242,11 @@ def converse_gemini(
program_execution_context=program_execution_context,
)
messages_for_gemini, system_prompt = format_messages_for_gemini(messages, system_prompt)
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
# Get Response from Google AI
return gemini_chat_completion_with_backoff(
messages=messages_for_gemini,
messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model,

View File

@@ -80,6 +80,8 @@ def gemini_completion_with_backoff(
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
@@ -91,8 +93,6 @@ def gemini_completion_with_backoff(
seed=seed,
)
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
try:
# Generate the response
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
@@ -165,6 +165,8 @@ def gemini_llm_thread(
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
@@ -176,7 +178,6 @@ def gemini_llm_thread(
)
aggregated_response = ""
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
for chunk in client.models.generate_content_stream(
model=model_name, config=config, contents=formatted_messages
@@ -300,4 +301,5 @@ def format_messages_for_gemini(
if len(messages) == 1:
messages[0].role = "user"
return messages, system_prompt
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
return formatted_messages, system_prompt