diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index cd893ec3..f78c420b 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -2,6 +2,7 @@ import json import logging import os import random +import re from copy import deepcopy from time import perf_counter from typing import Any, AsyncGenerator, AsyncIterator, Dict, List @@ -13,6 +14,7 @@ from google.genai import types as gtypes from langchain_core.messages.chat import ChatMessage from pydantic import BaseModel from tenacity import ( + RetryCallState, before_sleep_log, retry, retry_if_exception, @@ -73,7 +75,7 @@ SAFETY_SETTINGS = [ def _is_retryable_error(exception: BaseException) -> bool: """Check if the exception is a retryable error""" # server errors - if isinstance(exception, gerrors.APIError): + if isinstance(exception, (gerrors.APIError, gerrors.ClientError)): return exception.code in [429, 502, 503, 504] # client errors if ( @@ -88,9 +90,48 @@ def _is_retryable_error(exception: BaseException) -> bool: return False +def _extract_retry_delay(exception: BaseException) -> float: + """Extract retry delay from Gemini error response, return in seconds""" + if ( + isinstance(exception, (gerrors.ClientError, gerrors.APIError)) + and hasattr(exception, "details") + and isinstance(exception.details, dict) + ): + # Look for retryDelay key, value pair. E.g "retryDelay": "54s" + if delay_str := exception.details.get("retryDelay"): + delay_seconds_match = re.search(r"(\d+)s", delay_str) + if delay_seconds_match: + delay_seconds = float(delay_seconds_match.group(1)) + return delay_seconds + return None + + +def _wait_with_gemini_delay(min_wait=4, max_wait=120, multiplier=1, fallback_wait=None): + """Custom wait strategy that respects Gemini's retryDelay if present""" + + def wait_func(retry_state: RetryCallState) -> float: + # Use backoff time if last exception suggests a retry delay + if retry_state.outcome and retry_state.outcome.failed: + exception = retry_state.outcome.exception() + gemini_delay = _extract_retry_delay(exception) + if gemini_delay: + # Use the Gemini-suggested delay, but cap it at max_wait + suggested_delay = min(gemini_delay, max_wait) + logger.info(f"Using Gemini suggested retry delay: {suggested_delay} seconds") + return suggested_delay + # Else use fallback backoff if provided + if fallback_wait: + return fallback_wait(retry_state) + # Else use exponential backoff with provided parameters + else: + return wait_exponential(multiplier=multiplier, min=min_wait, max=max_wait)(retry_state) + + return wait_func + + @retry( retry=retry_if_exception(_is_retryable_error), - wait=wait_random_exponential(min=1, max=10), + wait=_wait_with_gemini_delay(min_wait=1, max_wait=10, fallback_wait=wait_random_exponential(min=1, max=10)), stop=stop_after_attempt(2), before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, @@ -169,7 +210,14 @@ def gemini_completion_with_backoff( ) except gerrors.ClientError as e: response = None - response_text, _ = handle_gemini_response(e.args) + # Handle 429 rate limit errors directly + if e.code == 429: + response_text = f"My brain is exhausted. Can you please try again in a bit?" + # Log the full error details for debugging + logger.error(f"Gemini ClientError: {e.code} {e.status}. Details: {e.details}") + # Handle other errors + else: + response_text, _ = handle_gemini_response(e.args) # Respond with reason for stopping logger.warning( f"LLM Response Prevented for {model_name}: {response_text}.\n" @@ -206,7 +254,7 @@ def gemini_completion_with_backoff( @retry( retry=retry_if_exception(_is_retryable_error), - wait=wait_exponential(multiplier=1, min=4, max=10), + wait=_wait_with_gemini_delay(multiplier=1, min_wait=4, max_wait=10), stop=stop_after_attempt(3), before_sleep=before_sleep_log(logger, logging.WARNING), reraise=False, @@ -310,6 +358,13 @@ def handle_gemini_response( candidates: list[gtypes.Candidate], prompt_feedback: gtypes.GenerateContentResponsePromptFeedback = None ): """Check if Gemini response was blocked and return an explanatory error message.""" + + # Ensure we have a proper list of candidates + if not isinstance(candidates, list): + message = f"\nUnexpected response format. Try again." + stopped = True + return message, stopped + # Check if the response was blocked due to safety concerns with the prompt if len(candidates) == 0 and prompt_feedback: message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query." @@ -428,7 +483,18 @@ def format_messages_for_gemini( if len(messages) == 1: messages[0].role = "user" - formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages] + # Ensure messages are properly formatted for Content creation + valid_messages = [] + for message in messages: + try: + # Try create Content object to validate the structure before adding to valid messages + gtypes.Content(role=message.role, parts=message.content) + valid_messages.append(message) + except Exception as e: + logger.warning(f"Dropping message with invalid content structure: {e}. Message: {message}") + continue + + formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in valid_messages] return formatted_messages, system_prompt