From 490f0a435dd46b8f12b038c0b967bda9014268ee Mon Sep 17 00:00:00 2001 From: Debanjum Date: Fri, 13 Jun 2025 01:04:07 -0700 Subject: [PATCH] Pass research tools directly with their varied args for flexibility Why --- Previously researcher had a uniform response schema to pick next tool, scratchpad, query and tool. This didn't allow choosing different arguments for the different tools being called. And the tool call, result format passed by khoj was custom and static across all LLMs. Passing the tools and their schemas directly to llm when picking next tool allows passing multiple, tool specific arguments for llm to select. For example, model can choose webpage urls to read or image gen aspect ratio (apart from tool query) to pass to the specific tool. Using the LLM tool calling paradigm allows model to see tool call, tool result in a format that it understands best. Using standard tool calling paradigm also allows for incorporating community builts tools more easily via MCP servers, clients tools, native llm api tools etc. What --- - Return ResponseWithThought from completion_with_backoff ai model provider methods - Show reasoning model thoughts in research mode train of thought. For non-reasoning models do not show researcher train of thought. As non-reasoning models don't (by default) think before selecing tool. Showing tool call is lame and resembles tool's action shown in next step. - Store tool calls in standardized format. - Specify tool schemas in tool for research llm definitions as well. - Transform tool calls, tool results to standardized form for use within khoj. Manage the following tool call, result transformations: - Model provider tool_call -> standardized tool call - Standardized tool call, result -> model specific tool call, result - Make researcher choose webpages urls to read as well for the webpage tool. Previously it would just decide the query but let the webpage reader infer the query url(s). But researcher has better context on which webpages it wants to have read to answer their query. This should eliminate the webpage reader deciding urls to read step and speed up webpage read tool use. Handle unset response thoughts. Useful when retry on failed request Previously resulted in unbound local variable response_thoughts error --- .../processor/conversation/anthropic/utils.py | 68 ++++- .../processor/conversation/google/utils.py | 50 +++- .../conversation/offline/chat_model.py | 2 +- src/khoj/processor/conversation/openai/gpt.py | 3 +- .../processor/conversation/openai/utils.py | 58 ++++- src/khoj/processor/conversation/prompts.py | 9 - src/khoj/processor/conversation/utils.py | 195 ++++++-------- .../processor/operator/grounding_agent.py | 2 +- .../operator/operator_agent_binary.py | 7 +- src/khoj/processor/tools/online_search.py | 18 ++ src/khoj/processor/tools/run_code.py | 2 +- src/khoj/routers/helpers.py | 54 ++-- src/khoj/routers/research.py | 242 +++++++----------- src/khoj/utils/helpers.py | 179 ++++++++++++- 14 files changed, 535 insertions(+), 354 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index a9fd1db3..09506c9e 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -16,13 +16,14 @@ from tenacity import ( from khoj.processor.conversation.utils import ( ResponseWithThought, - ToolDefinition, + ToolCall, commit_conversation_trace, - create_tool_definition, get_image_from_base64, get_image_from_url, ) from khoj.utils.helpers import ( + ToolDefinition, + create_tool_definition, get_anthropic_async_client, get_anthropic_client, get_chat_usage_metrics, @@ -60,7 +61,7 @@ def anthropic_completion_with_backoff( tools: List[ToolDefinition] = None, deepthought: bool = False, tracer: dict = {}, -) -> str: +) -> ResponseWithThought: client = anthropic_clients.get(api_key) if not client: client = get_anthropic_client(api_key, api_base_url) @@ -68,6 +69,7 @@ def anthropic_completion_with_backoff( formatted_messages, system = format_messages_for_anthropic(messages, system_prompt) + thoughts = "" aggregated_response = "" final_message = None model_kwargs = model_kwargs or dict() @@ -107,15 +109,30 @@ def anthropic_completion_with_backoff( max_tokens=max_tokens, **(model_kwargs), ) as stream: - for text in stream.text_stream: - aggregated_response += text + for chunk in stream: + if chunk.type != "content_block_delta": + continue + if chunk.delta.type == "thinking_delta": + thoughts += chunk.delta.thinking + elif chunk.delta.type == "text_delta": + aggregated_response += chunk.delta.text final_message = stream.get_final_message() - # Extract first tool call from final message - for item in final_message.content: - if item.type == "tool_use": - aggregated_response = json.dumps([{"name": item.name, "args": item.input}]) - break + # Extract all tool calls if tools are enabled + if tools: + tool_calls = [ + ToolCall(name=item.name, args=item.input, id=item.id).__dict__ + for item in final_message.content + if item.type == "tool_use" + ] + if tool_calls: + aggregated_response = json.dumps(tool_calls) + # If response schema is used, return the first tool call's input + elif response_schema: + for item in final_message.content: + if item.type == "tool_use": + aggregated_response = json.dumps(item.input) + break # Calculate cost of chat input_tokens = final_message.usage.input_tokens @@ -137,7 +154,7 @@ def anthropic_completion_with_backoff( if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) - return aggregated_response + return ResponseWithThought(response=aggregated_response, thought=thoughts) @retry( @@ -269,7 +286,34 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st # Convert image urls to base64 encoded images in Anthropic message format for message in messages: - if isinstance(message.content, list): + # Handle tool call and tool result message types from additional_kwargs + message_type = message.additional_kwargs.get("message_type") + if message_type == "tool_call": + # Convert tool_call to Anthropic tool_use format + content = [] + for part in message.content: + content.append( + { + "type": "tool_use", + "id": part.pop("id"), + "name": part.pop("name"), + "input": part, + } + ) + message.content = content + elif message_type == "tool_result": + # Convert tool_result to Anthropic tool_result format + content = [] + for part in message.content: + content.append( + { + "type": "tool_result", + "tool_use_id": part["id"], + "content": part["content"], + } + ) + message.content = content + elif isinstance(message.content, list): content = [] # Sort the content. Anthropic models prefer that text comes after images. message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 14eb5119..c0ee6d1a 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -23,12 +23,13 @@ from tenacity import ( from khoj.processor.conversation.utils import ( ResponseWithThought, - ToolDefinition, + ToolCall, commit_conversation_trace, get_image_from_base64, get_image_from_url, ) from khoj.utils.helpers import ( + ToolDefinition, get_chat_usage_metrics, get_gemini_client, is_none_or_empty, @@ -100,7 +101,7 @@ def gemini_completion_with_backoff( model_kwargs={}, deepthought=False, tracer={}, -) -> str: +) -> ResponseWithThought: client = gemini_clients.get(api_key) if not client: client = get_gemini_client(api_key, api_base_url) @@ -119,7 +120,7 @@ def gemini_completion_with_backoff( thinking_config = None if deepthought and is_reasoning_model(model_name): - thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI) + thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI, include_thoughts=True) max_output_tokens = MAX_OUTPUT_TOKENS_FOR_STANDARD_GEMINI if is_reasoning_model(model_name): @@ -145,12 +146,16 @@ def gemini_completion_with_backoff( response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages) if response.function_calls: function_calls = [ - {"name": function_call.name, "args": function_call.args} for function_call in response.function_calls + ToolCall(name=function_call.name, args=function_call.args, id=function_call.id).__dict__ + for function_call in response.function_calls ] response_text = json.dumps(function_calls) else: # If no function calls, use the text response response_text = response.text + response_thoughts = "\n".join( + [part.text for part in response.candidates[0].content.parts if part.thought and isinstance(part.text, str)] + ) except gerrors.ClientError as e: response = None response_text, _ = handle_gemini_response(e.args) @@ -164,8 +169,14 @@ def gemini_completion_with_backoff( input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0 output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0 thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0 + cache_read_tokens = response.usage_metadata.cached_content_token_count or 0 if response else 0 tracer["usage"] = get_chat_usage_metrics( - model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") + model_name, + input_tokens, + output_tokens, + cache_read_tokens=cache_read_tokens, + thought_tokens=thought_tokens, + usage=tracer.get("usage"), ) # Validate the response. If empty, raise an error to retry. @@ -179,7 +190,7 @@ def gemini_completion_with_backoff( if is_promptrace_enabled(): commit_conversation_trace(messages, response_text, tracer) - return response_text + return ResponseWithThought(response=response_text, thought=response_thoughts) @retry( @@ -359,8 +370,28 @@ def format_messages_for_gemini( system_prompt = None if is_none_or_empty(system_prompt) else system_prompt for message in messages: + if message.role == "assistant": + message.role = "model" + + # Handle tool call and tool result message types from additional_kwargs + message_type = message.additional_kwargs.get("message_type") + if message_type == "tool_call": + # Convert tool_call to Gemini function call format + tool_call_msg_content = [] + for part in message.content: + tool_call_msg_content.append(gtypes.Part.from_function_call(name=part["name"], args=part["args"])) + message.content = tool_call_msg_content + elif message_type == "tool_result": + # Convert tool_result to Gemini function response format + # Need to find the corresponding function call from previous messages + tool_result_msg_content = [] + for part in message.content: + tool_result_msg_content.append( + gtypes.Part.from_function_response(name=part["name"], response={"result": part["content"]}) + ) + message.content = tool_result_msg_content # Convert message content to string list from chatml dictionary list - if isinstance(message.content, list): + elif isinstance(message.content, list): # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) message_content = [] for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1): @@ -380,16 +411,13 @@ def format_messages_for_gemini( messages.remove(message) continue message.content = message_content - elif isinstance(message.content, str): + elif isinstance(message.content, str) and message.content.strip(): message.content = [gtypes.Part.from_text(text=message.content)] else: logger.error(f"Dropping invalid type: {type(message.content)} of message content: {message.content}") messages.remove(message) continue - if message.role == "assistant": - message.role = "model" - if len(messages) == 1: messages[0].role = "user" diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index ddcdc569..28893290 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -221,4 +221,4 @@ def send_message_to_model_offline( if is_promptrace_enabled(): commit_conversation_trace(messages, response_text, tracer) - return response_text + return ResponseWithThought(response=response_text) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index f123ad63..6a644074 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -15,11 +15,10 @@ from khoj.processor.conversation.utils import ( OperatorRun, ResponseWithThought, StructuredOutputSupport, - ToolDefinition, generate_chatml_messages_with_context, messages_to_print, ) -from khoj.utils.helpers import is_none_or_empty, truncate_code_context +from khoj.utils.helpers import ToolDefinition, is_none_or_empty, truncate_code_context from khoj.utils.rawconfig import FileAttachment, LocationData from khoj.utils.yaml import yaml_dump diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index b0c311b7..fd8edf51 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -35,10 +35,11 @@ from tenacity import ( from khoj.processor.conversation.utils import ( ResponseWithThought, StructuredOutputSupport, - ToolDefinition, + ToolCall, commit_conversation_trace, ) from khoj.utils.helpers import ( + ToolDefinition, convert_image_data_uri, get_chat_usage_metrics, get_openai_async_client, @@ -76,7 +77,7 @@ def completion_with_backoff( deepthought: bool = False, model_kwargs: dict = {}, tracer: dict = {}, -) -> str: +) -> ResponseWithThought: client_key = f"{openai_api_key}--{api_base_url}" client = openai_clients.get(client_key) if not client: @@ -121,6 +122,9 @@ def completion_with_backoff( if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) + tool_ids = [] + tool_calls: list[ToolCall] = [] + thoughts = "" aggregated_response = "" if stream: with client.beta.chat.completions.stream( @@ -134,9 +138,16 @@ def completion_with_backoff( if chunk.type == "content.delta": aggregated_response += chunk.delta elif chunk.type == "thought.delta": - pass + thoughts += chunk.delta + elif chunk.type == "chunk" and chunk.chunk.choices and chunk.chunk.choices[0].delta.tool_calls: + tool_ids += [tool_call.id for tool_call in chunk.chunk.choices[0].delta.tool_calls] elif chunk.type == "tool_calls.function.arguments.done": - aggregated_response = json.dumps([{"name": chunk.name, "args": chunk.arguments}]) + tool_calls += [ToolCall(name=chunk.name, args=json.loads(chunk.arguments), id=None)] + if tool_calls: + tool_calls = [ + ToolCall(name=chunk.name, args=chunk.args, id=tool_id) for chunk, tool_id in zip(tool_calls, tool_ids) + ] + aggregated_response = json.dumps([tool_call.__dict__ for tool_call in tool_calls]) else: # Non-streaming chat completion chunk = client.beta.chat.completions.parse( @@ -170,7 +181,7 @@ def completion_with_backoff( if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) - return aggregated_response + return ResponseWithThought(response=aggregated_response, thought=thoughts) @retry( @@ -354,6 +365,43 @@ def format_message_for_api(messages: List[ChatMessage], api_base_url: str) -> Li """ formatted_messages = [] for message in deepcopy(messages): + # Handle tool call and tool result message types + message_type = message.additional_kwargs.get("message_type") + if message_type == "tool_call": + # Convert tool_call to OpenAI function call format + content = [] + for part in message.content: + content.append( + { + "type": "function", + "id": part.get("id"), + "function": { + "name": part.get("name"), + "arguments": json.dumps(part.get("input", part.get("args", {}))), + }, + } + ) + formatted_messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": content, + } + ) + continue + if message_type == "tool_result": + # Convert tool_result to OpenAI tool result format + # Each part is a result for a tool call + for part in message.content: + formatted_messages.append( + { + "role": "tool", + "tool_call_id": part.get("id") or part.get("tool_use_id"), + "name": part.get("name"), + "content": part.get("content"), + } + ) + continue if isinstance(message.content, list) and not is_openai_api(api_base_url): assistant_texts = [] has_images = False diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index c094c5d3..94e9c4c1 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -715,15 +715,6 @@ Given the results of your previous iterations, which tool AI will you use next t """.strip() ) -previous_iteration = PromptTemplate.from_template( - """ -# Iteration {index}: -- tool: {tool} -- query: {query} -- result: {result} -""".strip() -) - pick_relevant_tools = PromptTemplate.from_template( """ You are Khoj, an extremely smart and helpful search assistant. diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 4b5f0f38..4a6d5a01 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -6,13 +6,11 @@ import mimetypes import os import re import uuid -from copy import deepcopy from dataclasses import dataclass from datetime import datetime from enum import Enum from io import BytesIO -from textwrap import dedent -from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Union import PIL.Image import pyjson5 @@ -139,11 +137,17 @@ class OperatorRun: } +class ToolCall: + def __init__(self, name: str, args: dict, id: str): + self.name = name + self.args = args + self.id = id + + class ResearchIteration: def __init__( self, - tool: str, - query: str, + query: ToolCall | dict | str, context: list = None, onlineContext: dict = None, codeContext: dict = None, @@ -151,8 +155,7 @@ class ResearchIteration: summarizedResult: str = None, warning: str = None, ): - self.tool = tool - self.query = query + self.query = ToolCall(**query) if isinstance(query, dict) else query self.context = context self.onlineContext = onlineContext self.codeContext = codeContext @@ -162,37 +165,43 @@ class ResearchIteration: def to_dict(self) -> dict: data = vars(self).copy() + data["query"] = self.query.__dict__ if isinstance(self.query, ToolCall) else self.query data["operatorContext"] = self.operatorContext.to_dict() if self.operatorContext else None return data def construct_iteration_history( previous_iterations: List[ResearchIteration], - previous_iteration_prompt: str, query: str = None, + query_images: List[str] = None, + query_files: str = None, ) -> list[ChatMessageModel]: iteration_history: list[ChatMessageModel] = [] - previous_iteration_messages: list[dict] = [] - for idx, iteration in enumerate(previous_iterations): - iteration_data = previous_iteration_prompt.format( - tool=iteration.tool, - query=iteration.query, - result=iteration.summarizedResult, - index=idx + 1, - ) + query_message_content = construct_structured_message(query, query_images, attached_file_context=query_files) + if query_message_content: + iteration_history.append(ChatMessageModel(by="you", message=query_message_content)) - previous_iteration_messages.append({"type": "text", "text": iteration_data}) - - if previous_iteration_messages: - if query: - iteration_history.append(ChatMessageModel(by="you", message=query)) - iteration_history.append( + for iteration in previous_iterations: + iteration_history += [ ChatMessageModel( by="khoj", - intent=Intent(type="remember", query=query), - message=previous_iteration_messages, - ) - ) + message=[iteration.query.__dict__], + intent=Intent(type="tool_call", query=query), + ), + ChatMessageModel( + by="you", + intent=Intent(type="tool_result"), + message=[ + { + "type": "tool_result", + "id": iteration.query.id, + "name": iteration.query.name, + "content": iteration.summarizedResult, + } + ], + ), + ] + return iteration_history @@ -319,18 +328,18 @@ def construct_tool_chat_history( # If no tool is provided, use inferred query extractor for the tool used in the iteration # Fallback to base extractor if the tool does not have an inferred query extractor inferred_query_extractor = extract_inferred_query_map.get( - tool or ConversationCommand(iteration.tool), base_extractor + tool or ConversationCommand(iteration.query.name), base_extractor ) chat_history += [ ChatMessageModel( by="you", - message=iteration.query, + message=yaml.dump(iteration.query.args, default_flow_style=False), ), ChatMessageModel( by="khoj", intent=Intent( type="remember", - query=iteration.query, + query=yaml.dump(iteration.query.args, default_flow_style=False), inferred_queries=inferred_query_extractor(iteration), memory_type="notes", ), @@ -483,28 +492,32 @@ Khoj: "{chat_response}" def construct_structured_message( message: list[dict] | str, - images: list[str], - model_type: str, - vision_enabled: bool, + images: list[str] = None, + model_type: str = None, + vision_enabled: bool = True, attached_file_context: str = None, ): """ - Format messages into appropriate multimedia format for supported chat model types + Format messages into appropriate multimedia format for supported chat model types. + + Assume vision is enabled and chat model provider supports messages in chatml format, unless specified otherwise. """ - if model_type in [ + if not model_type or model_type in [ ChatModel.ModelType.OPENAI, ChatModel.ModelType.GOOGLE, ChatModel.ModelType.ANTHROPIC, ]: - constructed_messages: List[dict[str, Any]] = ( - [{"type": "text", "text": message}] if isinstance(message, str) else message - ) - + constructed_messages: List[dict[str, Any]] = [] + if not is_none_or_empty(message): + constructed_messages += [{"type": "text", "text": message}] if isinstance(message, str) else message + # Drop image message passed by caller if chat model does not have vision enabled + if not vision_enabled: + constructed_messages = [m for m in constructed_messages if m.get("type") != "image_url"] if not is_none_or_empty(attached_file_context): - constructed_messages.append({"type": "text", "text": attached_file_context}) + constructed_messages += [{"type": "text", "text": attached_file_context}] if vision_enabled and images: for image in images: - constructed_messages.append({"type": "image_url", "image_url": {"url": image}}) + constructed_messages += [{"type": "image_url", "image_url": {"url": image}}] return constructed_messages message = message if isinstance(message, str) else "\n\n".join(m["text"] for m in message) @@ -640,7 +653,11 @@ def generate_chatml_messages_with_context( chat_message, chat.images if role == "user" else [], model_type, vision_enabled ) - reconstructed_message = ChatMessage(content=message_content, role=role) + reconstructed_message = ChatMessage( + content=message_content, + role=role, + additional_kwargs={"message_type": chat.intent.type if chat.intent else None}, + ) chatml_messages.insert(0, reconstructed_message) if len(chatml_messages) >= 3 * lookback_turns: @@ -739,10 +756,21 @@ def count_tokens( message_content_parts: list[str] = [] # Collate message content into single string to ease token counting for part in message_content: - if isinstance(part, dict) and part.get("type") == "text": - message_content_parts.append(part["text"]) - elif isinstance(part, dict) and part.get("type") == "image_url": + if isinstance(part, dict) and part.get("type") == "image_url": image_count += 1 + elif isinstance(part, dict) and part.get("type") == "text": + message_content_parts.append(part["text"]) + elif isinstance(part, dict) and hasattr(part, "model_dump"): + message_content_parts.append(json.dumps(part.model_dump())) + elif isinstance(part, dict) and hasattr(part, "__dict__"): + message_content_parts.append(json.dumps(part.__dict__)) + elif isinstance(part, dict): + # If part is a dict but not a recognized type, convert to JSON string + try: + message_content_parts.append(json.dumps(part)) + except (TypeError, ValueError) as e: + logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.") + image_count += 1 # Treat as an image/binary if serialization fails elif isinstance(part, str): message_content_parts.append(part) else: @@ -1162,82 +1190,3 @@ class ResponseWithThought: def __init__(self, response: str = None, thought: str = None): self.response = response self.thought = thought - - -class ToolDefinition: - def __init__(self, name: str, description: str, schema: dict): - self.name = name - self.description = description - self.schema = schema - - -def create_tool_definition( - schema: Type[BaseModel], - name: str = None, - description: Optional[str] = None, -) -> ToolDefinition: - """ - Converts a response schema BaseModel class into a normalized tool definition. - - A standard AI provider agnostic tool format to specify tools the model can use. - Common logic used across models is kept here. AI provider specific adaptations - should be handled in provider code. - - Args: - response_schema: The Pydantic BaseModel class to convert. - This class defines the response schema for the tool. - tool_name: The name for the AI model tool (e.g., "get_weather", "plan_next_step"). - tool_description: Optional description for the AI model tool. - If None, it attempts to use the Pydantic model's docstring. - If that's also missing, a fallback description is generated. - - Returns: - A normalized tool definition for AI model APIs. - """ - raw_schema_dict = schema.model_json_schema() - - name = name or schema.__name__.lower() - description = description - if description is None: - docstring = schema.__doc__ - if docstring: - description = dedent(docstring).strip() - else: - # Fallback description if no explicit one or docstring is provided - description = f"Tool named '{name}' accepts specified parameters." - - # Process properties to inline enums and remove $defs dependency - processed_properties = {} - original_properties = raw_schema_dict.get("properties", {}) - defs = raw_schema_dict.get("$defs", {}) - - for prop_name, prop_schema in original_properties.items(): - current_prop_schema = deepcopy(prop_schema) # Work on a copy - # Check for enums defined directly in the property for simpler direct enum definitions. - if "$ref" in current_prop_schema: - ref_path = current_prop_schema["$ref"] - if ref_path.startswith("#/$defs/"): - def_name = ref_path.split("/")[-1] - if def_name in defs and "enum" in defs[def_name]: - enum_def = defs[def_name] - current_prop_schema["enum"] = enum_def["enum"] - current_prop_schema["type"] = enum_def.get("type", "string") - if "description" not in current_prop_schema and "description" in enum_def: - current_prop_schema["description"] = enum_def["description"] - del current_prop_schema["$ref"] # Remove the $ref as it's been inlined - - processed_properties[prop_name] = current_prop_schema - - # Generate the compiled schema dictionary for the tool definition. - compiled_schema = { - "type": "object", - "properties": processed_properties, - # Generate content in the order in which the schema properties were defined - "property_ordering": list(schema.model_fields.keys()), - } - - # Include 'required' fields if specified in the Pydantic model - if "required" in raw_schema_dict and raw_schema_dict["required"]: - compiled_schema["required"] = raw_schema_dict["required"] - - return ToolDefinition(name=name, description=description, schema=compiled_schema) diff --git a/src/khoj/processor/operator/grounding_agent.py b/src/khoj/processor/operator/grounding_agent.py index 3697391e..16f2d510 100644 --- a/src/khoj/processor/operator/grounding_agent.py +++ b/src/khoj/processor/operator/grounding_agent.py @@ -73,7 +73,7 @@ class GroundingAgent: grounding_user_prompt = self.get_instruction(instruction, self.environment_type) screenshots = [f"data:image/webp;base64,{current_state.screenshot}"] grounding_messages_content = construct_structured_message( - grounding_user_prompt, screenshots, self.model.name, vision_enabled=True + grounding_user_prompt, screenshots, self.model.model_type, vision_enabled=True ) return [{"role": "user", "content": grounding_messages_content}] diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index 8106e9cc..4f78cac7 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -121,7 +121,7 @@ class BinaryOperatorAgent(OperatorAgent): # Construct input for visual reasoner history visual_reasoner_history = self._format_message_for_api(self.messages) try: - natural_language_action = await send_message_to_model_wrapper( + raw_response = await send_message_to_model_wrapper( query=query_text, query_images=query_screenshot, system_message=reasoning_system_prompt, @@ -129,6 +129,7 @@ class BinaryOperatorAgent(OperatorAgent): agent_chat_model=self.reasoning_model, tracer=self.tracer, ) + natural_language_action = raw_response.response if not isinstance(natural_language_action, str) or not natural_language_action.strip(): raise ValueError(f"Natural language action is empty or not a string. Got {natural_language_action}") @@ -255,10 +256,10 @@ class BinaryOperatorAgent(OperatorAgent): # Append summary messages to history trigger_summary = AgentMessage(role="user", content=summarize_prompt) - summary_message = AgentMessage(role="assistant", content=summary) + summary_message = AgentMessage(role="assistant", content=summary.response) self.messages.extend([trigger_summary, summary_message]) - return summary + return summary.response def _compile_response(self, response_content: str | List) -> str: """Compile response content into a string, handling OpenAI message structures.""" diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 78a4117f..5261af51 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -390,7 +390,25 @@ async def read_webpages( query_files=query_files, tracer=tracer, ) + async for result in read_webpages_content( + query, + urls, + user, + send_status_func=send_status_func, + agent=agent, + tracer=tracer, + ): + yield result + +async def read_webpages_content( + query: str, + urls: List[str], + user: KhojUser, + send_status_func: Optional[Callable] = None, + agent: Agent = None, + tracer: dict = {}, +): logger.info(f"Reading web pages at: {urls}") if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(urls)) diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index d9a80f71..e481b955 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -161,7 +161,7 @@ async def generate_python_code( ) # Extract python code wrapped in markdown code blocks from the response - code_blocks = re.findall(r"```(?:python)?\n(.*?)```", response, re.DOTALL) + code_blocks = re.findall(r"```(?:python)?\n(.*?)```", response.response, re.DOTALL) if not code_blocks: raise ValueError("No Python code blocks found in response") diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 5b0b83e3..5a002e41 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -101,7 +101,6 @@ from khoj.processor.conversation.utils import ( OperatorRun, ResearchIteration, ResponseWithThought, - ToolDefinition, clean_json, clean_mermaidjs, construct_chat_history, @@ -121,6 +120,7 @@ from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.helpers import ( LRU, ConversationCommand, + ToolDefinition, get_file_type, in_debug_mode, is_none_or_empty, @@ -304,7 +304,7 @@ async def acreate_title_from_history( with timer("Chat actor: Generate title from conversation history", logger): response = await send_message_to_model_wrapper(title_generation_prompt, user=user) - return response.strip() + return response.response.strip() async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: @@ -316,7 +316,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: with timer("Chat actor: Generate title from query", logger): response = await send_message_to_model_wrapper(title_generation_prompt, user=user) - return response.strip() + return response.response.strip() async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: bool = False) -> Tuple[bool, str]: @@ -340,7 +340,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck ) - response = response.strip() + response = response.response.strip() try: response = json.loads(clean_json(response)) is_safe = str(response.get("safe", "true")).lower() == "true" @@ -419,7 +419,7 @@ async def aget_data_sources_and_output_format( output: str with timer("Chat actor: Infer information sources to refer", logger): - response = await send_message_to_model_wrapper( + raw_response = await send_message_to_model_wrapper( relevant_tools_prompt, response_type="json_object", response_schema=PickTools, @@ -430,7 +430,7 @@ async def aget_data_sources_and_output_format( ) try: - response = clean_json(response) + response = clean_json(raw_response.response) response = json.loads(response) chosen_sources = [s.strip() for s in response.get("source", []) if s.strip()] @@ -507,7 +507,7 @@ async def infer_webpage_urls( links: List[str] = Field(..., min_items=1, max_items=max_webpages) with timer("Chat actor: Infer webpage urls to read", logger): - response = await send_message_to_model_wrapper( + raw_response = await send_message_to_model_wrapper( online_queries_prompt, query_images=query_images, response_type="json_object", @@ -520,7 +520,7 @@ async def infer_webpage_urls( # Validate that the response is a non-empty, JSON-serializable list of URLs try: - response = clean_json(response) + response = clean_json(raw_response.response) urls = json.loads(response) valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)} if is_none_or_empty(valid_unique_urls): @@ -572,7 +572,7 @@ async def generate_online_subqueries( queries: List[str] = Field(..., min_items=1, max_items=max_queries) with timer("Chat actor: Generate online search subqueries", logger): - response = await send_message_to_model_wrapper( + raw_response = await send_message_to_model_wrapper( online_queries_prompt, query_images=query_images, response_type="json_object", @@ -585,7 +585,7 @@ async def generate_online_subqueries( # Validate that the response is a non-empty, JSON-serializable list try: - response = clean_json(response) + response = clean_json(raw_response.response) response = pyjson5.loads(response) response = {q.strip() for q in response["queries"] if q.strip()} if not isinstance(response, set) or not response or len(response) == 0: @@ -646,7 +646,7 @@ async def aschedule_query( # Validate that the response is a non-empty, JSON-serializable list try: - raw_response = raw_response.strip() + raw_response = raw_response.response.strip() response: Dict[str, str] = json.loads(clean_json(raw_response)) if not response or not isinstance(response, Dict) or len(response) != 3: raise AssertionError(f"Invalid response for scheduling query : {response}") @@ -684,7 +684,7 @@ async def extract_relevant_info( agent_chat_model=agent_chat_model, tracer=tracer, ) - return response.strip() + return response.response.strip() async def extract_relevant_summary( @@ -727,7 +727,7 @@ async def extract_relevant_summary( agent_chat_model=agent_chat_model, tracer=tracer, ) - return response.strip() + return response.response.strip() async def generate_summary_from_files( @@ -898,7 +898,7 @@ async def generate_better_diagram_description( agent_chat_model=agent_chat_model, tracer=tracer, ) - response = response.strip() + response = response.response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): response = response[1:-1] @@ -926,10 +926,10 @@ async def generate_excalidraw_diagram_from_description( raw_response = await send_message_to_model_wrapper( query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer ) - raw_response = clean_json(raw_response) + raw_response_text = clean_json(raw_response.response) try: # Expect response to have `elements` and `scratchpad` keys - response: Dict[str, str] = json.loads(raw_response) + response: Dict[str, str] = json.loads(raw_response_text) if ( not response or not isinstance(response, Dict) @@ -938,7 +938,7 @@ async def generate_excalidraw_diagram_from_description( ): raise AssertionError(f"Invalid response for generating Excalidraw diagram: {response}") except Exception: - raise AssertionError(f"Invalid response for generating Excalidraw diagram: {raw_response}") + raise AssertionError(f"Invalid response for generating Excalidraw diagram: {raw_response_text}") if not response or not isinstance(response["elements"], List) or not isinstance(response["elements"][0], Dict): # TODO Some additional validation here that it's a valid Excalidraw diagram raise AssertionError(f"Invalid response for improving diagram description: {response}") @@ -1049,11 +1049,11 @@ async def generate_better_mermaidjs_diagram_description( agent_chat_model=agent_chat_model, tracer=tracer, ) - response = response.strip() - if response.startswith(('"', "'")) and response.endswith(('"', "'")): - response = response[1:-1] + response_text = response.response.strip() + if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")): + response_text = response_text[1:-1] - return response + return response_text async def generate_mermaidjs_diagram_from_description( @@ -1077,7 +1077,7 @@ async def generate_mermaidjs_diagram_from_description( raw_response = await send_message_to_model_wrapper( query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer ) - return clean_mermaidjs(raw_response.strip()) + return clean_mermaidjs(raw_response.response.strip()) async def generate_better_image_prompt( @@ -1152,11 +1152,11 @@ async def generate_better_image_prompt( agent_chat_model=agent_chat_model, tracer=tracer, ) - response = response.strip() - if response.startswith(('"', "'")) and response.endswith(('"', "'")): - response = response[1:-1] + response_text = response.response.strip() + if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")): + response_text = response_text[1:-1] - return response + return response_text async def search_documents( @@ -1330,7 +1330,7 @@ async def extract_questions( # Extract questions from the response try: - response = clean_json(raw_response) + response = clean_json(raw_response.response) response = pyjson5.loads(response) queries = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(queries, list) or not queries: diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index f7467b9f..aa9cdbcc 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -3,11 +3,9 @@ import logging import os from copy import deepcopy from datetime import datetime -from enum import Enum -from typing import Callable, Dict, List, Optional, Type +from typing import Callable, Dict, List, Optional import yaml -from pydantic import BaseModel, Field from khoj.database.adapters import AgentAdapters, EntryAdapters from khoj.database.models import Agent, ChatMessageModel, KhojUser @@ -15,14 +13,13 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( OperatorRun, ResearchIteration, - ToolDefinition, + ToolCall, construct_iteration_history, construct_tool_chat_history, - create_tool_definition, load_complex_json, ) from khoj.processor.operator import operate_environment -from khoj.processor.tools.online_search import read_webpages, search_online +from khoj.processor.tools.online_search import read_webpages_content, search_online from khoj.processor.tools.run_code import run_code from khoj.routers.helpers import ( ChatEvent, @@ -32,10 +29,12 @@ from khoj.routers.helpers import ( ) from khoj.utils.helpers import ( ConversationCommand, + ToolDefinition, + dict_to_tuple, is_none_or_empty, is_operator_enabled, timer, - tool_description_for_research_llm, + tools_for_research_llm, truncate_code_context, ) from khoj.utils.rawconfig import LocationData @@ -43,47 +42,6 @@ from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) -class PlanningResponse(BaseModel): - """ - Schema for the response from planning agent when deciding the next tool to pick. - """ - - scratchpad: str = Field(..., description="Scratchpad to reason about which tool to use next") - - class Config: - arbitrary_types_allowed = True - - @classmethod - def create_model_with_enum(cls: Type["PlanningResponse"], tool_options: dict) -> Type["PlanningResponse"]: - """ - Factory method that creates a customized PlanningResponse model - with a properly typed tool field based on available tools. - - The tool field is dynamically generated based on available tools. - The query field should be filled by the model after the tool field for a more logical reasoning flow. - - Args: - tool_options: Dictionary mapping tool names to values - - Returns: - A customized PlanningResponse class - """ - # Create dynamic enum from tool options - tool_enum = Enum("ToolEnum", tool_options) # type: ignore - - # Create and return a customized response model with the enum - class PlanningResponseWithTool(PlanningResponse): - """ - Use the scratchpad to reason about which tool to use next and the query to send to the tool. - Pick tool from provided options and your query to send to the tool. - """ - - tool: tool_enum = Field(..., description="Name of the tool to use") - query: str = Field(..., description="Detailed query for the selected tool") - - return PlanningResponseWithTool - - async def apick_next_tool( query: str, conversation_history: List[ChatMessageModel], @@ -106,12 +64,13 @@ async def apick_next_tool( # Continue with previous iteration if a multi-step tool use is in progress if ( previous_iterations - and previous_iterations[-1].tool == ConversationCommand.Operator + and previous_iterations[-1].query + and isinstance(previous_iterations[-1].query, ToolCall) + and previous_iterations[-1].query.name == ConversationCommand.Operator and not previous_iterations[-1].summarizedResult ): previous_iteration = previous_iterations[-1] yield ResearchIteration( - tool=previous_iteration.tool, query=query, context=previous_iteration.context, onlineContext=previous_iteration.onlineContext, @@ -122,37 +81,35 @@ async def apick_next_tool( return # Construct tool options for the agent to choose from - tool_options = dict() + tools = [] tool_options_str = "" agent_tools = agent.input_tools if agent else [] user_has_entries = await EntryAdapters.auser_has_entries(user) - for tool, description in tool_description_for_research_llm.items(): + for tool, tool_data in tools_for_research_llm.items(): # Skip showing operator tool as an option if not enabled if tool == ConversationCommand.Operator and not is_operator_enabled(): continue # Skip showing Notes tool as an option if user has no entries - if tool == ConversationCommand.Notes: + elif tool == ConversationCommand.Notes: if not user_has_entries: continue - description = description.format(max_search_queries=max_document_searches) - if tool == ConversationCommand.Webpage: - description = description.format(max_webpages_to_read=max_webpages_to_read) - if tool == ConversationCommand.Online: - description = description.format(max_search_queries=max_online_searches) + description = tool_data.description.format(max_search_queries=max_document_searches) + elif tool == ConversationCommand.Webpage: + description = tool_data.description.format(max_webpages_to_read=max_webpages_to_read) + elif tool == ConversationCommand.Online: + description = tool_data.description.format(max_search_queries=max_online_searches) + else: + description = tool_data.description # Add tool if agent does not have any tools defined or the tool is supported by the agent. if len(agent_tools) == 0 or tool.value in agent_tools: - tool_options[tool.name] = tool.value tool_options_str += f'- "{tool.value}": "{description}"\n' - - # Create planning reponse model with dynamically populated tool enum class - planning_response_model = PlanningResponse.create_model_with_enum(tool_options) - tools = [ - create_tool_definition( - name="infer_information_sources_to_refer", - description="Infer which tool to use next and the query to send to the tool.", - schema=planning_response_model, - ) - ] + tools.append( + ToolDefinition( + name=tool.value, + description=description, + schema=tool_data.schema, + ) + ) today = datetime.today() location_data = f"{location}" if location else "Unknown" @@ -171,24 +128,16 @@ async def apick_next_tool( max_iterations=max_iterations, ) - if query_images: - query = f"[placeholder for user attached images]\n{query}" - # Construct chat history with user and iteration history with researcher agent for context - iteration_chat_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query) + iteration_chat_history = construct_iteration_history(previous_iterations, query, query_images, query_files) chat_and_research_history = conversation_history + iteration_chat_history - # Plan function execution for the next tool - query = prompts.plan_function_execution_next_tool.format(query=query) if iteration_chat_history else query - try: with timer("Chat actor: Infer information sources to refer", logger): raw_response = await send_message_to_model_wrapper( - query=query, + query="", system_message=function_planning_prompt, chat_history=chat_and_research_history, - response_type="json_object", - response_schema=planning_response_model, tools=tools, deepthought=True, user=user, @@ -200,7 +149,6 @@ async def apick_next_tool( except Exception as e: logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True) yield ResearchIteration( - tool=None, query=None, warning="Failed to infer information sources to refer. Skipping iteration. Try again.", ) @@ -208,40 +156,32 @@ async def apick_next_tool( try: # Try parse the response as function call response to infer next tool to use. - response = load_complex_json(load_complex_json(raw_response)[0]["args"]) + # TODO: Handle multiple tool calls. + response_text = raw_response.response + parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0] except Exception as e: - try: - # Else try parse the text response as JSON to infer next tool to use. - response = load_complex_json(raw_response) - except Exception as e: - # Otherwise assume the model has decided to end the research run and respond to the user. - response = {"tool": ConversationCommand.Text, "query": None, "scratchpad": raw_response} + # Otherwise assume the model has decided to end the research run and respond to the user. + parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None) # If we have a valid response, extract the tool and query. - selected_tool = response.get("tool", None) - generated_query = response.get("query", None) - scratchpad = response.get("scratchpad", None) - warning = None - logger.info(f"Response for determining relevant tools: {response}") + logger.info(f"Response for determining relevant tools: {parsed_response.name}({parsed_response.args})") # Detect selection of previously used query, tool combination. - previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations if i.warning is None} - if (selected_tool, generated_query) in previous_tool_query_combinations: + previous_tool_query_combinations = { + (i.query.name, dict_to_tuple(i.query.args)) + for i in previous_iterations + if i.warning is None and isinstance(i.query, ToolCall) + } + if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations: warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different." - # Only send client status updates if we'll execute this iteration - elif send_status_func and scratchpad: - determined_tool_message = "**Determined Tool**: " - determined_tool_message += ( - f"{selected_tool}({generated_query})." if selected_tool != ConversationCommand.Text else "respond." - ) - determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else "" - async for event in send_status_func(f"{scratchpad}"): + # Only send client status updates if we'll execute this iteration and model has thoughts to share. + elif send_status_func and not is_none_or_empty(raw_response.thought): + async for event in send_status_func(raw_response.thought): yield {ChatEvent.STATUS: event} yield ResearchIteration( - tool=selected_tool, - query=generated_query, + query=parsed_response, warning=warning, ) @@ -269,10 +209,10 @@ async def research( MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5)) # Incorporate previous partial research into current research chat history - research_conversation_history = deepcopy(conversation_history) + research_conversation_history = [chat for chat in deepcopy(conversation_history) if chat.message] if current_iteration := len(previous_iterations) > 0: logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.") - previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) + previous_iterations_history = construct_iteration_history(previous_iterations) research_conversation_history += previous_iterations_history while current_iteration < MAX_ITERATIONS: @@ -285,7 +225,7 @@ async def research( code_results: Dict = dict() document_results: List[Dict[str, str]] = [] operator_results: OperatorRun = None - this_iteration = ResearchIteration(tool=None, query=query) + this_iteration = ResearchIteration(query=query) async for result in apick_next_tool( query, @@ -315,26 +255,30 @@ async def research( logger.warning(f"Research mode: {this_iteration.warning}.") # Terminate research if selected text tool or query, tool not set for next iteration - elif not this_iteration.query or not this_iteration.tool or this_iteration.tool == ConversationCommand.Text: + elif ( + not this_iteration.query + or isinstance(this_iteration.query, str) + or this_iteration.query.name == ConversationCommand.Text + ): current_iteration = MAX_ITERATIONS - elif this_iteration.tool == ConversationCommand.Notes: + elif this_iteration.query.name == ConversationCommand.Notes: this_iteration.context = [] document_results = [] previous_inferred_queries = { c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context } async for result in search_documents( - this_iteration.query, - max_document_searches, - None, - user, - construct_tool_chat_history(previous_iterations, ConversationCommand.Notes), - conversation_id, - [ConversationCommand.Default], - location, - send_status_func, - query_images, + **this_iteration.query.args, + n=max_document_searches, + d=None, + user=user, + chat_history=construct_tool_chat_history(previous_iterations, ConversationCommand.Notes), + conversation_id=conversation_id, + conversation_commands=[ConversationCommand.Default], + location_data=location, + send_status_func=send_status_func, + query_images=query_images, previous_inferred_queries=previous_inferred_queries, agent=agent, tracer=tracer, @@ -362,7 +306,7 @@ async def research( else: this_iteration.warning = "No matching document references found" - elif this_iteration.tool == ConversationCommand.Online: + elif this_iteration.query.name == ConversationCommand.Online: previous_subqueries = { subquery for iteration in previous_iterations @@ -371,12 +315,12 @@ async def research( } try: async for result in search_online( - this_iteration.query, - construct_tool_chat_history(previous_iterations, ConversationCommand.Online), - location, - user, - send_status_func, - [], + **this_iteration.query.args, + conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.Online), + location=location, + user=user, + send_status_func=send_status_func, + custom_filters=[], max_online_searches=max_online_searches, max_webpages_to_read=0, query_images=query_images, @@ -395,19 +339,15 @@ async def research( this_iteration.warning = f"Error searching online: {e}" logger.error(this_iteration.warning, exc_info=True) - elif this_iteration.tool == ConversationCommand.Webpage: + elif this_iteration.query.name == ConversationCommand.Webpage: try: - async for result in read_webpages( - this_iteration.query, - construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage), - location, - user, - send_status_func, - max_webpages_to_read=max_webpages_to_read, - query_images=query_images, + async for result in read_webpages_content( + **this_iteration.query.args, + user=user, + send_status_func=send_status_func, + # max_webpages_to_read=max_webpages_to_read, agent=agent, tracer=tracer, - query_files=query_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -428,15 +368,15 @@ async def research( this_iteration.warning = f"Error reading webpages: {e}" logger.error(this_iteration.warning, exc_info=True) - elif this_iteration.tool == ConversationCommand.Code: + elif this_iteration.query.name == ConversationCommand.Code: try: async for result in run_code( - this_iteration.query, - construct_tool_chat_history(previous_iterations, ConversationCommand.Code), - "", - location, - user, - send_status_func, + **this_iteration.query.args, + conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.Code), + context="", + location_data=location, + user=user, + send_status_func=send_status_func, query_images=query_images, agent=agent, query_files=query_files, @@ -453,14 +393,14 @@ async def research( this_iteration.warning = f"Error running code: {e}" logger.warning(this_iteration.warning, exc_info=True) - elif this_iteration.tool == ConversationCommand.Operator: + elif this_iteration.query.name == ConversationCommand.Operator: try: async for result in operate_environment( - this_iteration.query, - user, - construct_tool_chat_history(previous_iterations, ConversationCommand.Operator), - location, - previous_iterations[-1].operatorContext if previous_iterations else None, + **this_iteration.query.args, + user=user, + conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator), + location_data=location, + previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None, send_status_func=send_status_func, query_images=query_images, agent=agent, @@ -493,7 +433,7 @@ async def research( current_iteration += 1 if document_results or online_results or code_results or operator_results or this_iteration.warning: - results_data = f"\n{current_iteration}\n{this_iteration.tool}\n{this_iteration.query}\n" + results_data = f"\n" if document_results: results_data += f"\n\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if online_results: @@ -506,7 +446,7 @@ async def research( ) if this_iteration.warning: results_data += f"\n\n{this_iteration.warning}\n" - results_data += "\n\n" + results_data += f"\n\n" # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) this_iteration.summarizedResult = results_data diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index cdbed998..e4a02320 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -12,6 +12,7 @@ import random import urllib.parse import uuid from collections import OrderedDict +from copy import deepcopy from enum import Enum from functools import lru_cache from importlib import import_module @@ -19,8 +20,9 @@ from importlib.metadata import version from itertools import islice from os import path from pathlib import Path +from textwrap import dedent from time import perf_counter -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Type, Union from urllib.parse import ParseResult, urlparse import anthropic @@ -36,6 +38,7 @@ from google.auth.credentials import Credentials from google.oauth2 import service_account from magika import Magika from PIL import Image +from pydantic import BaseModel from pytz import country_names, country_timezones from khoj.utils import constants @@ -334,6 +337,85 @@ def is_e2b_code_sandbox_enabled(): return not is_none_or_empty(os.getenv("E2B_API_KEY")) +class ToolDefinition: + def __init__(self, name: str, description: str, schema: dict): + self.name = name + self.description = description + self.schema = schema + + +def create_tool_definition( + schema: Type[BaseModel], + name: str = None, + description: Optional[str] = None, +) -> ToolDefinition: + """ + Converts a response schema BaseModel class into a normalized tool definition. + + A standard AI provider agnostic tool format to specify tools the model can use. + Common logic used across models is kept here. AI provider specific adaptations + should be handled in provider code. + + Args: + response_schema: The Pydantic BaseModel class to convert. + This class defines the response schema for the tool. + tool_name: The name for the AI model tool (e.g., "get_weather", "plan_next_step"). + tool_description: Optional description for the AI model tool. + If None, it attempts to use the Pydantic model's docstring. + If that's also missing, a fallback description is generated. + + Returns: + A normalized tool definition for AI model APIs. + """ + raw_schema_dict = schema.model_json_schema() + + name = name or schema.__name__.lower() + description = description + if description is None: + docstring = schema.__doc__ + if docstring: + description = dedent(docstring).strip() + else: + # Fallback description if no explicit one or docstring is provided + description = f"Tool named '{name}' accepts specified parameters." + + # Process properties to inline enums and remove $defs dependency + processed_properties = {} + original_properties = raw_schema_dict.get("properties", {}) + defs = raw_schema_dict.get("$defs", {}) + + for prop_name, prop_schema in original_properties.items(): + current_prop_schema = deepcopy(prop_schema) # Work on a copy + # Check for enums defined directly in the property for simpler direct enum definitions. + if "$ref" in current_prop_schema: + ref_path = current_prop_schema["$ref"] + if ref_path.startswith("#/$defs/"): + def_name = ref_path.split("/")[-1] + if def_name in defs and "enum" in defs[def_name]: + enum_def = defs[def_name] + current_prop_schema["enum"] = enum_def["enum"] + current_prop_schema["type"] = enum_def.get("type", "string") + if "description" not in current_prop_schema and "description" in enum_def: + current_prop_schema["description"] = enum_def["description"] + del current_prop_schema["$ref"] # Remove the $ref as it's been inlined + + processed_properties[prop_name] = current_prop_schema + + # Generate the compiled schema dictionary for the tool definition. + compiled_schema = { + "type": "object", + "properties": processed_properties, + # Generate content in the order in which the schema properties were defined + "property_ordering": list(schema.model_fields.keys()), + } + + # Include 'required' fields if specified in the Pydantic model + if "required" in raw_schema_dict and raw_schema_dict["required"]: + compiled_schema["required"] = raw_schema_dict["required"] + + return ToolDefinition(name=name, description=description, schema=compiled_schema) + + class ConversationCommand(str, Enum): Default = "default" General = "general" @@ -385,13 +467,84 @@ tool_descriptions_for_llm = { ConversationCommand.Operator: "To use when you need to operate a computer to complete the task.", } -tool_description_for_research_llm = { - ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents. Max {max_search_queries} search queries allowed per iteration.", - ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed. Max {max_search_queries} search queries allowed per iteration.", - ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share upto {max_webpages_to_read} webpage links and what information to extract from them in your query.", - ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description, - ConversationCommand.Text: "To respond to the user once you've completed your research and have the required information.", - ConversationCommand.Operator: "To operate a computer to complete the task.", +tools_for_research_llm = { + ConversationCommand.Notes: ToolDefinition( + name="notes", + description="To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents. Max {max_search_queries} search queries allowed per iteration.", + schema={ + "type": "object", + "properties": { + "q": { + "type": "string", + "description": "The query to search in the user's personal knowledge base.", + }, + }, + "required": ["q"], + }, + ), + ConversationCommand.Online: ToolDefinition( + name="online", + description="To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed. Max {max_search_queries} search queries allowed per iteration.", + schema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The query to search on the internet.", + }, + }, + "required": ["query"], + }, + ), + ConversationCommand.Webpage: ToolDefinition( + name="webpage", + description="To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share upto {max_webpages_to_read} webpage links and what information to extract from them in your query.", + schema={ + "type": "object", + "properties": { + "urls": { + "type": "array", + "items": { + "type": "string", + }, + "description": "The webpage URLs to extract information from.", + }, + "query": { + "type": "string", + "description": "The query to extract information from the webpages.", + }, + }, + "required": ["urls", "query"], + }, + ), + ConversationCommand.Code: ToolDefinition( + name="code", + description=e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description, + schema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Detailed query and all input data required to generate, execute code in the sandbox.", + }, + }, + "required": ["query"], + }, + ), + ConversationCommand.Operator: ToolDefinition( + name="operator", + description="To operate a computer to complete the task.", + schema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The task to perform on the computer.", + }, + }, + "required": ["query"], + }, + ), } mode_descriptions_for_llm = { @@ -850,3 +1003,13 @@ def clean_object_for_db(data): return [clean_object_for_db(item) for item in data] else: return data + + +def dict_to_tuple(d): + # Recursively convert dicts to sorted tuples for hashability + if isinstance(d, dict): + return tuple(sorted((k, dict_to_tuple(v)) for k, v in d.items())) + elif isinstance(d, list): + return tuple(dict_to_tuple(i) for i in d) + else: + return d