diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 77cff325..7f18b079 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -121,6 +121,7 @@ def gemini_send_message_to_model( api_key, model, response_type="text", + response_schema=None, temperature=0.6, model_kwargs=None, tracer={}, @@ -135,6 +136,7 @@ def gemini_send_message_to_model( # This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series. if response_type == "json_object" and model in ["gemini-2.0-flash"]: model_kwargs["response_mime_type"] = "application/json" + model_kwargs["response_schema"] = response_schema # Get Response from Gemini return gemini_completion_with_backoff( diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index ebe91527..b1a5fe77 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -66,6 +66,7 @@ def gemini_completion_with_backoff( max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI, safety_settings=SAFETY_SETTINGS, response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain", + response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None, ) formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages] diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 389f52ab..f087fc93 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -10,8 +10,10 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, completion_with_backoff, + get_openai_api_json_support, ) from khoj.processor.conversation.utils import ( + JsonSupport, clean_json, construct_structured_message, generate_chatml_messages_with_context, @@ -119,12 +121,26 @@ def extract_questions( def send_message_to_model( - messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {} + messages, + api_key, + model, + response_type="text", + response_schema=None, + api_base_url=None, + temperature=0, + tracer: dict = {}, ): """ Send message to model """ + model_kwargs = {} + json_support = get_openai_api_json_support(model, api_base_url) + if response_schema and json_support == JsonSupport.SCHEMA: + model_kwargs["response_format"] = response_schema + elif response_type == "json_object" and json_support == JsonSupport.OBJECT: + model_kwargs["response_format"] = {"type": response_type} + # Get Response from GPT return completion_with_backoff( messages=messages, @@ -132,7 +148,7 @@ def send_message_to_model( openai_api_key=api_key, temperature=temperature, api_base_url=api_base_url, - model_kwargs={"response_format": {"type": response_type}}, + model_kwargs=model_kwargs, tracer=tracer, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 444b6541..25ddd60a 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -2,6 +2,7 @@ import logging import os from threading import Thread from typing import Dict, List +from urllib.parse import urlparse import openai from openai.types.chat.chat_completion import ChatCompletion @@ -16,6 +17,7 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ( + JsonSupport, ThreadedGenerator, commit_conversation_trace, ) @@ -60,45 +62,29 @@ def completion_with_backoff( formatted_messages = [{"role": message.role, "content": message.content} for message in messages] - # Update request parameters for compatability with o1 model series - # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations - stream = True - model_kwargs["stream_options"] = {"include_usage": True} - if model_name == "o1": - temperature = 1 - stream = False - model_kwargs.pop("stream_options", None) - elif model_name.startswith("o1"): - temperature = 1 - model_kwargs.pop("response_format", None) - elif model_name.startswith("o3-"): + # Tune reasoning models arguments + if model_name.startswith("o1") or model_name.startswith("o3"): temperature = 1 + model_kwargs["reasoning_effort"] = "medium" + model_kwargs["stream_options"] = {"include_usage": True} if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) - chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( + aggregated_response = "" + with client.beta.chat.completions.stream( messages=formatted_messages, # type: ignore - model=model_name, # type: ignore - stream=stream, + model=model_name, temperature=temperature, timeout=20, **model_kwargs, - ) - - aggregated_response = "" - if not stream: - chunk = chat - aggregated_response = chunk.choices[0].message.content - else: + ) as chat: for chunk in chat: - if len(chunk.choices) == 0: + if chunk.type == "error": + logger.error(f"Openai api response error: {chunk.error}", exc_info=True) continue - delta_chunk = chunk.choices[0].delta # type: ignore - if isinstance(delta_chunk, str): - aggregated_response += delta_chunk - elif delta_chunk.content: - aggregated_response += delta_chunk.content + elif chunk.type == "content.delta": + aggregated_response += chunk.delta # Calculate cost of chat input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 @@ -172,20 +158,13 @@ def llm_thread( formatted_messages = [{"role": message.role, "content": message.content} for message in messages] - # Update request parameters for compatability with o1 model series - # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations - stream = True - model_kwargs["stream_options"] = {"include_usage": True} - if model_name == "o1": + # Tune reasoning models arguments + if model_name.startswith("o1"): temperature = 1 - stream = False - model_kwargs.pop("stream_options", None) - elif model_name.startswith("o1-"): + elif model_name.startswith("o3"): temperature = 1 - model_kwargs.pop("response_format", None) - elif model_name.startswith("o3-"): - temperature = 1 - # Get the first system message and add the string `Formatting re-enabled` to it. See https://platform.openai.com/docs/guides/reasoning-best-practices + # Get the first system message and add the string `Formatting re-enabled` to it. + # See https://platform.openai.com/docs/guides/reasoning-best-practices if len(formatted_messages) > 0: system_messages = [ (i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system" @@ -195,7 +174,6 @@ def llm_thread( formatted_messages[first_system_message_index][ "content" ] = f"{first_system_message} Formatting re-enabled" - elif model_name.startswith("deepseek-reasoner"): # Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role. # The first message should always be a user message (except system message). @@ -210,6 +188,8 @@ def llm_thread( formatted_messages = updated_messages + stream = True + model_kwargs["stream_options"] = {"include_usage": True} if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) @@ -258,3 +238,13 @@ def llm_thread( logger.error(f"Error in llm_thread: {e}", exc_info=True) finally: g.close() + + +def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: + if model_name.startswith("deepseek-reasoner"): + return JsonSupport.NONE + if api_base_url: + host = urlparse(api_base_url).hostname + if host and host.endswith(".ai.azure.com"): + return JsonSupport.OBJECT + return JsonSupport.SCHEMA diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index a7e6e694..de82f067 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -878,3 +878,9 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str: return str(content) return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages]) + + +class JsonSupport(int, Enum): + NONE = 0 + OBJECT = 1 + SCHEMA = 2 diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 75f38948..caca0eaa 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -540,11 +540,15 @@ async def generate_online_subqueries( agent_chat_model = agent.chat_model if agent else None + class OnlineQueries(BaseModel): + queries: List[str] + with timer("Chat actor: Generate online search subqueries", logger): response = await send_message_to_model_wrapper( online_queries_prompt, query_images=query_images, response_type="json_object", + response_schema=OnlineQueries, user=user, query_files=query_files, agent_chat_model=agent_chat_model, @@ -1129,6 +1133,7 @@ async def send_message_to_model_wrapper( query: str, system_message: str = "", response_type: str = "text", + response_schema: BaseModel = None, deepthought: bool = False, user: KhojUser = None, query_images: List[str] = None, @@ -1209,6 +1214,7 @@ async def send_message_to_model_wrapper( api_key=api_key, model=chat_model_name, response_type=response_type, + response_schema=response_schema, api_base_url=api_base_url, tracer=tracer, ) @@ -1255,6 +1261,7 @@ async def send_message_to_model_wrapper( api_key=api_key, model=chat_model_name, response_type=response_type, + response_schema=response_schema, tracer=tracer, ) else: @@ -1265,6 +1272,7 @@ def send_message_to_model_wrapper_sync( message: str, system_message: str = "", response_type: str = "text", + response_schema: BaseModel = None, user: KhojUser = None, query_images: List[str] = None, query_files: str = "", @@ -1326,6 +1334,7 @@ def send_message_to_model_wrapper_sync( api_base_url=api_base_url, model=chat_model_name, response_type=response_type, + response_schema=response_schema, tracer=tracer, ) @@ -1370,6 +1379,7 @@ def send_message_to_model_wrapper_sync( api_key=api_key, model=chat_model_name, response_type=response_type, + response_schema=response_schema, tracer=tracer, ) else: diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 07a66002..6e2493c5 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -1,10 +1,12 @@ import logging import os from datetime import datetime -from typing import Callable, Dict, List, Optional +from enum import Enum +from typing import Callable, Dict, List, Optional, Type import yaml from fastapi import Request +from pydantic import BaseModel, Field from khoj.database.adapters import EntryAdapters from khoj.database.models import Agent, KhojUser @@ -36,6 +38,40 @@ from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) +class PlanningResponse(BaseModel): + """ + Schema for the response from planning agent when deciding the next tool to pick. + The tool field is dynamically validated based on available tools. + """ + + scratchpad: str = Field(..., description="Reasoning about which tool to use next") + query: str = Field(..., description="Detailed query for the selected tool") + + class Config: + arbitrary_types_allowed = True + + @classmethod + def create_model_with_enum(cls: Type["PlanningResponse"], tool_options: dict) -> Type["PlanningResponse"]: + """ + Factory method that creates a customized PlanningResponse model + with a properly typed tool field based on available tools. + + Args: + tool_options: Dictionary mapping tool names to values + + Returns: + A customized PlanningResponse class + """ + # Create dynamic enum from tool options + tool_enum = Enum("ToolEnum", tool_options) # type: ignore + + # Create and return a customized response model with the enum + class PlanningResponseWithTool(PlanningResponse): + tool: tool_enum = Field(..., description="Name of the tool to use") + + return PlanningResponseWithTool + + async def apick_next_tool( query: str, conversation_history: dict, @@ -61,10 +97,13 @@ async def apick_next_tool( # Skip showing Notes tool as an option if user has no entries if tool == ConversationCommand.Notes and not user_has_entries: continue - tool_options[tool.value] = description if len(agent_tools) == 0 or tool.value in agent_tools: + tool_options[tool.name] = tool.value tool_options_str += f'- "{tool.value}": "{description}"\n' + # Create planning reponse model with dynamically populated tool enum class + planning_response_model = PlanningResponse.create_model_with_enum(tool_options) + # Construct chat history with user and iteration history with researcher agent for context chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj") previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) @@ -96,6 +135,7 @@ async def apick_next_tool( query=query, context=function_planning_prompt, response_type="json_object", + response_schema=planning_response_model, deepthought=True, user=user, query_images=query_images,