diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index c440b541..f9219c28 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -275,7 +275,8 @@ async def converse_gemini( deepthought=deepthought, tracer=tracer, ): - full_response += chunk + if chunk.response: + full_response += chunk.response yield chunk # Call completion_func once finish streaming and we have the full response diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 65cee030..f44ed3d4 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -21,6 +21,7 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ( + ResponseWithThought, commit_conversation_trace, get_image_from_base64, get_image_from_url, @@ -110,7 +111,7 @@ def gemini_completion_with_backoff( response_schema = clean_response_schema(model_kwargs["response_schema"]) thinking_config = None - if deepthought and model_name.startswith("gemini-2-5"): + if deepthought and model_name.startswith("gemini-2.5"): thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI) seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None @@ -178,7 +179,7 @@ async def gemini_chat_completion_with_backoff( model_kwargs=None, deepthought=False, tracer: dict = {}, -) -> AsyncGenerator[str, None]: +) -> AsyncGenerator[ResponseWithThought, None]: client = gemini_clients.get(api_key) if not client: client = get_gemini_client(api_key, api_base_url) @@ -187,8 +188,8 @@ async def gemini_chat_completion_with_backoff( formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt) thinking_config = None - if deepthought and model_name.startswith("gemini-2-5"): - thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI) + if deepthought and model_name.startswith("gemini-2.5"): + thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI, include_thoughts=True) seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( @@ -216,18 +217,25 @@ async def gemini_chat_completion_with_backoff( logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") # Keep track of the last chunk for usage data final_chunk = chunk - # Handle streamed response chunk + + # handle safety, rate-limit, other finish reasons stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) - message = stop_message or chunk.text - aggregated_response += message - yield message if stopped: + yield ResponseWithThought(response=stop_message) logger.warning( f"LLM Response Prevented for {model_name}: {stop_message}.\n" + f"Last Message by {messages[-1].role}: {messages[-1].content}" ) break + # emit thought vs response parts + for part in chunk.candidates[0].content.parts: + if part.text: + aggregated_response += part.text + yield ResponseWithThought(response=part.text) + if part.thought: + yield ResponseWithThought(thought=part.text) + # Calculate cost of chat input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0 output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0