Files
khoj/src/khoj/processor/conversation/google/utils.py
Debanjum fd591c6e6c Upgrade tenacity to respect min time for exponential backoff
Fix for issue is in tenacity 9.0.0. But older langchain required
tenacity <0.9.0.

Explicitly pin version of langchain sub packages to avoid indexing
and doc parsing breakage.
2025-05-17 17:37:15 -07:00

365 lines
15 KiB
Python

import logging
import os
import random
from copy import deepcopy
from time import perf_counter
from typing import AsyncGenerator, AsyncIterator, Dict
import httpx
from google import genai
from google.genai import errors as gerrors
from google.genai import types as gtypes
from langchain_core.messages.chat import ChatMessage
from pydantic import BaseModel
from tenacity import (
before_sleep_log,
retry,
retry_if_exception,
stop_after_attempt,
wait_exponential,
wait_random_exponential,
)
from khoj.processor.conversation.utils import (
commit_conversation_trace,
get_image_from_base64,
get_image_from_url,
)
from khoj.utils.helpers import (
get_chat_usage_metrics,
get_gemini_client,
is_none_or_empty,
is_promptrace_enabled,
)
logger = logging.getLogger(__name__)
gemini_clients: Dict[str, genai.Client] = {}
MAX_OUTPUT_TOKENS_GEMINI = 8192
MAX_REASONING_TOKENS_GEMINI = 10000
SAFETY_SETTINGS = [
gtypes.SafetySetting(
category=gtypes.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
gtypes.SafetySetting(
category=gtypes.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
gtypes.SafetySetting(
category=gtypes.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
gtypes.SafetySetting(
category=gtypes.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
gtypes.SafetySetting(
category=gtypes.HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
]
def _is_retryable_error(exception: BaseException) -> bool:
"""Check if the exception is a retryable error"""
# server errors
if isinstance(exception, gerrors.APIError):
return exception.code in [429, 502, 503, 504]
# client errors
if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError):
return True
return False
@retry(
retry=retry_if_exception(_is_retryable_error),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def gemini_completion_with_backoff(
messages,
system_prompt,
model_name: str,
temperature=1.0,
api_key=None,
api_base_url: str = None,
model_kwargs=None,
deepthought=False,
tracer={},
) -> str:
client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
# format model response schema
response_schema = None
if model_kwargs and model_kwargs.get("response_schema"):
response_schema = clean_response_schema(model_kwargs["response_schema"])
thinking_config = None
if deepthought and model_name.startswith("gemini-2-5"):
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
temperature=temperature,
thinking_config=thinking_config,
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
safety_settings=SAFETY_SETTINGS,
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
response_schema=response_schema,
seed=seed,
http_options=gtypes.HttpOptions(client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
)
try:
# Generate the response
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
response_text = response.text
except gerrors.ClientError as e:
response = None
response_text, _ = handle_gemini_response(e.args)
# Respond with reason for stopping
logger.warning(
f"LLM Response Prevented for {model_name}: {response_text}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
# Aggregate cost of chat
input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0
output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(messages, response_text, tracer)
return response_text
@retry(
retry=retry_if_exception(_is_retryable_error),
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def gemini_chat_completion_with_backoff(
messages,
model_name,
temperature,
api_key,
api_base_url,
system_prompt,
model_kwargs=None,
deepthought=False,
tracer: dict = {},
) -> AsyncGenerator[str, None]:
try:
client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
thinking_config = None
if deepthought and model_name.startswith("gemini-2-5"):
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
temperature=temperature,
thinking_config=thinking_config,
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
stop_sequences=["Notes:\n["],
safety_settings=SAFETY_SETTINGS,
seed=seed,
http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
)
aggregated_response = ""
final_chunk = None
response_started = False
start_time = perf_counter()
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
model=model_name, config=config, contents=formatted_messages
)
async for chunk in chat_stream:
# Log the time taken to start response
if not response_started:
response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Handle streamed response chunk
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text
aggregated_response += message
yield message
if stopped:
raise ValueError(message)
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
except ValueError as e:
logger.warning(
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
except Exception as e:
logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
def handle_gemini_response(
candidates: list[gtypes.Candidate], prompt_feedback: gtypes.GenerateContentResponsePromptFeedback = None
):
"""Check if Gemini response was blocked and return an explanatory error message."""
# Check if the response was blocked due to safety concerns with the prompt
if len(candidates) == 0 and prompt_feedback:
message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query."
stopped = True
# Check if the response was blocked due to safety concerns with the generated content
elif candidates[0].finish_reason == gtypes.FinishReason.SAFETY:
message = generate_safety_response(candidates[0].safety_ratings)
stopped = True
# Check if finish reason is empty, therefore generation is in progress
elif not candidates[0].finish_reason:
message = None
stopped = False
# Check if the response was stopped due to reaching maximum token limit or other reasons
elif candidates[0].finish_reason != gtypes.FinishReason.STOP:
message = f"\nI can't talk further about that because of **{candidates[0].finish_reason.name} issue.**"
stopped = True
# Otherwise, the response is valid and can be used
else:
message = None
stopped = False
return message, stopped
def generate_safety_response(safety_ratings: list[gtypes.SafetyRating]):
"""Generate a conversational response based on the safety ratings of the response."""
# Get the safety rating with the highest probability
max_safety_rating: gtypes.SafetyRating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0]
# Remove the "HARM_CATEGORY_" prefix and title case the category name
max_safety_category = " ".join(max_safety_rating.category.name.split("_")[2:]).title()
# Add a bit of variety to the discomfort level based on the safety rating probability
discomfort_level = {
gtypes.HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ",
gtypes.HarmProbability.NEGLIGIBLE: "a little ",
gtypes.HarmProbability.LOW: "a bit ",
gtypes.HarmProbability.MEDIUM: "moderately ",
gtypes.HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]),
}[max_safety_rating.probability]
# Generate a response using a random response template
safety_response_choice = random.choice(
[
"\nUmm, I'd rather not to respond to that. The conversation has some probability of going into **{category}** territory.",
"\nI'd prefer not to talk about **{category}** related topics. It makes me {discomfort_level}uncomfortable.",
"\nI feel {discomfort_level}squeamish talking about **{category}** related stuff! Can we talk about something less controversial?",
"\nThat sounds {discomfort_level}outside the [Overtone Window](https://en.wikipedia.org/wiki/Overton_window) of acceptable conversation. Should we stick to something less {category} related?",
]
)
return safety_response_choice.format(
category=max_safety_category, probability=max_safety_rating.probability.name, discomfort_level=discomfort_level
)
def format_messages_for_gemini(
original_messages: list[ChatMessage], system_prompt: str = None
) -> tuple[list[str], str]:
# Extract system message
system_prompt = system_prompt or ""
messages = deepcopy(original_messages)
for message in messages.copy():
if message.role == "system":
if isinstance(message.content, list):
system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"])
else:
system_prompt += message.content
messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
for message in messages:
# Convert message content to string list from chatml dictionary list
if isinstance(message.content, list):
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
message_content = []
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1):
if item["type"] == "image_url":
image_data = item["image_url"]["url"]
if image_data.startswith("http"):
image = get_image_from_url(image_data, type="bytes")
else:
image = get_image_from_base64(image_data, type="bytes")
message_content += [gtypes.Part.from_bytes(data=image.content, mime_type=image.type)]
elif not is_none_or_empty(item.get("text")):
message_content += [gtypes.Part.from_text(text=item["text"])]
else:
logger.error(f"Dropping invalid message content part: {item}")
if not message_content:
logger.error(f"Dropping message with empty content as not supported:\n{message}")
messages.remove(message)
continue
message.content = message_content
elif isinstance(message.content, str):
message.content = [gtypes.Part.from_text(text=message.content)]
else:
logger.error(f"Dropping invalid type: {type(message.content)} of message content: {message.content}")
messages.remove(message)
continue
if message.role == "assistant":
message.role = "model"
if len(messages) == 1:
messages[0].role = "user"
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
return formatted_messages, system_prompt
def clean_response_schema(response_schema: BaseModel) -> dict:
"""
Convert Pydantic model to dict for Gemini response schema.
Ensure response schema adheres to the order of the original property definition.
"""
# Convert Pydantic model to dict
response_schema_dict = response_schema.model_json_schema()
# Get field names in original definition order
field_names = list(response_schema.model_fields.keys())
# Generate content in the order in which the schema properties were defined
response_schema_dict["property_ordering"] = field_names
return response_schema_dict