From d74c3a1db4dbe5330baa81e1c5dacb01aa6b5e9f Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 19 Mar 2025 17:37:52 +0530 Subject: [PATCH 1/6] Simplify OpenAI reasoning model specific arguments to OpenAI API Previously OpenAI reasoning models didn't support stream_options and response_format Add reasoning_effort arg for calls to OpenAI reasoning models via API. Right now it defaults to medium but can be changed to low or high --- .../processor/conversation/openai/utils.py | 37 ++++++------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 444b6541..88d75763 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -60,20 +60,13 @@ def completion_with_backoff( formatted_messages = [{"role": message.role, "content": message.content} for message in messages] - # Update request parameters for compatability with o1 model series - # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations + # Tune reasoning models arguments + if model_name.startswith("o1") or model_name.startswith("o3"): + temperature = 1 + model_kwargs["reasoning_effort"] = "medium" + stream = True model_kwargs["stream_options"] = {"include_usage": True} - if model_name == "o1": - temperature = 1 - stream = False - model_kwargs.pop("stream_options", None) - elif model_name.startswith("o1"): - temperature = 1 - model_kwargs.pop("response_format", None) - elif model_name.startswith("o3-"): - temperature = 1 - if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) @@ -172,20 +165,13 @@ def llm_thread( formatted_messages = [{"role": message.role, "content": message.content} for message in messages] - # Update request parameters for compatability with o1 model series - # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations - stream = True - model_kwargs["stream_options"] = {"include_usage": True} - if model_name == "o1": + # Tune reasoning models arguments + if model_name.startswith("o1"): temperature = 1 - stream = False - model_kwargs.pop("stream_options", None) - elif model_name.startswith("o1-"): + elif model_name.startswith("o3"): temperature = 1 - model_kwargs.pop("response_format", None) - elif model_name.startswith("o3-"): - temperature = 1 - # Get the first system message and add the string `Formatting re-enabled` to it. See https://platform.openai.com/docs/guides/reasoning-best-practices + # Get the first system message and add the string `Formatting re-enabled` to it. + # See https://platform.openai.com/docs/guides/reasoning-best-practices if len(formatted_messages) > 0: system_messages = [ (i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system" @@ -195,7 +181,6 @@ def llm_thread( formatted_messages[first_system_message_index][ "content" ] = f"{first_system_message} Formatting re-enabled" - elif model_name.startswith("deepseek-reasoner"): # Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role. # The first message should always be a user message (except system message). @@ -210,6 +195,8 @@ def llm_thread( formatted_messages = updated_messages + stream = True + model_kwargs["stream_options"] = {"include_usage": True} if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) From 4a4d225455f9b0962df422d557019bb4d7c3d34a Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 19 Mar 2025 19:28:54 +0530 Subject: [PATCH 2/6] Only enforce json output in supported AI model APIs Deepseek reasoner does not support json object or schema via deepseek API Azure Ai API does not support json schema Resolves #1126 --- src/khoj/processor/conversation/openai/gpt.py | 5 ++++- src/khoj/processor/conversation/openai/utils.py | 12 ++++++++++++ src/khoj/processor/conversation/utils.py | 6 ++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 389f52ab..18eaea47 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -10,8 +10,10 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, completion_with_backoff, + get_openai_api_json_support, ) from khoj.processor.conversation.utils import ( + JsonSupport, clean_json, construct_structured_message, generate_chatml_messages_with_context, @@ -126,13 +128,14 @@ def send_message_to_model( """ # Get Response from GPT + json_support = get_openai_api_json_support(model, api_base_url) return completion_with_backoff( messages=messages, model_name=model, openai_api_key=api_key, temperature=temperature, api_base_url=api_base_url, - model_kwargs={"response_format": {"type": response_type}}, + model_kwargs={"response_format": {"type": response_type}} if json_support >= JsonSupport.OBJECT else {}, tracer=tracer, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 88d75763..f80c446a 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -2,6 +2,7 @@ import logging import os from threading import Thread from typing import Dict, List +from urllib.parse import urlparse import openai from openai.types.chat.chat_completion import ChatCompletion @@ -16,6 +17,7 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ( + JsonSupport, ThreadedGenerator, commit_conversation_trace, ) @@ -245,3 +247,13 @@ def llm_thread( logger.error(f"Error in llm_thread: {e}", exc_info=True) finally: g.close() + + +def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: + if model_name.startswith("deepseek-reasoner"): + return JsonSupport.NONE + if api_base_url: + host = urlparse(api_base_url).hostname + if host and host.endswith(".ai.azure.com"): + return JsonSupport.OBJECT + return JsonSupport.SCHEMA diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index a7e6e694..de82f067 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -878,3 +878,9 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str: return str(content) return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages]) + + +class JsonSupport(int, Enum): + NONE = 0 + OBJECT = 1 + SCHEMA = 2 From ac4b36b9fd91005eefd1a6405693f70008214414 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 19 Mar 2025 17:46:26 +0530 Subject: [PATCH 3/6] Support constraining OpenAI model output to specified response schema --- src/khoj/processor/conversation/openai/gpt.py | 19 +++++++++++--- .../processor/conversation/openai/utils.py | 25 ++++++------------- src/khoj/routers/helpers.py | 2 ++ 3 files changed, 26 insertions(+), 20 deletions(-) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 18eaea47..f087fc93 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -121,21 +121,34 @@ def extract_questions( def send_message_to_model( - messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {} + messages, + api_key, + model, + response_type="text", + response_schema=None, + api_base_url=None, + temperature=0, + tracer: dict = {}, ): """ Send message to model """ - # Get Response from GPT + model_kwargs = {} json_support = get_openai_api_json_support(model, api_base_url) + if response_schema and json_support == JsonSupport.SCHEMA: + model_kwargs["response_format"] = response_schema + elif response_type == "json_object" and json_support == JsonSupport.OBJECT: + model_kwargs["response_format"] = {"type": response_type} + + # Get Response from GPT return completion_with_backoff( messages=messages, model_name=model, openai_api_key=api_key, temperature=temperature, api_base_url=api_base_url, - model_kwargs={"response_format": {"type": response_type}} if json_support >= JsonSupport.OBJECT else {}, + model_kwargs=model_kwargs, tracer=tracer, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index f80c446a..25ddd60a 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -67,33 +67,24 @@ def completion_with_backoff( temperature = 1 model_kwargs["reasoning_effort"] = "medium" - stream = True model_kwargs["stream_options"] = {"include_usage": True} if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) - chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( + aggregated_response = "" + with client.beta.chat.completions.stream( messages=formatted_messages, # type: ignore - model=model_name, # type: ignore - stream=stream, + model=model_name, temperature=temperature, timeout=20, **model_kwargs, - ) - - aggregated_response = "" - if not stream: - chunk = chat - aggregated_response = chunk.choices[0].message.content - else: + ) as chat: for chunk in chat: - if len(chunk.choices) == 0: + if chunk.type == "error": + logger.error(f"Openai api response error: {chunk.error}", exc_info=True) continue - delta_chunk = chunk.choices[0].delta # type: ignore - if isinstance(delta_chunk, str): - aggregated_response += delta_chunk - elif delta_chunk.content: - aggregated_response += delta_chunk.content + elif chunk.type == "content.delta": + aggregated_response += chunk.delta # Calculate cost of chat input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 75f38948..4f3a0fcc 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1209,6 +1209,7 @@ async def send_message_to_model_wrapper( api_key=api_key, model=chat_model_name, response_type=response_type, + response_schema=response_schema, api_base_url=api_base_url, tracer=tracer, ) @@ -1326,6 +1327,7 @@ def send_message_to_model_wrapper_sync( api_base_url=api_base_url, model=chat_model_name, response_type=response_type, + response_schema=response_schema, tracer=tracer, ) From 6980014838c76d4ff6248c2c73df8582b1622f8c Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 19 Mar 2025 16:10:24 +0530 Subject: [PATCH 4/6] Support constraining Gemini model output to specified response schema If the response_schema argument is passed to send_message_to_model_wrapper it is used to constrain output by Gemini models --- src/khoj/processor/conversation/google/gemini_chat.py | 2 ++ src/khoj/processor/conversation/google/utils.py | 1 + src/khoj/routers/helpers.py | 4 ++++ 3 files changed, 7 insertions(+) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 77cff325..7f18b079 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -121,6 +121,7 @@ def gemini_send_message_to_model( api_key, model, response_type="text", + response_schema=None, temperature=0.6, model_kwargs=None, tracer={}, @@ -135,6 +136,7 @@ def gemini_send_message_to_model( # This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series. if response_type == "json_object" and model in ["gemini-2.0-flash"]: model_kwargs["response_mime_type"] = "application/json" + model_kwargs["response_schema"] = response_schema # Get Response from Gemini return gemini_completion_with_backoff( diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index ebe91527..b1a5fe77 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -66,6 +66,7 @@ def gemini_completion_with_backoff( max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI, safety_settings=SAFETY_SETTINGS, response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain", + response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None, ) formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages] diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 4f3a0fcc..a339642a 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1129,6 +1129,7 @@ async def send_message_to_model_wrapper( query: str, system_message: str = "", response_type: str = "text", + response_schema: BaseModel = None, deepthought: bool = False, user: KhojUser = None, query_images: List[str] = None, @@ -1256,6 +1257,7 @@ async def send_message_to_model_wrapper( api_key=api_key, model=chat_model_name, response_type=response_type, + response_schema=response_schema, tracer=tracer, ) else: @@ -1266,6 +1268,7 @@ def send_message_to_model_wrapper_sync( message: str, system_message: str = "", response_type: str = "text", + response_schema: BaseModel = None, user: KhojUser = None, query_images: List[str] = None, query_files: str = "", @@ -1372,6 +1375,7 @@ def send_message_to_model_wrapper_sync( api_key=api_key, model=chat_model_name, response_type=response_type, + response_schema=response_schema, tracer=tracer, ) else: From 2c53eb9de15035d5a65b2839e1e1cf564bcbdac1 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 19 Mar 2025 01:07:35 +0530 Subject: [PATCH 5/6] Use json schema to enforce research mode tool pick format --- src/khoj/routers/research.py | 44 ++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 07a66002..6e2493c5 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -1,10 +1,12 @@ import logging import os from datetime import datetime -from typing import Callable, Dict, List, Optional +from enum import Enum +from typing import Callable, Dict, List, Optional, Type import yaml from fastapi import Request +from pydantic import BaseModel, Field from khoj.database.adapters import EntryAdapters from khoj.database.models import Agent, KhojUser @@ -36,6 +38,40 @@ from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) +class PlanningResponse(BaseModel): + """ + Schema for the response from planning agent when deciding the next tool to pick. + The tool field is dynamically validated based on available tools. + """ + + scratchpad: str = Field(..., description="Reasoning about which tool to use next") + query: str = Field(..., description="Detailed query for the selected tool") + + class Config: + arbitrary_types_allowed = True + + @classmethod + def create_model_with_enum(cls: Type["PlanningResponse"], tool_options: dict) -> Type["PlanningResponse"]: + """ + Factory method that creates a customized PlanningResponse model + with a properly typed tool field based on available tools. + + Args: + tool_options: Dictionary mapping tool names to values + + Returns: + A customized PlanningResponse class + """ + # Create dynamic enum from tool options + tool_enum = Enum("ToolEnum", tool_options) # type: ignore + + # Create and return a customized response model with the enum + class PlanningResponseWithTool(PlanningResponse): + tool: tool_enum = Field(..., description="Name of the tool to use") + + return PlanningResponseWithTool + + async def apick_next_tool( query: str, conversation_history: dict, @@ -61,10 +97,13 @@ async def apick_next_tool( # Skip showing Notes tool as an option if user has no entries if tool == ConversationCommand.Notes and not user_has_entries: continue - tool_options[tool.value] = description if len(agent_tools) == 0 or tool.value in agent_tools: + tool_options[tool.name] = tool.value tool_options_str += f'- "{tool.value}": "{description}"\n' + # Create planning reponse model with dynamically populated tool enum class + planning_response_model = PlanningResponse.create_model_with_enum(tool_options) + # Construct chat history with user and iteration history with researcher agent for context chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj") previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) @@ -96,6 +135,7 @@ async def apick_next_tool( query=query, context=function_planning_prompt, response_type="json_object", + response_schema=planning_response_model, deepthought=True, user=user, query_images=query_images, From a5627ef787681b29a2d8a4e7829fdd9ed3382dc2 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 19 Mar 2025 14:31:39 +0530 Subject: [PATCH 6/6] Use json schema to enforce generate online queries format --- src/khoj/routers/helpers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a339642a..caca0eaa 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -540,11 +540,15 @@ async def generate_online_subqueries( agent_chat_model = agent.chat_model if agent else None + class OnlineQueries(BaseModel): + queries: List[str] + with timer("Chat actor: Generate online search subqueries", logger): response = await send_message_to_model_wrapper( online_queries_prompt, query_images=query_images, response_type="json_object", + response_schema=OnlineQueries, user=user, query_files=query_files, agent_chat_model=agent_chat_model,