Refactor Gemini chat response to stream async, no separate thread

This commit is contained in:
Debanjum
2025-04-20 02:54:56 +05:30
parent 0751f2ea30
commit a557031447
3 changed files with 48 additions and 66 deletions

View File

@@ -1,6 +1,6 @@
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, List, Optional from typing import AsyncGenerator, Dict, List, Optional
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
@@ -160,7 +160,7 @@ def gemini_send_message_to_model(
) )
def converse_gemini( async def converse_gemini(
references, references,
user_query, user_query,
online_results: Optional[Dict[str, Dict]] = None, online_results: Optional[Dict[str, Dict]] = None,
@@ -185,7 +185,7 @@ def converse_gemini(
program_execution_context: List[str] = None, program_execution_context: List[str] = None,
deepthought: Optional[bool] = False, deepthought: Optional[bool] = False,
tracer={}, tracer={},
): ) -> AsyncGenerator[str, None]:
""" """
Converse with user using Google's Gemini Converse with user using Google's Gemini
""" """
@@ -216,11 +216,17 @@ def converse_gemini(
# Get Conversation Primer appropriate to Conversation Type # Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format()) response = prompts.no_notes_found.format()
return iter([prompts.no_notes_found.format()]) if completion_func:
await completion_func(chat_response=response)
yield response
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format()) response = prompts.no_online_results_found.format()
return iter([prompts.no_online_results_found.format()]) if completion_func:
await completion_func(chat_response=response)
yield response
return
context_message = "" context_message = ""
if not is_none_or_empty(references): if not is_none_or_empty(references):
@@ -253,16 +259,20 @@ def converse_gemini(
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}") logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
# Get Response from Google AI # Get Response from Google AI
return gemini_chat_completion_with_backoff( full_response = ""
async for chunk in gemini_chat_completion_with_backoff(
messages=messages, messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url, api_base_url=api_base_url,
system_prompt=system_prompt, system_prompt=system_prompt,
completion_func=completion_func,
deepthought=deepthought, deepthought=deepthought,
tracer=tracer, tracer=tracer,
) ):
full_response += chunk
yield chunk
# Call completion_func once finish streaming and we have the full response
if completion_func:
await completion_func(chat_response=full_response)

View File

@@ -2,8 +2,8 @@ import logging
import os import os
import random import random
from copy import deepcopy from copy import deepcopy
from threading import Thread from time import perf_counter
from typing import Dict from typing import AsyncGenerator, AsyncIterator, Dict
from google import genai from google import genai
from google.genai import errors as gerrors from google.genai import errors as gerrors
@@ -19,7 +19,6 @@ from tenacity import (
) )
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace, commit_conversation_trace,
get_image_from_base64, get_image_from_base64,
get_image_from_url, get_image_from_url,
@@ -121,8 +120,8 @@ def gemini_completion_with_backoff(
) )
# Aggregate cost of chat # Aggregate cost of chat
input_tokens = response.usage_metadata.prompt_token_count if response else 0 input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0
output_tokens = response.usage_metadata.candidates_token_count if response else 0 output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0 thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
tracer["usage"] = get_chat_usage_metrics( tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
@@ -143,52 +142,17 @@ def gemini_completion_with_backoff(
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
def gemini_chat_completion_with_backoff( async def gemini_chat_completion_with_backoff(
messages, messages,
compiled_references,
online_results,
model_name, model_name,
temperature, temperature,
api_key, api_key,
api_base_url, api_base_url,
system_prompt, system_prompt,
completion_func=None,
model_kwargs=None, model_kwargs=None,
deepthought=False, deepthought=False,
tracer: dict = {}, tracer: dict = {},
): ) -> AsyncGenerator[str, None]:
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=gemini_llm_thread,
args=(
g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url,
model_kwargs,
deepthought,
tracer,
),
)
t.start()
return g
def gemini_llm_thread(
g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url=None,
model_kwargs=None,
deepthought=False,
tracer: dict = {},
):
try: try:
client = gemini_clients.get(api_key) client = gemini_clients.get(api_key)
if not client: if not client:
@@ -213,21 +177,32 @@ def gemini_llm_thread(
) )
aggregated_response = "" aggregated_response = ""
final_chunk = None
for chunk in client.models.generate_content_stream( start_time = perf_counter()
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
model=model_name, config=config, contents=formatted_messages model=model_name, config=config, contents=formatted_messages
): )
async for chunk in chat_stream:
# Log the time taken to start response
if final_chunk is None:
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Handle streamed response chunk
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text message = message or chunk.text
aggregated_response += message aggregated_response += message
g.send(message) yield message
if stopped: if stopped:
raise ValueError(message) raise ValueError(message)
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat # Calculate cost of chat
input_tokens = chunk.usage_metadata.prompt_token_count input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
output_tokens = chunk.usage_metadata.candidates_token_count output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
thought_tokens = chunk.usage_metadata.thoughts_token_count or 0 thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
tracer["usage"] = get_chat_usage_metrics( tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
) )
@@ -243,9 +218,7 @@ def gemini_llm_thread(
+ f"Last Message by {messages[-1].role}: {messages[-1].content}" + f"Last Message by {messages[-1].role}: {messages[-1].content}"
) )
except Exception as e: except Exception as e:
logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True) logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
finally:
g.close()
def handle_gemini_response( def handle_gemini_response(

View File

@@ -1563,10 +1563,9 @@ async def agenerate_chat_response(
tracer=tracer, tracer=tracer,
) )
elif chat_model.model_type == ChatModel.ModelType.GOOGLE: elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
# Assuming converse_gemini remains sync or is refactored separately
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url api_base_url = chat_model.ai_model_api.api_base_url
chat_response_generator = converse_gemini( # Needs adaptation if it becomes async chat_response_generator = converse_gemini(
compiled_references, compiled_references,
query_to_run, query_to_run,
online_results, online_results,