mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 21:29:11 +00:00
Normalize chat messages sent to gemini funcs to work with prompt tracer
Previously messages passed to gemini (chat) completion functions got a little of Gemini specific formatting mixed in. These functions expect a message of type list[ChatMessage] to work with prompt tracer etc. Move the code to format messages of type list[ChatMessage] into gemini specific format down to the gemini (chat) completion functions. This allows the rest of the functionality like prompt tracing to work with normalize list[ChatMesssage] type of chat messages across providers
This commit is contained in:
@@ -134,8 +134,6 @@ def gemini_send_message_to_model(
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
messages_for_gemini, system_prompt = format_messages_for_gemini(messages)
|
||||
|
||||
model_kwargs = {}
|
||||
|
||||
# This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series.
|
||||
@@ -145,8 +143,8 @@ def gemini_send_message_to_model(
|
||||
|
||||
# Get Response from Gemini
|
||||
return gemini_completion_with_backoff(
|
||||
messages=messages_for_gemini,
|
||||
system_prompt=system_prompt,
|
||||
messages=messages,
|
||||
system_prompt="",
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
@@ -244,12 +242,11 @@ def converse_gemini(
|
||||
program_execution_context=program_execution_context,
|
||||
)
|
||||
|
||||
messages_for_gemini, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
||||
|
||||
# Get Response from Google AI
|
||||
return gemini_chat_completion_with_backoff(
|
||||
messages=messages_for_gemini,
|
||||
messages=messages,
|
||||
compiled_references=references,
|
||||
online_results=online_results,
|
||||
model_name=model,
|
||||
|
||||
@@ -80,6 +80,8 @@ def gemini_completion_with_backoff(
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
gemini_clients[api_key] = client
|
||||
|
||||
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
@@ -91,8 +93,6 @@ def gemini_completion_with_backoff(
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||
|
||||
try:
|
||||
# Generate the response
|
||||
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
|
||||
@@ -165,6 +165,8 @@ def gemini_llm_thread(
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
gemini_clients[api_key] = client
|
||||
|
||||
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
@@ -176,7 +178,6 @@ def gemini_llm_thread(
|
||||
)
|
||||
|
||||
aggregated_response = ""
|
||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||
|
||||
for chunk in client.models.generate_content_stream(
|
||||
model=model_name, config=config, contents=formatted_messages
|
||||
@@ -300,4 +301,5 @@ def format_messages_for_gemini(
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
|
||||
return messages, system_prompt
|
||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||
return formatted_messages, system_prompt
|
||||
|
||||
Reference in New Issue
Block a user