mirror of
https://github.com/khoaliber/khoj.git
synced 2026-04-28 00:19:25 +00:00
Upgrade Retrieval from KB in Research Mode. Use Function Calling for Tool Use (#1205)
## Why Move to function calling paradigm to give models tool call -> tool result in formats they're fine-tuned to understand. Previously we were giving them results in our specific format (as function calling paradigm wasn't well-established yet). And improve prompt cache hits by caching tool definitions. This is a **breaking change**. AI Models and APIs that do not support function calling will not work with Khoj in research mode. Function calling is supported by: - Standard commercial AI Models and APIs like Anthropic, Gemini, OpenAI, OpenRouter - Standard open-source AI APIs like llama.cpp server, Ollama - Standard open source models like Qwen, DeepSeek, Gemma, Llama, Mistral ## What ### Use Function Calling for Tool Use - Add Function Calling support to Anthropic, Gemini, OpenAI AI Model APIs - Move Existing Research Mode Tools to Use Function Calling ### Get More Comprehensive Results from your Knowledge Base (KB) - Give Research Agent better Document Retrieval Tools - Add grep files tool to enable researcher to find documents by regex - Add list files tool to enable researcher to find documents by path - Add file viewer tool to enable researcher to read documents ### Miscellaneous - Improve Research Prompt, Truncation, Retry and Caching - Show reasoning model thoughts in Khoj train of thought for intermediate steps as well
This commit is contained in:
@@ -1716,6 +1716,14 @@ class FileObjectAdapters:
|
||||
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
|
||||
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aget_file_objects_by_path_prefix(user: KhojUser, path_prefix: str, agent: Agent = None):
|
||||
"""Get file objects from the database by path prefix."""
|
||||
return await sync_to_async(list)(
|
||||
FileObject.objects.filter(user=user, agent=agent, file_name__startswith=path_prefix)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
|
||||
@@ -1748,6 +1756,18 @@ class FileObjectAdapters:
|
||||
async def adelete_all_file_objects(user: KhojUser):
|
||||
return await FileObject.objects.filter(user=user).adelete()
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aget_file_objects_by_regex(user: KhojUser, regex_pattern: str, path_prefix: Optional[str] = None):
|
||||
"""
|
||||
Search for a regex pattern in file objects, with an optional path prefix filter.
|
||||
Outputs results in grep format.
|
||||
"""
|
||||
query = FileObject.objects.filter(user=user, agent=None, raw_text__iregex=regex_pattern)
|
||||
if path_prefix:
|
||||
query = query.filter(file_name__startswith=path_prefix)
|
||||
return await sync_to_async(list)(query)
|
||||
|
||||
|
||||
class EntryAdapters:
|
||||
word_filter = WordFilter()
|
||||
|
||||
@@ -22,12 +22,20 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def anthropic_send_message_to_model(
|
||||
messages, api_key, api_base_url, model, response_type="text", response_schema=None, deepthought=False, tracer={}
|
||||
messages,
|
||||
api_key,
|
||||
api_base_url,
|
||||
model,
|
||||
response_type="text",
|
||||
response_schema=None,
|
||||
tools=None,
|
||||
deepthought=False,
|
||||
tracer={},
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
# Get Response from GPT. Don't use response_type because Anthropic doesn't support it.
|
||||
# Get response from model. Don't use response_type because Anthropic doesn't support it.
|
||||
return anthropic_completion_with_backoff(
|
||||
messages=messages,
|
||||
system_prompt="",
|
||||
@@ -36,6 +44,7 @@ def anthropic_send_message_to_model(
|
||||
api_base_url=api_base_url,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tools=tools,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from textwrap import dedent
|
||||
from time import perf_counter
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Type
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
|
||||
import anthropic
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
@@ -18,11 +17,14 @@ from tenacity import (
|
||||
|
||||
from khoj.processor.conversation.utils import (
|
||||
ResponseWithThought,
|
||||
ToolCall,
|
||||
commit_conversation_trace,
|
||||
get_image_from_base64,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
ToolDefinition,
|
||||
create_tool_definition,
|
||||
get_anthropic_async_client,
|
||||
get_anthropic_client,
|
||||
get_chat_usage_metrics,
|
||||
@@ -57,9 +59,10 @@ def anthropic_completion_with_backoff(
|
||||
max_tokens: int | None = None,
|
||||
response_type: str = "text",
|
||||
response_schema: BaseModel | None = None,
|
||||
tools: List[ToolDefinition] = None,
|
||||
deepthought: bool = False,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
) -> ResponseWithThought:
|
||||
client = anthropic_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_anthropic_client(api_key, api_base_url)
|
||||
@@ -67,12 +70,26 @@ def anthropic_completion_with_backoff(
|
||||
|
||||
formatted_messages, system = format_messages_for_anthropic(messages, system_prompt)
|
||||
|
||||
thoughts = ""
|
||||
aggregated_response = ""
|
||||
final_message = None
|
||||
model_kwargs = model_kwargs or dict()
|
||||
if response_schema:
|
||||
tool = create_anthropic_tool_definition(response_schema=response_schema)
|
||||
model_kwargs["tools"] = [tool]
|
||||
|
||||
# Configure structured output
|
||||
if tools:
|
||||
# Convert tools to Anthropic format
|
||||
model_kwargs["tools"] = [
|
||||
anthropic.types.ToolParam(name=tool.name, description=tool.description, input_schema=tool.schema)
|
||||
for tool in tools
|
||||
]
|
||||
# Cache tool definitions
|
||||
last_tool = model_kwargs["tools"][-1]
|
||||
last_tool["cache_control"] = {"type": "ephemeral"}
|
||||
elif response_schema:
|
||||
tool = create_tool_definition(response_schema)
|
||||
model_kwargs["tools"] = [
|
||||
anthropic.types.ToolParam(name=tool.name, description=tool.description, input_schema=tool.schema)
|
||||
]
|
||||
elif response_type == "json_object" and not (is_reasoning_model(model_name) and deepthought):
|
||||
# Prefill model response with '{' to make it output a valid JSON object. Not supported with extended thinking.
|
||||
formatted_messages.append(anthropic.types.MessageParam(role="assistant", content="{"))
|
||||
@@ -96,15 +113,41 @@ def anthropic_completion_with_backoff(
|
||||
max_tokens=max_tokens,
|
||||
**(model_kwargs),
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
aggregated_response += text
|
||||
for chunk in stream:
|
||||
if chunk.type != "content_block_delta":
|
||||
continue
|
||||
if chunk.delta.type == "thinking_delta":
|
||||
thoughts += chunk.delta.thinking
|
||||
elif chunk.delta.type == "text_delta":
|
||||
aggregated_response += chunk.delta.text
|
||||
final_message = stream.get_final_message()
|
||||
|
||||
# Extract first tool call from final message
|
||||
for item in final_message.content:
|
||||
if item.type == "tool_use":
|
||||
aggregated_response = json.dumps(item.input)
|
||||
break
|
||||
# Track raw content of model response to reuse for cache hits in multi-turn chats
|
||||
raw_content = [item.model_dump() for item in final_message.content]
|
||||
|
||||
# Extract all tool calls if tools are enabled
|
||||
if tools:
|
||||
tool_calls = [
|
||||
ToolCall(name=item.name, args=item.input, id=item.id).__dict__
|
||||
for item in final_message.content
|
||||
if item.type == "tool_use"
|
||||
]
|
||||
if tool_calls:
|
||||
# If there are tool calls, aggregate thoughts and responses into thoughts
|
||||
if thoughts and aggregated_response:
|
||||
# wrap each line of thought in italics
|
||||
thoughts = "\n".join([f"*{line.strip()}*" for line in thoughts.splitlines() if line.strip()])
|
||||
thoughts = f"{thoughts}\n\n{aggregated_response}"
|
||||
else:
|
||||
thoughts = thoughts or aggregated_response
|
||||
# Json dump tool calls into aggregated response
|
||||
aggregated_response = json.dumps(tool_calls)
|
||||
# If response schema is used, return the first tool call's input
|
||||
elif response_schema:
|
||||
for item in final_message.content:
|
||||
if item.type == "tool_use":
|
||||
aggregated_response = json.dumps(item.input)
|
||||
break
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = final_message.usage.input_tokens
|
||||
@@ -126,7 +169,7 @@ def anthropic_completion_with_backoff(
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
return aggregated_response
|
||||
return ResponseWithThought(text=aggregated_response, thought=thoughts, raw_content=raw_content)
|
||||
|
||||
|
||||
@retry(
|
||||
@@ -183,10 +226,10 @@ async def anthropic_chat_completion_with_backoff(
|
||||
if chunk.type == "message_delta":
|
||||
if chunk.delta.stop_reason == "refusal":
|
||||
yield ResponseWithThought(
|
||||
response="...I'm sorry, but my safety filters prevent me from assisting with this query."
|
||||
text="...I'm sorry, but my safety filters prevent me from assisting with this query."
|
||||
)
|
||||
elif chunk.delta.stop_reason == "max_tokens":
|
||||
yield ResponseWithThought(response="...I'm sorry, but I've hit my response length limit.")
|
||||
yield ResponseWithThought(text="...I'm sorry, but I've hit my response length limit.")
|
||||
if chunk.delta.stop_reason in ["refusal", "max_tokens"]:
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {chunk.delta.stop_reason}.\n"
|
||||
@@ -199,7 +242,7 @@ async def anthropic_chat_completion_with_backoff(
|
||||
# Handle streamed response chunk
|
||||
response_chunk: ResponseWithThought = None
|
||||
if chunk.delta.type == "text_delta":
|
||||
response_chunk = ResponseWithThought(response=chunk.delta.text)
|
||||
response_chunk = ResponseWithThought(text=chunk.delta.text)
|
||||
aggregated_response += chunk.delta.text
|
||||
if chunk.delta.type == "thinking_delta":
|
||||
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
|
||||
@@ -232,13 +275,14 @@ async def anthropic_chat_completion_with_backoff(
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
|
||||
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):
|
||||
def format_messages_for_anthropic(raw_messages: list[ChatMessage], system_prompt: str = None):
|
||||
"""
|
||||
Format messages for Anthropic
|
||||
"""
|
||||
# Extract system prompt
|
||||
system_prompt = system_prompt or ""
|
||||
for message in messages.copy():
|
||||
messages = deepcopy(raw_messages)
|
||||
for message in messages:
|
||||
if message.role == "system":
|
||||
if isinstance(message.content, list):
|
||||
system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"])
|
||||
@@ -250,15 +294,30 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
|
||||
else:
|
||||
system = None
|
||||
|
||||
# Anthropic requires the first message to be a 'user' message
|
||||
if len(messages) == 1:
|
||||
# Anthropic requires the first message to be a user message unless its a tool call
|
||||
message_type = messages[0].additional_kwargs.get("message_type", None)
|
||||
if len(messages) == 1 and message_type != "tool_call":
|
||||
messages[0].role = "user"
|
||||
elif len(messages) > 1 and messages[0].role == "assistant":
|
||||
messages = messages[1:]
|
||||
|
||||
# Convert image urls to base64 encoded images in Anthropic message format
|
||||
for message in messages:
|
||||
if isinstance(message.content, list):
|
||||
# Handle tool call and tool result message types from additional_kwargs
|
||||
message_type = message.additional_kwargs.get("message_type")
|
||||
if message_type == "tool_call":
|
||||
pass
|
||||
elif message_type == "tool_result":
|
||||
# Convert tool_result to Anthropic tool_result format
|
||||
content = []
|
||||
for part in message.content:
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": part["id"],
|
||||
"content": part["content"],
|
||||
}
|
||||
)
|
||||
message.content = content
|
||||
# Convert image urls to base64 encoded images in Anthropic message format
|
||||
elif isinstance(message.content, list):
|
||||
content = []
|
||||
# Sort the content. Anthropic models prefer that text comes after images.
|
||||
message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1)
|
||||
@@ -304,18 +363,15 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
|
||||
if isinstance(block, dict) and "cache_control" in block:
|
||||
del block["cache_control"]
|
||||
|
||||
# Add cache control to the last content block of second to last message.
|
||||
# In research mode, this message content is list of iterations, updated after each research iteration.
|
||||
# Caching it should improve research efficiency.
|
||||
cache_message = messages[-2]
|
||||
# Add cache control to the last content block of last message.
|
||||
# Caching should improve research efficiency.
|
||||
cache_message = messages[-1]
|
||||
if isinstance(cache_message.content, list) and cache_message.content:
|
||||
# Add cache control to the last content block only if it's a text block with non-empty content
|
||||
last_block = cache_message.content[-1]
|
||||
if (
|
||||
isinstance(last_block, dict)
|
||||
and last_block.get("type") == "text"
|
||||
and last_block.get("text")
|
||||
and last_block.get("text").strip()
|
||||
if isinstance(last_block, dict) and (
|
||||
(last_block.get("type") == "text" and last_block.get("text", "").strip())
|
||||
or (last_block.get("type") == "tool_result" and last_block.get("content", []))
|
||||
):
|
||||
last_block["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
@@ -326,74 +382,5 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
|
||||
return formatted_messages, system
|
||||
|
||||
|
||||
def create_anthropic_tool_definition(
|
||||
response_schema: Type[BaseModel],
|
||||
tool_name: str = None,
|
||||
tool_description: Optional[str] = None,
|
||||
) -> anthropic.types.ToolParam:
|
||||
"""
|
||||
Converts a response schema BaseModel class into an Anthropic tool definition dictionary.
|
||||
|
||||
This format is expected by Anthropic's API when defining tools the model can use.
|
||||
|
||||
Args:
|
||||
response_schema: The Pydantic BaseModel class to convert.
|
||||
This class defines the response schema for the tool.
|
||||
tool_name: The name for the Anthropic tool (e.g., "get_weather", "plan_next_step").
|
||||
tool_description: Optional description for the Anthropic tool.
|
||||
If None, it attempts to use the Pydantic model's docstring.
|
||||
If that's also missing, a fallback description is generated.
|
||||
|
||||
Returns:
|
||||
An tool definition for Anthropic's API.
|
||||
"""
|
||||
model_schema = response_schema.model_json_schema()
|
||||
|
||||
name = tool_name or response_schema.__name__.lower()
|
||||
description = tool_description
|
||||
if description is None:
|
||||
docstring = response_schema.__doc__
|
||||
if docstring:
|
||||
description = dedent(docstring).strip()
|
||||
else:
|
||||
# Fallback description if no explicit one or docstring is provided
|
||||
description = f"Tool named '{name}' accepts specified parameters."
|
||||
|
||||
# Process properties to inline enums and remove $defs dependency
|
||||
processed_properties = {}
|
||||
original_properties = model_schema.get("properties", {})
|
||||
defs = model_schema.get("$defs", {})
|
||||
|
||||
for prop_name, prop_schema in original_properties.items():
|
||||
current_prop_schema = deepcopy(prop_schema) # Work on a copy
|
||||
# Check for enums defined directly in the property for simpler direct enum definitions.
|
||||
if "$ref" in current_prop_schema:
|
||||
ref_path = current_prop_schema["$ref"]
|
||||
if ref_path.startswith("#/$defs/"):
|
||||
def_name = ref_path.split("/")[-1]
|
||||
if def_name in defs and "enum" in defs[def_name]:
|
||||
enum_def = defs[def_name]
|
||||
current_prop_schema["enum"] = enum_def["enum"]
|
||||
current_prop_schema["type"] = enum_def.get("type", "string")
|
||||
if "description" not in current_prop_schema and "description" in enum_def:
|
||||
current_prop_schema["description"] = enum_def["description"]
|
||||
del current_prop_schema["$ref"] # Remove the $ref as it's been inlined
|
||||
|
||||
processed_properties[prop_name] = current_prop_schema
|
||||
|
||||
# The input_schema for Anthropic tools is a JSON Schema object.
|
||||
# Pydantic's model_json_schema() provides most of what's needed.
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": processed_properties,
|
||||
}
|
||||
|
||||
# Include 'required' fields if specified in the Pydantic model
|
||||
if "required" in model_schema and model_schema["required"]:
|
||||
input_schema["required"] = model_schema["required"]
|
||||
|
||||
return anthropic.types.ToolParam(name=name, description=description, input_schema=input_schema)
|
||||
|
||||
|
||||
def is_reasoning_model(model_name: str) -> bool:
|
||||
return any(model_name.startswith(model) for model in REASONING_MODELS)
|
||||
|
||||
@@ -28,6 +28,7 @@ def gemini_send_message_to_model(
|
||||
api_base_url=None,
|
||||
response_type="text",
|
||||
response_schema=None,
|
||||
tools=None,
|
||||
model_kwargs=None,
|
||||
deepthought=False,
|
||||
tracer={},
|
||||
@@ -37,8 +38,10 @@ def gemini_send_message_to_model(
|
||||
"""
|
||||
model_kwargs = {}
|
||||
|
||||
if tools:
|
||||
model_kwargs["tools"] = tools
|
||||
# Monitor for flakiness in 1.5+ models. This would cause unwanted behavior and terminate response early in 1.5 models.
|
||||
if response_type == "json_object" and not model.startswith("gemini-1.5"):
|
||||
elif response_type == "json_object" and not model.startswith("gemini-1.5"):
|
||||
model_kwargs["response_mime_type"] = "application/json"
|
||||
if response_schema:
|
||||
model_kwargs["response_schema"] = response_schema
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from time import perf_counter
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List
|
||||
|
||||
import httpx
|
||||
from google import genai
|
||||
@@ -22,11 +23,13 @@ from tenacity import (
|
||||
|
||||
from khoj.processor.conversation.utils import (
|
||||
ResponseWithThought,
|
||||
ToolCall,
|
||||
commit_conversation_trace,
|
||||
get_image_from_base64,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
ToolDefinition,
|
||||
get_chat_usage_metrics,
|
||||
get_gemini_client,
|
||||
is_none_or_empty,
|
||||
@@ -95,26 +98,29 @@ def gemini_completion_with_backoff(
|
||||
temperature=1.2,
|
||||
api_key=None,
|
||||
api_base_url: str = None,
|
||||
model_kwargs=None,
|
||||
model_kwargs={},
|
||||
deepthought=False,
|
||||
tracer={},
|
||||
) -> str:
|
||||
) -> ResponseWithThought:
|
||||
client = gemini_clients.get(api_key)
|
||||
if not client:
|
||||
client = get_gemini_client(api_key, api_base_url)
|
||||
gemini_clients[api_key] = client
|
||||
|
||||
formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt)
|
||||
response_thoughts: str | None = None
|
||||
raw_content, response_text, response_thoughts = [], "", None
|
||||
|
||||
# format model response schema
|
||||
# Configure structured output
|
||||
tools = None
|
||||
response_schema = None
|
||||
if model_kwargs and model_kwargs.get("response_schema"):
|
||||
if model_kwargs.get("tools"):
|
||||
tools = to_gemini_tools(model_kwargs["tools"])
|
||||
elif model_kwargs.get("response_schema"):
|
||||
response_schema = clean_response_schema(model_kwargs["response_schema"])
|
||||
|
||||
thinking_config = None
|
||||
if deepthought and is_reasoning_model(model_name):
|
||||
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
|
||||
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI, include_thoughts=True)
|
||||
|
||||
max_output_tokens = MAX_OUTPUT_TOKENS_FOR_STANDARD_GEMINI
|
||||
if is_reasoning_model(model_name):
|
||||
@@ -127,8 +133,9 @@ def gemini_completion_with_backoff(
|
||||
thinking_config=thinking_config,
|
||||
max_output_tokens=max_output_tokens,
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
|
||||
response_mime_type=model_kwargs.get("response_mime_type", "text/plain"),
|
||||
response_schema=response_schema,
|
||||
tools=tools,
|
||||
seed=seed,
|
||||
top_p=0.95,
|
||||
http_options=gtypes.HttpOptions(client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
|
||||
@@ -137,7 +144,25 @@ def gemini_completion_with_backoff(
|
||||
try:
|
||||
# Generate the response
|
||||
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
|
||||
response_text = response.text
|
||||
if (
|
||||
not response.candidates
|
||||
or not response.candidates[0].content
|
||||
or response.candidates[0].content.parts is None
|
||||
):
|
||||
raise ValueError(f"Failed to get response from model.")
|
||||
raw_content = [part.model_dump() for part in response.candidates[0].content.parts]
|
||||
if response.function_calls:
|
||||
function_calls = [
|
||||
ToolCall(name=function_call.name, args=function_call.args, id=function_call.id).__dict__
|
||||
for function_call in response.function_calls
|
||||
]
|
||||
response_text = json.dumps(function_calls)
|
||||
else:
|
||||
# If no function calls, use the text response
|
||||
response_text = response.text
|
||||
response_thoughts = "\n".join(
|
||||
[part.text for part in response.candidates[0].content.parts if part.thought and isinstance(part.text, str)]
|
||||
)
|
||||
except gerrors.ClientError as e:
|
||||
response = None
|
||||
response_text, _ = handle_gemini_response(e.args)
|
||||
@@ -151,8 +176,14 @@ def gemini_completion_with_backoff(
|
||||
input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0
|
||||
output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0
|
||||
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
|
||||
cache_read_tokens = response.usage_metadata.cached_content_token_count or 0 if response else 0
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||
model_name,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
thought_tokens=thought_tokens,
|
||||
usage=tracer.get("usage"),
|
||||
)
|
||||
|
||||
# Validate the response. If empty, raise an error to retry.
|
||||
@@ -166,7 +197,7 @@ def gemini_completion_with_backoff(
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
|
||||
return response_text
|
||||
return ResponseWithThought(text=response_text, thought=response_thoughts, raw_content=raw_content)
|
||||
|
||||
|
||||
@retry(
|
||||
@@ -234,7 +265,7 @@ async def gemini_chat_completion_with_backoff(
|
||||
# handle safety, rate-limit, other finish reasons
|
||||
stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
if stopped:
|
||||
yield ResponseWithThought(response=stop_message)
|
||||
yield ResponseWithThought(text=stop_message)
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {stop_message}.\n"
|
||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||
@@ -247,7 +278,7 @@ async def gemini_chat_completion_with_backoff(
|
||||
yield ResponseWithThought(thought=part.text)
|
||||
elif part.text:
|
||||
aggregated_response += part.text
|
||||
yield ResponseWithThought(response=part.text)
|
||||
yield ResponseWithThought(text=part.text)
|
||||
# Calculate cost of chat
|
||||
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
||||
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
|
||||
@@ -346,8 +377,24 @@ def format_messages_for_gemini(
|
||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||
|
||||
for message in messages:
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
# Handle tool call and tool result message types from additional_kwargs
|
||||
message_type = message.additional_kwargs.get("message_type")
|
||||
if message_type == "tool_call":
|
||||
pass
|
||||
elif message_type == "tool_result":
|
||||
# Convert tool_result to Gemini function response format
|
||||
# Need to find the corresponding function call from previous messages
|
||||
tool_result_msg_content = []
|
||||
for part in message.content:
|
||||
tool_result_msg_content.append(
|
||||
gtypes.Part.from_function_response(name=part["name"], response={"result": part["content"]})
|
||||
)
|
||||
message.content = tool_result_msg_content
|
||||
# Convert message content to string list from chatml dictionary list
|
||||
if isinstance(message.content, list):
|
||||
elif isinstance(message.content, list):
|
||||
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
||||
message_content = []
|
||||
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1):
|
||||
@@ -367,16 +414,13 @@ def format_messages_for_gemini(
|
||||
messages.remove(message)
|
||||
continue
|
||||
message.content = message_content
|
||||
elif isinstance(message.content, str):
|
||||
elif isinstance(message.content, str) and message.content.strip():
|
||||
message.content = [gtypes.Part.from_text(text=message.content)]
|
||||
else:
|
||||
logger.error(f"Dropping invalid type: {type(message.content)} of message content: {message.content}")
|
||||
messages.remove(message)
|
||||
continue
|
||||
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
|
||||
@@ -404,3 +448,21 @@ def is_reasoning_model(model_name: str) -> bool:
|
||||
Check if the model is a reasoning model.
|
||||
"""
|
||||
return model_name.startswith("gemini-2.5")
|
||||
|
||||
|
||||
def to_gemini_tools(tools: List[ToolDefinition]) -> List[gtypes.ToolDict] | None:
|
||||
"Transform tool definitions from standard format to Gemini format."
|
||||
gemini_tools = [
|
||||
gtypes.ToolDict(
|
||||
function_declarations=[
|
||||
gtypes.FunctionDeclarationDict(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.schema,
|
||||
)
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
return gemini_tools or None
|
||||
|
||||
@@ -145,12 +145,12 @@ async def converse_offline(
|
||||
aggregated_response += response_delta
|
||||
# Put chunk into the asyncio queue (non-blocking)
|
||||
try:
|
||||
queue.put_nowait(ResponseWithThought(response=response_delta))
|
||||
queue.put_nowait(ResponseWithThought(text=response_delta))
|
||||
except asyncio.QueueFull:
|
||||
# Should not happen with default queue size unless consumer is very slow
|
||||
logger.warning("Asyncio queue full during offline LLM streaming.")
|
||||
# Potentially block here or handle differently if needed
|
||||
asyncio.run(queue.put(ResponseWithThought(response=response_delta)))
|
||||
asyncio.run(queue.put(ResponseWithThought(text=response_delta)))
|
||||
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
@@ -221,4 +221,4 @@ def send_message_to_model_offline(
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
|
||||
return response_text
|
||||
return ResponseWithThought(text=response_text)
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from openai.lib._pydantic import _ensure_strict_json_schema
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.openai.utils import (
|
||||
chat_completion_with_backoff,
|
||||
clean_response_schema,
|
||||
completion_with_backoff,
|
||||
get_openai_api_json_support,
|
||||
get_structured_output_support,
|
||||
to_openai_tools,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
JsonSupport,
|
||||
OperatorRun,
|
||||
ResponseWithThought,
|
||||
StructuredOutputSupport,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
|
||||
from khoj.utils.helpers import ToolDefinition, is_none_or_empty, truncate_code_context
|
||||
from khoj.utils.rawconfig import FileAttachment, LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
@@ -32,6 +31,7 @@ def send_message_to_model(
|
||||
model,
|
||||
response_type="text",
|
||||
response_schema=None,
|
||||
tools: list[ToolDefinition] = None,
|
||||
deepthought=False,
|
||||
api_base_url=None,
|
||||
tracer: dict = {},
|
||||
@@ -40,9 +40,11 @@ def send_message_to_model(
|
||||
Send message to model
|
||||
"""
|
||||
|
||||
model_kwargs = {}
|
||||
json_support = get_openai_api_json_support(model, api_base_url)
|
||||
if response_schema and json_support == JsonSupport.SCHEMA:
|
||||
model_kwargs: Dict[str, Any] = {}
|
||||
json_support = get_structured_output_support(model, api_base_url)
|
||||
if tools and json_support == StructuredOutputSupport.TOOL:
|
||||
model_kwargs["tools"] = to_openai_tools(tools)
|
||||
elif response_schema and json_support >= StructuredOutputSupport.SCHEMA:
|
||||
# Drop unsupported fields from schema passed to OpenAI APi
|
||||
cleaned_response_schema = clean_response_schema(response_schema)
|
||||
model_kwargs["response_format"] = {
|
||||
@@ -53,7 +55,7 @@ def send_message_to_model(
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
|
||||
elif response_type == "json_object" and json_support == StructuredOutputSupport.OBJECT:
|
||||
model_kwargs["response_format"] = {"type": response_type}
|
||||
|
||||
# Get Response from GPT
|
||||
@@ -171,30 +173,3 @@ async def converse_openai(
|
||||
tracer=tracer,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
||||
"""
|
||||
Format response schema to be compatible with OpenAI API.
|
||||
|
||||
Clean the response schema by removing unsupported fields.
|
||||
"""
|
||||
# Normalize schema to OpenAI compatible JSON schema format
|
||||
schema_json = schema if isinstance(schema, dict) else schema.model_json_schema()
|
||||
schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json)
|
||||
|
||||
# Recursively drop unsupported fields from schema passed to OpenAI API
|
||||
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
|
||||
fields_to_exclude = ["minItems", "maxItems"]
|
||||
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
|
||||
for _, prop_value in schema_json["properties"].items():
|
||||
if isinstance(prop_value, dict):
|
||||
# Remove specified fields from direct properties
|
||||
for field in fields_to_exclude:
|
||||
prop_value.pop(field, None)
|
||||
# Recursively remove specified fields from child properties
|
||||
if "items" in prop_value and isinstance(prop_value["items"], dict):
|
||||
clean_response_schema(prop_value["items"])
|
||||
|
||||
# Return cleaned schema
|
||||
return schema_json
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
@@ -9,6 +10,7 @@ from urllib.parse import urlparse
|
||||
import httpx
|
||||
import openai
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from openai.lib._pydantic import _ensure_strict_json_schema
|
||||
from openai.lib.streaming.chat import (
|
||||
ChatCompletionStream,
|
||||
ChatCompletionStreamEvent,
|
||||
@@ -20,6 +22,7 @@ from openai.types.chat.chat_completion_chunk import (
|
||||
Choice,
|
||||
ChoiceDelta,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
@@ -30,11 +33,13 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import (
|
||||
JsonSupport,
|
||||
ResponseWithThought,
|
||||
StructuredOutputSupport,
|
||||
ToolCall,
|
||||
commit_conversation_trace,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
ToolDefinition,
|
||||
convert_image_data_uri,
|
||||
get_chat_usage_metrics,
|
||||
get_openai_async_client,
|
||||
@@ -72,7 +77,7 @@ def completion_with_backoff(
|
||||
deepthought: bool = False,
|
||||
model_kwargs: dict = {},
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
) -> ResponseWithThought:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client = openai_clients.get(client_key)
|
||||
if not client:
|
||||
@@ -117,6 +122,9 @@ def completion_with_backoff(
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
tool_ids = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
thoughts = ""
|
||||
aggregated_response = ""
|
||||
if stream:
|
||||
with client.beta.chat.completions.stream(
|
||||
@@ -130,7 +138,16 @@ def completion_with_backoff(
|
||||
if chunk.type == "content.delta":
|
||||
aggregated_response += chunk.delta
|
||||
elif chunk.type == "thought.delta":
|
||||
pass
|
||||
thoughts += chunk.delta
|
||||
elif chunk.type == "chunk" and chunk.chunk.choices and chunk.chunk.choices[0].delta.tool_calls:
|
||||
tool_ids += [tool_call.id for tool_call in chunk.chunk.choices[0].delta.tool_calls]
|
||||
elif chunk.type == "tool_calls.function.arguments.done":
|
||||
tool_calls += [ToolCall(name=chunk.name, args=json.loads(chunk.arguments), id=None)]
|
||||
if tool_calls:
|
||||
tool_calls = [
|
||||
ToolCall(name=chunk.name, args=chunk.args, id=tool_id) for chunk, tool_id in zip(tool_calls, tool_ids)
|
||||
]
|
||||
aggregated_response = json.dumps([tool_call.__dict__ for tool_call in tool_calls])
|
||||
else:
|
||||
# Non-streaming chat completion
|
||||
chunk = client.beta.chat.completions.parse(
|
||||
@@ -164,7 +181,7 @@ def completion_with_backoff(
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
return aggregated_response
|
||||
return ResponseWithThought(text=aggregated_response, thought=thoughts)
|
||||
|
||||
|
||||
@retry(
|
||||
@@ -190,6 +207,7 @@ async def chat_completion_with_backoff(
|
||||
deepthought=False,
|
||||
model_kwargs: dict = {},
|
||||
tracer: dict = {},
|
||||
tools=None,
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client = openai_async_clients.get(client_key)
|
||||
@@ -258,6 +276,8 @@ async def chat_completion_with_backoff(
|
||||
read_timeout = 300 if is_local_api(api_base_url) else 60
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
if tools:
|
||||
model_kwargs["tools"] = tools
|
||||
|
||||
aggregated_response = ""
|
||||
final_chunk = None
|
||||
@@ -277,7 +297,7 @@ async def chat_completion_with_backoff(
|
||||
raise ValueError("No response by model.")
|
||||
aggregated_response = response.choices[0].message.content
|
||||
final_chunk = response
|
||||
yield ResponseWithThought(response=aggregated_response)
|
||||
yield ResponseWithThought(text=aggregated_response)
|
||||
else:
|
||||
async for chunk in stream_processor(response):
|
||||
# Log the time taken to start response
|
||||
@@ -293,8 +313,8 @@ async def chat_completion_with_backoff(
|
||||
response_chunk: ResponseWithThought = None
|
||||
response_delta = chunk.choices[0].delta
|
||||
if response_delta.content:
|
||||
response_chunk = ResponseWithThought(response=response_delta.content)
|
||||
aggregated_response += response_chunk.response
|
||||
response_chunk = ResponseWithThought(text=response_delta.content)
|
||||
aggregated_response += response_chunk.text
|
||||
elif response_delta.thought:
|
||||
response_chunk = ResponseWithThought(thought=response_delta.thought)
|
||||
if response_chunk:
|
||||
@@ -327,16 +347,16 @@ async def chat_completion_with_backoff(
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
|
||||
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
||||
def get_structured_output_support(model_name: str, api_base_url: str = None) -> StructuredOutputSupport:
|
||||
if model_name.startswith("deepseek-reasoner"):
|
||||
return JsonSupport.NONE
|
||||
return StructuredOutputSupport.NONE
|
||||
if api_base_url:
|
||||
host = urlparse(api_base_url).hostname
|
||||
if host and host.endswith(".ai.azure.com"):
|
||||
return JsonSupport.OBJECT
|
||||
return StructuredOutputSupport.OBJECT
|
||||
if host == "api.deepinfra.com":
|
||||
return JsonSupport.OBJECT
|
||||
return JsonSupport.SCHEMA
|
||||
return StructuredOutputSupport.OBJECT
|
||||
return StructuredOutputSupport.TOOL
|
||||
|
||||
|
||||
def format_message_for_api(messages: List[ChatMessage], api_base_url: str) -> List[dict]:
|
||||
@@ -345,6 +365,43 @@ def format_message_for_api(messages: List[ChatMessage], api_base_url: str) -> Li
|
||||
"""
|
||||
formatted_messages = []
|
||||
for message in deepcopy(messages):
|
||||
# Handle tool call and tool result message types
|
||||
message_type = message.additional_kwargs.get("message_type")
|
||||
if message_type == "tool_call":
|
||||
# Convert tool_call to OpenAI function call format
|
||||
content = []
|
||||
for part in message.content:
|
||||
content.append(
|
||||
{
|
||||
"type": "function",
|
||||
"id": part.get("id"),
|
||||
"function": {
|
||||
"name": part.get("name"),
|
||||
"arguments": json.dumps(part.get("input", part.get("args", {}))),
|
||||
},
|
||||
}
|
||||
)
|
||||
formatted_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": content,
|
||||
}
|
||||
)
|
||||
continue
|
||||
if message_type == "tool_result":
|
||||
# Convert tool_result to OpenAI tool result format
|
||||
# Each part is a result for a tool call
|
||||
for part in message.content:
|
||||
formatted_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": part.get("id") or part.get("tool_use_id"),
|
||||
"name": part.get("name"),
|
||||
"content": part.get("content"),
|
||||
}
|
||||
)
|
||||
continue
|
||||
if isinstance(message.content, list) and not is_openai_api(api_base_url):
|
||||
assistant_texts = []
|
||||
has_images = False
|
||||
@@ -708,3 +765,47 @@ def add_qwen_no_think_tag(formatted_messages: List[dict]) -> None:
|
||||
if isinstance(content_part, dict) and content_part.get("type") == "text":
|
||||
content_part["text"] += " /no_think"
|
||||
break
|
||||
|
||||
|
||||
def to_openai_tools(tools: List[ToolDefinition]) -> List[Dict] | None:
|
||||
"Transform tool definitions from standard format to OpenAI format."
|
||||
openai_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": clean_response_schema(tool.schema),
|
||||
},
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
return openai_tools or None
|
||||
|
||||
|
||||
def clean_response_schema(schema: BaseModel | dict) -> dict:
|
||||
"""
|
||||
Format response schema to be compatible with OpenAI API.
|
||||
|
||||
Clean the response schema by removing unsupported fields.
|
||||
"""
|
||||
# Normalize schema to OpenAI compatible JSON schema format
|
||||
schema_json = schema if isinstance(schema, dict) else schema.model_json_schema()
|
||||
schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json)
|
||||
|
||||
# Recursively drop unsupported fields from schema passed to OpenAI API
|
||||
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
|
||||
fields_to_exclude = ["minItems", "maxItems"]
|
||||
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
|
||||
for _, prop_value in schema_json["properties"].items():
|
||||
if isinstance(prop_value, dict):
|
||||
# Remove specified fields from direct properties
|
||||
for field in fields_to_exclude:
|
||||
prop_value.pop(field, None)
|
||||
# Recursively remove specified fields from child properties
|
||||
if "items" in prop_value and isinstance(prop_value["items"], dict):
|
||||
clean_response_schema(prop_value["items"])
|
||||
|
||||
# Return cleaned schema
|
||||
return schema_json
|
||||
|
||||
@@ -667,33 +667,37 @@ Here's some additional context about you:
|
||||
|
||||
plan_function_execution = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, a smart, creative and methodical researcher. Use the provided tool AIs to investigate information to answer query.
|
||||
Create a multi-step plan and intelligently iterate on the plan based on the retrieved information to find the requested information.
|
||||
You are Khoj, a smart, creative and meticulous researcher. Use the provided tool AIs to accomplish the task assigned to you.
|
||||
Create a multi-step plan and intelligently iterate on the plan to complete the task.
|
||||
{personality_context}
|
||||
|
||||
# Instructions
|
||||
- Ask highly diverse, detailed queries to the tool AIs, one tool AI at a time, to discover required information or run calculations. Their response will be shown to you in the next iteration.
|
||||
- Provide highly diverse, detailed requests to the tool AIs, one tool AI at a time, to gather information, perform actions etc. Their response will be shown to you in the next iteration.
|
||||
- Break down your research process into independent, self-contained steps that can be executed sequentially using the available tool AIs to answer the user's query. Write your step-by-step plan in the scratchpad.
|
||||
- Always ask a new query that was not asked to the tool AI in a previous iteration. Build on the results of the previous iterations.
|
||||
- Ensure that all required context is passed to the tool AIs for successful execution. Include any relevant stuff that has previously been attempted. They only know the context provided in your query.
|
||||
- Think step by step to come up with creative strategies when the previous iteration did not yield useful results.
|
||||
- You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to answer the user's question.
|
||||
- Stop when you have the required information by returning a JSON object with the "tool" field set to "text" and "query" field empty. E.g., {{"scratchpad": "I have all I need", "tool": "text", "query": ""}}
|
||||
- You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to accomplish the task assigned to you. Only stop when you have completed the task.
|
||||
|
||||
# Examples
|
||||
Assuming you can search the user's notes and the internet.
|
||||
Assuming you can search the user's files and the internet.
|
||||
- When the user asks for the population of their hometown
|
||||
1. Try look up their hometown in their notes. Ask the note search AI to search for their birth certificate, childhood memories, school, resume etc.
|
||||
2. If not found in their notes, try infer their hometown from their online social media profiles. Ask the online search AI to look for {username}'s biography, school, resume on linkedin, facebook, website etc.
|
||||
3. Only then try find the latest population of their hometown by reading official websites with the help of the online search and web page reading AI.
|
||||
1. Try look up their hometown in their notes. Ask the semantic search AI to search for their birth certificate, childhood memories, school, resume etc.
|
||||
2. Use the other document retrieval tools to build on the semantic search results, fill in the gaps, add more details or confirm your hypothesis.
|
||||
3. If not found in their notes, try infer their hometown from their online social media profiles. Ask the online search AI to look for {username}'s biography, school, resume on linkedin, facebook, website etc.
|
||||
4. Only then try find the latest population of their hometown by reading official websites with the help of the online search and web page reading AI.
|
||||
- When the user asks for their computer's specs
|
||||
1. Try find their computer model in their notes.
|
||||
1. Try find their computer model in their documents.
|
||||
2. Now find webpages with their computer model's spec online.
|
||||
3. Ask the webpage tool AI to extract the required information from the relevant webpages.
|
||||
- When the user asks what clothes to carry for their upcoming trip
|
||||
1. Find the itinerary of their upcoming trip in their notes.
|
||||
1. Use the semantic search tool to find the itinerary of their upcoming trip in their documents.
|
||||
2. Next find the weather forecast at the destination online.
|
||||
3. Then find if they mentioned what clothes they own in their notes.
|
||||
3. Then combine the semantic search, regex search, view file and list files tools to find if all the clothes they own in their files.
|
||||
- When the user asks you to summarize their expenses in a particular month
|
||||
1. Combine the semantic search and regex search tool AI to find all transactions in the user's documents for that month.
|
||||
2. Use the view file tool to read the line ranges in the matched files
|
||||
3. Finally summarize the expenses
|
||||
|
||||
# Background Context
|
||||
- Current Date: {day_of_week}, {current_date}
|
||||
@@ -701,31 +705,9 @@ Assuming you can search the user's notes and the internet.
|
||||
- User Name: {username}
|
||||
|
||||
# Available Tool AIs
|
||||
You decide which of the tool AIs listed below would you use to answer the user's question. You **only** have access to the following tool AIs:
|
||||
You decide which of the tool AIs listed below would you use to accomplish the user assigned task. You **only** have access to the following tool AIs:
|
||||
|
||||
{tools}
|
||||
|
||||
Your response should always be a valid JSON object with keys: "scratchpad" (str), "tool" (str) and "query" (str). Do not say anything else.
|
||||
Response format:
|
||||
{{"scratchpad": "<your_scratchpad_to_reason_about_which_tool_to_use>", "tool": "<name_of_tool_ai>", "query": "<your_detailed_query_for_the_tool_ai>"}}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
plan_function_execution_next_tool = PromptTemplate.from_template(
|
||||
"""
|
||||
Given the results of your previous iterations, which tool AI will you use next to answer the target query?
|
||||
|
||||
# Target Query:
|
||||
{query}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
previous_iteration = PromptTemplate.from_template(
|
||||
"""
|
||||
# Iteration {index}:
|
||||
- tool: {tool}
|
||||
- query: {query}
|
||||
- result: {result}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import PIL.Image
|
||||
import pyjson5
|
||||
@@ -137,60 +137,83 @@ class OperatorRun:
|
||||
}
|
||||
|
||||
|
||||
class ToolCall:
|
||||
def __init__(self, name: str, args: dict, id: str):
|
||||
self.name = name
|
||||
self.args = args
|
||||
self.id = id
|
||||
|
||||
|
||||
class ResearchIteration:
|
||||
def __init__(
|
||||
self,
|
||||
tool: str,
|
||||
query: str,
|
||||
query: ToolCall | dict | str,
|
||||
context: list = None,
|
||||
onlineContext: dict = None,
|
||||
codeContext: dict = None,
|
||||
operatorContext: dict | OperatorRun = None,
|
||||
summarizedResult: str = None,
|
||||
warning: str = None,
|
||||
raw_response: list = None,
|
||||
):
|
||||
self.tool = tool
|
||||
self.query = query
|
||||
self.query = ToolCall(**query) if isinstance(query, dict) else query
|
||||
self.context = context
|
||||
self.onlineContext = onlineContext
|
||||
self.codeContext = codeContext
|
||||
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext
|
||||
self.summarizedResult = summarizedResult
|
||||
self.warning = warning
|
||||
self.raw_response = raw_response
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
data = vars(self).copy()
|
||||
data["query"] = self.query.__dict__ if isinstance(self.query, ToolCall) else self.query
|
||||
data["operatorContext"] = self.operatorContext.to_dict() if self.operatorContext else None
|
||||
return data
|
||||
|
||||
|
||||
def construct_iteration_history(
|
||||
previous_iterations: List[ResearchIteration],
|
||||
previous_iteration_prompt: str,
|
||||
query: str = None,
|
||||
query_images: List[str] = None,
|
||||
query_files: str = None,
|
||||
) -> list[ChatMessageModel]:
|
||||
iteration_history: list[ChatMessageModel] = []
|
||||
previous_iteration_messages: list[dict] = []
|
||||
for idx, iteration in enumerate(previous_iterations):
|
||||
iteration_data = previous_iteration_prompt.format(
|
||||
tool=iteration.tool,
|
||||
query=iteration.query,
|
||||
result=iteration.summarizedResult,
|
||||
index=idx + 1,
|
||||
)
|
||||
query_message_content = construct_structured_message(query, query_images, attached_file_context=query_files)
|
||||
if query_message_content:
|
||||
iteration_history.append(ChatMessageModel(by="you", message=query_message_content))
|
||||
|
||||
previous_iteration_messages.append({"type": "text", "text": iteration_data})
|
||||
|
||||
if previous_iteration_messages:
|
||||
if query:
|
||||
iteration_history.append(ChatMessageModel(by="you", message=query))
|
||||
iteration_history.append(
|
||||
for iteration in previous_iterations:
|
||||
if not iteration.query or isinstance(iteration.query, str):
|
||||
iteration_history.append(
|
||||
ChatMessageModel(
|
||||
by="you",
|
||||
message=iteration.summarizedResult
|
||||
or iteration.warning
|
||||
or "Please specify what you want to do next.",
|
||||
)
|
||||
)
|
||||
continue
|
||||
iteration_history += [
|
||||
ChatMessageModel(
|
||||
by="khoj",
|
||||
intent=Intent(type="remember", query=query),
|
||||
message=previous_iteration_messages,
|
||||
)
|
||||
)
|
||||
message=iteration.raw_response or [iteration.query.__dict__],
|
||||
intent=Intent(type="tool_call", query=query),
|
||||
),
|
||||
ChatMessageModel(
|
||||
by="you",
|
||||
intent=Intent(type="tool_result"),
|
||||
message=[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"id": iteration.query.id,
|
||||
"name": iteration.query.name,
|
||||
"content": iteration.summarizedResult,
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
return iteration_history
|
||||
|
||||
|
||||
@@ -302,33 +325,44 @@ def construct_tool_chat_history(
|
||||
ConversationCommand.Notes: (
|
||||
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
|
||||
),
|
||||
ConversationCommand.Online: (
|
||||
ConversationCommand.SearchWeb: (
|
||||
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
|
||||
),
|
||||
ConversationCommand.Webpage: (
|
||||
ConversationCommand.ReadWebpage: (
|
||||
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
|
||||
),
|
||||
ConversationCommand.Code: (
|
||||
ConversationCommand.RunCode: (
|
||||
lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
|
||||
),
|
||||
}
|
||||
for iteration in previous_iterations:
|
||||
if not iteration.query or isinstance(iteration.query, str):
|
||||
chat_history.append(
|
||||
ChatMessageModel(
|
||||
by="you",
|
||||
message=iteration.summarizedResult
|
||||
or iteration.warning
|
||||
or "Please specify what you want to do next.",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# If a tool is provided use the inferred query extractor for that tool if available
|
||||
# If no tool is provided, use inferred query extractor for the tool used in the iteration
|
||||
# Fallback to base extractor if the tool does not have an inferred query extractor
|
||||
inferred_query_extractor = extract_inferred_query_map.get(
|
||||
tool or ConversationCommand(iteration.tool), base_extractor
|
||||
tool or ConversationCommand(iteration.query.name), base_extractor
|
||||
)
|
||||
chat_history += [
|
||||
ChatMessageModel(
|
||||
by="you",
|
||||
message=iteration.query,
|
||||
message=yaml.dump(iteration.query.args, default_flow_style=False),
|
||||
),
|
||||
ChatMessageModel(
|
||||
by="khoj",
|
||||
intent=Intent(
|
||||
type="remember",
|
||||
query=iteration.query,
|
||||
query=yaml.dump(iteration.query.args, default_flow_style=False),
|
||||
inferred_queries=inferred_query_extractor(iteration),
|
||||
memory_type="notes",
|
||||
),
|
||||
@@ -481,28 +515,32 @@ Khoj: "{chat_response}"
|
||||
|
||||
def construct_structured_message(
|
||||
message: list[dict] | str,
|
||||
images: list[str],
|
||||
model_type: str,
|
||||
vision_enabled: bool,
|
||||
images: list[str] = None,
|
||||
model_type: str = None,
|
||||
vision_enabled: bool = True,
|
||||
attached_file_context: str = None,
|
||||
):
|
||||
"""
|
||||
Format messages into appropriate multimedia format for supported chat model types
|
||||
Format messages into appropriate multimedia format for supported chat model types.
|
||||
|
||||
Assume vision is enabled and chat model provider supports messages in chatml format, unless specified otherwise.
|
||||
"""
|
||||
if model_type in [
|
||||
if not model_type or model_type in [
|
||||
ChatModel.ModelType.OPENAI,
|
||||
ChatModel.ModelType.GOOGLE,
|
||||
ChatModel.ModelType.ANTHROPIC,
|
||||
]:
|
||||
constructed_messages: List[dict[str, Any]] = (
|
||||
[{"type": "text", "text": message}] if isinstance(message, str) else message
|
||||
)
|
||||
|
||||
constructed_messages: List[dict[str, Any]] = []
|
||||
if not is_none_or_empty(message):
|
||||
constructed_messages += [{"type": "text", "text": message}] if isinstance(message, str) else message
|
||||
# Drop image message passed by caller if chat model does not have vision enabled
|
||||
if not vision_enabled:
|
||||
constructed_messages = [m for m in constructed_messages if m.get("type") != "image_url"]
|
||||
if not is_none_or_empty(attached_file_context):
|
||||
constructed_messages.append({"type": "text", "text": attached_file_context})
|
||||
constructed_messages += [{"type": "text", "text": attached_file_context}]
|
||||
if vision_enabled and images:
|
||||
for image in images:
|
||||
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
|
||||
constructed_messages += [{"type": "image_url", "image_url": {"url": image}}]
|
||||
return constructed_messages
|
||||
|
||||
message = message if isinstance(message, str) else "\n\n".join(m["text"] for m in message)
|
||||
@@ -638,7 +676,11 @@ def generate_chatml_messages_with_context(
|
||||
chat_message, chat.images if role == "user" else [], model_type, vision_enabled
|
||||
)
|
||||
|
||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||
reconstructed_message = ChatMessage(
|
||||
content=message_content,
|
||||
role=role,
|
||||
additional_kwargs={"message_type": chat.intent.type if chat.intent else None},
|
||||
)
|
||||
chatml_messages.insert(0, reconstructed_message)
|
||||
|
||||
if len(chatml_messages) >= 3 * lookback_turns:
|
||||
@@ -737,10 +779,21 @@ def count_tokens(
|
||||
message_content_parts: list[str] = []
|
||||
# Collate message content into single string to ease token counting
|
||||
for part in message_content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
message_content_parts.append(part["text"])
|
||||
elif isinstance(part, dict) and part.get("type") == "image_url":
|
||||
if isinstance(part, dict) and part.get("type") == "image_url":
|
||||
image_count += 1
|
||||
elif isinstance(part, dict) and part.get("type") == "text":
|
||||
message_content_parts.append(part["text"])
|
||||
elif isinstance(part, dict) and hasattr(part, "model_dump"):
|
||||
message_content_parts.append(json.dumps(part.model_dump()))
|
||||
elif isinstance(part, dict) and hasattr(part, "__dict__"):
|
||||
message_content_parts.append(json.dumps(part.__dict__))
|
||||
elif isinstance(part, dict):
|
||||
# If part is a dict but not a recognized type, convert to JSON string
|
||||
try:
|
||||
message_content_parts.append(json.dumps(part))
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.")
|
||||
image_count += 1 # Treat as an image/binary if serialization fails
|
||||
elif isinstance(part, str):
|
||||
message_content_parts.append(part)
|
||||
else:
|
||||
@@ -753,6 +806,15 @@ def count_tokens(
|
||||
return len(encoder.encode(json.dumps(message_content)))
|
||||
|
||||
|
||||
def count_total_tokens(messages: list[ChatMessage], encoder, system_message: Optional[ChatMessage]) -> Tuple[int, int]:
|
||||
"""Count total tokens in messages including system message"""
|
||||
system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0
|
||||
message_tokens = sum([count_tokens(message.content, encoder) for message in messages])
|
||||
# Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
|
||||
total_tokens = message_tokens + system_message_tokens + 4 * len(messages)
|
||||
return total_tokens, system_message_tokens
|
||||
|
||||
|
||||
def truncate_messages(
|
||||
messages: list[ChatMessage],
|
||||
max_prompt_size: int,
|
||||
@@ -771,23 +833,30 @@ def truncate_messages(
|
||||
break
|
||||
|
||||
# Drop older messages until under max supported prompt size by model
|
||||
# Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
|
||||
system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0
|
||||
tokens = sum([count_tokens(message.content, encoder) for message in messages])
|
||||
total_tokens = tokens + system_message_tokens + 4 * len(messages)
|
||||
total_tokens, system_message_tokens = count_total_tokens(messages, encoder, system_message)
|
||||
|
||||
while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1):
|
||||
if len(messages[-1].content) > 1:
|
||||
# If the last message has more than one content part, pop the oldest content part.
|
||||
# For tool calls, the whole message should dropped, assistant's tool call content being truncated annoys AI APIs.
|
||||
if len(messages[-1].content) > 1 and messages[-1].additional_kwargs.get("message_type") != "tool_call":
|
||||
# The oldest content part is earlier in content list. So pop from the front.
|
||||
messages[-1].content.pop(0)
|
||||
# Otherwise, pop the last message if it has only one content part or is a tool call.
|
||||
else:
|
||||
# The oldest message is the last one. So pop from the back.
|
||||
messages.pop()
|
||||
tokens = sum([count_tokens(message.content, encoder) for message in messages])
|
||||
total_tokens = tokens + system_message_tokens + 4 * len(messages)
|
||||
dropped_message = messages.pop()
|
||||
# Drop tool result pair of tool call, if tool call message has been removed
|
||||
if (
|
||||
dropped_message.additional_kwargs.get("message_type") == "tool_call"
|
||||
and messages
|
||||
and messages[-1].additional_kwargs.get("message_type") == "tool_result"
|
||||
):
|
||||
messages.pop()
|
||||
|
||||
total_tokens, _ = count_total_tokens(messages, encoder, system_message)
|
||||
|
||||
# Truncate current message if still over max supported prompt size by model
|
||||
total_tokens = tokens + system_message_tokens + 4 * len(messages)
|
||||
total_tokens, _ = count_total_tokens(messages, encoder, system_message)
|
||||
if total_tokens > max_prompt_size:
|
||||
# At this point, a single message with a single content part of type dict should remain
|
||||
assert (
|
||||
@@ -1149,13 +1218,15 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str:
|
||||
return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages])
|
||||
|
||||
|
||||
class JsonSupport(int, Enum):
|
||||
class StructuredOutputSupport(int, Enum):
|
||||
NONE = 0
|
||||
OBJECT = 1
|
||||
SCHEMA = 2
|
||||
TOOL = 3
|
||||
|
||||
|
||||
class ResponseWithThought:
|
||||
def __init__(self, response: str = None, thought: str = None):
|
||||
self.response = response
|
||||
def __init__(self, text: str = None, thought: str = None, raw_content: list = None):
|
||||
self.text = text
|
||||
self.thought = thought
|
||||
self.raw_content = raw_content
|
||||
|
||||
@@ -73,7 +73,7 @@ class GroundingAgent:
|
||||
grounding_user_prompt = self.get_instruction(instruction, self.environment_type)
|
||||
screenshots = [f"data:image/webp;base64,{current_state.screenshot}"]
|
||||
grounding_messages_content = construct_structured_message(
|
||||
grounding_user_prompt, screenshots, self.model.name, vision_enabled=True
|
||||
grounding_user_prompt, screenshots, self.model.model_type, vision_enabled=True
|
||||
)
|
||||
return [{"role": "user", "content": grounding_messages_content}]
|
||||
|
||||
|
||||
@@ -121,7 +121,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
# Construct input for visual reasoner history
|
||||
visual_reasoner_history = self._format_message_for_api(self.messages)
|
||||
try:
|
||||
natural_language_action = await send_message_to_model_wrapper(
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
query=query_text,
|
||||
query_images=query_screenshot,
|
||||
system_message=reasoning_system_prompt,
|
||||
@@ -129,6 +129,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
agent_chat_model=self.reasoning_model,
|
||||
tracer=self.tracer,
|
||||
)
|
||||
natural_language_action = raw_response.text
|
||||
|
||||
if not isinstance(natural_language_action, str) or not natural_language_action.strip():
|
||||
raise ValueError(f"Natural language action is empty or not a string. Got {natural_language_action}")
|
||||
@@ -255,10 +256,10 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
|
||||
# Append summary messages to history
|
||||
trigger_summary = AgentMessage(role="user", content=summarize_prompt)
|
||||
summary_message = AgentMessage(role="assistant", content=summary)
|
||||
summary_message = AgentMessage(role="assistant", content=summary.text)
|
||||
self.messages.extend([trigger_summary, summary_message])
|
||||
|
||||
return summary
|
||||
return summary.text
|
||||
|
||||
def _compile_response(self, response_content: str | List) -> str:
|
||||
"""Compile response content into a string, handling OpenAI message structures."""
|
||||
|
||||
@@ -390,7 +390,25 @@ async def read_webpages(
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
async for result in read_webpages_content(
|
||||
query,
|
||||
urls,
|
||||
user,
|
||||
send_status_func=send_status_func,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
yield result
|
||||
|
||||
|
||||
async def read_webpages_content(
|
||||
query: str,
|
||||
urls: List[str],
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
logger.info(f"Reading web pages at: {urls}")
|
||||
if send_status_func:
|
||||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||
|
||||
@@ -161,7 +161,7 @@ async def generate_python_code(
|
||||
)
|
||||
|
||||
# Extract python code wrapped in markdown code blocks from the response
|
||||
code_blocks = re.findall(r"```(?:python)?\n(.*?)```", response, re.DOTALL)
|
||||
code_blocks = re.findall(r"```(?:python)?\n(.*?)```", response.text, re.DOTALL)
|
||||
|
||||
if not code_blocks:
|
||||
raise ValueError("No Python code blocks found in response")
|
||||
|
||||
@@ -1390,7 +1390,7 @@ async def chat(
|
||||
continue
|
||||
if cancellation_event.is_set():
|
||||
break
|
||||
message = item.response
|
||||
message = item.text
|
||||
full_response += message if message else ""
|
||||
if item.thought:
|
||||
async for result in send_event(ChatEvent.THOUGHT, item.thought):
|
||||
|
||||
+293
-26
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import fnmatch
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
@@ -120,6 +121,7 @@ from khoj.utils.config import OfflineChatProcessorModel
|
||||
from khoj.utils.helpers import (
|
||||
LRU,
|
||||
ConversationCommand,
|
||||
ToolDefinition,
|
||||
get_file_type,
|
||||
in_debug_mode,
|
||||
is_none_or_empty,
|
||||
@@ -303,7 +305,7 @@ async def acreate_title_from_history(
|
||||
with timer("Chat actor: Generate title from conversation history", logger):
|
||||
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
|
||||
|
||||
return response.strip()
|
||||
return response.text.strip()
|
||||
|
||||
|
||||
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
||||
@@ -315,7 +317,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
||||
with timer("Chat actor: Generate title from query", logger):
|
||||
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
|
||||
|
||||
return response.strip()
|
||||
return response.text.strip()
|
||||
|
||||
|
||||
async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: bool = False) -> Tuple[bool, str]:
|
||||
@@ -339,7 +341,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
|
||||
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck
|
||||
)
|
||||
|
||||
response = response.strip()
|
||||
response = response.text.strip()
|
||||
try:
|
||||
response = json.loads(clean_json(response))
|
||||
is_safe = str(response.get("safe", "true")).lower() == "true"
|
||||
@@ -418,7 +420,7 @@ async def aget_data_sources_and_output_format(
|
||||
output: str
|
||||
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
relevant_tools_prompt,
|
||||
response_type="json_object",
|
||||
response_schema=PickTools,
|
||||
@@ -429,7 +431,7 @@ async def aget_data_sources_and_output_format(
|
||||
)
|
||||
|
||||
try:
|
||||
response = clean_json(response)
|
||||
response = clean_json(raw_response.text)
|
||||
response = json.loads(response)
|
||||
|
||||
chosen_sources = [s.strip() for s in response.get("source", []) if s.strip()]
|
||||
@@ -506,7 +508,7 @@ async def infer_webpage_urls(
|
||||
links: List[str] = Field(..., min_items=1, max_items=max_webpages)
|
||||
|
||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
@@ -519,7 +521,7 @@ async def infer_webpage_urls(
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||
try:
|
||||
response = clean_json(response)
|
||||
response = clean_json(raw_response.text)
|
||||
urls = json.loads(response)
|
||||
valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
|
||||
if is_none_or_empty(valid_unique_urls):
|
||||
@@ -571,7 +573,7 @@ async def generate_online_subqueries(
|
||||
queries: List[str] = Field(..., min_items=1, max_items=max_queries)
|
||||
|
||||
with timer("Chat actor: Generate online search subqueries", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
@@ -584,7 +586,7 @@ async def generate_online_subqueries(
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
try:
|
||||
response = clean_json(response)
|
||||
response = clean_json(raw_response.text)
|
||||
response = pyjson5.loads(response)
|
||||
response = {q.strip() for q in response["queries"] if q.strip()}
|
||||
if not isinstance(response, set) or not response or len(response) == 0:
|
||||
@@ -645,7 +647,7 @@ async def aschedule_query(
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
try:
|
||||
raw_response = raw_response.strip()
|
||||
raw_response = raw_response.text.strip()
|
||||
response: Dict[str, str] = json.loads(clean_json(raw_response))
|
||||
if not response or not isinstance(response, Dict) or len(response) != 3:
|
||||
raise AssertionError(f"Invalid response for scheduling query : {response}")
|
||||
@@ -683,7 +685,7 @@ async def extract_relevant_info(
|
||||
agent_chat_model=agent_chat_model,
|
||||
tracer=tracer,
|
||||
)
|
||||
return response.strip()
|
||||
return response.text.strip()
|
||||
|
||||
|
||||
async def extract_relevant_summary(
|
||||
@@ -726,7 +728,7 @@ async def extract_relevant_summary(
|
||||
agent_chat_model=agent_chat_model,
|
||||
tracer=tracer,
|
||||
)
|
||||
return response.strip()
|
||||
return response.text.strip()
|
||||
|
||||
|
||||
async def generate_summary_from_files(
|
||||
@@ -897,7 +899,7 @@ async def generate_better_diagram_description(
|
||||
agent_chat_model=agent_chat_model,
|
||||
tracer=tracer,
|
||||
)
|
||||
response = response.strip()
|
||||
response = response.text.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
response = response[1:-1]
|
||||
|
||||
@@ -925,10 +927,10 @@ async def generate_excalidraw_diagram_from_description(
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
|
||||
)
|
||||
raw_response = clean_json(raw_response)
|
||||
raw_response_text = clean_json(raw_response.text)
|
||||
try:
|
||||
# Expect response to have `elements` and `scratchpad` keys
|
||||
response: Dict[str, str] = json.loads(raw_response)
|
||||
response: Dict[str, str] = json.loads(raw_response_text)
|
||||
if (
|
||||
not response
|
||||
or not isinstance(response, Dict)
|
||||
@@ -937,7 +939,7 @@ async def generate_excalidraw_diagram_from_description(
|
||||
):
|
||||
raise AssertionError(f"Invalid response for generating Excalidraw diagram: {response}")
|
||||
except Exception:
|
||||
raise AssertionError(f"Invalid response for generating Excalidraw diagram: {raw_response}")
|
||||
raise AssertionError(f"Invalid response for generating Excalidraw diagram: {raw_response_text}")
|
||||
if not response or not isinstance(response["elements"], List) or not isinstance(response["elements"][0], Dict):
|
||||
# TODO Some additional validation here that it's a valid Excalidraw diagram
|
||||
raise AssertionError(f"Invalid response for improving diagram description: {response}")
|
||||
@@ -1048,11 +1050,11 @@ async def generate_better_mermaidjs_diagram_description(
|
||||
agent_chat_model=agent_chat_model,
|
||||
tracer=tracer,
|
||||
)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
response = response[1:-1]
|
||||
response_text = response.text.strip()
|
||||
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
|
||||
response_text = response_text[1:-1]
|
||||
|
||||
return response
|
||||
return response_text
|
||||
|
||||
|
||||
async def generate_mermaidjs_diagram_from_description(
|
||||
@@ -1076,7 +1078,7 @@ async def generate_mermaidjs_diagram_from_description(
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
|
||||
)
|
||||
return clean_mermaidjs(raw_response.strip())
|
||||
return clean_mermaidjs(raw_response.text.strip())
|
||||
|
||||
|
||||
async def generate_better_image_prompt(
|
||||
@@ -1151,11 +1153,11 @@ async def generate_better_image_prompt(
|
||||
agent_chat_model=agent_chat_model,
|
||||
tracer=tracer,
|
||||
)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
response = response[1:-1]
|
||||
response_text = response.text.strip()
|
||||
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
|
||||
response_text = response_text[1:-1]
|
||||
|
||||
return response
|
||||
return response_text
|
||||
|
||||
|
||||
async def search_documents(
|
||||
@@ -1329,7 +1331,7 @@ async def extract_questions(
|
||||
|
||||
# Extract questions from the response
|
||||
try:
|
||||
response = clean_json(raw_response)
|
||||
response = clean_json(raw_response.text)
|
||||
response = pyjson5.loads(response)
|
||||
queries = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(queries, list) or not queries:
|
||||
@@ -1439,6 +1441,7 @@ async def send_message_to_model_wrapper(
|
||||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
response_schema: BaseModel = None,
|
||||
tools: List[ToolDefinition] = None,
|
||||
deepthought: bool = False,
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
@@ -1506,6 +1509,7 @@ async def send_message_to_model_wrapper(
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tools=tools,
|
||||
deepthought=deepthought,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
@@ -1517,6 +1521,7 @@ async def send_message_to_model_wrapper(
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tools=tools,
|
||||
deepthought=deepthought,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
@@ -1528,6 +1533,7 @@ async def send_message_to_model_wrapper(
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
tools=tools,
|
||||
deepthought=deepthought,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
@@ -2796,3 +2802,264 @@ def get_notion_auth_url(user: KhojUser):
|
||||
if not NOTION_OAUTH_CLIENT_ID or not NOTION_OAUTH_CLIENT_SECRET or not NOTION_REDIRECT_URI:
|
||||
return None
|
||||
return f"https://api.notion.com/v1/oauth/authorize?client_id={NOTION_OAUTH_CLIENT_ID}&redirect_uri={NOTION_REDIRECT_URI}&response_type=code&state={user.uuid}"
|
||||
|
||||
|
||||
async def view_file_content(
|
||||
path: str,
|
||||
start_line: Optional[int] = None,
|
||||
end_line: Optional[int] = None,
|
||||
user: KhojUser = None,
|
||||
):
|
||||
"""
|
||||
View the contents of a file from the user's document database with optional line range specification.
|
||||
"""
|
||||
query = f"View file: {path}"
|
||||
if start_line and end_line:
|
||||
query += f" (lines {start_line}-{end_line})"
|
||||
|
||||
try:
|
||||
# Get the file object from the database by name
|
||||
file_objects = await FileObjectAdapters.aget_file_objects_by_name(user, path)
|
||||
|
||||
if not file_objects:
|
||||
error_msg = f"File '{path}' not found in user documents"
|
||||
logger.warning(error_msg)
|
||||
yield [{"query": query, "file": path, "compiled": error_msg}]
|
||||
return
|
||||
|
||||
# Use the first file object if multiple exist
|
||||
file_object = file_objects[0]
|
||||
raw_text = file_object.raw_text
|
||||
|
||||
# Apply line range filtering if specified
|
||||
if start_line is None and end_line is None:
|
||||
filtered_text = raw_text
|
||||
else:
|
||||
lines = raw_text.split("\n")
|
||||
start_line = start_line or 1
|
||||
end_line = end_line or len(lines)
|
||||
|
||||
# Validate line range
|
||||
if start_line < 1 or end_line < 1 or start_line > end_line:
|
||||
error_msg = f"Invalid line range: {start_line}-{end_line}"
|
||||
logger.warning(error_msg)
|
||||
yield [{"query": query, "file": path, "compiled": error_msg}]
|
||||
return
|
||||
if start_line > len(lines):
|
||||
error_msg = f"Start line {start_line} exceeds total number of lines {len(lines)}"
|
||||
logger.warning(error_msg)
|
||||
yield [{"query": query, "file": path, "compiled": error_msg}]
|
||||
return
|
||||
|
||||
# Convert from 1-based to 0-based indexing and ensure bounds
|
||||
start_idx = max(0, start_line - 1)
|
||||
end_idx = min(len(lines), end_line)
|
||||
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
filtered_text = "\n".join(selected_lines)
|
||||
|
||||
# Truncate the text if it's too long
|
||||
if len(filtered_text) > 10000:
|
||||
filtered_text = filtered_text[:10000] + "\n\n[Truncated. Use line numbers to view specific sections.]"
|
||||
|
||||
# Format the result as a document reference
|
||||
document_results = [
|
||||
{
|
||||
"query": query,
|
||||
"file": path,
|
||||
"compiled": filtered_text,
|
||||
}
|
||||
]
|
||||
|
||||
yield document_results
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error viewing file {path}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
|
||||
# Return an error result in the expected format
|
||||
yield [{"query": query, "file": path, "compiled": error_msg}]
|
||||
|
||||
|
||||
async def grep_files(
|
||||
regex_pattern: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
lines_before: Optional[int] = None,
|
||||
lines_after: Optional[int] = None,
|
||||
user: KhojUser = None,
|
||||
):
|
||||
"""
|
||||
Search for a regex pattern in files with an optional path prefix and context lines.
|
||||
"""
|
||||
|
||||
# Construct the query string based on provided parameters
|
||||
def _generate_query(line_count, doc_count, path, pattern, lines_before, lines_after, max_results=1000):
|
||||
query = f"**Found {line_count} matches for '{pattern}' in {doc_count} documents**"
|
||||
if path:
|
||||
query += f" in {path}"
|
||||
if lines_before or lines_after or line_count > max_results:
|
||||
query += " Showing"
|
||||
if lines_before or lines_after:
|
||||
context_info = []
|
||||
if lines_before:
|
||||
context_info.append(f"{lines_before} lines before")
|
||||
if lines_after:
|
||||
context_info.append(f"{lines_after} lines after")
|
||||
query += f" {' and '.join(context_info)}"
|
||||
if line_count > max_results:
|
||||
if lines_before or lines_after:
|
||||
query += f" for"
|
||||
query += f" first {max_results} results"
|
||||
return query
|
||||
|
||||
# Validate regex pattern
|
||||
path_prefix = path_prefix or ""
|
||||
lines_before = lines_before or 0
|
||||
lines_after = lines_after or 0
|
||||
|
||||
try:
|
||||
regex = re.compile(regex_pattern, re.IGNORECASE)
|
||||
except re.error as e:
|
||||
yield {
|
||||
"query": _generate_query(0, 0, path_prefix, regex_pattern, lines_before, lines_after),
|
||||
"file": path_prefix,
|
||||
"compiled": f"Invalid regex pattern: {e}",
|
||||
}
|
||||
return
|
||||
|
||||
try:
|
||||
file_matches = await FileObjectAdapters.aget_file_objects_by_regex(user, regex_pattern, path_prefix)
|
||||
|
||||
line_matches = []
|
||||
for file_object in file_matches:
|
||||
lines = file_object.raw_text.split("\n")
|
||||
matched_line_numbers = []
|
||||
|
||||
# Find all matching line numbers first
|
||||
for i, line in enumerate(lines, 1):
|
||||
if regex.search(line):
|
||||
matched_line_numbers.append(i)
|
||||
|
||||
# Build context for each match
|
||||
for line_num in matched_line_numbers:
|
||||
context_lines = []
|
||||
|
||||
# Calculate start and end indices for context (0-based)
|
||||
start_idx = max(0, line_num - 1 - lines_before)
|
||||
end_idx = min(len(lines), line_num + lines_after)
|
||||
|
||||
# Add context lines with line numbers
|
||||
for idx in range(start_idx, end_idx):
|
||||
current_line_num = idx + 1
|
||||
line_content = lines[idx]
|
||||
|
||||
if current_line_num == line_num:
|
||||
# This is the matching line, mark it
|
||||
context_lines.append(f"{file_object.file_name}:{current_line_num}:> {line_content}")
|
||||
else:
|
||||
# This is a context line
|
||||
context_lines.append(f"{file_object.file_name}:{current_line_num}: {line_content}")
|
||||
|
||||
# Add separator between matches if showing context
|
||||
if lines_before > 0 or lines_after > 0:
|
||||
context_lines.append("--")
|
||||
|
||||
line_matches.extend(context_lines)
|
||||
|
||||
# Remove the last separator if it exists
|
||||
if line_matches and line_matches[-1] == "--":
|
||||
line_matches.pop()
|
||||
|
||||
# Check if no results found
|
||||
max_results = 1000
|
||||
query = _generate_query(
|
||||
len([m for m in line_matches if ":>" in m]),
|
||||
len(file_matches),
|
||||
path_prefix,
|
||||
regex_pattern,
|
||||
lines_before,
|
||||
lines_after,
|
||||
max_results,
|
||||
)
|
||||
if not line_matches:
|
||||
yield {"query": query, "file": path_prefix, "compiled": "No matches found."}
|
||||
return
|
||||
|
||||
# Truncate matched lines list if too long
|
||||
if len(line_matches) > max_results:
|
||||
line_matches = line_matches[:max_results] + [
|
||||
f"... {len(line_matches) - max_results} more results found. Use stricter regex or path to narrow down results."
|
||||
]
|
||||
|
||||
yield {"query": query, "file": path_prefix or "", "compiled": "\n".join(line_matches)}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error using grep files tool: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
yield [
|
||||
{
|
||||
"query": _generate_query(0, 0, path_prefix or "", regex_pattern, lines_before, lines_after),
|
||||
"file": path_prefix,
|
||||
"compiled": error_msg,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
async def list_files(
|
||||
path: Optional[str] = None,
|
||||
pattern: Optional[str] = None,
|
||||
user: KhojUser = None,
|
||||
):
|
||||
"""
|
||||
List files under a given path or glob pattern from the user's document database.
|
||||
"""
|
||||
|
||||
# Construct the query string based on provided parameters
|
||||
def _generate_query(doc_count, path, pattern):
|
||||
query = f"**Found {doc_count} files**"
|
||||
if path:
|
||||
query += f" in {path}"
|
||||
if pattern:
|
||||
query += f" filtered by {pattern}"
|
||||
return query
|
||||
|
||||
try:
|
||||
# Get user files by path prefix when specified
|
||||
path = path or ""
|
||||
if path in ["", "/", ".", "./", "~", "~/"]:
|
||||
file_objects = await FileObjectAdapters.aget_all_file_objects(user, limit=10000)
|
||||
else:
|
||||
file_objects = await FileObjectAdapters.aget_file_objects_by_path_prefix(user, path)
|
||||
|
||||
if not file_objects:
|
||||
yield {"query": _generate_query(0, path, pattern), "file": path, "compiled": "No files found."}
|
||||
return
|
||||
|
||||
# Extract file names from file objects
|
||||
files = [f.file_name for f in file_objects]
|
||||
# Convert to relative file path (similar to ls)
|
||||
if path:
|
||||
files = [f[len(path) :] for f in files]
|
||||
|
||||
# Apply glob pattern filtering if specified
|
||||
if pattern:
|
||||
files = [f for f in files if fnmatch.fnmatch(f, pattern)]
|
||||
|
||||
query = _generate_query(len(files), path, pattern)
|
||||
if not files:
|
||||
yield {"query": query, "file": path, "compiled": "No files found."}
|
||||
return
|
||||
|
||||
# Truncate the list if it's too long
|
||||
max_files = 100
|
||||
if len(files) > max_files:
|
||||
files = files[:max_files] + [
|
||||
f"... {len(files) - max_files} more files found. Use glob pattern to narrow down results."
|
||||
]
|
||||
|
||||
yield {"query": query, "file": path, "compiled": "\n- ".join(files)}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error listing files in {path}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
yield {"query": query, "file": path, "compiled": error_msg}
|
||||
|
||||
+169
-155
@@ -3,11 +3,9 @@ import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, List, Optional, Type
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from khoj.database.adapters import AgentAdapters, EntryAdapters
|
||||
from khoj.database.models import Agent, ChatMessageModel, KhojUser
|
||||
@@ -15,25 +13,31 @@ from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
ResearchIteration,
|
||||
ToolCall,
|
||||
construct_iteration_history,
|
||||
construct_tool_chat_history,
|
||||
load_complex_json,
|
||||
)
|
||||
from khoj.processor.operator import operate_environment
|
||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
from khoj.processor.tools.online_search import read_webpages_content, search_online
|
||||
from khoj.processor.tools.run_code import run_code
|
||||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
generate_summary_from_files,
|
||||
grep_files,
|
||||
list_files,
|
||||
search_documents,
|
||||
send_message_to_model_wrapper,
|
||||
view_file_content,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
ToolDefinition,
|
||||
dict_to_tuple,
|
||||
is_none_or_empty,
|
||||
is_operator_enabled,
|
||||
timer,
|
||||
tool_description_for_research_llm,
|
||||
tools_for_research_llm,
|
||||
truncate_code_context,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
@@ -41,47 +45,6 @@ from khoj.utils.rawconfig import LocationData
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlanningResponse(BaseModel):
|
||||
"""
|
||||
Schema for the response from planning agent when deciding the next tool to pick.
|
||||
"""
|
||||
|
||||
scratchpad: str = Field(..., description="Scratchpad to reason about which tool to use next")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def create_model_with_enum(cls: Type["PlanningResponse"], tool_options: dict) -> Type["PlanningResponse"]:
|
||||
"""
|
||||
Factory method that creates a customized PlanningResponse model
|
||||
with a properly typed tool field based on available tools.
|
||||
|
||||
The tool field is dynamically generated based on available tools.
|
||||
The query field should be filled by the model after the tool field for a more logical reasoning flow.
|
||||
|
||||
Args:
|
||||
tool_options: Dictionary mapping tool names to values
|
||||
|
||||
Returns:
|
||||
A customized PlanningResponse class
|
||||
"""
|
||||
# Create dynamic enum from tool options
|
||||
tool_enum = Enum("ToolEnum", tool_options) # type: ignore
|
||||
|
||||
# Create and return a customized response model with the enum
|
||||
class PlanningResponseWithTool(PlanningResponse):
|
||||
"""
|
||||
Use the scratchpad to reason about which tool to use next and the query to send to the tool.
|
||||
Pick tool from provided options and your query to send to the tool.
|
||||
"""
|
||||
|
||||
tool: tool_enum = Field(..., description="Name of the tool to use")
|
||||
query: str = Field(..., description="Detailed query for the selected tool")
|
||||
|
||||
return PlanningResponseWithTool
|
||||
|
||||
|
||||
async def apick_next_tool(
|
||||
query: str,
|
||||
conversation_history: List[ChatMessageModel],
|
||||
@@ -104,12 +67,13 @@ async def apick_next_tool(
|
||||
# Continue with previous iteration if a multi-step tool use is in progress
|
||||
if (
|
||||
previous_iterations
|
||||
and previous_iterations[-1].tool == ConversationCommand.Operator
|
||||
and previous_iterations[-1].query
|
||||
and isinstance(previous_iterations[-1].query, ToolCall)
|
||||
and previous_iterations[-1].query.name == ConversationCommand.Operator
|
||||
and not previous_iterations[-1].summarizedResult
|
||||
):
|
||||
previous_iteration = previous_iterations[-1]
|
||||
yield ResearchIteration(
|
||||
tool=previous_iteration.tool,
|
||||
query=query,
|
||||
context=previous_iteration.context,
|
||||
onlineContext=previous_iteration.onlineContext,
|
||||
@@ -120,30 +84,40 @@ async def apick_next_tool(
|
||||
return
|
||||
|
||||
# Construct tool options for the agent to choose from
|
||||
tool_options = dict()
|
||||
tools = []
|
||||
tool_options_str = ""
|
||||
agent_tools = agent.input_tools if agent else []
|
||||
user_has_entries = await EntryAdapters.auser_has_entries(user)
|
||||
for tool, description in tool_description_for_research_llm.items():
|
||||
for tool, tool_data in tools_for_research_llm.items():
|
||||
# Skip showing operator tool as an option if not enabled
|
||||
if tool == ConversationCommand.Operator and not is_operator_enabled():
|
||||
continue
|
||||
# Skip showing Notes tool as an option if user has no entries
|
||||
if tool == ConversationCommand.Notes:
|
||||
if not user_has_entries:
|
||||
continue
|
||||
description = description.format(max_search_queries=max_document_searches)
|
||||
if tool == ConversationCommand.Webpage:
|
||||
description = description.format(max_webpages_to_read=max_webpages_to_read)
|
||||
if tool == ConversationCommand.Online:
|
||||
description = description.format(max_search_queries=max_online_searches)
|
||||
# Skip showing document related tools if user has no documents
|
||||
if (
|
||||
tool == ConversationCommand.SemanticSearchFiles
|
||||
or tool == ConversationCommand.RegexSearchFiles
|
||||
or tool == ConversationCommand.ViewFile
|
||||
or tool == ConversationCommand.ListFiles
|
||||
) and not user_has_entries:
|
||||
continue
|
||||
if tool == ConversationCommand.SemanticSearchFiles:
|
||||
description = tool_data.description.format(max_search_queries=max_document_searches)
|
||||
elif tool == ConversationCommand.Webpage:
|
||||
description = tool_data.description.format(max_webpages_to_read=max_webpages_to_read)
|
||||
elif tool == ConversationCommand.Online:
|
||||
description = tool_data.description.format(max_search_queries=max_online_searches)
|
||||
else:
|
||||
description = tool_data.description
|
||||
# Add tool if agent does not have any tools defined or the tool is supported by the agent.
|
||||
if len(agent_tools) == 0 or tool.value in agent_tools:
|
||||
tool_options[tool.name] = tool.value
|
||||
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
||||
|
||||
# Create planning reponse model with dynamically populated tool enum class
|
||||
planning_response_model = PlanningResponse.create_model_with_enum(tool_options)
|
||||
tools.append(
|
||||
ToolDefinition(
|
||||
name=tool.value,
|
||||
description=description,
|
||||
schema=tool_data.schema,
|
||||
)
|
||||
)
|
||||
|
||||
today = datetime.today()
|
||||
location_data = f"{location}" if location else "Unknown"
|
||||
@@ -162,24 +136,17 @@ async def apick_next_tool(
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
if query_images:
|
||||
query = f"[placeholder for user attached images]\n{query}"
|
||||
|
||||
# Construct chat history with user and iteration history with researcher agent for context
|
||||
iteration_chat_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query)
|
||||
iteration_chat_history = construct_iteration_history(previous_iterations, query, query_images, query_files)
|
||||
chat_and_research_history = conversation_history + iteration_chat_history
|
||||
|
||||
# Plan function execution for the next tool
|
||||
query = prompts.plan_function_execution_next_tool.format(query=query) if iteration_chat_history else query
|
||||
|
||||
try:
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
query=query,
|
||||
query="",
|
||||
system_message=function_planning_prompt,
|
||||
chat_history=chat_and_research_history,
|
||||
response_type="json_object",
|
||||
response_schema=planning_response_model,
|
||||
tools=tools,
|
||||
deepthought=True,
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
@@ -190,48 +157,38 @@ async def apick_next_tool(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True)
|
||||
yield ResearchIteration(
|
||||
tool=None,
|
||||
query=None,
|
||||
warning="Failed to infer information sources to refer. Skipping iteration. Try again.",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
response = load_complex_json(response)
|
||||
if not isinstance(response, dict):
|
||||
raise ValueError(f"Expected dict response, got {type(response).__name__}: {response}")
|
||||
selected_tool = response.get("tool", None)
|
||||
generated_query = response.get("query", None)
|
||||
scratchpad = response.get("scratchpad", None)
|
||||
warning = None
|
||||
logger.info(f"Response for determining relevant tools: {response}")
|
||||
|
||||
# Detect selection of previously used query, tool combination.
|
||||
previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations if i.warning is None}
|
||||
if (selected_tool, generated_query) in previous_tool_query_combinations:
|
||||
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
|
||||
# Only send client status updates if we'll execute this iteration
|
||||
elif send_status_func:
|
||||
determined_tool_message = "**Determined Tool**: "
|
||||
determined_tool_message += (
|
||||
f"{selected_tool}({generated_query})." if selected_tool != ConversationCommand.Text else "respond."
|
||||
)
|
||||
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
|
||||
async for event in send_status_func(f"{scratchpad}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
yield ResearchIteration(
|
||||
tool=selected_tool,
|
||||
query=generated_query,
|
||||
warning=warning,
|
||||
)
|
||||
# Try parse the response as function call response to infer next tool to use.
|
||||
# TODO: Handle multiple tool calls.
|
||||
response_text = response.text
|
||||
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0]
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
|
||||
yield ResearchIteration(
|
||||
tool=None,
|
||||
query=None,
|
||||
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
|
||||
)
|
||||
# Otherwise assume the model has decided to end the research run and respond to the user.
|
||||
parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None)
|
||||
|
||||
# If we have a valid response, extract the tool and query.
|
||||
warning = None
|
||||
logger.info(f"Response for determining relevant tools: {parsed_response.name}({parsed_response.args})")
|
||||
|
||||
# Detect selection of previously used query, tool combination.
|
||||
previous_tool_query_combinations = {
|
||||
(i.query.name, dict_to_tuple(i.query.args))
|
||||
for i in previous_iterations
|
||||
if i.warning is None and isinstance(i.query, ToolCall)
|
||||
}
|
||||
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
|
||||
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
|
||||
# Only send client status updates if we'll execute this iteration and model has thoughts to share.
|
||||
elif send_status_func and not is_none_or_empty(response.thought):
|
||||
async for event in send_status_func(response.thought):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
yield ResearchIteration(query=parsed_response, warning=warning, raw_response=response.raw_content)
|
||||
|
||||
|
||||
async def research(
|
||||
@@ -257,10 +214,10 @@ async def research(
|
||||
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
|
||||
|
||||
# Incorporate previous partial research into current research chat history
|
||||
research_conversation_history = deepcopy(conversation_history)
|
||||
research_conversation_history = [chat for chat in deepcopy(conversation_history) if chat.message]
|
||||
if current_iteration := len(previous_iterations) > 0:
|
||||
logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.")
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations)
|
||||
research_conversation_history += previous_iterations_history
|
||||
|
||||
while current_iteration < MAX_ITERATIONS:
|
||||
@@ -273,7 +230,7 @@ async def research(
|
||||
code_results: Dict = dict()
|
||||
document_results: List[Dict[str, str]] = []
|
||||
operator_results: OperatorRun = None
|
||||
this_iteration = ResearchIteration(tool=None, query=query)
|
||||
this_iteration = ResearchIteration(query=query)
|
||||
|
||||
async for result in apick_next_tool(
|
||||
query,
|
||||
@@ -303,26 +260,30 @@ async def research(
|
||||
logger.warning(f"Research mode: {this_iteration.warning}.")
|
||||
|
||||
# Terminate research if selected text tool or query, tool not set for next iteration
|
||||
elif not this_iteration.query or not this_iteration.tool or this_iteration.tool == ConversationCommand.Text:
|
||||
elif (
|
||||
not this_iteration.query
|
||||
or isinstance(this_iteration.query, str)
|
||||
or this_iteration.query.name == ConversationCommand.Text
|
||||
):
|
||||
current_iteration = MAX_ITERATIONS
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Notes:
|
||||
elif this_iteration.query.name == ConversationCommand.SemanticSearchFiles:
|
||||
this_iteration.context = []
|
||||
document_results = []
|
||||
previous_inferred_queries = {
|
||||
c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context
|
||||
}
|
||||
async for result in search_documents(
|
||||
this_iteration.query,
|
||||
max_document_searches,
|
||||
None,
|
||||
user,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
|
||||
conversation_id,
|
||||
[ConversationCommand.Default],
|
||||
location,
|
||||
send_status_func,
|
||||
query_images,
|
||||
**this_iteration.query.args,
|
||||
n=max_document_searches,
|
||||
d=None,
|
||||
user=user,
|
||||
chat_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SemanticSearchFiles),
|
||||
conversation_id=conversation_id,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
location_data=location,
|
||||
send_status_func=send_status_func,
|
||||
query_images=query_images,
|
||||
previous_inferred_queries=previous_inferred_queries,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
@@ -350,7 +311,7 @@ async def research(
|
||||
else:
|
||||
this_iteration.warning = "No matching document references found"
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Online:
|
||||
elif this_iteration.query.name == ConversationCommand.SearchWeb:
|
||||
previous_subqueries = {
|
||||
subquery
|
||||
for iteration in previous_iterations
|
||||
@@ -359,12 +320,12 @@ async def research(
|
||||
}
|
||||
try:
|
||||
async for result in search_online(
|
||||
this_iteration.query,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
|
||||
location,
|
||||
user,
|
||||
send_status_func,
|
||||
[],
|
||||
**this_iteration.query.args,
|
||||
conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
|
||||
location=location,
|
||||
user=user,
|
||||
send_status_func=send_status_func,
|
||||
custom_filters=[],
|
||||
max_online_searches=max_online_searches,
|
||||
max_webpages_to_read=0,
|
||||
query_images=query_images,
|
||||
@@ -383,19 +344,15 @@ async def research(
|
||||
this_iteration.warning = f"Error searching online: {e}"
|
||||
logger.error(this_iteration.warning, exc_info=True)
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Webpage:
|
||||
elif this_iteration.query.name == ConversationCommand.ReadWebpage:
|
||||
try:
|
||||
async for result in read_webpages(
|
||||
this_iteration.query,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage),
|
||||
location,
|
||||
user,
|
||||
send_status_func,
|
||||
max_webpages_to_read=max_webpages_to_read,
|
||||
query_images=query_images,
|
||||
async for result in read_webpages_content(
|
||||
**this_iteration.query.args,
|
||||
user=user,
|
||||
send_status_func=send_status_func,
|
||||
# max_webpages_to_read=max_webpages_to_read,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
query_files=query_files,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
@@ -416,15 +373,15 @@ async def research(
|
||||
this_iteration.warning = f"Error reading webpages: {e}"
|
||||
logger.error(this_iteration.warning, exc_info=True)
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Code:
|
||||
elif this_iteration.query.name == ConversationCommand.RunCode:
|
||||
try:
|
||||
async for result in run_code(
|
||||
this_iteration.query,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Code),
|
||||
"",
|
||||
location,
|
||||
user,
|
||||
send_status_func,
|
||||
**this_iteration.query.args,
|
||||
conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.Code),
|
||||
context="",
|
||||
location_data=location,
|
||||
user=user,
|
||||
send_status_func=send_status_func,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
@@ -441,14 +398,14 @@ async def research(
|
||||
this_iteration.warning = f"Error running code: {e}"
|
||||
logger.warning(this_iteration.warning, exc_info=True)
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Operator:
|
||||
elif this_iteration.query.name == ConversationCommand.OperateComputer:
|
||||
try:
|
||||
async for result in operate_environment(
|
||||
this_iteration.query,
|
||||
user,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
|
||||
location,
|
||||
previous_iterations[-1].operatorContext if previous_iterations else None,
|
||||
**this_iteration.query.args,
|
||||
user=user,
|
||||
conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
|
||||
location_data=location,
|
||||
previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
|
||||
send_status_func=send_status_func,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
@@ -474,6 +431,63 @@ async def research(
|
||||
this_iteration.warning = f"Error operating browser: {e}"
|
||||
logger.error(this_iteration.warning, exc_info=True)
|
||||
|
||||
elif this_iteration.query.name == ConversationCommand.ViewFile:
|
||||
try:
|
||||
async for result in view_file_content(
|
||||
**this_iteration.query.args,
|
||||
user=user,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
if this_iteration.context is None:
|
||||
this_iteration.context = []
|
||||
document_results: List[Dict[str, str]] = result # type: ignore
|
||||
this_iteration.context += document_results
|
||||
async for result in send_status_func(f"**Viewed file**: {this_iteration.query.args['path']}"):
|
||||
yield result
|
||||
except Exception as e:
|
||||
this_iteration.warning = f"Error viewing file: {e}"
|
||||
logger.error(this_iteration.warning, exc_info=True)
|
||||
|
||||
elif this_iteration.query.name == ConversationCommand.ListFiles:
|
||||
try:
|
||||
async for result in list_files(
|
||||
**this_iteration.query.args,
|
||||
user=user,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
if this_iteration.context is None:
|
||||
this_iteration.context = []
|
||||
document_results: List[Dict[str, str]] = [result] # type: ignore
|
||||
this_iteration.context += document_results
|
||||
async for result in send_status_func(result["query"]):
|
||||
yield result
|
||||
except Exception as e:
|
||||
this_iteration.warning = f"Error listing files: {e}"
|
||||
logger.error(this_iteration.warning, exc_info=True)
|
||||
|
||||
elif this_iteration.query.name == ConversationCommand.RegexSearchFiles:
|
||||
try:
|
||||
async for result in grep_files(
|
||||
**this_iteration.query.args,
|
||||
user=user,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
if this_iteration.context is None:
|
||||
this_iteration.context = []
|
||||
document_results: List[Dict[str, str]] = [result] # type: ignore
|
||||
this_iteration.context += document_results
|
||||
async for result in send_status_func(result["query"]):
|
||||
yield result
|
||||
except Exception as e:
|
||||
this_iteration.warning = f"Error searching with regex: {e}"
|
||||
logger.error(this_iteration.warning, exc_info=True)
|
||||
|
||||
else:
|
||||
# No valid tools. This is our exit condition.
|
||||
current_iteration = MAX_ITERATIONS
|
||||
@@ -481,7 +495,7 @@ async def research(
|
||||
current_iteration += 1
|
||||
|
||||
if document_results or online_results or code_results or operator_results or this_iteration.warning:
|
||||
results_data = f"\n<iteration>{current_iteration}\n<tool>{this_iteration.tool}</tool>\n<query>{this_iteration.query}</query>\n<results>"
|
||||
results_data = f"\n<iteration_{current_iteration}_results>"
|
||||
if document_results:
|
||||
results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
|
||||
if online_results:
|
||||
@@ -494,7 +508,7 @@ async def research(
|
||||
)
|
||||
if this_iteration.warning:
|
||||
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
|
||||
results_data += "\n</results>\n</iteration>"
|
||||
results_data += f"\n</results>\n</iteration_{current_iteration}_results>"
|
||||
|
||||
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
|
||||
this_iteration.summarizedResult = results_data
|
||||
|
||||
+284
-8
@@ -12,6 +12,7 @@ import random
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from importlib import import_module
|
||||
@@ -19,8 +20,9 @@ from importlib.metadata import version
|
||||
from itertools import islice
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from time import perf_counter
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Type, Union
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
import anthropic
|
||||
@@ -36,6 +38,7 @@ from google.auth.credentials import Credentials
|
||||
from google.oauth2 import service_account
|
||||
from magika import Magika
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
from pytz import country_names, country_timezones
|
||||
|
||||
from khoj.utils import constants
|
||||
@@ -334,6 +337,85 @@ def is_e2b_code_sandbox_enabled():
|
||||
return not is_none_or_empty(os.getenv("E2B_API_KEY"))
|
||||
|
||||
|
||||
class ToolDefinition:
|
||||
def __init__(self, name: str, description: str, schema: dict):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.schema = schema
|
||||
|
||||
|
||||
def create_tool_definition(
|
||||
schema: Type[BaseModel],
|
||||
name: str = None,
|
||||
description: Optional[str] = None,
|
||||
) -> ToolDefinition:
|
||||
"""
|
||||
Converts a response schema BaseModel class into a normalized tool definition.
|
||||
|
||||
A standard AI provider agnostic tool format to specify tools the model can use.
|
||||
Common logic used across models is kept here. AI provider specific adaptations
|
||||
should be handled in provider code.
|
||||
|
||||
Args:
|
||||
response_schema: The Pydantic BaseModel class to convert.
|
||||
This class defines the response schema for the tool.
|
||||
tool_name: The name for the AI model tool (e.g., "get_weather", "plan_next_step").
|
||||
tool_description: Optional description for the AI model tool.
|
||||
If None, it attempts to use the Pydantic model's docstring.
|
||||
If that's also missing, a fallback description is generated.
|
||||
|
||||
Returns:
|
||||
A normalized tool definition for AI model APIs.
|
||||
"""
|
||||
raw_schema_dict = schema.model_json_schema()
|
||||
|
||||
name = name or schema.__name__.lower()
|
||||
description = description
|
||||
if description is None:
|
||||
docstring = schema.__doc__
|
||||
if docstring:
|
||||
description = dedent(docstring).strip()
|
||||
else:
|
||||
# Fallback description if no explicit one or docstring is provided
|
||||
description = f"Tool named '{name}' accepts specified parameters."
|
||||
|
||||
# Process properties to inline enums and remove $defs dependency
|
||||
processed_properties = {}
|
||||
original_properties = raw_schema_dict.get("properties", {})
|
||||
defs = raw_schema_dict.get("$defs", {})
|
||||
|
||||
for prop_name, prop_schema in original_properties.items():
|
||||
current_prop_schema = deepcopy(prop_schema) # Work on a copy
|
||||
# Check for enums defined directly in the property for simpler direct enum definitions.
|
||||
if "$ref" in current_prop_schema:
|
||||
ref_path = current_prop_schema["$ref"]
|
||||
if ref_path.startswith("#/$defs/"):
|
||||
def_name = ref_path.split("/")[-1]
|
||||
if def_name in defs and "enum" in defs[def_name]:
|
||||
enum_def = defs[def_name]
|
||||
current_prop_schema["enum"] = enum_def["enum"]
|
||||
current_prop_schema["type"] = enum_def.get("type", "string")
|
||||
if "description" not in current_prop_schema and "description" in enum_def:
|
||||
current_prop_schema["description"] = enum_def["description"]
|
||||
del current_prop_schema["$ref"] # Remove the $ref as it's been inlined
|
||||
|
||||
processed_properties[prop_name] = current_prop_schema
|
||||
|
||||
# Generate the compiled schema dictionary for the tool definition.
|
||||
compiled_schema = {
|
||||
"type": "object",
|
||||
"properties": processed_properties,
|
||||
# Generate content in the order in which the schema properties were defined
|
||||
"property_ordering": list(schema.model_fields.keys()),
|
||||
}
|
||||
|
||||
# Include 'required' fields if specified in the Pydantic model
|
||||
if "required" in raw_schema_dict and raw_schema_dict["required"]:
|
||||
compiled_schema["required"] = raw_schema_dict["required"]
|
||||
|
||||
return ToolDefinition(name=name, description=description, schema=compiled_schema)
|
||||
|
||||
|
||||
class ConversationCommand(str, Enum):
|
||||
Default = "default"
|
||||
General = "general"
|
||||
@@ -347,6 +429,14 @@ class ConversationCommand(str, Enum):
|
||||
Diagram = "diagram"
|
||||
Research = "research"
|
||||
Operator = "operator"
|
||||
ViewFile = "view_file"
|
||||
ListFiles = "list_files"
|
||||
RegexSearchFiles = "regex_search_files"
|
||||
SemanticSearchFiles = "semantic_search_files"
|
||||
SearchWeb = "search_web"
|
||||
ReadWebpage = "read_webpage"
|
||||
RunCode = "run_code"
|
||||
OperateComputer = "operate_computer"
|
||||
|
||||
|
||||
command_descriptions = {
|
||||
@@ -360,6 +450,9 @@ command_descriptions = {
|
||||
ConversationCommand.Diagram: "Draw a flowchart, diagram, or any other visual representation best expressed with primitives like lines, rectangles, and text.",
|
||||
ConversationCommand.Research: "Do deep research on a topic. This will take longer than usual, but give a more detailed, comprehensive answer.",
|
||||
ConversationCommand.Operator: "Operate and perform tasks using a computer.",
|
||||
ConversationCommand.ViewFile: "View the contents of a file with optional line range specification.",
|
||||
ConversationCommand.ListFiles: "List files under a given path with optional glob pattern.",
|
||||
ConversationCommand.RegexSearchFiles: "Search for lines in files matching regex pattern with an optional path prefix.",
|
||||
}
|
||||
|
||||
command_descriptions_for_agent = {
|
||||
@@ -385,13 +478,186 @@ tool_descriptions_for_llm = {
|
||||
ConversationCommand.Operator: "To use when you need to operate a computer to complete the task.",
|
||||
}
|
||||
|
||||
tool_description_for_research_llm = {
|
||||
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents. Max {max_search_queries} search queries allowed per iteration.",
|
||||
ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed. Max {max_search_queries} search queries allowed per iteration.",
|
||||
ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share upto {max_webpages_to_read} webpage links and what information to extract from them in your query.",
|
||||
ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description,
|
||||
ConversationCommand.Text: "To respond to the user once you've completed your research and have the required information.",
|
||||
ConversationCommand.Operator: "To operate a computer to complete the task.",
|
||||
tools_for_research_llm = {
|
||||
ConversationCommand.SearchWeb: ToolDefinition(
|
||||
name="search_web",
|
||||
description="To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed. Max {max_search_queries} search queries allowed per iteration.",
|
||||
schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search on the internet.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
ConversationCommand.ReadWebpage: ToolDefinition(
|
||||
name="read_webpage",
|
||||
description="To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share upto {max_webpages_to_read} webpage links and what information to extract from them in your query.",
|
||||
schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"urls": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
},
|
||||
"description": "The webpage URLs to extract information from.",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to extract information from the webpages.",
|
||||
},
|
||||
},
|
||||
"required": ["urls", "query"],
|
||||
},
|
||||
),
|
||||
ConversationCommand.RunCode: ToolDefinition(
|
||||
name="run_code",
|
||||
description=e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description,
|
||||
schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Detailed query and all input data required to generate, execute code in the sandbox.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
ConversationCommand.OperateComputer: ToolDefinition(
|
||||
name="operate_computer",
|
||||
description="To operate a computer to complete the task.",
|
||||
schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The task to perform on the computer.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
ConversationCommand.ViewFile: ToolDefinition(
|
||||
name="view_file",
|
||||
description=dedent(
|
||||
"""
|
||||
To view the contents of specific note or document in the user's personal knowledge base.
|
||||
Especially helpful if the question expects context from the user's notes or documents.
|
||||
It can be used after finding the document path with the document search tool.
|
||||
Optionally specify a line range to view only specific sections of large files.
|
||||
"""
|
||||
).strip(),
|
||||
schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The file path to view (can be absolute or relative).",
|
||||
},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
"description": "Optional starting line number for viewing a specific range (1-indexed).",
|
||||
},
|
||||
"end_line": {
|
||||
"type": "integer",
|
||||
"description": "Optional ending line number for viewing a specific range (1-indexed).",
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
},
|
||||
),
|
||||
ConversationCommand.ListFiles: ToolDefinition(
|
||||
name="list_files",
|
||||
description=dedent(
|
||||
"""
|
||||
To list files in the user's knowledge base.
|
||||
|
||||
Use the path parameter to only show files under the specified path.
|
||||
"""
|
||||
).strip(),
|
||||
schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The directory path to list files from.",
|
||||
},
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Optional glob pattern to filter files (e.g., '*.md').",
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
ConversationCommand.SemanticSearchFiles: ToolDefinition(
|
||||
name="semantic_search_files",
|
||||
description=dedent(
|
||||
"""
|
||||
To have the tool AI semantic search through the user's knowledge base.
|
||||
Helpful to answer questions for which finding some relevant notes or documents can complete the search. Example: "When was Tom born?"
|
||||
This tool AI cannot find all relevant notes or documents, only a subset of them.
|
||||
It is a good starting point to find keywords, discover similar topics or related concepts and some relevant notes or documents.
|
||||
The tool AI can perform a maximum of {max_search_queries} semantic search queries per iteration.
|
||||
"""
|
||||
).strip(),
|
||||
schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"q": {
|
||||
"type": "string",
|
||||
"description": "Your natural language query for the tool to search in the user's knowledge base.",
|
||||
},
|
||||
},
|
||||
"required": ["q"],
|
||||
},
|
||||
),
|
||||
ConversationCommand.RegexSearchFiles: ToolDefinition(
|
||||
name="regex_search_files",
|
||||
description=dedent(
|
||||
"""
|
||||
To search through the user's knowledge base using regex patterns. Returns all lines matching the pattern.
|
||||
Helpful to answer questions for which all relevant notes or documents are needed to complete the search. Example: "Notes that mention Tom".
|
||||
You need to know all the correct keywords or regex patterns for this tool to be useful.
|
||||
|
||||
REMEMBER:
|
||||
- The regex pattern will ONLY match content on a single line. Multi-line matches are NOT supported (even if you use \\n).
|
||||
|
||||
An optional path prefix can restrict search to specific files/directories.
|
||||
Use lines_before, lines_after to show context around matches.
|
||||
"""
|
||||
).strip(),
|
||||
schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"regex_pattern": {
|
||||
"type": "string",
|
||||
"description": "The regex pattern to search for content in the user's files.",
|
||||
},
|
||||
"path_prefix": {
|
||||
"type": "string",
|
||||
"description": "Optional path prefix to limit the search to files under a specified path.",
|
||||
},
|
||||
"lines_before": {
|
||||
"type": "integer",
|
||||
"description": "Optional number of lines to show before each line match for context.",
|
||||
"minimum": 0,
|
||||
"maximum": 20,
|
||||
},
|
||||
"lines_after": {
|
||||
"type": "integer",
|
||||
"description": "Optional number of lines to show after each line match for context.",
|
||||
"minimum": 0,
|
||||
"maximum": 20,
|
||||
},
|
||||
},
|
||||
"required": ["regex_pattern"],
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
@@ -850,3 +1116,13 @@ def clean_object_for_db(data):
|
||||
return [clean_object_for_db(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def dict_to_tuple(d):
|
||||
# Recursively convert dicts to sorted tuples for hashability
|
||||
if isinstance(d, dict):
|
||||
return tuple(sorted((k, dict_to_tuple(v)) for k, v in d.items()))
|
||||
elif isinstance(d, list):
|
||||
return tuple(dict_to_tuple(i) for i in d)
|
||||
else:
|
||||
return d
|
||||
|
||||
@@ -48,17 +48,18 @@ class TestTruncateMessage:
|
||||
big_chat_message = ChatMessage(role="user", content=content_list)
|
||||
copy_big_chat_message = deepcopy(big_chat_message)
|
||||
chat_history = [big_chat_message]
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
|
||||
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
|
||||
def test_truncate_message_with_content_list(self):
|
||||
# Arrange
|
||||
@@ -68,11 +69,11 @@ class TestTruncateMessage:
|
||||
big_chat_message = ChatMessage(role="user", content=content_list)
|
||||
copy_big_chat_message = deepcopy(big_chat_message)
|
||||
chat_history.insert(0, big_chat_message)
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
@@ -83,7 +84,8 @@ class TestTruncateMessage:
|
||||
copy_big_chat_message.content
|
||||
), "message content list should be modified"
|
||||
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
|
||||
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
|
||||
def test_truncate_message_first_large(self):
|
||||
# Arrange
|
||||
@@ -91,11 +93,11 @@ class TestTruncateMessage:
|
||||
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
chat_history.insert(0, big_chat_message)
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
@@ -104,7 +106,8 @@ class TestTruncateMessage:
|
||||
), "Only most recent message should be present as it itself is larger than context size"
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
|
||||
def test_truncate_message_large_system_message_first(self):
|
||||
# Arrange
|
||||
|
||||
@@ -189,7 +189,7 @@ async def test_chat_with_no_chat_history_or_retrieved_content():
|
||||
user_query="Hello, my name is Testatron. Who are you?",
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
||||
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Khoj", "khoj"]
|
||||
@@ -217,7 +217,7 @@ async def test_answer_from_chat_history_and_no_content():
|
||||
chat_history=populate_chat_history(message_list),
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
||||
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Testatron", "testatron"]
|
||||
@@ -250,7 +250,7 @@ async def test_answer_from_chat_history_and_previously_retrieved_content():
|
||||
chat_history=populate_chat_history(message_list),
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
||||
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
@@ -279,7 +279,7 @@ async def test_answer_from_chat_history_and_currently_retrieved_content():
|
||||
chat_history=populate_chat_history(message_list),
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
||||
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
@@ -305,7 +305,7 @@ async def test_refuse_answering_unanswerable_question():
|
||||
chat_history=populate_chat_history(message_list),
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
||||
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
@@ -359,7 +359,7 @@ Expenses:Food:Dining 10.00 USD""",
|
||||
user_query="What did I have for Dinner today?",
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
||||
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["tacos", "Tacos"]
|
||||
@@ -405,7 +405,7 @@ Expenses:Food:Dining 10.00 USD""",
|
||||
user_query="How much did I spend on dining this year?",
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
||||
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
@@ -432,7 +432,7 @@ async def test_answer_general_question_not_in_chat_history_or_retrieved_content(
|
||||
chat_history=populate_chat_history(message_list),
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
||||
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["test", "bug", "code"]
|
||||
@@ -473,7 +473,7 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
||||
user_query="How many kids does my older sister have?",
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk.response async for response_chunk in response_gen])
|
||||
response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
@@ -508,14 +508,14 @@ async def test_agent_prompt_should_be_used(openai_agent):
|
||||
user_query="What did I buy?",
|
||||
api_key=api_key,
|
||||
)
|
||||
no_agent_response = "".join([response_chunk.response async for response_chunk in response_gen])
|
||||
no_agent_response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||
response_gen = converse_openai(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I buy?",
|
||||
api_key=api_key,
|
||||
agent=openai_agent,
|
||||
)
|
||||
agent_response = "".join([response_chunk.response async for response_chunk in response_gen])
|
||||
agent_response = "".join([response_chunk.text async for response_chunk in response_gen])
|
||||
|
||||
# Assert that the model without the agent prompt does not include the summary of purchases
|
||||
assert all([expected_response not in no_agent_response for expected_response in expected_responses]), (
|
||||
|
||||
Reference in New Issue
Block a user