Refactor Gemini chat response to stream async, no separate thread

This commit is contained in:
Debanjum
2025-04-20 02:54:56 +05:30
parent 0751f2ea30
commit a557031447
3 changed files with 48 additions and 66 deletions

View File

@@ -1,6 +1,6 @@
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from typing import AsyncGenerator, Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
@@ -160,7 +160,7 @@ def gemini_send_message_to_model(
)
def converse_gemini(
async def converse_gemini(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
@@ -185,7 +185,7 @@ def converse_gemini(
program_execution_context: List[str] = None,
deepthought: Optional[bool] = False,
tracer={},
):
) -> AsyncGenerator[str, None]:
"""
Converse with user using Google's Gemini
"""
@@ -216,11 +216,17 @@ def converse_gemini(
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
response = prompts.no_notes_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
response = prompts.no_online_results_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
context_message = ""
if not is_none_or_empty(references):
@@ -253,16 +259,20 @@ def converse_gemini(
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
# Get Response from Google AI
return gemini_chat_completion_with_backoff(
full_response = ""
async for chunk in gemini_chat_completion_with_backoff(
messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model,
temperature=temperature,
api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt,
completion_func=completion_func,
deepthought=deepthought,
tracer=tracer,
)
):
full_response += chunk
yield chunk
# Call completion_func once finish streaming and we have the full response
if completion_func:
await completion_func(chat_response=full_response)

View File

@@ -2,8 +2,8 @@ import logging
import os
import random
from copy import deepcopy
from threading import Thread
from typing import Dict
from time import perf_counter
from typing import AsyncGenerator, AsyncIterator, Dict
from google import genai
from google.genai import errors as gerrors
@@ -19,7 +19,6 @@ from tenacity import (
)
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
get_image_from_base64,
get_image_from_url,
@@ -121,8 +120,8 @@ def gemini_completion_with_backoff(
)
# Aggregate cost of chat
input_tokens = response.usage_metadata.prompt_token_count if response else 0
output_tokens = response.usage_metadata.candidates_token_count if response else 0
input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0
output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
@@ -143,52 +142,17 @@ def gemini_completion_with_backoff(
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def gemini_chat_completion_with_backoff(
async def gemini_chat_completion_with_backoff(
messages,
compiled_references,
online_results,
model_name,
temperature,
api_key,
api_base_url,
system_prompt,
completion_func=None,
model_kwargs=None,
deepthought=False,
tracer: dict = {},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=gemini_llm_thread,
args=(
g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url,
model_kwargs,
deepthought,
tracer,
),
)
t.start()
return g
def gemini_llm_thread(
g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url=None,
model_kwargs=None,
deepthought=False,
tracer: dict = {},
):
) -> AsyncGenerator[str, None]:
try:
client = gemini_clients.get(api_key)
if not client:
@@ -213,21 +177,32 @@ def gemini_llm_thread(
)
aggregated_response = ""
for chunk in client.models.generate_content_stream(
final_chunk = None
start_time = perf_counter()
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
model=model_name, config=config, contents=formatted_messages
):
)
async for chunk in chat_stream:
# Log the time taken to start response
if final_chunk is None:
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Handle streamed response chunk
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text
aggregated_response += message
g.send(message)
yield message
if stopped:
raise ValueError(message)
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat
input_tokens = chunk.usage_metadata.prompt_token_count
output_tokens = chunk.usage_metadata.candidates_token_count
thought_tokens = chunk.usage_metadata.thoughts_token_count or 0
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
)
@@ -243,9 +218,7 @@ def gemini_llm_thread(
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
except Exception as e:
logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True)
finally:
g.close()
logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
def handle_gemini_response(

View File

@@ -1563,10 +1563,9 @@ async def agenerate_chat_response(
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
# Assuming converse_gemini remains sync or is refactored separately
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response_generator = converse_gemini( # Needs adaptation if it becomes async
chat_response_generator = converse_gemini(
compiled_references,
query_to_run,
online_results,