Upgrade to new Gemini package to interface with Google AI

This commit is contained in:
Debanjum
2025-03-11 06:03:06 +05:30
parent 2790ba3121
commit bdfa6400ef
4 changed files with 74 additions and 71 deletions

View File

@@ -61,14 +61,14 @@ dependencies = [
"langchain-community == 0.2.5",
"requests >= 2.26.0",
"tenacity == 8.3.0",
"anyio == 3.7.1",
"anyio ~= 4.8.0",
"pymupdf == 1.24.11",
"django == 5.0.10",
"django-unfold == 0.42.0",
"authlib == 1.2.1",
"llama-cpp-python == 0.2.88",
"itsdangerous == 2.1.2",
"httpx == 0.27.2",
"httpx == 0.28.1",
"pgvector == 0.2.4",
"psycopg2-binary == 2.9.9",
"lxml == 4.9.3",
@@ -79,7 +79,7 @@ dependencies = [
"phonenumbers == 8.13.27",
"markdownify ~= 0.11.6",
"markdown-it-py ~= 3.0.0",
"websockets == 12.0",
"websockets == 13.0",
"psutil >= 5.8.0",
"huggingface-hub >= 0.22.2",
"apscheduler ~= 3.10.0",
@@ -88,7 +88,7 @@ dependencies = [
"django_apscheduler == 0.6.2",
"anthropic == 0.49.0",
"docx2txt == 0.8",
"google-generativeai == 0.8.3",
"google-genai == 1.5.0",
"pyjson5 == 1.6.7",
"resend == 1.0.1",
"email-validator == 2.2.0",

View File

@@ -128,7 +128,7 @@ def gemini_send_message_to_model(
"""
Send message to model
"""
messages, system_prompt = format_messages_for_gemini(messages)
messages_for_gemini, system_prompt = format_messages_for_gemini(messages)
model_kwargs = {}
@@ -138,7 +138,7 @@ def gemini_send_message_to_model(
# Get Response from Gemini
return gemini_completion_with_backoff(
messages=messages,
messages=messages_for_gemini,
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
@@ -236,12 +236,12 @@ def converse_gemini(
program_execution_context=program_execution_context,
)
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
messages_for_gemini, system_prompt = format_messages_for_gemini(messages, system_prompt)
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
# Get Response from Google AI
return gemini_chat_completion_with_backoff(
messages=messages,
messages=messages_for_gemini,
compiled_references=references,
online_results=online_results,
model_name=model,

View File

@@ -1,15 +1,11 @@
import logging
import random
from copy import deepcopy
from threading import Thread
import google.generativeai as genai
from google.generativeai.types.answer_types import FinishReason
from google import genai
from google.genai import types as gtypes
from google.generativeai.types.generation_types import StopCandidateException
from google.generativeai.types.safety_types import (
HarmBlockThreshold,
HarmCategory,
HarmProbability,
)
from langchain.schema import ChatMessage
from tenacity import (
before_sleep_log,
@@ -24,7 +20,6 @@ from khoj.processor.conversation.utils import (
commit_conversation_trace,
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import (
get_chat_usage_metrics,
is_none_or_empty,
@@ -35,6 +30,24 @@ logger = logging.getLogger(__name__)
MAX_OUTPUT_TOKENS_GEMINI = 8192
SAFETY_SETTINGS = [
gtypes.SafetySetting(
category=gtypes.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
gtypes.SafetySetting(
category=gtypes.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
gtypes.SafetySetting(
category=gtypes.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
gtypes.SafetySetting(
category=gtypes.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
]
@retry(
@@ -46,30 +59,19 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
) -> str:
genai.configure(api_key=api_key)
model_kwargs = model_kwargs or dict()
model_kwargs["temperature"] = temperature
model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI
model = genai.GenerativeModel(
model_name,
generation_config=model_kwargs,
client = genai.Client(api_key=api_key)
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
},
temperature=temperature,
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
safety_settings=SAFETY_SETTINGS,
)
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# Start chat session. All messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
try:
# Generate the response. The last message is considered to be the current prompt
response = chat_session.send_message(formatted_messages[-1]["parts"])
# Generate the response
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
response_text = response.text
except StopCandidateException as e:
response = None
@@ -125,30 +127,21 @@ def gemini_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
):
try:
genai.configure(api_key=api_key)
model_kwargs = model_kwargs or dict()
model_kwargs["temperature"] = temperature
model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI
model_kwargs["stop_sequences"] = ["Notes:\n["]
model = genai.GenerativeModel(
model_name,
generation_config=model_kwargs,
client = genai.Client(api_key=api_key)
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
},
temperature=temperature,
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
stop_sequences=["Notes:\n["],
safety_settings=SAFETY_SETTINGS,
)
aggregated_response = ""
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
# all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
for chunk in client.models.generate_content_stream(
model=model_name, config=config, contents=formatted_messages
):
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text
aggregated_response += message
@@ -177,14 +170,16 @@ def gemini_llm_thread(
g.close()
def handle_gemini_response(candidates, prompt_feedback=None):
def handle_gemini_response(
candidates: list[gtypes.Candidate], prompt_feedback: gtypes.GenerateContentResponsePromptFeedback = None
):
"""Check if Gemini response was blocked and return an explanatory error message."""
# Check if the response was blocked due to safety concerns with the prompt
if len(candidates) == 0 and prompt_feedback:
message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query."
stopped = True
# Check if the response was blocked due to safety concerns with the generated content
elif candidates[0].finish_reason == FinishReason.SAFETY:
elif candidates[0].finish_reason == gtypes.FinishReason.SAFETY:
message = generate_safety_response(candidates[0].safety_ratings)
stopped = True
# Check if finish reason is empty, therefore generation is in progress
@@ -192,7 +187,7 @@ def handle_gemini_response(candidates, prompt_feedback=None):
message = None
stopped = False
# Check if the response was stopped due to reaching maximum token limit or other reasons
elif candidates[0].finish_reason != FinishReason.STOP:
elif candidates[0].finish_reason != gtypes.FinishReason.STOP:
message = f"\nI can't talk further about that because of **{candidates[0].finish_reason.name} issue.**"
stopped = True
# Otherwise, the response is valid and can be used
@@ -202,18 +197,18 @@ def handle_gemini_response(candidates, prompt_feedback=None):
return message, stopped
def generate_safety_response(safety_ratings):
def generate_safety_response(safety_ratings: list[gtypes.SafetyRating]):
"""Generate a conversational response based on the safety ratings of the response."""
# Get the safety rating with the highest probability
max_safety_rating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0]
max_safety_rating: gtypes.SafetyRating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0]
# Remove the "HARM_CATEGORY_" prefix and title case the category name
max_safety_category = " ".join(max_safety_rating.category.name.split("_")[2:]).title()
# Add a bit of variety to the discomfort level based on the safety rating probability
discomfort_level = {
HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ",
HarmProbability.LOW: "a bit ",
HarmProbability.MEDIUM: "moderately ",
HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]),
gtypes.HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ",
gtypes.HarmProbability.LOW: "a bit ",
gtypes.HarmProbability.MEDIUM: "moderately ",
gtypes.HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]),
}[max_safety_rating.probability]
# Generate a response using a random response template
safety_response_choice = random.choice(
@@ -229,9 +224,12 @@ def generate_safety_response(safety_ratings):
)
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
def format_messages_for_gemini(
original_messages: list[ChatMessage], system_prompt: str = None
) -> tuple[list[str], str]:
# Extract system message
system_prompt = system_prompt or ""
messages = deepcopy(original_messages)
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
@@ -242,14 +240,16 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
# Convert message content to string list from chatml dictionary list
if isinstance(message.content, list):
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
message.content = [
get_image_from_url(item["image_url"]["url"]).content
if item["type"] == "image_url"
else item.get("text", "")
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
]
message_content = []
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1):
if item["type"] == "image_url":
image = get_image_from_url(item["image_url"]["url"], type="bytes")
message_content += [gtypes.Part.from_bytes(data=image.content, mime_type=image.type)]
else:
message_content += [gtypes.Part.from_text(text=item.get("text", ""))]
message.content = message_content
elif isinstance(message.content, str):
message.content = [message.content]
message.content = [gtypes.Part.from_text(text=message.content)]
if message.role == "assistant":
message.role = "model"

View File

@@ -673,10 +673,13 @@ def get_image_from_url(image_url: str, type="pil"):
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
# Convert image to desired format
image_data: Any = None
if type == "b64":
image_data = base64.b64encode(response.content).decode("utf-8")
elif type == "pil":
image_data = PIL.Image.open(BytesIO(response.content))
elif type == "bytes":
image_data = response.content
else:
raise ValueError(f"Invalid image type: {type}")