diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index a9fd1db3..09506c9e 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -16,13 +16,14 @@ from tenacity import ( from khoj.processor.conversation.utils import ( ResponseWithThought, - ToolDefinition, + ToolCall, commit_conversation_trace, - create_tool_definition, get_image_from_base64, get_image_from_url, ) from khoj.utils.helpers import ( + ToolDefinition, + create_tool_definition, get_anthropic_async_client, get_anthropic_client, get_chat_usage_metrics, @@ -60,7 +61,7 @@ def anthropic_completion_with_backoff( tools: List[ToolDefinition] = None, deepthought: bool = False, tracer: dict = {}, -) -> str: +) -> ResponseWithThought: client = anthropic_clients.get(api_key) if not client: client = get_anthropic_client(api_key, api_base_url) @@ -68,6 +69,7 @@ def anthropic_completion_with_backoff( formatted_messages, system = format_messages_for_anthropic(messages, system_prompt) + thoughts = "" aggregated_response = "" final_message = None model_kwargs = model_kwargs or dict() @@ -107,15 +109,30 @@ def anthropic_completion_with_backoff( max_tokens=max_tokens, **(model_kwargs), ) as stream: - for text in stream.text_stream: - aggregated_response += text + for chunk in stream: + if chunk.type != "content_block_delta": + continue + if chunk.delta.type == "thinking_delta": + thoughts += chunk.delta.thinking + elif chunk.delta.type == "text_delta": + aggregated_response += chunk.delta.text final_message = stream.get_final_message() - # Extract first tool call from final message - for item in final_message.content: - if item.type == "tool_use": - aggregated_response = json.dumps([{"name": item.name, "args": item.input}]) - break + # Extract all tool calls if tools are enabled + if tools: + tool_calls = [ + ToolCall(name=item.name, args=item.input, id=item.id).__dict__ + for item in final_message.content + if item.type == "tool_use" + ] + if tool_calls: + aggregated_response = json.dumps(tool_calls) + # If response schema is used, return the first tool call's input + elif response_schema: + for item in final_message.content: + if item.type == "tool_use": + aggregated_response = json.dumps(item.input) + break # Calculate cost of chat input_tokens = final_message.usage.input_tokens @@ -137,7 +154,7 @@ def anthropic_completion_with_backoff( if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) - return aggregated_response + return ResponseWithThought(response=aggregated_response, thought=thoughts) @retry( @@ -269,7 +286,34 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st # Convert image urls to base64 encoded images in Anthropic message format for message in messages: - if isinstance(message.content, list): + # Handle tool call and tool result message types from additional_kwargs + message_type = message.additional_kwargs.get("message_type") + if message_type == "tool_call": + # Convert tool_call to Anthropic tool_use format + content = [] + for part in message.content: + content.append( + { + "type": "tool_use", + "id": part.pop("id"), + "name": part.pop("name"), + "input": part, + } + ) + message.content = content + elif message_type == "tool_result": + # Convert tool_result to Anthropic tool_result format + content = [] + for part in message.content: + content.append( + { + "type": "tool_result", + "tool_use_id": part["id"], + "content": part["content"], + } + ) + message.content = content + elif isinstance(message.content, list): content = [] # Sort the content. Anthropic models prefer that text comes after images. message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 14eb5119..c0ee6d1a 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -23,12 +23,13 @@ from tenacity import ( from khoj.processor.conversation.utils import ( ResponseWithThought, - ToolDefinition, + ToolCall, commit_conversation_trace, get_image_from_base64, get_image_from_url, ) from khoj.utils.helpers import ( + ToolDefinition, get_chat_usage_metrics, get_gemini_client, is_none_or_empty, @@ -100,7 +101,7 @@ def gemini_completion_with_backoff( model_kwargs={}, deepthought=False, tracer={}, -) -> str: +) -> ResponseWithThought: client = gemini_clients.get(api_key) if not client: client = get_gemini_client(api_key, api_base_url) @@ -119,7 +120,7 @@ def gemini_completion_with_backoff( thinking_config = None if deepthought and is_reasoning_model(model_name): - thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI) + thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI, include_thoughts=True) max_output_tokens = MAX_OUTPUT_TOKENS_FOR_STANDARD_GEMINI if is_reasoning_model(model_name): @@ -145,12 +146,16 @@ def gemini_completion_with_backoff( response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages) if response.function_calls: function_calls = [ - {"name": function_call.name, "args": function_call.args} for function_call in response.function_calls + ToolCall(name=function_call.name, args=function_call.args, id=function_call.id).__dict__ + for function_call in response.function_calls ] response_text = json.dumps(function_calls) else: # If no function calls, use the text response response_text = response.text + response_thoughts = "\n".join( + [part.text for part in response.candidates[0].content.parts if part.thought and isinstance(part.text, str)] + ) except gerrors.ClientError as e: response = None response_text, _ = handle_gemini_response(e.args) @@ -164,8 +169,14 @@ def gemini_completion_with_backoff( input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0 output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0 thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0 + cache_read_tokens = response.usage_metadata.cached_content_token_count or 0 if response else 0 tracer["usage"] = get_chat_usage_metrics( - model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") + model_name, + input_tokens, + output_tokens, + cache_read_tokens=cache_read_tokens, + thought_tokens=thought_tokens, + usage=tracer.get("usage"), ) # Validate the response. If empty, raise an error to retry. @@ -179,7 +190,7 @@ def gemini_completion_with_backoff( if is_promptrace_enabled(): commit_conversation_trace(messages, response_text, tracer) - return response_text + return ResponseWithThought(response=response_text, thought=response_thoughts) @retry( @@ -359,8 +370,28 @@ def format_messages_for_gemini( system_prompt = None if is_none_or_empty(system_prompt) else system_prompt for message in messages: + if message.role == "assistant": + message.role = "model" + + # Handle tool call and tool result message types from additional_kwargs + message_type = message.additional_kwargs.get("message_type") + if message_type == "tool_call": + # Convert tool_call to Gemini function call format + tool_call_msg_content = [] + for part in message.content: + tool_call_msg_content.append(gtypes.Part.from_function_call(name=part["name"], args=part["args"])) + message.content = tool_call_msg_content + elif message_type == "tool_result": + # Convert tool_result to Gemini function response format + # Need to find the corresponding function call from previous messages + tool_result_msg_content = [] + for part in message.content: + tool_result_msg_content.append( + gtypes.Part.from_function_response(name=part["name"], response={"result": part["content"]}) + ) + message.content = tool_result_msg_content # Convert message content to string list from chatml dictionary list - if isinstance(message.content, list): + elif isinstance(message.content, list): # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) message_content = [] for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1): @@ -380,16 +411,13 @@ def format_messages_for_gemini( messages.remove(message) continue message.content = message_content - elif isinstance(message.content, str): + elif isinstance(message.content, str) and message.content.strip(): message.content = [gtypes.Part.from_text(text=message.content)] else: logger.error(f"Dropping invalid type: {type(message.content)} of message content: {message.content}") messages.remove(message) continue - if message.role == "assistant": - message.role = "model" - if len(messages) == 1: messages[0].role = "user" diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index ddcdc569..28893290 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -221,4 +221,4 @@ def send_message_to_model_offline( if is_promptrace_enabled(): commit_conversation_trace(messages, response_text, tracer) - return response_text + return ResponseWithThought(response=response_text) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index f123ad63..6a644074 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -15,11 +15,10 @@ from khoj.processor.conversation.utils import ( OperatorRun, ResponseWithThought, StructuredOutputSupport, - ToolDefinition, generate_chatml_messages_with_context, messages_to_print, ) -from khoj.utils.helpers import is_none_or_empty, truncate_code_context +from khoj.utils.helpers import ToolDefinition, is_none_or_empty, truncate_code_context from khoj.utils.rawconfig import FileAttachment, LocationData from khoj.utils.yaml import yaml_dump diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index b0c311b7..fd8edf51 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -35,10 +35,11 @@ from tenacity import ( from khoj.processor.conversation.utils import ( ResponseWithThought, StructuredOutputSupport, - ToolDefinition, + ToolCall, commit_conversation_trace, ) from khoj.utils.helpers import ( + ToolDefinition, convert_image_data_uri, get_chat_usage_metrics, get_openai_async_client, @@ -76,7 +77,7 @@ def completion_with_backoff( deepthought: bool = False, model_kwargs: dict = {}, tracer: dict = {}, -) -> str: +) -> ResponseWithThought: client_key = f"{openai_api_key}--{api_base_url}" client = openai_clients.get(client_key) if not client: @@ -121,6 +122,9 @@ def completion_with_backoff( if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) + tool_ids = [] + tool_calls: list[ToolCall] = [] + thoughts = "" aggregated_response = "" if stream: with client.beta.chat.completions.stream( @@ -134,9 +138,16 @@ def completion_with_backoff( if chunk.type == "content.delta": aggregated_response += chunk.delta elif chunk.type == "thought.delta": - pass + thoughts += chunk.delta + elif chunk.type == "chunk" and chunk.chunk.choices and chunk.chunk.choices[0].delta.tool_calls: + tool_ids += [tool_call.id for tool_call in chunk.chunk.choices[0].delta.tool_calls] elif chunk.type == "tool_calls.function.arguments.done": - aggregated_response = json.dumps([{"name": chunk.name, "args": chunk.arguments}]) + tool_calls += [ToolCall(name=chunk.name, args=json.loads(chunk.arguments), id=None)] + if tool_calls: + tool_calls = [ + ToolCall(name=chunk.name, args=chunk.args, id=tool_id) for chunk, tool_id in zip(tool_calls, tool_ids) + ] + aggregated_response = json.dumps([tool_call.__dict__ for tool_call in tool_calls]) else: # Non-streaming chat completion chunk = client.beta.chat.completions.parse( @@ -170,7 +181,7 @@ def completion_with_backoff( if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) - return aggregated_response + return ResponseWithThought(response=aggregated_response, thought=thoughts) @retry( @@ -354,6 +365,43 @@ def format_message_for_api(messages: List[ChatMessage], api_base_url: str) -> Li """ formatted_messages = [] for message in deepcopy(messages): + # Handle tool call and tool result message types + message_type = message.additional_kwargs.get("message_type") + if message_type == "tool_call": + # Convert tool_call to OpenAI function call format + content = [] + for part in message.content: + content.append( + { + "type": "function", + "id": part.get("id"), + "function": { + "name": part.get("name"), + "arguments": json.dumps(part.get("input", part.get("args", {}))), + }, + } + ) + formatted_messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": content, + } + ) + continue + if message_type == "tool_result": + # Convert tool_result to OpenAI tool result format + # Each part is a result for a tool call + for part in message.content: + formatted_messages.append( + { + "role": "tool", + "tool_call_id": part.get("id") or part.get("tool_use_id"), + "name": part.get("name"), + "content": part.get("content"), + } + ) + continue if isinstance(message.content, list) and not is_openai_api(api_base_url): assistant_texts = [] has_images = False diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index c094c5d3..94e9c4c1 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -715,15 +715,6 @@ Given the results of your previous iterations, which tool AI will you use next t """.strip() ) -previous_iteration = PromptTemplate.from_template( - """ -# Iteration {index}: -- tool: {tool} -- query: {query} -- result: {result} -""".strip() -) - pick_relevant_tools = PromptTemplate.from_template( """ You are Khoj, an extremely smart and helpful search assistant. diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 4b5f0f38..4a6d5a01 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -6,13 +6,11 @@ import mimetypes import os import re import uuid -from copy import deepcopy from dataclasses import dataclass from datetime import datetime from enum import Enum from io import BytesIO -from textwrap import dedent -from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Union import PIL.Image import pyjson5 @@ -139,11 +137,17 @@ class OperatorRun: } +class ToolCall: + def __init__(self, name: str, args: dict, id: str): + self.name = name + self.args = args + self.id = id + + class ResearchIteration: def __init__( self, - tool: str, - query: str, + query: ToolCall | dict | str, context: list = None, onlineContext: dict = None, codeContext: dict = None, @@ -151,8 +155,7 @@ class ResearchIteration: summarizedResult: str = None, warning: str = None, ): - self.tool = tool - self.query = query + self.query = ToolCall(**query) if isinstance(query, dict) else query self.context = context self.onlineContext = onlineContext self.codeContext = codeContext @@ -162,37 +165,43 @@ class ResearchIteration: def to_dict(self) -> dict: data = vars(self).copy() + data["query"] = self.query.__dict__ if isinstance(self.query, ToolCall) else self.query data["operatorContext"] = self.operatorContext.to_dict() if self.operatorContext else None return data def construct_iteration_history( previous_iterations: List[ResearchIteration], - previous_iteration_prompt: str, query: str = None, + query_images: List[str] = None, + query_files: str = None, ) -> list[ChatMessageModel]: iteration_history: list[ChatMessageModel] = [] - previous_iteration_messages: list[dict] = [] - for idx, iteration in enumerate(previous_iterations): - iteration_data = previous_iteration_prompt.format( - tool=iteration.tool, - query=iteration.query, - result=iteration.summarizedResult, - index=idx + 1, - ) + query_message_content = construct_structured_message(query, query_images, attached_file_context=query_files) + if query_message_content: + iteration_history.append(ChatMessageModel(by="you", message=query_message_content)) - previous_iteration_messages.append({"type": "text", "text": iteration_data}) - - if previous_iteration_messages: - if query: - iteration_history.append(ChatMessageModel(by="you", message=query)) - iteration_history.append( + for iteration in previous_iterations: + iteration_history += [ ChatMessageModel( by="khoj", - intent=Intent(type="remember", query=query), - message=previous_iteration_messages, - ) - ) + message=[iteration.query.__dict__], + intent=Intent(type="tool_call", query=query), + ), + ChatMessageModel( + by="you", + intent=Intent(type="tool_result"), + message=[ + { + "type": "tool_result", + "id": iteration.query.id, + "name": iteration.query.name, + "content": iteration.summarizedResult, + } + ], + ), + ] + return iteration_history @@ -319,18 +328,18 @@ def construct_tool_chat_history( # If no tool is provided, use inferred query extractor for the tool used in the iteration # Fallback to base extractor if the tool does not have an inferred query extractor inferred_query_extractor = extract_inferred_query_map.get( - tool or ConversationCommand(iteration.tool), base_extractor + tool or ConversationCommand(iteration.query.name), base_extractor ) chat_history += [ ChatMessageModel( by="you", - message=iteration.query, + message=yaml.dump(iteration.query.args, default_flow_style=False), ), ChatMessageModel( by="khoj", intent=Intent( type="remember", - query=iteration.query, + query=yaml.dump(iteration.query.args, default_flow_style=False), inferred_queries=inferred_query_extractor(iteration), memory_type="notes", ), @@ -483,28 +492,32 @@ Khoj: "{chat_response}" def construct_structured_message( message: list[dict] | str, - images: list[str], - model_type: str, - vision_enabled: bool, + images: list[str] = None, + model_type: str = None, + vision_enabled: bool = True, attached_file_context: str = None, ): """ - Format messages into appropriate multimedia format for supported chat model types + Format messages into appropriate multimedia format for supported chat model types. + + Assume vision is enabled and chat model provider supports messages in chatml format, unless specified otherwise. """ - if model_type in [ + if not model_type or model_type in [ ChatModel.ModelType.OPENAI, ChatModel.ModelType.GOOGLE, ChatModel.ModelType.ANTHROPIC, ]: - constructed_messages: List[dict[str, Any]] = ( - [{"type": "text", "text": message}] if isinstance(message, str) else message - ) - + constructed_messages: List[dict[str, Any]] = [] + if not is_none_or_empty(message): + constructed_messages += [{"type": "text", "text": message}] if isinstance(message, str) else message + # Drop image message passed by caller if chat model does not have vision enabled + if not vision_enabled: + constructed_messages = [m for m in constructed_messages if m.get("type") != "image_url"] if not is_none_or_empty(attached_file_context): - constructed_messages.append({"type": "text", "text": attached_file_context}) + constructed_messages += [{"type": "text", "text": attached_file_context}] if vision_enabled and images: for image in images: - constructed_messages.append({"type": "image_url", "image_url": {"url": image}}) + constructed_messages += [{"type": "image_url", "image_url": {"url": image}}] return constructed_messages message = message if isinstance(message, str) else "\n\n".join(m["text"] for m in message) @@ -640,7 +653,11 @@ def generate_chatml_messages_with_context( chat_message, chat.images if role == "user" else [], model_type, vision_enabled ) - reconstructed_message = ChatMessage(content=message_content, role=role) + reconstructed_message = ChatMessage( + content=message_content, + role=role, + additional_kwargs={"message_type": chat.intent.type if chat.intent else None}, + ) chatml_messages.insert(0, reconstructed_message) if len(chatml_messages) >= 3 * lookback_turns: @@ -739,10 +756,21 @@ def count_tokens( message_content_parts: list[str] = [] # Collate message content into single string to ease token counting for part in message_content: - if isinstance(part, dict) and part.get("type") == "text": - message_content_parts.append(part["text"]) - elif isinstance(part, dict) and part.get("type") == "image_url": + if isinstance(part, dict) and part.get("type") == "image_url": image_count += 1 + elif isinstance(part, dict) and part.get("type") == "text": + message_content_parts.append(part["text"]) + elif isinstance(part, dict) and hasattr(part, "model_dump"): + message_content_parts.append(json.dumps(part.model_dump())) + elif isinstance(part, dict) and hasattr(part, "__dict__"): + message_content_parts.append(json.dumps(part.__dict__)) + elif isinstance(part, dict): + # If part is a dict but not a recognized type, convert to JSON string + try: + message_content_parts.append(json.dumps(part)) + except (TypeError, ValueError) as e: + logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.") + image_count += 1 # Treat as an image/binary if serialization fails elif isinstance(part, str): message_content_parts.append(part) else: @@ -1162,82 +1190,3 @@ class ResponseWithThought: def __init__(self, response: str = None, thought: str = None): self.response = response self.thought = thought - - -class ToolDefinition: - def __init__(self, name: str, description: str, schema: dict): - self.name = name - self.description = description - self.schema = schema - - -def create_tool_definition( - schema: Type[BaseModel], - name: str = None, - description: Optional[str] = None, -) -> ToolDefinition: - """ - Converts a response schema BaseModel class into a normalized tool definition. - - A standard AI provider agnostic tool format to specify tools the model can use. - Common logic used across models is kept here. AI provider specific adaptations - should be handled in provider code. - - Args: - response_schema: The Pydantic BaseModel class to convert. - This class defines the response schema for the tool. - tool_name: The name for the AI model tool (e.g., "get_weather", "plan_next_step"). - tool_description: Optional description for the AI model tool. - If None, it attempts to use the Pydantic model's docstring. - If that's also missing, a fallback description is generated. - - Returns: - A normalized tool definition for AI model APIs. - """ - raw_schema_dict = schema.model_json_schema() - - name = name or schema.__name__.lower() - description = description - if description is None: - docstring = schema.__doc__ - if docstring: - description = dedent(docstring).strip() - else: - # Fallback description if no explicit one or docstring is provided - description = f"Tool named '{name}' accepts specified parameters." - - # Process properties to inline enums and remove $defs dependency - processed_properties = {} - original_properties = raw_schema_dict.get("properties", {}) - defs = raw_schema_dict.get("$defs", {}) - - for prop_name, prop_schema in original_properties.items(): - current_prop_schema = deepcopy(prop_schema) # Work on a copy - # Check for enums defined directly in the property for simpler direct enum definitions. - if "$ref" in current_prop_schema: - ref_path = current_prop_schema["$ref"] - if ref_path.startswith("#/$defs/"): - def_name = ref_path.split("/")[-1] - if def_name in defs and "enum" in defs[def_name]: - enum_def = defs[def_name] - current_prop_schema["enum"] = enum_def["enum"] - current_prop_schema["type"] = enum_def.get("type", "string") - if "description" not in current_prop_schema and "description" in enum_def: - current_prop_schema["description"] = enum_def["description"] - del current_prop_schema["$ref"] # Remove the $ref as it's been inlined - - processed_properties[prop_name] = current_prop_schema - - # Generate the compiled schema dictionary for the tool definition. - compiled_schema = { - "type": "object", - "properties": processed_properties, - # Generate content in the order in which the schema properties were defined - "property_ordering": list(schema.model_fields.keys()), - } - - # Include 'required' fields if specified in the Pydantic model - if "required" in raw_schema_dict and raw_schema_dict["required"]: - compiled_schema["required"] = raw_schema_dict["required"] - - return ToolDefinition(name=name, description=description, schema=compiled_schema) diff --git a/src/khoj/processor/operator/grounding_agent.py b/src/khoj/processor/operator/grounding_agent.py index 3697391e..16f2d510 100644 --- a/src/khoj/processor/operator/grounding_agent.py +++ b/src/khoj/processor/operator/grounding_agent.py @@ -73,7 +73,7 @@ class GroundingAgent: grounding_user_prompt = self.get_instruction(instruction, self.environment_type) screenshots = [f"data:image/webp;base64,{current_state.screenshot}"] grounding_messages_content = construct_structured_message( - grounding_user_prompt, screenshots, self.model.name, vision_enabled=True + grounding_user_prompt, screenshots, self.model.model_type, vision_enabled=True ) return [{"role": "user", "content": grounding_messages_content}] diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index 8106e9cc..4f78cac7 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -121,7 +121,7 @@ class BinaryOperatorAgent(OperatorAgent): # Construct input for visual reasoner history visual_reasoner_history = self._format_message_for_api(self.messages) try: - natural_language_action = await send_message_to_model_wrapper( + raw_response = await send_message_to_model_wrapper( query=query_text, query_images=query_screenshot, system_message=reasoning_system_prompt, @@ -129,6 +129,7 @@ class BinaryOperatorAgent(OperatorAgent): agent_chat_model=self.reasoning_model, tracer=self.tracer, ) + natural_language_action = raw_response.response if not isinstance(natural_language_action, str) or not natural_language_action.strip(): raise ValueError(f"Natural language action is empty or not a string. Got {natural_language_action}") @@ -255,10 +256,10 @@ class BinaryOperatorAgent(OperatorAgent): # Append summary messages to history trigger_summary = AgentMessage(role="user", content=summarize_prompt) - summary_message = AgentMessage(role="assistant", content=summary) + summary_message = AgentMessage(role="assistant", content=summary.response) self.messages.extend([trigger_summary, summary_message]) - return summary + return summary.response def _compile_response(self, response_content: str | List) -> str: """Compile response content into a string, handling OpenAI message structures.""" diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 78a4117f..5261af51 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -390,7 +390,25 @@ async def read_webpages( query_files=query_files, tracer=tracer, ) + async for result in read_webpages_content( + query, + urls, + user, + send_status_func=send_status_func, + agent=agent, + tracer=tracer, + ): + yield result + +async def read_webpages_content( + query: str, + urls: List[str], + user: KhojUser, + send_status_func: Optional[Callable] = None, + agent: Agent = None, + tracer: dict = {}, +): logger.info(f"Reading web pages at: {urls}") if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(urls)) diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index d9a80f71..e481b955 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -161,7 +161,7 @@ async def generate_python_code( ) # Extract python code wrapped in markdown code blocks from the response - code_blocks = re.findall(r"```(?:python)?\n(.*?)```", response, re.DOTALL) + code_blocks = re.findall(r"```(?:python)?\n(.*?)```", response.response, re.DOTALL) if not code_blocks: raise ValueError("No Python code blocks found in response") diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 5b0b83e3..5a002e41 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -101,7 +101,6 @@ from khoj.processor.conversation.utils import ( OperatorRun, ResearchIteration, ResponseWithThought, - ToolDefinition, clean_json, clean_mermaidjs, construct_chat_history, @@ -121,6 +120,7 @@ from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.helpers import ( LRU, ConversationCommand, + ToolDefinition, get_file_type, in_debug_mode, is_none_or_empty, @@ -304,7 +304,7 @@ async def acreate_title_from_history( with timer("Chat actor: Generate title from conversation history", logger): response = await send_message_to_model_wrapper(title_generation_prompt, user=user) - return response.strip() + return response.response.strip() async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: @@ -316,7 +316,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: with timer("Chat actor: Generate title from query", logger): response = await send_message_to_model_wrapper(title_generation_prompt, user=user) - return response.strip() + return response.response.strip() async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: bool = False) -> Tuple[bool, str]: @@ -340,7 +340,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck ) - response = response.strip() + response = response.response.strip() try: response = json.loads(clean_json(response)) is_safe = str(response.get("safe", "true")).lower() == "true" @@ -419,7 +419,7 @@ async def aget_data_sources_and_output_format( output: str with timer("Chat actor: Infer information sources to refer", logger): - response = await send_message_to_model_wrapper( + raw_response = await send_message_to_model_wrapper( relevant_tools_prompt, response_type="json_object", response_schema=PickTools, @@ -430,7 +430,7 @@ async def aget_data_sources_and_output_format( ) try: - response = clean_json(response) + response = clean_json(raw_response.response) response = json.loads(response) chosen_sources = [s.strip() for s in response.get("source", []) if s.strip()] @@ -507,7 +507,7 @@ async def infer_webpage_urls( links: List[str] = Field(..., min_items=1, max_items=max_webpages) with timer("Chat actor: Infer webpage urls to read", logger): - response = await send_message_to_model_wrapper( + raw_response = await send_message_to_model_wrapper( online_queries_prompt, query_images=query_images, response_type="json_object", @@ -520,7 +520,7 @@ async def infer_webpage_urls( # Validate that the response is a non-empty, JSON-serializable list of URLs try: - response = clean_json(response) + response = clean_json(raw_response.response) urls = json.loads(response) valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)} if is_none_or_empty(valid_unique_urls): @@ -572,7 +572,7 @@ async def generate_online_subqueries( queries: List[str] = Field(..., min_items=1, max_items=max_queries) with timer("Chat actor: Generate online search subqueries", logger): - response = await send_message_to_model_wrapper( + raw_response = await send_message_to_model_wrapper( online_queries_prompt, query_images=query_images, response_type="json_object", @@ -585,7 +585,7 @@ async def generate_online_subqueries( # Validate that the response is a non-empty, JSON-serializable list try: - response = clean_json(response) + response = clean_json(raw_response.response) response = pyjson5.loads(response) response = {q.strip() for q in response["queries"] if q.strip()} if not isinstance(response, set) or not response or len(response) == 0: @@ -646,7 +646,7 @@ async def aschedule_query( # Validate that the response is a non-empty, JSON-serializable list try: - raw_response = raw_response.strip() + raw_response = raw_response.response.strip() response: Dict[str, str] = json.loads(clean_json(raw_response)) if not response or not isinstance(response, Dict) or len(response) != 3: raise AssertionError(f"Invalid response for scheduling query : {response}") @@ -684,7 +684,7 @@ async def extract_relevant_info( agent_chat_model=agent_chat_model, tracer=tracer, ) - return response.strip() + return response.response.strip() async def extract_relevant_summary( @@ -727,7 +727,7 @@ async def extract_relevant_summary( agent_chat_model=agent_chat_model, tracer=tracer, ) - return response.strip() + return response.response.strip() async def generate_summary_from_files( @@ -898,7 +898,7 @@ async def generate_better_diagram_description( agent_chat_model=agent_chat_model, tracer=tracer, ) - response = response.strip() + response = response.response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): response = response[1:-1] @@ -926,10 +926,10 @@ async def generate_excalidraw_diagram_from_description( raw_response = await send_message_to_model_wrapper( query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer ) - raw_response = clean_json(raw_response) + raw_response_text = clean_json(raw_response.response) try: # Expect response to have `elements` and `scratchpad` keys - response: Dict[str, str] = json.loads(raw_response) + response: Dict[str, str] = json.loads(raw_response_text) if ( not response or not isinstance(response, Dict) @@ -938,7 +938,7 @@ async def generate_excalidraw_diagram_from_description( ): raise AssertionError(f"Invalid response for generating Excalidraw diagram: {response}") except Exception: - raise AssertionError(f"Invalid response for generating Excalidraw diagram: {raw_response}") + raise AssertionError(f"Invalid response for generating Excalidraw diagram: {raw_response_text}") if not response or not isinstance(response["elements"], List) or not isinstance(response["elements"][0], Dict): # TODO Some additional validation here that it's a valid Excalidraw diagram raise AssertionError(f"Invalid response for improving diagram description: {response}") @@ -1049,11 +1049,11 @@ async def generate_better_mermaidjs_diagram_description( agent_chat_model=agent_chat_model, tracer=tracer, ) - response = response.strip() - if response.startswith(('"', "'")) and response.endswith(('"', "'")): - response = response[1:-1] + response_text = response.response.strip() + if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")): + response_text = response_text[1:-1] - return response + return response_text async def generate_mermaidjs_diagram_from_description( @@ -1077,7 +1077,7 @@ async def generate_mermaidjs_diagram_from_description( raw_response = await send_message_to_model_wrapper( query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer ) - return clean_mermaidjs(raw_response.strip()) + return clean_mermaidjs(raw_response.response.strip()) async def generate_better_image_prompt( @@ -1152,11 +1152,11 @@ async def generate_better_image_prompt( agent_chat_model=agent_chat_model, tracer=tracer, ) - response = response.strip() - if response.startswith(('"', "'")) and response.endswith(('"', "'")): - response = response[1:-1] + response_text = response.response.strip() + if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")): + response_text = response_text[1:-1] - return response + return response_text async def search_documents( @@ -1330,7 +1330,7 @@ async def extract_questions( # Extract questions from the response try: - response = clean_json(raw_response) + response = clean_json(raw_response.response) response = pyjson5.loads(response) queries = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(queries, list) or not queries: diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index f7467b9f..aa9cdbcc 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -3,11 +3,9 @@ import logging import os from copy import deepcopy from datetime import datetime -from enum import Enum -from typing import Callable, Dict, List, Optional, Type +from typing import Callable, Dict, List, Optional import yaml -from pydantic import BaseModel, Field from khoj.database.adapters import AgentAdapters, EntryAdapters from khoj.database.models import Agent, ChatMessageModel, KhojUser @@ -15,14 +13,13 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( OperatorRun, ResearchIteration, - ToolDefinition, + ToolCall, construct_iteration_history, construct_tool_chat_history, - create_tool_definition, load_complex_json, ) from khoj.processor.operator import operate_environment -from khoj.processor.tools.online_search import read_webpages, search_online +from khoj.processor.tools.online_search import read_webpages_content, search_online from khoj.processor.tools.run_code import run_code from khoj.routers.helpers import ( ChatEvent, @@ -32,10 +29,12 @@ from khoj.routers.helpers import ( ) from khoj.utils.helpers import ( ConversationCommand, + ToolDefinition, + dict_to_tuple, is_none_or_empty, is_operator_enabled, timer, - tool_description_for_research_llm, + tools_for_research_llm, truncate_code_context, ) from khoj.utils.rawconfig import LocationData @@ -43,47 +42,6 @@ from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) -class PlanningResponse(BaseModel): - """ - Schema for the response from planning agent when deciding the next tool to pick. - """ - - scratchpad: str = Field(..., description="Scratchpad to reason about which tool to use next") - - class Config: - arbitrary_types_allowed = True - - @classmethod - def create_model_with_enum(cls: Type["PlanningResponse"], tool_options: dict) -> Type["PlanningResponse"]: - """ - Factory method that creates a customized PlanningResponse model - with a properly typed tool field based on available tools. - - The tool field is dynamically generated based on available tools. - The query field should be filled by the model after the tool field for a more logical reasoning flow. - - Args: - tool_options: Dictionary mapping tool names to values - - Returns: - A customized PlanningResponse class - """ - # Create dynamic enum from tool options - tool_enum = Enum("ToolEnum", tool_options) # type: ignore - - # Create and return a customized response model with the enum - class PlanningResponseWithTool(PlanningResponse): - """ - Use the scratchpad to reason about which tool to use next and the query to send to the tool. - Pick tool from provided options and your query to send to the tool. - """ - - tool: tool_enum = Field(..., description="Name of the tool to use") - query: str = Field(..., description="Detailed query for the selected tool") - - return PlanningResponseWithTool - - async def apick_next_tool( query: str, conversation_history: List[ChatMessageModel], @@ -106,12 +64,13 @@ async def apick_next_tool( # Continue with previous iteration if a multi-step tool use is in progress if ( previous_iterations - and previous_iterations[-1].tool == ConversationCommand.Operator + and previous_iterations[-1].query + and isinstance(previous_iterations[-1].query, ToolCall) + and previous_iterations[-1].query.name == ConversationCommand.Operator and not previous_iterations[-1].summarizedResult ): previous_iteration = previous_iterations[-1] yield ResearchIteration( - tool=previous_iteration.tool, query=query, context=previous_iteration.context, onlineContext=previous_iteration.onlineContext, @@ -122,37 +81,35 @@ async def apick_next_tool( return # Construct tool options for the agent to choose from - tool_options = dict() + tools = [] tool_options_str = "" agent_tools = agent.input_tools if agent else [] user_has_entries = await EntryAdapters.auser_has_entries(user) - for tool, description in tool_description_for_research_llm.items(): + for tool, tool_data in tools_for_research_llm.items(): # Skip showing operator tool as an option if not enabled if tool == ConversationCommand.Operator and not is_operator_enabled(): continue # Skip showing Notes tool as an option if user has no entries - if tool == ConversationCommand.Notes: + elif tool == ConversationCommand.Notes: if not user_has_entries: continue - description = description.format(max_search_queries=max_document_searches) - if tool == ConversationCommand.Webpage: - description = description.format(max_webpages_to_read=max_webpages_to_read) - if tool == ConversationCommand.Online: - description = description.format(max_search_queries=max_online_searches) + description = tool_data.description.format(max_search_queries=max_document_searches) + elif tool == ConversationCommand.Webpage: + description = tool_data.description.format(max_webpages_to_read=max_webpages_to_read) + elif tool == ConversationCommand.Online: + description = tool_data.description.format(max_search_queries=max_online_searches) + else: + description = tool_data.description # Add tool if agent does not have any tools defined or the tool is supported by the agent. if len(agent_tools) == 0 or tool.value in agent_tools: - tool_options[tool.name] = tool.value tool_options_str += f'- "{tool.value}": "{description}"\n' - - # Create planning reponse model with dynamically populated tool enum class - planning_response_model = PlanningResponse.create_model_with_enum(tool_options) - tools = [ - create_tool_definition( - name="infer_information_sources_to_refer", - description="Infer which tool to use next and the query to send to the tool.", - schema=planning_response_model, - ) - ] + tools.append( + ToolDefinition( + name=tool.value, + description=description, + schema=tool_data.schema, + ) + ) today = datetime.today() location_data = f"{location}" if location else "Unknown" @@ -171,24 +128,16 @@ async def apick_next_tool( max_iterations=max_iterations, ) - if query_images: - query = f"[placeholder for user attached images]\n{query}" - # Construct chat history with user and iteration history with researcher agent for context - iteration_chat_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query) + iteration_chat_history = construct_iteration_history(previous_iterations, query, query_images, query_files) chat_and_research_history = conversation_history + iteration_chat_history - # Plan function execution for the next tool - query = prompts.plan_function_execution_next_tool.format(query=query) if iteration_chat_history else query - try: with timer("Chat actor: Infer information sources to refer", logger): raw_response = await send_message_to_model_wrapper( - query=query, + query="", system_message=function_planning_prompt, chat_history=chat_and_research_history, - response_type="json_object", - response_schema=planning_response_model, tools=tools, deepthought=True, user=user, @@ -200,7 +149,6 @@ async def apick_next_tool( except Exception as e: logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True) yield ResearchIteration( - tool=None, query=None, warning="Failed to infer information sources to refer. Skipping iteration. Try again.", ) @@ -208,40 +156,32 @@ async def apick_next_tool( try: # Try parse the response as function call response to infer next tool to use. - response = load_complex_json(load_complex_json(raw_response)[0]["args"]) + # TODO: Handle multiple tool calls. + response_text = raw_response.response + parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0] except Exception as e: - try: - # Else try parse the text response as JSON to infer next tool to use. - response = load_complex_json(raw_response) - except Exception as e: - # Otherwise assume the model has decided to end the research run and respond to the user. - response = {"tool": ConversationCommand.Text, "query": None, "scratchpad": raw_response} + # Otherwise assume the model has decided to end the research run and respond to the user. + parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None) # If we have a valid response, extract the tool and query. - selected_tool = response.get("tool", None) - generated_query = response.get("query", None) - scratchpad = response.get("scratchpad", None) - warning = None - logger.info(f"Response for determining relevant tools: {response}") + logger.info(f"Response for determining relevant tools: {parsed_response.name}({parsed_response.args})") # Detect selection of previously used query, tool combination. - previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations if i.warning is None} - if (selected_tool, generated_query) in previous_tool_query_combinations: + previous_tool_query_combinations = { + (i.query.name, dict_to_tuple(i.query.args)) + for i in previous_iterations + if i.warning is None and isinstance(i.query, ToolCall) + } + if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations: warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different." - # Only send client status updates if we'll execute this iteration - elif send_status_func and scratchpad: - determined_tool_message = "**Determined Tool**: " - determined_tool_message += ( - f"{selected_tool}({generated_query})." if selected_tool != ConversationCommand.Text else "respond." - ) - determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else "" - async for event in send_status_func(f"{scratchpad}"): + # Only send client status updates if we'll execute this iteration and model has thoughts to share. + elif send_status_func and not is_none_or_empty(raw_response.thought): + async for event in send_status_func(raw_response.thought): yield {ChatEvent.STATUS: event} yield ResearchIteration( - tool=selected_tool, - query=generated_query, + query=parsed_response, warning=warning, ) @@ -269,10 +209,10 @@ async def research( MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5)) # Incorporate previous partial research into current research chat history - research_conversation_history = deepcopy(conversation_history) + research_conversation_history = [chat for chat in deepcopy(conversation_history) if chat.message] if current_iteration := len(previous_iterations) > 0: logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.") - previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) + previous_iterations_history = construct_iteration_history(previous_iterations) research_conversation_history += previous_iterations_history while current_iteration < MAX_ITERATIONS: @@ -285,7 +225,7 @@ async def research( code_results: Dict = dict() document_results: List[Dict[str, str]] = [] operator_results: OperatorRun = None - this_iteration = ResearchIteration(tool=None, query=query) + this_iteration = ResearchIteration(query=query) async for result in apick_next_tool( query, @@ -315,26 +255,30 @@ async def research( logger.warning(f"Research mode: {this_iteration.warning}.") # Terminate research if selected text tool or query, tool not set for next iteration - elif not this_iteration.query or not this_iteration.tool or this_iteration.tool == ConversationCommand.Text: + elif ( + not this_iteration.query + or isinstance(this_iteration.query, str) + or this_iteration.query.name == ConversationCommand.Text + ): current_iteration = MAX_ITERATIONS - elif this_iteration.tool == ConversationCommand.Notes: + elif this_iteration.query.name == ConversationCommand.Notes: this_iteration.context = [] document_results = [] previous_inferred_queries = { c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context } async for result in search_documents( - this_iteration.query, - max_document_searches, - None, - user, - construct_tool_chat_history(previous_iterations, ConversationCommand.Notes), - conversation_id, - [ConversationCommand.Default], - location, - send_status_func, - query_images, + **this_iteration.query.args, + n=max_document_searches, + d=None, + user=user, + chat_history=construct_tool_chat_history(previous_iterations, ConversationCommand.Notes), + conversation_id=conversation_id, + conversation_commands=[ConversationCommand.Default], + location_data=location, + send_status_func=send_status_func, + query_images=query_images, previous_inferred_queries=previous_inferred_queries, agent=agent, tracer=tracer, @@ -362,7 +306,7 @@ async def research( else: this_iteration.warning = "No matching document references found" - elif this_iteration.tool == ConversationCommand.Online: + elif this_iteration.query.name == ConversationCommand.Online: previous_subqueries = { subquery for iteration in previous_iterations @@ -371,12 +315,12 @@ async def research( } try: async for result in search_online( - this_iteration.query, - construct_tool_chat_history(previous_iterations, ConversationCommand.Online), - location, - user, - send_status_func, - [], + **this_iteration.query.args, + conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.Online), + location=location, + user=user, + send_status_func=send_status_func, + custom_filters=[], max_online_searches=max_online_searches, max_webpages_to_read=0, query_images=query_images, @@ -395,19 +339,15 @@ async def research( this_iteration.warning = f"Error searching online: {e}" logger.error(this_iteration.warning, exc_info=True) - elif this_iteration.tool == ConversationCommand.Webpage: + elif this_iteration.query.name == ConversationCommand.Webpage: try: - async for result in read_webpages( - this_iteration.query, - construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage), - location, - user, - send_status_func, - max_webpages_to_read=max_webpages_to_read, - query_images=query_images, + async for result in read_webpages_content( + **this_iteration.query.args, + user=user, + send_status_func=send_status_func, + # max_webpages_to_read=max_webpages_to_read, agent=agent, tracer=tracer, - query_files=query_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -428,15 +368,15 @@ async def research( this_iteration.warning = f"Error reading webpages: {e}" logger.error(this_iteration.warning, exc_info=True) - elif this_iteration.tool == ConversationCommand.Code: + elif this_iteration.query.name == ConversationCommand.Code: try: async for result in run_code( - this_iteration.query, - construct_tool_chat_history(previous_iterations, ConversationCommand.Code), - "", - location, - user, - send_status_func, + **this_iteration.query.args, + conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.Code), + context="", + location_data=location, + user=user, + send_status_func=send_status_func, query_images=query_images, agent=agent, query_files=query_files, @@ -453,14 +393,14 @@ async def research( this_iteration.warning = f"Error running code: {e}" logger.warning(this_iteration.warning, exc_info=True) - elif this_iteration.tool == ConversationCommand.Operator: + elif this_iteration.query.name == ConversationCommand.Operator: try: async for result in operate_environment( - this_iteration.query, - user, - construct_tool_chat_history(previous_iterations, ConversationCommand.Operator), - location, - previous_iterations[-1].operatorContext if previous_iterations else None, + **this_iteration.query.args, + user=user, + conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator), + location_data=location, + previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None, send_status_func=send_status_func, query_images=query_images, agent=agent, @@ -493,7 +433,7 @@ async def research( current_iteration += 1 if document_results or online_results or code_results or operator_results or this_iteration.warning: - results_data = f"\n{current_iteration}\n{this_iteration.tool}\n{this_iteration.query}\n" + results_data = f"\n" if document_results: results_data += f"\n\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if online_results: @@ -506,7 +446,7 @@ async def research( ) if this_iteration.warning: results_data += f"\n\n{this_iteration.warning}\n" - results_data += "\n\n" + results_data += f"\n\n" # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) this_iteration.summarizedResult = results_data diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index cdbed998..e4a02320 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -12,6 +12,7 @@ import random import urllib.parse import uuid from collections import OrderedDict +from copy import deepcopy from enum import Enum from functools import lru_cache from importlib import import_module @@ -19,8 +20,9 @@ from importlib.metadata import version from itertools import islice from os import path from pathlib import Path +from textwrap import dedent from time import perf_counter -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Type, Union from urllib.parse import ParseResult, urlparse import anthropic @@ -36,6 +38,7 @@ from google.auth.credentials import Credentials from google.oauth2 import service_account from magika import Magika from PIL import Image +from pydantic import BaseModel from pytz import country_names, country_timezones from khoj.utils import constants @@ -334,6 +337,85 @@ def is_e2b_code_sandbox_enabled(): return not is_none_or_empty(os.getenv("E2B_API_KEY")) +class ToolDefinition: + def __init__(self, name: str, description: str, schema: dict): + self.name = name + self.description = description + self.schema = schema + + +def create_tool_definition( + schema: Type[BaseModel], + name: str = None, + description: Optional[str] = None, +) -> ToolDefinition: + """ + Converts a response schema BaseModel class into a normalized tool definition. + + A standard AI provider agnostic tool format to specify tools the model can use. + Common logic used across models is kept here. AI provider specific adaptations + should be handled in provider code. + + Args: + response_schema: The Pydantic BaseModel class to convert. + This class defines the response schema for the tool. + tool_name: The name for the AI model tool (e.g., "get_weather", "plan_next_step"). + tool_description: Optional description for the AI model tool. + If None, it attempts to use the Pydantic model's docstring. + If that's also missing, a fallback description is generated. + + Returns: + A normalized tool definition for AI model APIs. + """ + raw_schema_dict = schema.model_json_schema() + + name = name or schema.__name__.lower() + description = description + if description is None: + docstring = schema.__doc__ + if docstring: + description = dedent(docstring).strip() + else: + # Fallback description if no explicit one or docstring is provided + description = f"Tool named '{name}' accepts specified parameters." + + # Process properties to inline enums and remove $defs dependency + processed_properties = {} + original_properties = raw_schema_dict.get("properties", {}) + defs = raw_schema_dict.get("$defs", {}) + + for prop_name, prop_schema in original_properties.items(): + current_prop_schema = deepcopy(prop_schema) # Work on a copy + # Check for enums defined directly in the property for simpler direct enum definitions. + if "$ref" in current_prop_schema: + ref_path = current_prop_schema["$ref"] + if ref_path.startswith("#/$defs/"): + def_name = ref_path.split("/")[-1] + if def_name in defs and "enum" in defs[def_name]: + enum_def = defs[def_name] + current_prop_schema["enum"] = enum_def["enum"] + current_prop_schema["type"] = enum_def.get("type", "string") + if "description" not in current_prop_schema and "description" in enum_def: + current_prop_schema["description"] = enum_def["description"] + del current_prop_schema["$ref"] # Remove the $ref as it's been inlined + + processed_properties[prop_name] = current_prop_schema + + # Generate the compiled schema dictionary for the tool definition. + compiled_schema = { + "type": "object", + "properties": processed_properties, + # Generate content in the order in which the schema properties were defined + "property_ordering": list(schema.model_fields.keys()), + } + + # Include 'required' fields if specified in the Pydantic model + if "required" in raw_schema_dict and raw_schema_dict["required"]: + compiled_schema["required"] = raw_schema_dict["required"] + + return ToolDefinition(name=name, description=description, schema=compiled_schema) + + class ConversationCommand(str, Enum): Default = "default" General = "general" @@ -385,13 +467,84 @@ tool_descriptions_for_llm = { ConversationCommand.Operator: "To use when you need to operate a computer to complete the task.", } -tool_description_for_research_llm = { - ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents. Max {max_search_queries} search queries allowed per iteration.", - ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed. Max {max_search_queries} search queries allowed per iteration.", - ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share upto {max_webpages_to_read} webpage links and what information to extract from them in your query.", - ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description, - ConversationCommand.Text: "To respond to the user once you've completed your research and have the required information.", - ConversationCommand.Operator: "To operate a computer to complete the task.", +tools_for_research_llm = { + ConversationCommand.Notes: ToolDefinition( + name="notes", + description="To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents. Max {max_search_queries} search queries allowed per iteration.", + schema={ + "type": "object", + "properties": { + "q": { + "type": "string", + "description": "The query to search in the user's personal knowledge base.", + }, + }, + "required": ["q"], + }, + ), + ConversationCommand.Online: ToolDefinition( + name="online", + description="To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed. Max {max_search_queries} search queries allowed per iteration.", + schema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The query to search on the internet.", + }, + }, + "required": ["query"], + }, + ), + ConversationCommand.Webpage: ToolDefinition( + name="webpage", + description="To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share upto {max_webpages_to_read} webpage links and what information to extract from them in your query.", + schema={ + "type": "object", + "properties": { + "urls": { + "type": "array", + "items": { + "type": "string", + }, + "description": "The webpage URLs to extract information from.", + }, + "query": { + "type": "string", + "description": "The query to extract information from the webpages.", + }, + }, + "required": ["urls", "query"], + }, + ), + ConversationCommand.Code: ToolDefinition( + name="code", + description=e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description, + schema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Detailed query and all input data required to generate, execute code in the sandbox.", + }, + }, + "required": ["query"], + }, + ), + ConversationCommand.Operator: ToolDefinition( + name="operator", + description="To operate a computer to complete the task.", + schema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The task to perform on the computer.", + }, + }, + "required": ["query"], + }, + ), } mode_descriptions_for_llm = { @@ -850,3 +1003,13 @@ def clean_object_for_db(data): return [clean_object_for_db(item) for item in data] else: return data + + +def dict_to_tuple(d): + # Recursively convert dicts to sorted tuples for hashability + if isinstance(d, dict): + return tuple(sorted((k, dict_to_tuple(v)) for k, v in d.items())) + elif isinstance(d, list): + return tuple(dict_to_tuple(i) for i in d) + else: + return d