From bdfa6400ef3f106750126d9bb5438ab838ecfb15 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 11 Mar 2025 06:03:06 +0530 Subject: [PATCH] Upgrade to new Gemini package to interface with Google AI --- pyproject.toml | 8 +- .../conversation/google/gemini_chat.py | 8 +- .../processor/conversation/google/utils.py | 126 +++++++++--------- src/khoj/processor/conversation/utils.py | 3 + 4 files changed, 74 insertions(+), 71 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d0e06809..8f14f401 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,14 +61,14 @@ dependencies = [ "langchain-community == 0.2.5", "requests >= 2.26.0", "tenacity == 8.3.0", - "anyio == 3.7.1", + "anyio ~= 4.8.0", "pymupdf == 1.24.11", "django == 5.0.10", "django-unfold == 0.42.0", "authlib == 1.2.1", "llama-cpp-python == 0.2.88", "itsdangerous == 2.1.2", - "httpx == 0.27.2", + "httpx == 0.28.1", "pgvector == 0.2.4", "psycopg2-binary == 2.9.9", "lxml == 4.9.3", @@ -79,7 +79,7 @@ dependencies = [ "phonenumbers == 8.13.27", "markdownify ~= 0.11.6", "markdown-it-py ~= 3.0.0", - "websockets == 12.0", + "websockets == 13.0", "psutil >= 5.8.0", "huggingface-hub >= 0.22.2", "apscheduler ~= 3.10.0", @@ -88,7 +88,7 @@ dependencies = [ "django_apscheduler == 0.6.2", "anthropic == 0.49.0", "docx2txt == 0.8", - "google-generativeai == 0.8.3", + "google-genai == 1.5.0", "pyjson5 == 1.6.7", "resend == 1.0.1", "email-validator == 2.2.0", diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 067a2786..77cff325 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -128,7 +128,7 @@ def gemini_send_message_to_model( """ Send message to model """ - messages, system_prompt = format_messages_for_gemini(messages) + messages_for_gemini, system_prompt = format_messages_for_gemini(messages) model_kwargs = {} @@ -138,7 +138,7 @@ def gemini_send_message_to_model( # Get Response from Gemini return gemini_completion_with_backoff( - messages=messages, + messages=messages_for_gemini, system_prompt=system_prompt, model_name=model, api_key=api_key, @@ -236,12 +236,12 @@ def converse_gemini( program_execution_context=program_execution_context, ) - messages, system_prompt = format_messages_for_gemini(messages, system_prompt) + messages_for_gemini, system_prompt = format_messages_for_gemini(messages, system_prompt) logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}") # Get Response from Google AI return gemini_chat_completion_with_backoff( - messages=messages, + messages=messages_for_gemini, compiled_references=references, online_results=online_results, model_name=model, diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 8060ea5e..4c01be2c 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -1,15 +1,11 @@ import logging import random +from copy import deepcopy from threading import Thread -import google.generativeai as genai -from google.generativeai.types.answer_types import FinishReason +from google import genai +from google.genai import types as gtypes from google.generativeai.types.generation_types import StopCandidateException -from google.generativeai.types.safety_types import ( - HarmBlockThreshold, - HarmCategory, - HarmProbability, -) from langchain.schema import ChatMessage from tenacity import ( before_sleep_log, @@ -24,7 +20,6 @@ from khoj.processor.conversation.utils import ( commit_conversation_trace, get_image_from_url, ) -from khoj.utils import state from khoj.utils.helpers import ( get_chat_usage_metrics, is_none_or_empty, @@ -35,6 +30,24 @@ logger = logging.getLogger(__name__) MAX_OUTPUT_TOKENS_GEMINI = 8192 +SAFETY_SETTINGS = [ + gtypes.SafetySetting( + category=gtypes.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + gtypes.SafetySetting( + category=gtypes.HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + gtypes.SafetySetting( + category=gtypes.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + gtypes.SafetySetting( + category=gtypes.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), +] @retry( @@ -46,30 +59,19 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192 def gemini_completion_with_backoff( messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={} ) -> str: - genai.configure(api_key=api_key) - model_kwargs = model_kwargs or dict() - model_kwargs["temperature"] = temperature - model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI - model = genai.GenerativeModel( - model_name, - generation_config=model_kwargs, + client = genai.Client(api_key=api_key) + config = gtypes.GenerateContentConfig( system_instruction=system_prompt, - safety_settings={ - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - }, + temperature=temperature, + max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI, + safety_settings=SAFETY_SETTINGS, ) - formatted_messages = [{"role": message.role, "parts": message.content} for message in messages] - - # Start chat session. All messages up to the last are considered to be part of the chat history - chat_session = model.start_chat(history=formatted_messages[0:-1]) + formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages] try: - # Generate the response. The last message is considered to be the current prompt - response = chat_session.send_message(formatted_messages[-1]["parts"]) + # Generate the response + response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages) response_text = response.text except StopCandidateException as e: response = None @@ -125,30 +127,21 @@ def gemini_llm_thread( g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {} ): try: - genai.configure(api_key=api_key) - model_kwargs = model_kwargs or dict() - model_kwargs["temperature"] = temperature - model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI - model_kwargs["stop_sequences"] = ["Notes:\n["] - model = genai.GenerativeModel( - model_name, - generation_config=model_kwargs, + client = genai.Client(api_key=api_key) + config = gtypes.GenerateContentConfig( system_instruction=system_prompt, - safety_settings={ - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - }, + temperature=temperature, + max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI, + stop_sequences=["Notes:\n["], + safety_settings=SAFETY_SETTINGS, ) aggregated_response = "" - formatted_messages = [{"role": message.role, "parts": message.content} for message in messages] + formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages] - # all messages up to the last are considered to be part of the chat history - chat_session = model.start_chat(history=formatted_messages[0:-1]) - # the last message is considered to be the current prompt - for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True): + for chunk in client.models.generate_content_stream( + model=model_name, config=config, contents=formatted_messages + ): message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) message = message or chunk.text aggregated_response += message @@ -177,14 +170,16 @@ def gemini_llm_thread( g.close() -def handle_gemini_response(candidates, prompt_feedback=None): +def handle_gemini_response( + candidates: list[gtypes.Candidate], prompt_feedback: gtypes.GenerateContentResponsePromptFeedback = None +): """Check if Gemini response was blocked and return an explanatory error message.""" # Check if the response was blocked due to safety concerns with the prompt if len(candidates) == 0 and prompt_feedback: message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query." stopped = True # Check if the response was blocked due to safety concerns with the generated content - elif candidates[0].finish_reason == FinishReason.SAFETY: + elif candidates[0].finish_reason == gtypes.FinishReason.SAFETY: message = generate_safety_response(candidates[0].safety_ratings) stopped = True # Check if finish reason is empty, therefore generation is in progress @@ -192,7 +187,7 @@ def handle_gemini_response(candidates, prompt_feedback=None): message = None stopped = False # Check if the response was stopped due to reaching maximum token limit or other reasons - elif candidates[0].finish_reason != FinishReason.STOP: + elif candidates[0].finish_reason != gtypes.FinishReason.STOP: message = f"\nI can't talk further about that because of **{candidates[0].finish_reason.name} issue.**" stopped = True # Otherwise, the response is valid and can be used @@ -202,18 +197,18 @@ def handle_gemini_response(candidates, prompt_feedback=None): return message, stopped -def generate_safety_response(safety_ratings): +def generate_safety_response(safety_ratings: list[gtypes.SafetyRating]): """Generate a conversational response based on the safety ratings of the response.""" # Get the safety rating with the highest probability - max_safety_rating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0] + max_safety_rating: gtypes.SafetyRating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0] # Remove the "HARM_CATEGORY_" prefix and title case the category name max_safety_category = " ".join(max_safety_rating.category.name.split("_")[2:]).title() # Add a bit of variety to the discomfort level based on the safety rating probability discomfort_level = { - HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ", - HarmProbability.LOW: "a bit ", - HarmProbability.MEDIUM: "moderately ", - HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]), + gtypes.HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ", + gtypes.HarmProbability.LOW: "a bit ", + gtypes.HarmProbability.MEDIUM: "moderately ", + gtypes.HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]), }[max_safety_rating.probability] # Generate a response using a random response template safety_response_choice = random.choice( @@ -229,9 +224,12 @@ def generate_safety_response(safety_ratings): ) -def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]: +def format_messages_for_gemini( + original_messages: list[ChatMessage], system_prompt: str = None +) -> tuple[list[str], str]: # Extract system message system_prompt = system_prompt or "" + messages = deepcopy(original_messages) for message in messages.copy(): if message.role == "system": system_prompt += message.content @@ -242,14 +240,16 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = # Convert message content to string list from chatml dictionary list if isinstance(message.content, list): # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) - message.content = [ - get_image_from_url(item["image_url"]["url"]).content - if item["type"] == "image_url" - else item.get("text", "") - for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1) - ] + message_content = [] + for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1): + if item["type"] == "image_url": + image = get_image_from_url(item["image_url"]["url"], type="bytes") + message_content += [gtypes.Part.from_bytes(data=image.content, mime_type=image.type)] + else: + message_content += [gtypes.Part.from_text(text=item.get("text", ""))] + message.content = message_content elif isinstance(message.content, str): - message.content = [message.content] + message.content = [gtypes.Part.from_text(text=message.content)] if message.role == "assistant": message.role = "model" diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 015bbe6f..a7e6e694 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -673,10 +673,13 @@ def get_image_from_url(image_url: str, type="pil"): content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp" # Convert image to desired format + image_data: Any = None if type == "b64": image_data = base64.b64encode(response.content).decode("utf-8") elif type == "pil": image_data = PIL.Image.open(BytesIO(response.content)) + elif type == "bytes": + image_data = response.content else: raise ValueError(f"Invalid image type: {type}")