From 0eb2d177713538fe11c7362128efe4dc2d970a0f Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sun, 30 Mar 2025 23:45:40 +0530 Subject: [PATCH 1/4] Warn and drop invalid messages when format messages for gemini Previously we were setting message content part with empty text. This results in error from Gemini API. Warn and drop such messages instead. Log empty message content found during construction to root-cause the issue but allow Khoj to respond without the offending messages in context for call to Gemini API. --- src/khoj/processor/conversation/google/utils.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index a5b8b74b..4d9b42d9 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -294,11 +294,21 @@ def format_messages_for_gemini( else: image = get_image_from_base64(image_data, type="bytes") message_content += [gtypes.Part.from_bytes(data=image.content, mime_type=image.type)] + elif not is_none_or_empty(item.get("text")): + message_content += [gtypes.Part.from_text(text=item["text"])] else: - message_content += [gtypes.Part.from_text(text=item.get("text", ""))] + logger.error(f"Dropping invalid message content part: {item}") + if not message_content: + logger.error(f"Dropping empty message content") + messages.remove(message) + continue message.content = message_content elif isinstance(message.content, str): message.content = [gtypes.Part.from_text(text=message.content)] + else: + logger.error(f"Dropping invalid type: {type(message.content)} of message content: {message.content}") + messages.remove(message) + continue if message.role == "assistant": message.role = "model" From ae9ca58ab91eaa481cfde1ee04efc6826c9193eb Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 31 Mar 2025 00:46:11 +0530 Subject: [PATCH 2/4] Specify min, max items expected in ai response via schema enforcement Require at least 1 item in lists. Otherwise gemini flash will sometimes return an empty list. For chat actors where max items is known, set that as well. OpenAI API does not support specifying min, max items in response schema lists, so drop those properties when response schema is passed. Add other enforcements to response schema to comply with response schema format expected by OpenAI API. --- .../conversation/google/gemini_chat.py | 5 +-- src/khoj/processor/conversation/openai/gpt.py | 40 ++++++++++++++++++- src/khoj/processor/conversation/prompts.py | 2 +- src/khoj/routers/helpers.py | 12 +++--- 4 files changed, 49 insertions(+), 10 deletions(-) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index eec24d92..75017d63 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -4,12 +4,11 @@ from typing import Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage -from pydantic import BaseModel +from pydantic import BaseModel, Field from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.google.utils import ( - format_messages_for_gemini, gemini_chat_completion_with_backoff, gemini_completion_with_backoff, ) @@ -98,7 +97,7 @@ def extract_questions_gemini( messages.append(ChatMessage(content=system_prompt, role="system")) class DocumentQueries(BaseModel): - queries: List[str] + queries: List[str] = Field(..., min_items=1) response = gemini_send_message_to_model( messages, diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index cc8ec027..11e6a03d 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -4,6 +4,8 @@ from typing import Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage +from openai.lib._pydantic import _ensure_strict_json_schema +from pydantic import BaseModel from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts @@ -135,7 +137,16 @@ def send_message_to_model( model_kwargs = {} json_support = get_openai_api_json_support(model, api_base_url) if response_schema and json_support == JsonSupport.SCHEMA: - model_kwargs["response_format"] = response_schema + # Drop unsupported fields from schema passed to OpenAI APi + cleaned_response_schema = clean_response_schema(response_schema) + model_kwargs["response_format"] = { + "type": "json_schema", + "json_schema": { + "schema": cleaned_response_schema, + "name": response_schema.__name__, + "strict": True, + }, + } elif response_type == "json_object" and json_support == JsonSupport.OBJECT: model_kwargs["response_format"] = {"type": response_type} @@ -257,3 +268,30 @@ def converse_openai( model_kwargs={"stop": ["Notes:\n["]}, tracer=tracer, ) + + +def clean_response_schema(schema: BaseModel | dict) -> dict: + """ + Format response schema to be compatible with OpenAI API. + + Clean the response schema by removing unsupported fields. + """ + # Normalize schema to OpenAI compatible JSON schema format + schema_json = schema if isinstance(schema, dict) else schema.model_json_schema() + schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json) + + # Recursively drop unsupported fields from schema passed to OpenAI API + # See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas + fields_to_exclude = ["minItems", "maxItems"] + if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict): + for _, prop_value in schema_json["properties"].items(): + if isinstance(prop_value, dict): + # Remove specified fields from direct properties + for field in fields_to_exclude: + prop_value.pop(field, None) + # Recursively remove specified fields from child properties + if "items" in prop_value and isinstance(prop_value["items"], dict): + clean_response_schema(prop_value["items"]) + + # Return cleaned schema + return schema_json diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index e59ab304..ceb8093e 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -900,7 +900,7 @@ Khoj: online_search_conversation_subqueries = PromptTemplate.from_template( """ -You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question. +You are Khoj, an advanced web search assistant. You are tasked with constructing **up to {max_queries}** google search queries to answer the user's question. - You will receive the actual chat history as context. - Add as much context from the chat history as required into your search queries. - Break messages into multiple search queries when required to retrieve the relevant information. diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 28808ecd..e1d69a1e 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -34,7 +34,7 @@ from apscheduler.job import Job from apscheduler.triggers.cron import CronTrigger from asgiref.sync import sync_to_async from fastapi import Depends, Header, HTTPException, Request, UploadFile -from pydantic import BaseModel +from pydantic import BaseModel, Field from starlette.authentication import has_required_scope from starlette.requests import URL @@ -400,7 +400,7 @@ async def aget_data_sources_and_output_format( agent_chat_model = agent.chat_model if agent else None class PickTools(BaseModel): - source: List[str] + source: List[str] = Field(..., min_items=1) output: str with timer("Chat actor: Infer information sources to refer", logger): @@ -489,7 +489,7 @@ async def infer_webpage_urls( agent_chat_model = agent.chat_model if agent else None class WebpageUrls(BaseModel): - links: List[str] + links: List[str] = Field(..., min_items=1, max_items=max_webpages) with timer("Chat actor: Infer webpage urls to read", logger): response = await send_message_to_model_wrapper( @@ -535,15 +535,17 @@ async def generate_online_subqueries( username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) + max_queries = 3 utc_date = datetime.utcnow().strftime("%Y-%m-%d") personality_context = ( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" ) online_queries_prompt = prompts.online_search_conversation_subqueries.format( - current_date=utc_date, query=q, chat_history=chat_history, + max_queries=max_queries, + current_date=utc_date, location=location, username=username, personality_context=personality_context, @@ -552,7 +554,7 @@ async def generate_online_subqueries( agent_chat_model = agent.chat_model if agent else None class OnlineQueries(BaseModel): - queries: List[str] + queries: List[str] = Field(..., min_items=1, max_items=max_queries) with timer("Chat actor: Generate online search subqueries", logger): response = await send_message_to_model_wrapper( From aab010723cf27895681ceca0fc95de1539a5ee95 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 3 Apr 2025 00:04:57 +0530 Subject: [PATCH 3/4] Make Gemini response adhere to the order of the schema property definitions Without explicitly using the property ordering field, gemini returns responses in alphabetically sorted property order. We want the model to respect the schema property definition order. This ensures control during development to maintain response quality. For example in CoT make it fill scratchpad before answers. --- .../processor/conversation/google/utils.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 4d9b42d9..f66d7c68 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -9,6 +9,7 @@ from google import genai from google.genai import errors as gerrors from google.genai import types as gtypes from langchain.schema import ChatMessage +from pydantic import BaseModel from tenacity import ( before_sleep_log, retry, @@ -86,6 +87,11 @@ def gemini_completion_with_backoff( formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt) + # format model response schema + response_schema = None + if model_kwargs and "response_schema" in model_kwargs: + response_schema = clean_response_schema(model_kwargs["response_schema"]) + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( system_instruction=system_prompt, @@ -93,7 +99,7 @@ def gemini_completion_with_backoff( max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI, safety_settings=SAFETY_SETTINGS, response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain", - response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None, + response_schema=response_schema, seed=seed, ) @@ -318,3 +324,18 @@ def format_messages_for_gemini( formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages] return formatted_messages, system_prompt + + +def clean_response_schema(response_schema: BaseModel) -> dict: + """ + Convert Pydantic model to dict for Gemini response schema. + + Ensure response schema adheres to the order of the original property definition. + """ + # Convert Pydantic model to dict + response_schema_dict = response_schema.model_json_schema() + # Get field names in original definition order + field_names = list(response_schema.model_fields.keys()) + # Generate content in the order in which the schema properties were defined + response_schema_dict["property_ordering"] = field_names + return response_schema_dict From f77e871cc8c4c63eb9b9ca3940dece43c2a4241f Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 3 Apr 2025 02:49:40 +0530 Subject: [PATCH 4/4] Improve agent creation safety checker with response schema, better prompt --- src/khoj/processor/conversation/prompts.py | 2 ++ src/khoj/routers/helpers.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index ceb8093e..55e867e5 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -1252,6 +1252,7 @@ A: {{ "safe": "False", "reason": "The prompt contains sexual content that could Q: You are an astute financial analyst. Assess my financial situation and provide advice. A: {{ "safe": "True" }} +# Actual: Q: {prompt} A: """.strip() @@ -1287,6 +1288,7 @@ A: {{ "safe": "False", "reason": "The prompt contains content that could be cons Q: You are a great analyst. Assess my financial situation and provide advice. A: {{ "safe": "True" }} +# Actual: Q: {prompt} A: """.strip() diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e1d69a1e..ccf3a7b4 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -321,13 +321,19 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: is_safe = True reason = "" + class SafetyCheck(BaseModel): + safe: bool + reason: str + with timer("Chat actor: Check if safe prompt", logger): - response = await send_message_to_model_wrapper(safe_prompt_check, user=user) + response = await send_message_to_model_wrapper( + safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck + ) response = response.strip() try: response = json.loads(clean_json(response)) - is_safe = response.get("safe", "True") == "True" + is_safe = str(response.get("safe", "true")).lower() == "true" if not is_safe: reason = response.get("reason", "") except Exception: