mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Normalize chat messages sent to gemini funcs to work with prompt tracer
Previously messages passed to gemini (chat) completion functions got a little of Gemini specific formatting mixed in. These functions expect a message of type list[ChatMessage] to work with prompt tracer etc. Move the code to format messages of type list[ChatMessage] into gemini specific format down to the gemini (chat) completion functions. This allows the rest of the functionality like prompt tracing to work with normalize list[ChatMesssage] type of chat messages across providers
This commit is contained in:
@@ -134,8 +134,6 @@ def gemini_send_message_to_model(
|
|||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
messages_for_gemini, system_prompt = format_messages_for_gemini(messages)
|
|
||||||
|
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
|
|
||||||
# This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series.
|
# This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series.
|
||||||
@@ -145,8 +143,8 @@ def gemini_send_message_to_model(
|
|||||||
|
|
||||||
# Get Response from Gemini
|
# Get Response from Gemini
|
||||||
return gemini_completion_with_backoff(
|
return gemini_completion_with_backoff(
|
||||||
messages=messages_for_gemini,
|
messages=messages,
|
||||||
system_prompt=system_prompt,
|
system_prompt="",
|
||||||
model_name=model,
|
model_name=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
@@ -244,12 +242,11 @@ def converse_gemini(
|
|||||||
program_execution_context=program_execution_context,
|
program_execution_context=program_execution_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages_for_gemini, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
|
||||||
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from Google AI
|
# Get Response from Google AI
|
||||||
return gemini_chat_completion_with_backoff(
|
return gemini_chat_completion_with_backoff(
|
||||||
messages=messages_for_gemini,
|
messages=messages,
|
||||||
compiled_references=references,
|
compiled_references=references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
|
|||||||
@@ -80,6 +80,8 @@ def gemini_completion_with_backoff(
|
|||||||
client = get_gemini_client(api_key, api_base_url)
|
client = get_gemini_client(api_key, api_base_url)
|
||||||
gemini_clients[api_key] = client
|
gemini_clients[api_key] = client
|
||||||
|
|
||||||
|
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||||
|
|
||||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||||
config = gtypes.GenerateContentConfig(
|
config = gtypes.GenerateContentConfig(
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
@@ -91,8 +93,6 @@ def gemini_completion_with_backoff(
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate the response
|
# Generate the response
|
||||||
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
|
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
|
||||||
@@ -165,6 +165,8 @@ def gemini_llm_thread(
|
|||||||
client = get_gemini_client(api_key, api_base_url)
|
client = get_gemini_client(api_key, api_base_url)
|
||||||
gemini_clients[api_key] = client
|
gemini_clients[api_key] = client
|
||||||
|
|
||||||
|
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||||
|
|
||||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||||
config = gtypes.GenerateContentConfig(
|
config = gtypes.GenerateContentConfig(
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
@@ -176,7 +178,6 @@ def gemini_llm_thread(
|
|||||||
)
|
)
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
|
||||||
|
|
||||||
for chunk in client.models.generate_content_stream(
|
for chunk in client.models.generate_content_stream(
|
||||||
model=model_name, config=config, contents=formatted_messages
|
model=model_name, config=config, contents=formatted_messages
|
||||||
@@ -300,4 +301,5 @@ def format_messages_for_gemini(
|
|||||||
if len(messages) == 1:
|
if len(messages) == 1:
|
||||||
messages[0].role = "user"
|
messages[0].role = "user"
|
||||||
|
|
||||||
return messages, system_prompt
|
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||||
|
return formatted_messages, system_prompt
|
||||||
|
|||||||
Reference in New Issue
Block a user