diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 73167ca2..3c42ef06 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -1,6 +1,6 @@ import logging from datetime import datetime, timedelta -from typing import Dict, List, Optional +from typing import AsyncGenerator, Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage @@ -160,7 +160,7 @@ def gemini_send_message_to_model( ) -def converse_gemini( +async def converse_gemini( references, user_query, online_results: Optional[Dict[str, Dict]] = None, @@ -185,7 +185,7 @@ def converse_gemini( program_execution_context: List[str] = None, deepthought: Optional[bool] = False, tracer={}, -): +) -> AsyncGenerator[str, None]: """ Converse with user using Google's Gemini """ @@ -216,11 +216,17 @@ def converse_gemini( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - completion_func(chat_response=prompts.no_notes_found.format()) - return iter([prompts.no_notes_found.format()]) + response = prompts.no_notes_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - completion_func(chat_response=prompts.no_online_results_found.format()) - return iter([prompts.no_online_results_found.format()]) + response = prompts.no_online_results_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return context_message = "" if not is_none_or_empty(references): @@ -253,16 +259,20 @@ def converse_gemini( logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}") # Get Response from Google AI - return gemini_chat_completion_with_backoff( + full_response = "" + async for chunk in gemini_chat_completion_with_backoff( messages=messages, - compiled_references=references, - online_results=online_results, model_name=model, temperature=temperature, api_key=api_key, api_base_url=api_base_url, system_prompt=system_prompt, - completion_func=completion_func, deepthought=deepthought, tracer=tracer, - ) + ): + full_response += chunk + yield chunk + + # Call completion_func once finish streaming and we have the full response + if completion_func: + await completion_func(chat_response=full_response) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index b497edec..d85cd09e 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -2,8 +2,8 @@ import logging import os import random from copy import deepcopy -from threading import Thread -from typing import Dict +from time import perf_counter +from typing import AsyncGenerator, AsyncIterator, Dict from google import genai from google.genai import errors as gerrors @@ -19,7 +19,6 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ( - ThreadedGenerator, commit_conversation_trace, get_image_from_base64, get_image_from_url, @@ -121,8 +120,8 @@ def gemini_completion_with_backoff( ) # Aggregate cost of chat - input_tokens = response.usage_metadata.prompt_token_count if response else 0 - output_tokens = response.usage_metadata.candidates_token_count if response else 0 + input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0 + output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0 thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0 tracer["usage"] = get_chat_usage_metrics( model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") @@ -143,52 +142,17 @@ def gemini_completion_with_backoff( before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def gemini_chat_completion_with_backoff( +async def gemini_chat_completion_with_backoff( messages, - compiled_references, - online_results, model_name, temperature, api_key, api_base_url, system_prompt, - completion_func=None, model_kwargs=None, deepthought=False, tracer: dict = {}, -): - g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) - t = Thread( - target=gemini_llm_thread, - args=( - g, - messages, - system_prompt, - model_name, - temperature, - api_key, - api_base_url, - model_kwargs, - deepthought, - tracer, - ), - ) - t.start() - return g - - -def gemini_llm_thread( - g, - messages, - system_prompt, - model_name, - temperature, - api_key, - api_base_url=None, - model_kwargs=None, - deepthought=False, - tracer: dict = {}, -): +) -> AsyncGenerator[str, None]: try: client = gemini_clients.get(api_key) if not client: @@ -213,21 +177,32 @@ def gemini_llm_thread( ) aggregated_response = "" - - for chunk in client.models.generate_content_stream( + final_chunk = None + start_time = perf_counter() + chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream( model=model_name, config=config, contents=formatted_messages - ): + ) + async for chunk in chat_stream: + # Log the time taken to start response + if final_chunk is None: + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Keep track of the last chunk for usage data + final_chunk = chunk + # Handle streamed response chunk message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) message = message or chunk.text aggregated_response += message - g.send(message) + yield message if stopped: raise ValueError(message) + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") + # Calculate cost of chat - input_tokens = chunk.usage_metadata.prompt_token_count - output_tokens = chunk.usage_metadata.candidates_token_count - thought_tokens = chunk.usage_metadata.thoughts_token_count or 0 + input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0 + output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0 + thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0 tracer["usage"] = get_chat_usage_metrics( model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") ) @@ -243,9 +218,7 @@ def gemini_llm_thread( + f"Last Message by {messages[-1].role}: {messages[-1].content}" ) except Exception as e: - logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True) - finally: - g.close() + logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True) def handle_gemini_response( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 6042469a..a5aa415a 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1563,10 +1563,9 @@ async def agenerate_chat_response( tracer=tracer, ) elif chat_model.model_type == ChatModel.ModelType.GOOGLE: - # Assuming converse_gemini remains sync or is refactored separately api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - chat_response_generator = converse_gemini( # Needs adaptation if it becomes async + chat_response_generator = converse_gemini( compiled_references, query_to_run, online_results,