mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Upgrade to new Gemini package to interface with Google AI
This commit is contained in:
@@ -61,14 +61,14 @@ dependencies = [
|
|||||||
"langchain-community == 0.2.5",
|
"langchain-community == 0.2.5",
|
||||||
"requests >= 2.26.0",
|
"requests >= 2.26.0",
|
||||||
"tenacity == 8.3.0",
|
"tenacity == 8.3.0",
|
||||||
"anyio == 3.7.1",
|
"anyio ~= 4.8.0",
|
||||||
"pymupdf == 1.24.11",
|
"pymupdf == 1.24.11",
|
||||||
"django == 5.0.10",
|
"django == 5.0.10",
|
||||||
"django-unfold == 0.42.0",
|
"django-unfold == 0.42.0",
|
||||||
"authlib == 1.2.1",
|
"authlib == 1.2.1",
|
||||||
"llama-cpp-python == 0.2.88",
|
"llama-cpp-python == 0.2.88",
|
||||||
"itsdangerous == 2.1.2",
|
"itsdangerous == 2.1.2",
|
||||||
"httpx == 0.27.2",
|
"httpx == 0.28.1",
|
||||||
"pgvector == 0.2.4",
|
"pgvector == 0.2.4",
|
||||||
"psycopg2-binary == 2.9.9",
|
"psycopg2-binary == 2.9.9",
|
||||||
"lxml == 4.9.3",
|
"lxml == 4.9.3",
|
||||||
@@ -79,7 +79,7 @@ dependencies = [
|
|||||||
"phonenumbers == 8.13.27",
|
"phonenumbers == 8.13.27",
|
||||||
"markdownify ~= 0.11.6",
|
"markdownify ~= 0.11.6",
|
||||||
"markdown-it-py ~= 3.0.0",
|
"markdown-it-py ~= 3.0.0",
|
||||||
"websockets == 12.0",
|
"websockets == 13.0",
|
||||||
"psutil >= 5.8.0",
|
"psutil >= 5.8.0",
|
||||||
"huggingface-hub >= 0.22.2",
|
"huggingface-hub >= 0.22.2",
|
||||||
"apscheduler ~= 3.10.0",
|
"apscheduler ~= 3.10.0",
|
||||||
@@ -88,7 +88,7 @@ dependencies = [
|
|||||||
"django_apscheduler == 0.6.2",
|
"django_apscheduler == 0.6.2",
|
||||||
"anthropic == 0.49.0",
|
"anthropic == 0.49.0",
|
||||||
"docx2txt == 0.8",
|
"docx2txt == 0.8",
|
||||||
"google-generativeai == 0.8.3",
|
"google-genai == 1.5.0",
|
||||||
"pyjson5 == 1.6.7",
|
"pyjson5 == 1.6.7",
|
||||||
"resend == 1.0.1",
|
"resend == 1.0.1",
|
||||||
"email-validator == 2.2.0",
|
"email-validator == 2.2.0",
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ def gemini_send_message_to_model(
|
|||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
messages, system_prompt = format_messages_for_gemini(messages)
|
messages_for_gemini, system_prompt = format_messages_for_gemini(messages)
|
||||||
|
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
|
|
||||||
@@ -138,7 +138,7 @@ def gemini_send_message_to_model(
|
|||||||
|
|
||||||
# Get Response from Gemini
|
# Get Response from Gemini
|
||||||
return gemini_completion_with_backoff(
|
return gemini_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages_for_gemini,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
@@ -236,12 +236,12 @@ def converse_gemini(
|
|||||||
program_execution_context=program_execution_context,
|
program_execution_context=program_execution_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
messages_for_gemini, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||||
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from Google AI
|
# Get Response from Google AI
|
||||||
return gemini_chat_completion_with_backoff(
|
return gemini_chat_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages_for_gemini,
|
||||||
compiled_references=references,
|
compiled_references=references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
|
|||||||
@@ -1,15 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
from copy import deepcopy
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
import google.generativeai as genai
|
from google import genai
|
||||||
from google.generativeai.types.answer_types import FinishReason
|
from google.genai import types as gtypes
|
||||||
from google.generativeai.types.generation_types import StopCandidateException
|
from google.generativeai.types.generation_types import StopCandidateException
|
||||||
from google.generativeai.types.safety_types import (
|
|
||||||
HarmBlockThreshold,
|
|
||||||
HarmCategory,
|
|
||||||
HarmProbability,
|
|
||||||
)
|
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
@@ -24,7 +20,6 @@ from khoj.processor.conversation.utils import (
|
|||||||
commit_conversation_trace,
|
commit_conversation_trace,
|
||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils import state
|
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
@@ -35,6 +30,24 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
MAX_OUTPUT_TOKENS_GEMINI = 8192
|
MAX_OUTPUT_TOKENS_GEMINI = 8192
|
||||||
|
SAFETY_SETTINGS = [
|
||||||
|
gtypes.SafetySetting(
|
||||||
|
category=gtypes.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||||
|
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||||
|
),
|
||||||
|
gtypes.SafetySetting(
|
||||||
|
category=gtypes.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||||
|
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||||
|
),
|
||||||
|
gtypes.SafetySetting(
|
||||||
|
category=gtypes.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||||
|
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||||
|
),
|
||||||
|
gtypes.SafetySetting(
|
||||||
|
category=gtypes.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||||
|
threshold=gtypes.HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
@@ -46,30 +59,19 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
|
|||||||
def gemini_completion_with_backoff(
|
def gemini_completion_with_backoff(
|
||||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
|
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
|
||||||
) -> str:
|
) -> str:
|
||||||
genai.configure(api_key=api_key)
|
client = genai.Client(api_key=api_key)
|
||||||
model_kwargs = model_kwargs or dict()
|
config = gtypes.GenerateContentConfig(
|
||||||
model_kwargs["temperature"] = temperature
|
|
||||||
model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI
|
|
||||||
model = genai.GenerativeModel(
|
|
||||||
model_name,
|
|
||||||
generation_config=model_kwargs,
|
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
safety_settings={
|
temperature=temperature,
|
||||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
safety_settings=SAFETY_SETTINGS,
|
||||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
|
||||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||||
|
|
||||||
# Start chat session. All messages up to the last are considered to be part of the chat history
|
|
||||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate the response. The last message is considered to be the current prompt
|
# Generate the response
|
||||||
response = chat_session.send_message(formatted_messages[-1]["parts"])
|
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
|
||||||
response_text = response.text
|
response_text = response.text
|
||||||
except StopCandidateException as e:
|
except StopCandidateException as e:
|
||||||
response = None
|
response = None
|
||||||
@@ -125,30 +127,21 @@ def gemini_llm_thread(
|
|||||||
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
|
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
genai.configure(api_key=api_key)
|
client = genai.Client(api_key=api_key)
|
||||||
model_kwargs = model_kwargs or dict()
|
config = gtypes.GenerateContentConfig(
|
||||||
model_kwargs["temperature"] = temperature
|
|
||||||
model_kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS_GEMINI
|
|
||||||
model_kwargs["stop_sequences"] = ["Notes:\n["]
|
|
||||||
model = genai.GenerativeModel(
|
|
||||||
model_name,
|
|
||||||
generation_config=model_kwargs,
|
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
safety_settings={
|
temperature=temperature,
|
||||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
stop_sequences=["Notes:\n["],
|
||||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
safety_settings=SAFETY_SETTINGS,
|
||||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||||
|
|
||||||
# all messages up to the last are considered to be part of the chat history
|
for chunk in client.models.generate_content_stream(
|
||||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
model=model_name, config=config, contents=formatted_messages
|
||||||
# the last message is considered to be the current prompt
|
):
|
||||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
|
||||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||||
message = message or chunk.text
|
message = message or chunk.text
|
||||||
aggregated_response += message
|
aggregated_response += message
|
||||||
@@ -177,14 +170,16 @@ def gemini_llm_thread(
|
|||||||
g.close()
|
g.close()
|
||||||
|
|
||||||
|
|
||||||
def handle_gemini_response(candidates, prompt_feedback=None):
|
def handle_gemini_response(
|
||||||
|
candidates: list[gtypes.Candidate], prompt_feedback: gtypes.GenerateContentResponsePromptFeedback = None
|
||||||
|
):
|
||||||
"""Check if Gemini response was blocked and return an explanatory error message."""
|
"""Check if Gemini response was blocked and return an explanatory error message."""
|
||||||
# Check if the response was blocked due to safety concerns with the prompt
|
# Check if the response was blocked due to safety concerns with the prompt
|
||||||
if len(candidates) == 0 and prompt_feedback:
|
if len(candidates) == 0 and prompt_feedback:
|
||||||
message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query."
|
message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query."
|
||||||
stopped = True
|
stopped = True
|
||||||
# Check if the response was blocked due to safety concerns with the generated content
|
# Check if the response was blocked due to safety concerns with the generated content
|
||||||
elif candidates[0].finish_reason == FinishReason.SAFETY:
|
elif candidates[0].finish_reason == gtypes.FinishReason.SAFETY:
|
||||||
message = generate_safety_response(candidates[0].safety_ratings)
|
message = generate_safety_response(candidates[0].safety_ratings)
|
||||||
stopped = True
|
stopped = True
|
||||||
# Check if finish reason is empty, therefore generation is in progress
|
# Check if finish reason is empty, therefore generation is in progress
|
||||||
@@ -192,7 +187,7 @@ def handle_gemini_response(candidates, prompt_feedback=None):
|
|||||||
message = None
|
message = None
|
||||||
stopped = False
|
stopped = False
|
||||||
# Check if the response was stopped due to reaching maximum token limit or other reasons
|
# Check if the response was stopped due to reaching maximum token limit or other reasons
|
||||||
elif candidates[0].finish_reason != FinishReason.STOP:
|
elif candidates[0].finish_reason != gtypes.FinishReason.STOP:
|
||||||
message = f"\nI can't talk further about that because of **{candidates[0].finish_reason.name} issue.**"
|
message = f"\nI can't talk further about that because of **{candidates[0].finish_reason.name} issue.**"
|
||||||
stopped = True
|
stopped = True
|
||||||
# Otherwise, the response is valid and can be used
|
# Otherwise, the response is valid and can be used
|
||||||
@@ -202,18 +197,18 @@ def handle_gemini_response(candidates, prompt_feedback=None):
|
|||||||
return message, stopped
|
return message, stopped
|
||||||
|
|
||||||
|
|
||||||
def generate_safety_response(safety_ratings):
|
def generate_safety_response(safety_ratings: list[gtypes.SafetyRating]):
|
||||||
"""Generate a conversational response based on the safety ratings of the response."""
|
"""Generate a conversational response based on the safety ratings of the response."""
|
||||||
# Get the safety rating with the highest probability
|
# Get the safety rating with the highest probability
|
||||||
max_safety_rating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0]
|
max_safety_rating: gtypes.SafetyRating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0]
|
||||||
# Remove the "HARM_CATEGORY_" prefix and title case the category name
|
# Remove the "HARM_CATEGORY_" prefix and title case the category name
|
||||||
max_safety_category = " ".join(max_safety_rating.category.name.split("_")[2:]).title()
|
max_safety_category = " ".join(max_safety_rating.category.name.split("_")[2:]).title()
|
||||||
# Add a bit of variety to the discomfort level based on the safety rating probability
|
# Add a bit of variety to the discomfort level based on the safety rating probability
|
||||||
discomfort_level = {
|
discomfort_level = {
|
||||||
HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ",
|
gtypes.HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ",
|
||||||
HarmProbability.LOW: "a bit ",
|
gtypes.HarmProbability.LOW: "a bit ",
|
||||||
HarmProbability.MEDIUM: "moderately ",
|
gtypes.HarmProbability.MEDIUM: "moderately ",
|
||||||
HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]),
|
gtypes.HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]),
|
||||||
}[max_safety_rating.probability]
|
}[max_safety_rating.probability]
|
||||||
# Generate a response using a random response template
|
# Generate a response using a random response template
|
||||||
safety_response_choice = random.choice(
|
safety_response_choice = random.choice(
|
||||||
@@ -229,9 +224,12 @@ def generate_safety_response(safety_ratings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
def format_messages_for_gemini(
|
||||||
|
original_messages: list[ChatMessage], system_prompt: str = None
|
||||||
|
) -> tuple[list[str], str]:
|
||||||
# Extract system message
|
# Extract system message
|
||||||
system_prompt = system_prompt or ""
|
system_prompt = system_prompt or ""
|
||||||
|
messages = deepcopy(original_messages)
|
||||||
for message in messages.copy():
|
for message in messages.copy():
|
||||||
if message.role == "system":
|
if message.role == "system":
|
||||||
system_prompt += message.content
|
system_prompt += message.content
|
||||||
@@ -242,14 +240,16 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
|
|||||||
# Convert message content to string list from chatml dictionary list
|
# Convert message content to string list from chatml dictionary list
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
||||||
message.content = [
|
message_content = []
|
||||||
get_image_from_url(item["image_url"]["url"]).content
|
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1):
|
||||||
if item["type"] == "image_url"
|
if item["type"] == "image_url":
|
||||||
else item.get("text", "")
|
image = get_image_from_url(item["image_url"]["url"], type="bytes")
|
||||||
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
|
message_content += [gtypes.Part.from_bytes(data=image.content, mime_type=image.type)]
|
||||||
]
|
else:
|
||||||
|
message_content += [gtypes.Part.from_text(text=item.get("text", ""))]
|
||||||
|
message.content = message_content
|
||||||
elif isinstance(message.content, str):
|
elif isinstance(message.content, str):
|
||||||
message.content = [message.content]
|
message.content = [gtypes.Part.from_text(text=message.content)]
|
||||||
|
|
||||||
if message.role == "assistant":
|
if message.role == "assistant":
|
||||||
message.role = "model"
|
message.role = "model"
|
||||||
|
|||||||
@@ -673,10 +673,13 @@ def get_image_from_url(image_url: str, type="pil"):
|
|||||||
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
|
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"
|
||||||
|
|
||||||
# Convert image to desired format
|
# Convert image to desired format
|
||||||
|
image_data: Any = None
|
||||||
if type == "b64":
|
if type == "b64":
|
||||||
image_data = base64.b64encode(response.content).decode("utf-8")
|
image_data = base64.b64encode(response.content).decode("utf-8")
|
||||||
elif type == "pil":
|
elif type == "pil":
|
||||||
image_data = PIL.Image.open(BytesIO(response.content))
|
image_data = PIL.Image.open(BytesIO(response.content))
|
||||||
|
elif type == "bytes":
|
||||||
|
image_data = response.content
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid image type: {type}")
|
raise ValueError(f"Invalid image type: {type}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user