diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index eec24d92..75017d63 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -4,12 +4,11 @@ from typing import Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage -from pydantic import BaseModel +from pydantic import BaseModel, Field from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.google.utils import ( - format_messages_for_gemini, gemini_chat_completion_with_backoff, gemini_completion_with_backoff, ) @@ -98,7 +97,7 @@ def extract_questions_gemini( messages.append(ChatMessage(content=system_prompt, role="system")) class DocumentQueries(BaseModel): - queries: List[str] + queries: List[str] = Field(..., min_items=1) response = gemini_send_message_to_model( messages, diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index cc8ec027..11e6a03d 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -4,6 +4,8 @@ from typing import Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage +from openai.lib._pydantic import _ensure_strict_json_schema +from pydantic import BaseModel from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts @@ -135,7 +137,16 @@ def send_message_to_model( model_kwargs = {} json_support = get_openai_api_json_support(model, api_base_url) if response_schema and json_support == JsonSupport.SCHEMA: - model_kwargs["response_format"] = response_schema + # Drop unsupported fields from schema passed to OpenAI APi + cleaned_response_schema = clean_response_schema(response_schema) + model_kwargs["response_format"] = { + "type": "json_schema", + "json_schema": { + "schema": cleaned_response_schema, + "name": response_schema.__name__, + "strict": True, + }, + } elif response_type == "json_object" and json_support == JsonSupport.OBJECT: model_kwargs["response_format"] = {"type": response_type} @@ -257,3 +268,30 @@ def converse_openai( model_kwargs={"stop": ["Notes:\n["]}, tracer=tracer, ) + + +def clean_response_schema(schema: BaseModel | dict) -> dict: + """ + Format response schema to be compatible with OpenAI API. + + Clean the response schema by removing unsupported fields. + """ + # Normalize schema to OpenAI compatible JSON schema format + schema_json = schema if isinstance(schema, dict) else schema.model_json_schema() + schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json) + + # Recursively drop unsupported fields from schema passed to OpenAI API + # See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas + fields_to_exclude = ["minItems", "maxItems"] + if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict): + for _, prop_value in schema_json["properties"].items(): + if isinstance(prop_value, dict): + # Remove specified fields from direct properties + for field in fields_to_exclude: + prop_value.pop(field, None) + # Recursively remove specified fields from child properties + if "items" in prop_value and isinstance(prop_value["items"], dict): + clean_response_schema(prop_value["items"]) + + # Return cleaned schema + return schema_json diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index e59ab304..ceb8093e 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -900,7 +900,7 @@ Khoj: online_search_conversation_subqueries = PromptTemplate.from_template( """ -You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question. +You are Khoj, an advanced web search assistant. You are tasked with constructing **up to {max_queries}** google search queries to answer the user's question. - You will receive the actual chat history as context. - Add as much context from the chat history as required into your search queries. - Break messages into multiple search queries when required to retrieve the relevant information. diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 28808ecd..e1d69a1e 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -34,7 +34,7 @@ from apscheduler.job import Job from apscheduler.triggers.cron import CronTrigger from asgiref.sync import sync_to_async from fastapi import Depends, Header, HTTPException, Request, UploadFile -from pydantic import BaseModel +from pydantic import BaseModel, Field from starlette.authentication import has_required_scope from starlette.requests import URL @@ -400,7 +400,7 @@ async def aget_data_sources_and_output_format( agent_chat_model = agent.chat_model if agent else None class PickTools(BaseModel): - source: List[str] + source: List[str] = Field(..., min_items=1) output: str with timer("Chat actor: Infer information sources to refer", logger): @@ -489,7 +489,7 @@ async def infer_webpage_urls( agent_chat_model = agent.chat_model if agent else None class WebpageUrls(BaseModel): - links: List[str] + links: List[str] = Field(..., min_items=1, max_items=max_webpages) with timer("Chat actor: Infer webpage urls to read", logger): response = await send_message_to_model_wrapper( @@ -535,15 +535,17 @@ async def generate_online_subqueries( username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) + max_queries = 3 utc_date = datetime.utcnow().strftime("%Y-%m-%d") personality_context = ( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" ) online_queries_prompt = prompts.online_search_conversation_subqueries.format( - current_date=utc_date, query=q, chat_history=chat_history, + max_queries=max_queries, + current_date=utc_date, location=location, username=username, personality_context=personality_context, @@ -552,7 +554,7 @@ async def generate_online_subqueries( agent_chat_model = agent.chat_model if agent else None class OnlineQueries(BaseModel): - queries: List[str] + queries: List[str] = Field(..., min_items=1, max_items=max_queries) with timer("Chat actor: Generate online search subqueries", logger): response = await send_message_to_model_wrapper(