mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Upgrade to new Gemini package to interface with Google AI
This commit is contained in:
@@ -61,14 +61,14 @@ dependencies = [
|
||||
"langchain-community == 0.2.5",
|
||||
"requests >= 2.26.0",
|
||||
"tenacity == 8.3.0",
|
||||
"anyio == 3.7.1",
|
||||
"anyio ~= 4.8.0",
|
||||
"pymupdf == 1.24.11",
|
||||
"django == 5.0.10",
|
||||
"django-unfold == 0.42.0",
|
||||
"authlib == 1.2.1",
|
||||
"llama-cpp-python == 0.2.88",
|
||||
"itsdangerous == 2.1.2",
|
||||
"httpx == 0.27.2",
|
||||
"httpx == 0.28.1",
|
||||
"pgvector == 0.2.4",
|
||||
"psycopg2-binary == 2.9.9",
|
||||
"lxml == 4.9.3",
|
||||
@@ -79,7 +79,7 @@ dependencies = [
|
||||
"phonenumbers == 8.13.27",
|
||||
"markdownify ~= 0.11.6",
|
||||
"markdown-it-py ~= 3.0.0",
|
||||
"websockets == 12.0",
|
||||
"websockets == 13.0",
|
||||
"psutil >= 5.8.0",
|
||||
"huggingface-hub >= 0.22.2",
|
||||
"apscheduler ~= 3.10.0",
|
||||
@@ -88,7 +88,7 @@ dependencies = [
|
||||
"django_apscheduler == 0.6.2",
|
||||
"anthropic == 0.49.0",
|
||||
"docx2txt == 0.8",
|
||||
"google-generativeai == 0.8.3",
|
||||
"google-genai == 1.5.0",
|
||||
"pyjson5 == 1.6.7",
|
||||
"resend == 1.0.1",
|
||||
"email-validator == 2.2.0",
|
||||
|
||||
@@ -128,7 +128,7 @@ def gemini_send_message_to_model(
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
messages, system_prompt = format_messages_for_gemini(messages)
|
||||
messages_for_gemini, system_prompt = format_messages_for_gemini(messages)
|
||||
|
||||
model_kwargs = {}
|
||||
|
||||
@@ -138,7 +138,7 @@ def gemini_send_message_to_model(
|
||||
|
||||
# Get Response from Gemini
|
||||
return gemini_completion_with_backoff(
|
||||
messages=messages,
|
||||
messages=messages_for_gemini,
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
@@ -236,12 +236,12 @@ def converse_gemini(
|
||||
program_execution_context=program_execution_context,
|
||||
)
|
||||
|
||||
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
messages_for_gemini, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
||||
|
||||
# Get Response from Google AI
|
||||
return gemini_chat_completion_with_backoff(
|
||||
messages=messages,
|
||||
messages=messages_for_gemini,
|
||||
compiled_references=references,
|
||||
online_results=online_results,
|
||||
model_name=model,
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
import logging
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from threading import Thread
|
||||
|
||||
import google.generativeai as genai
|
||||
from google.generativeai.types.answer_types import FinishReason
|
||||
from google import genai
|
||||
from google.genai import types as gtypes
|
||||
from google.generativeai.types.generation_types import StopCandidateException
|
||||
from google.generativeai.types.safety_types import (
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
HarmProbability,
|
||||
)
|
||||
from langchain.schema import ChatMessage
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
@@ -24,7 +20,6 @@ from khoj.processor.conversation.utils import (
|
||||
commit_conversation_trace,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import (
|
||||
get_chat_usage_metrics,
|
||||
is_none_or_empty,
|
||||
@@ -35,6 +30,24 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_OUTPUT_TOKENS_GEMINI = 8192
|
||||
SAFETY_SETTINGS = [
|
||||
gtypes.SafetySetting(
|
||||
category=gtypes.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
),
|
||||
gtypes.SafetySetting(
|
||||
category=gtypes.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
),
|
||||
gtypes.SafetySetting(
|
||||
category=gtypes.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
),
|
||||
gtypes.SafetySetting(
|
||||
category=gtypes.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@retry(
|
||||
@@ -46,30 +59,19 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
|
||||
def gemini_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
|
||||
) -> str:
|
||||
genai.configure(api_key=api_key)
|
||||
model_kwargs = model_kwargs or dict()
|
||||
model_kwargs["temperature"] = temperature
|
||||
model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI
|
||||
model = genai.GenerativeModel(
|
||||
model_name,
|
||||
generation_config=model_kwargs,
|
||||
client = genai.Client(api_key=api_key)
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
safety_settings={
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
},
|
||||
temperature=temperature,
|
||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
)
|
||||
|
||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||
|
||||
# Start chat session. All messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||
|
||||
try:
|
||||
# Generate the response. The last message is considered to be the current prompt
|
||||
response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||
# Generate the response
|
||||
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
|
||||
response_text = response.text
|
||||
except StopCandidateException as e:
|
||||
response = None
|
||||
@@ -125,30 +127,21 @@ def gemini_llm_thread(
|
||||
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
|
||||
):
|
||||
try:
|
||||
genai.configure(api_key=api_key)
|
||||
model_kwargs = model_kwargs or dict()
|
||||
model_kwargs["temperature"] = temperature
|
||||
model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI
|
||||
model_kwargs["stop_sequences"] = ["Notes:\n["]
|
||||
model = genai.GenerativeModel(
|
||||
model_name,
|
||||
generation_config=model_kwargs,
|
||||
client = genai.Client(api_key=api_key)
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
safety_settings={
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
},
|
||||
temperature=temperature,
|
||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||
stop_sequences=["Notes:\n["],
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
)
|
||||
|
||||
aggregated_response = ""
|
||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||
|
||||
# all messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
# the last message is considered to be the current prompt
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
||||
for chunk in client.models.generate_content_stream(
|
||||
model=model_name, config=config, contents=formatted_messages
|
||||
):
|
||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
message = message or chunk.text
|
||||
aggregated_response += message
|
||||
@@ -177,14 +170,16 @@ def gemini_llm_thread(
|
||||
g.close()
|
||||
|
||||
|
||||
def handle_gemini_response(candidates, prompt_feedback=None):
|
||||
def handle_gemini_response(
|
||||
candidates: list[gtypes.Candidate], prompt_feedback: gtypes.GenerateContentResponsePromptFeedback = None
|
||||
):
|
||||
"""Check if Gemini response was blocked and return an explanatory error message."""
|
||||
# Check if the response was blocked due to safety concerns with the prompt
|
||||
if len(candidates) == 0 and prompt_feedback:
|
||||
message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query."
|
||||
stopped = True
|
||||
# Check if the response was blocked due to safety concerns with the generated content
|
||||
elif candidates[0].finish_reason == FinishReason.SAFETY:
|
||||
elif candidates[0].finish_reason == gtypes.FinishReason.SAFETY:
|
||||
message = generate_safety_response(candidates[0].safety_ratings)
|
||||
stopped = True
|
||||
# Check if finish reason is empty, therefore generation is in progress
|
||||
@@ -192,7 +187,7 @@ def handle_gemini_response(candidates, prompt_feedback=None):
|
||||
message = None
|
||||
stopped = False
|
||||
# Check if the response was stopped due to reaching maximum token limit or other reasons
|
||||
elif candidates[0].finish_reason != FinishReason.STOP:
|
||||
elif candidates[0].finish_reason != gtypes.FinishReason.STOP:
|
||||
message = f"\nI can't talk further about that because of **{candidates[0].finish_reason.name} issue.**"
|
||||
stopped = True
|
||||
# Otherwise, the response is valid and can be used
|
||||
@@ -202,18 +197,18 @@ def handle_gemini_response(candidates, prompt_feedback=None):
|
||||
return message, stopped
|
||||
|
||||
|
||||
def generate_safety_response(safety_ratings):
|
||||
def generate_safety_response(safety_ratings: list[gtypes.SafetyRating]):
|
||||
"""Generate a conversational response based on the safety ratings of the response."""
|
||||
# Get the safety rating with the highest probability
|
||||
max_safety_rating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0]
|
||||
max_safety_rating: gtypes.SafetyRating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0]
|
||||
# Remove the "HARM_CATEGORY_" prefix and title case the category name
|
||||
max_safety_category = " ".join(max_safety_rating.category.name.split("_")[2:]).title()
|
||||
# Add a bit of variety to the discomfort level based on the safety rating probability
|
||||
discomfort_level = {
|
||||
HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ",
|
||||
HarmProbability.LOW: "a bit ",
|
||||
HarmProbability.MEDIUM: "moderately ",
|
||||
HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]),
|
||||
gtypes.HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ",
|
||||
gtypes.HarmProbability.LOW: "a bit ",
|
||||
gtypes.HarmProbability.MEDIUM: "moderately ",
|
||||
gtypes.HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]),
|
||||
}[max_safety_rating.probability]
|
||||
# Generate a response using a random response template
|
||||
safety_response_choice = random.choice(
|
||||
@@ -229,9 +224,12 @@ def generate_safety_response(safety_ratings):
|
||||
)
|
||||
|
||||
|
||||
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
||||
def format_messages_for_gemini(
|
||||
original_messages: list[ChatMessage], system_prompt: str = None
|
||||
) -> tuple[list[str], str]:
|
||||
# Extract system message
|
||||
system_prompt = system_prompt or ""
|
||||
messages = deepcopy(original_messages)
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
@@ -242,14 +240,16 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
|
||||
# Convert message content to string list from chatml dictionary list
|
||||
if isinstance(message.content, list):
|
||||
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
||||
message.content = [
|
||||
get_image_from_url(item["image_url"]["url"]).content
|
||||
if item["type"] == "image_url"
|
||||
else item.get("text", "")
|
||||
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
|
||||
]
|
||||
message_content = []
|
||||
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1):
|
||||
if item["type"] == "image_url":
|
||||
image = get_image_from_url(item["image_url"]["url"], type="bytes")
|
||||
message_content += [gtypes.Part.from_bytes(data=image.content, mime_type=image.type)]
|
||||
else:
|
||||
message_content += [gtypes.Part.from_text(text=item.get("text", ""))]
|
||||
message.content = message_content
|
||||
elif isinstance(message.content, str):
|
||||
message.content = [message.content]
|
||||
message.content = [gtypes.Part.from_text(text=message.content)]
|
||||
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
@@ -673,10 +673,13 @@ def get_image_from_url(image_url: str, type="pil"):
|
||||
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
|
||||
|
||||
# Convert image to desired format
|
||||
image_data: Any = None
|
||||
if type == "b64":
|
||||
image_data = base64.b64encode(response.content).decode("utf-8")
|
||||
elif type == "pil":
|
||||
image_data = PIL.Image.open(BytesIO(response.content))
|
||||
elif type == "bytes":
|
||||
image_data = response.content
|
||||
else:
|
||||
raise ValueError(f"Invalid image type: {type}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user