Upgrade Retrieval from KB in Research Mode. Use Function Calling for Tool Use (#1205)

## Why
Move to function calling paradigm to give models tool call -> tool
result in formats they're fine-tuned to understand. Previously we were
giving them results in our specific format (as function calling paradigm
wasn't well-established yet).

And improve prompt cache hits by caching tool definitions.

This is a **breaking change**. AI Models and APIs that do not support
function calling will not work with Khoj in research mode. Function
calling is supported by:
- Standard commercial AI Models and APIs like Anthropic, Gemini, OpenAI,
OpenRouter
- Standard open-source AI APIs like llama.cpp server, Ollama
- Standard open source models like Qwen, DeepSeek, Gemma, Llama, Mistral

## What
### Use Function Calling for Tool Use
- Add Function Calling support to Anthropic, Gemini, OpenAI AI Model
APIs
- Move Existing Research Mode Tools to Use Function Calling

### Get More Comprehensive Results from your Knowledge Base (KB)
- Give Research Agent better Document Retrieval Tools
  - Add grep files tool to enable researcher to find documents by regex
  - Add list files tool to enable researcher to find documents by path
  - Add file viewer tool to enable researcher to read documents

### Miscellaneous
- Improve Research Prompt, Truncation, Retry and Caching
- Show reasoning model thoughts in Khoj train of thought for
intermediate steps as well
This commit is contained in:
Debanjum
2025-07-03 00:14:07 -07:00
committed by GitHub
20 changed files with 1273 additions and 484 deletions
+20
View File
@@ -1716,6 +1716,14 @@ class FileObjectAdapters:
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
@staticmethod
@arequire_valid_user
async def aget_file_objects_by_path_prefix(user: KhojUser, path_prefix: str, agent: Agent = None):
"""Get file objects from the database by path prefix."""
return await sync_to_async(list)(
FileObject.objects.filter(user=user, agent=agent, file_name__startswith=path_prefix)
)
@staticmethod
@arequire_valid_user
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
@@ -1748,6 +1756,18 @@ class FileObjectAdapters:
async def adelete_all_file_objects(user: KhojUser):
return await FileObject.objects.filter(user=user).adelete()
@staticmethod
@arequire_valid_user
async def aget_file_objects_by_regex(user: KhojUser, regex_pattern: str, path_prefix: Optional[str] = None):
"""
Search for a regex pattern in file objects, with an optional path prefix filter.
Outputs results in grep format.
"""
query = FileObject.objects.filter(user=user, agent=None, raw_text__iregex=regex_pattern)
if path_prefix:
query = query.filter(file_name__startswith=path_prefix)
return await sync_to_async(list)(query)
class EntryAdapters:
word_filter = WordFilter()
@@ -22,12 +22,20 @@ logger = logging.getLogger(__name__)
def anthropic_send_message_to_model(
messages, api_key, api_base_url, model, response_type="text", response_schema=None, deepthought=False, tracer={}
messages,
api_key,
api_base_url,
model,
response_type="text",
response_schema=None,
tools=None,
deepthought=False,
tracer={},
):
"""
Send message to model
"""
# Get Response from GPT. Don't use response_type because Anthropic doesn't support it.
# Get response from model. Don't use response_type because Anthropic doesn't support it.
return anthropic_completion_with_backoff(
messages=messages,
system_prompt="",
@@ -36,6 +44,7 @@ def anthropic_send_message_to_model(
api_base_url=api_base_url,
response_type=response_type,
response_schema=response_schema,
tools=tools,
deepthought=deepthought,
tracer=tracer,
)
@@ -1,9 +1,8 @@
import json
import logging
from copy import deepcopy
from textwrap import dedent
from time import perf_counter
from typing import AsyncGenerator, Dict, List, Optional, Type
from typing import AsyncGenerator, Dict, List
import anthropic
from langchain_core.messages.chat import ChatMessage
@@ -18,11 +17,14 @@ from tenacity import (
from khoj.processor.conversation.utils import (
ResponseWithThought,
ToolCall,
commit_conversation_trace,
get_image_from_base64,
get_image_from_url,
)
from khoj.utils.helpers import (
ToolDefinition,
create_tool_definition,
get_anthropic_async_client,
get_anthropic_client,
get_chat_usage_metrics,
@@ -57,9 +59,10 @@ def anthropic_completion_with_backoff(
max_tokens: int | None = None,
response_type: str = "text",
response_schema: BaseModel | None = None,
tools: List[ToolDefinition] = None,
deepthought: bool = False,
tracer: dict = {},
) -> str:
) -> ResponseWithThought:
client = anthropic_clients.get(api_key)
if not client:
client = get_anthropic_client(api_key, api_base_url)
@@ -67,12 +70,26 @@ def anthropic_completion_with_backoff(
formatted_messages, system = format_messages_for_anthropic(messages, system_prompt)
thoughts = ""
aggregated_response = ""
final_message = None
model_kwargs = model_kwargs or dict()
if response_schema:
tool = create_anthropic_tool_definition(response_schema=response_schema)
model_kwargs["tools"] = [tool]
# Configure structured output
if tools:
# Convert tools to Anthropic format
model_kwargs["tools"] = [
anthropic.types.ToolParam(name=tool.name, description=tool.description, input_schema=tool.schema)
for tool in tools
]
# Cache tool definitions
last_tool = model_kwargs["tools"][-1]
last_tool["cache_control"] = {"type": "ephemeral"}
elif response_schema:
tool = create_tool_definition(response_schema)
model_kwargs["tools"] = [
anthropic.types.ToolParam(name=tool.name, description=tool.description, input_schema=tool.schema)
]
elif response_type == "json_object" and not (is_reasoning_model(model_name) and deepthought):
# Prefill model response with '{' to make it output a valid JSON object. Not supported with extended thinking.
formatted_messages.append(anthropic.types.MessageParam(role="assistant", content="{"))
@@ -96,15 +113,41 @@ def anthropic_completion_with_backoff(
max_tokens=max_tokens,
**(model_kwargs),
) as stream:
for text in stream.text_stream:
aggregated_response += text
for chunk in stream:
if chunk.type != "content_block_delta":
continue
if chunk.delta.type == "thinking_delta":
thoughts += chunk.delta.thinking
elif chunk.delta.type == "text_delta":
aggregated_response += chunk.delta.text
final_message = stream.get_final_message()
# Extract first tool call from final message
for item in final_message.content:
if item.type == "tool_use":
aggregated_response = json.dumps(item.input)
break
# Track raw content of model response to reuse for cache hits in multi-turn chats
raw_content = [item.model_dump() for item in final_message.content]
# Extract all tool calls if tools are enabled
if tools:
tool_calls = [
ToolCall(name=item.name, args=item.input, id=item.id).__dict__
for item in final_message.content
if item.type == "tool_use"
]
if tool_calls:
# If there are tool calls, aggregate thoughts and responses into thoughts
if thoughts and aggregated_response:
# wrap each line of thought in italics
thoughts = "\n".join([f"*{line.strip()}*" for line in thoughts.splitlines() if line.strip()])
thoughts = f"{thoughts}\n\n{aggregated_response}"
else:
thoughts = thoughts or aggregated_response
# Json dump tool calls into aggregated response
aggregated_response = json.dumps(tool_calls)
# If response schema is used, return the first tool call's input
elif response_schema:
for item in final_message.content:
if item.type == "tool_use":
aggregated_response = json.dumps(item.input)
break
# Calculate cost of chat
input_tokens = final_message.usage.input_tokens
@@ -126,7 +169,7 @@ def anthropic_completion_with_backoff(
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response
return ResponseWithThought(text=aggregated_response, thought=thoughts, raw_content=raw_content)
@retry(
@@ -183,10 +226,10 @@ async def anthropic_chat_completion_with_backoff(
if chunk.type == "message_delta":
if chunk.delta.stop_reason == "refusal":
yield ResponseWithThought(
response="...I'm sorry, but my safety filters prevent me from assisting with this query."
text="...I'm sorry, but my safety filters prevent me from assisting with this query."
)
elif chunk.delta.stop_reason == "max_tokens":
yield ResponseWithThought(response="...I'm sorry, but I've hit my response length limit.")
yield ResponseWithThought(text="...I'm sorry, but I've hit my response length limit.")
if chunk.delta.stop_reason in ["refusal", "max_tokens"]:
logger.warning(
f"LLM Response Prevented for {model_name}: {chunk.delta.stop_reason}.\n"
@@ -199,7 +242,7 @@ async def anthropic_chat_completion_with_backoff(
# Handle streamed response chunk
response_chunk: ResponseWithThought = None
if chunk.delta.type == "text_delta":
response_chunk = ResponseWithThought(response=chunk.delta.text)
response_chunk = ResponseWithThought(text=chunk.delta.text)
aggregated_response += chunk.delta.text
if chunk.delta.type == "thinking_delta":
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
@@ -232,13 +275,14 @@ async def anthropic_chat_completion_with_backoff(
commit_conversation_trace(messages, aggregated_response, tracer)
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):
def format_messages_for_anthropic(raw_messages: list[ChatMessage], system_prompt: str = None):
"""
Format messages for Anthropic
"""
# Extract system prompt
system_prompt = system_prompt or ""
for message in messages.copy():
messages = deepcopy(raw_messages)
for message in messages:
if message.role == "system":
if isinstance(message.content, list):
system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"])
@@ -250,15 +294,30 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
else:
system = None
# Anthropic requires the first message to be a 'user' message
if len(messages) == 1:
# Anthropic requires the first message to be a user message unless its a tool call
message_type = messages[0].additional_kwargs.get("message_type", None)
if len(messages) == 1 and message_type != "tool_call":
messages[0].role = "user"
elif len(messages) > 1 and messages[0].role == "assistant":
messages = messages[1:]
# Convert image urls to base64 encoded images in Anthropic message format
for message in messages:
if isinstance(message.content, list):
# Handle tool call and tool result message types from additional_kwargs
message_type = message.additional_kwargs.get("message_type")
if message_type == "tool_call":
pass
elif message_type == "tool_result":
# Convert tool_result to Anthropic tool_result format
content = []
for part in message.content:
content.append(
{
"type": "tool_result",
"tool_use_id": part["id"],
"content": part["content"],
}
)
message.content = content
# Convert image urls to base64 encoded images in Anthropic message format
elif isinstance(message.content, list):
content = []
# Sort the content. Anthropic models prefer that text comes after images.
message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1)
@@ -304,18 +363,15 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
if isinstance(block, dict) and "cache_control" in block:
del block["cache_control"]
# Add cache control to the last content block of second to last message.
# In research mode, this message content is list of iterations, updated after each research iteration.
# Caching it should improve research efficiency.
cache_message = messages[-2]
# Add cache control to the last content block of last message.
# Caching should improve research efficiency.
cache_message = messages[-1]
if isinstance(cache_message.content, list) and cache_message.content:
# Add cache control to the last content block only if it's a text block with non-empty content
last_block = cache_message.content[-1]
if (
isinstance(last_block, dict)
and last_block.get("type") == "text"
and last_block.get("text")
and last_block.get("text").strip()
if isinstance(last_block, dict) and (
(last_block.get("type") == "text" and last_block.get("text", "").strip())
or (last_block.get("type") == "tool_result" and last_block.get("content", []))
):
last_block["cache_control"] = {"type": "ephemeral"}
@@ -326,74 +382,5 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
return formatted_messages, system
def create_anthropic_tool_definition(
response_schema: Type[BaseModel],
tool_name: str = None,
tool_description: Optional[str] = None,
) -> anthropic.types.ToolParam:
"""
Converts a response schema BaseModel class into an Anthropic tool definition dictionary.
This format is expected by Anthropic's API when defining tools the model can use.
Args:
response_schema: The Pydantic BaseModel class to convert.
This class defines the response schema for the tool.
tool_name: The name for the Anthropic tool (e.g., "get_weather", "plan_next_step").
tool_description: Optional description for the Anthropic tool.
If None, it attempts to use the Pydantic model's docstring.
If that's also missing, a fallback description is generated.
Returns:
An tool definition for Anthropic's API.
"""
model_schema = response_schema.model_json_schema()
name = tool_name or response_schema.__name__.lower()
description = tool_description
if description is None:
docstring = response_schema.__doc__
if docstring:
description = dedent(docstring).strip()
else:
# Fallback description if no explicit one or docstring is provided
description = f"Tool named '{name}' accepts specified parameters."
# Process properties to inline enums and remove $defs dependency
processed_properties = {}
original_properties = model_schema.get("properties", {})
defs = model_schema.get("$defs", {})
for prop_name, prop_schema in original_properties.items():
current_prop_schema = deepcopy(prop_schema) # Work on a copy
# Check for enums defined directly in the property for simpler direct enum definitions.
if "$ref" in current_prop_schema:
ref_path = current_prop_schema["$ref"]
if ref_path.startswith("#/$defs/"):
def_name = ref_path.split("/")[-1]
if def_name in defs and "enum" in defs[def_name]:
enum_def = defs[def_name]
current_prop_schema["enum"] = enum_def["enum"]
current_prop_schema["type"] = enum_def.get("type", "string")
if "description" not in current_prop_schema and "description" in enum_def:
current_prop_schema["description"] = enum_def["description"]
del current_prop_schema["$ref"] # Remove the $ref as it's been inlined
processed_properties[prop_name] = current_prop_schema
# The input_schema for Anthropic tools is a JSON Schema object.
# Pydantic's model_json_schema() provides most of what's needed.
input_schema = {
"type": "object",
"properties": processed_properties,
}
# Include 'required' fields if specified in the Pydantic model
if "required" in model_schema and model_schema["required"]:
input_schema["required"] = model_schema["required"]
return anthropic.types.ToolParam(name=name, description=description, input_schema=input_schema)
def is_reasoning_model(model_name: str) -> bool:
return any(model_name.startswith(model) for model in REASONING_MODELS)
@@ -28,6 +28,7 @@ def gemini_send_message_to_model(
api_base_url=None,
response_type="text",
response_schema=None,
tools=None,
model_kwargs=None,
deepthought=False,
tracer={},
@@ -37,8 +38,10 @@ def gemini_send_message_to_model(
"""
model_kwargs = {}
if tools:
model_kwargs["tools"] = tools
# Monitor for flakiness in 1.5+ models. This would cause unwanted behavior and terminate response early in 1.5 models.
if response_type == "json_object" and not model.startswith("gemini-1.5"):
elif response_type == "json_object" and not model.startswith("gemini-1.5"):
model_kwargs["response_mime_type"] = "application/json"
if response_schema:
model_kwargs["response_schema"] = response_schema
+80 -18
View File
@@ -1,9 +1,10 @@
import json
import logging
import os
import random
from copy import deepcopy
from time import perf_counter
from typing import AsyncGenerator, AsyncIterator, Dict
from typing import AsyncGenerator, AsyncIterator, Dict, List
import httpx
from google import genai
@@ -22,11 +23,13 @@ from tenacity import (
from khoj.processor.conversation.utils import (
ResponseWithThought,
ToolCall,
commit_conversation_trace,
get_image_from_base64,
get_image_from_url,
)
from khoj.utils.helpers import (
ToolDefinition,
get_chat_usage_metrics,
get_gemini_client,
is_none_or_empty,
@@ -95,26 +98,29 @@ def gemini_completion_with_backoff(
temperature=1.2,
api_key=None,
api_base_url: str = None,
model_kwargs=None,
model_kwargs={},
deepthought=False,
tracer={},
) -> str:
) -> ResponseWithThought:
client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
formatted_messages, system_instruction = format_messages_for_gemini(messages, system_prompt)
response_thoughts: str | None = None
raw_content, response_text, response_thoughts = [], "", None
# format model response schema
# Configure structured output
tools = None
response_schema = None
if model_kwargs and model_kwargs.get("response_schema"):
if model_kwargs.get("tools"):
tools = to_gemini_tools(model_kwargs["tools"])
elif model_kwargs.get("response_schema"):
response_schema = clean_response_schema(model_kwargs["response_schema"])
thinking_config = None
if deepthought and is_reasoning_model(model_name):
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI, include_thoughts=True)
max_output_tokens = MAX_OUTPUT_TOKENS_FOR_STANDARD_GEMINI
if is_reasoning_model(model_name):
@@ -127,8 +133,9 @@ def gemini_completion_with_backoff(
thinking_config=thinking_config,
max_output_tokens=max_output_tokens,
safety_settings=SAFETY_SETTINGS,
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
response_mime_type=model_kwargs.get("response_mime_type", "text/plain"),
response_schema=response_schema,
tools=tools,
seed=seed,
top_p=0.95,
http_options=gtypes.HttpOptions(client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
@@ -137,7 +144,25 @@ def gemini_completion_with_backoff(
try:
# Generate the response
response = client.models.generate_content(model=model_name, config=config, contents=formatted_messages)
response_text = response.text
if (
not response.candidates
or not response.candidates[0].content
or response.candidates[0].content.parts is None
):
raise ValueError(f"Failed to get response from model.")
raw_content = [part.model_dump() for part in response.candidates[0].content.parts]
if response.function_calls:
function_calls = [
ToolCall(name=function_call.name, args=function_call.args, id=function_call.id).__dict__
for function_call in response.function_calls
]
response_text = json.dumps(function_calls)
else:
# If no function calls, use the text response
response_text = response.text
response_thoughts = "\n".join(
[part.text for part in response.candidates[0].content.parts if part.thought and isinstance(part.text, str)]
)
except gerrors.ClientError as e:
response = None
response_text, _ = handle_gemini_response(e.args)
@@ -151,8 +176,14 @@ def gemini_completion_with_backoff(
input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0
output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
cache_read_tokens = response.usage_metadata.cached_content_token_count or 0 if response else 0
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
model_name,
input_tokens,
output_tokens,
cache_read_tokens=cache_read_tokens,
thought_tokens=thought_tokens,
usage=tracer.get("usage"),
)
# Validate the response. If empty, raise an error to retry.
@@ -166,7 +197,7 @@ def gemini_completion_with_backoff(
if is_promptrace_enabled():
commit_conversation_trace(messages, response_text, tracer)
return response_text
return ResponseWithThought(text=response_text, thought=response_thoughts, raw_content=raw_content)
@retry(
@@ -234,7 +265,7 @@ async def gemini_chat_completion_with_backoff(
# handle safety, rate-limit, other finish reasons
stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
if stopped:
yield ResponseWithThought(response=stop_message)
yield ResponseWithThought(text=stop_message)
logger.warning(
f"LLM Response Prevented for {model_name}: {stop_message}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
@@ -247,7 +278,7 @@ async def gemini_chat_completion_with_backoff(
yield ResponseWithThought(thought=part.text)
elif part.text:
aggregated_response += part.text
yield ResponseWithThought(response=part.text)
yield ResponseWithThought(text=part.text)
# Calculate cost of chat
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
@@ -346,8 +377,24 @@ def format_messages_for_gemini(
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
for message in messages:
if message.role == "assistant":
message.role = "model"
# Handle tool call and tool result message types from additional_kwargs
message_type = message.additional_kwargs.get("message_type")
if message_type == "tool_call":
pass
elif message_type == "tool_result":
# Convert tool_result to Gemini function response format
# Need to find the corresponding function call from previous messages
tool_result_msg_content = []
for part in message.content:
tool_result_msg_content.append(
gtypes.Part.from_function_response(name=part["name"], response={"result": part["content"]})
)
message.content = tool_result_msg_content
# Convert message content to string list from chatml dictionary list
if isinstance(message.content, list):
elif isinstance(message.content, list):
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
message_content = []
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1):
@@ -367,16 +414,13 @@ def format_messages_for_gemini(
messages.remove(message)
continue
message.content = message_content
elif isinstance(message.content, str):
elif isinstance(message.content, str) and message.content.strip():
message.content = [gtypes.Part.from_text(text=message.content)]
else:
logger.error(f"Dropping invalid type: {type(message.content)} of message content: {message.content}")
messages.remove(message)
continue
if message.role == "assistant":
message.role = "model"
if len(messages) == 1:
messages[0].role = "user"
@@ -404,3 +448,21 @@ def is_reasoning_model(model_name: str) -> bool:
Check if the model is a reasoning model.
"""
return model_name.startswith("gemini-2.5")
def to_gemini_tools(tools: List[ToolDefinition]) -> List[gtypes.ToolDict] | None:
"Transform tool definitions from standard format to Gemini format."
gemini_tools = [
gtypes.ToolDict(
function_declarations=[
gtypes.FunctionDeclarationDict(
name=tool.name,
description=tool.description,
parameters=tool.schema,
)
for tool in tools
]
)
]
return gemini_tools or None
@@ -145,12 +145,12 @@ async def converse_offline(
aggregated_response += response_delta
# Put chunk into the asyncio queue (non-blocking)
try:
queue.put_nowait(ResponseWithThought(response=response_delta))
queue.put_nowait(ResponseWithThought(text=response_delta))
except asyncio.QueueFull:
# Should not happen with default queue size unless consumer is very slow
logger.warning("Asyncio queue full during offline LLM streaming.")
# Potentially block here or handle differently if needed
asyncio.run(queue.put(ResponseWithThought(response=response_delta)))
asyncio.run(queue.put(ResponseWithThought(text=response_delta)))
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
@@ -221,4 +221,4 @@ def send_message_to_model_offline(
if is_promptrace_enabled():
commit_conversation_trace(messages, response_text, tracer)
return response_text
return ResponseWithThought(text=response_text)
+13 -38
View File
@@ -1,25 +1,24 @@
import logging
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
from openai.lib._pydantic import _ensure_strict_json_schema
from pydantic import BaseModel
from typing import Any, AsyncGenerator, Dict, List, Optional
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
clean_response_schema,
completion_with_backoff,
get_openai_api_json_support,
get_structured_output_support,
to_openai_tools,
)
from khoj.processor.conversation.utils import (
JsonSupport,
OperatorRun,
ResponseWithThought,
StructuredOutputSupport,
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
from khoj.utils.helpers import ToolDefinition, is_none_or_empty, truncate_code_context
from khoj.utils.rawconfig import FileAttachment, LocationData
from khoj.utils.yaml import yaml_dump
@@ -32,6 +31,7 @@ def send_message_to_model(
model,
response_type="text",
response_schema=None,
tools: list[ToolDefinition] = None,
deepthought=False,
api_base_url=None,
tracer: dict = {},
@@ -40,9 +40,11 @@ def send_message_to_model(
Send message to model
"""
model_kwargs = {}
json_support = get_openai_api_json_support(model, api_base_url)
if response_schema and json_support == JsonSupport.SCHEMA:
model_kwargs: Dict[str, Any] = {}
json_support = get_structured_output_support(model, api_base_url)
if tools and json_support == StructuredOutputSupport.TOOL:
model_kwargs["tools"] = to_openai_tools(tools)
elif response_schema and json_support >= StructuredOutputSupport.SCHEMA:
# Drop unsupported fields from schema passed to OpenAI APi
cleaned_response_schema = clean_response_schema(response_schema)
model_kwargs["response_format"] = {
@@ -53,7 +55,7 @@ def send_message_to_model(
"strict": True,
},
}
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
elif response_type == "json_object" and json_support == StructuredOutputSupport.OBJECT:
model_kwargs["response_format"] = {"type": response_type}
# Get Response from GPT
@@ -171,30 +173,3 @@ async def converse_openai(
tracer=tracer,
):
yield chunk
def clean_response_schema(schema: BaseModel | dict) -> dict:
"""
Format response schema to be compatible with OpenAI API.
Clean the response schema by removing unsupported fields.
"""
# Normalize schema to OpenAI compatible JSON schema format
schema_json = schema if isinstance(schema, dict) else schema.model_json_schema()
schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json)
# Recursively drop unsupported fields from schema passed to OpenAI API
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
fields_to_exclude = ["minItems", "maxItems"]
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
for _, prop_value in schema_json["properties"].items():
if isinstance(prop_value, dict):
# Remove specified fields from direct properties
for field in fields_to_exclude:
prop_value.pop(field, None)
# Recursively remove specified fields from child properties
if "items" in prop_value and isinstance(prop_value["items"], dict):
clean_response_schema(prop_value["items"])
# Return cleaned schema
return schema_json
+113 -12
View File
@@ -1,3 +1,4 @@
import json
import logging
import os
from copy import deepcopy
@@ -9,6 +10,7 @@ from urllib.parse import urlparse
import httpx
import openai
from langchain_core.messages.chat import ChatMessage
from openai.lib._pydantic import _ensure_strict_json_schema
from openai.lib.streaming.chat import (
ChatCompletionStream,
ChatCompletionStreamEvent,
@@ -20,6 +22,7 @@ from openai.types.chat.chat_completion_chunk import (
Choice,
ChoiceDelta,
)
from pydantic import BaseModel
from tenacity import (
before_sleep_log,
retry,
@@ -30,11 +33,13 @@ from tenacity import (
)
from khoj.processor.conversation.utils import (
JsonSupport,
ResponseWithThought,
StructuredOutputSupport,
ToolCall,
commit_conversation_trace,
)
from khoj.utils.helpers import (
ToolDefinition,
convert_image_data_uri,
get_chat_usage_metrics,
get_openai_async_client,
@@ -72,7 +77,7 @@ def completion_with_backoff(
deepthought: bool = False,
model_kwargs: dict = {},
tracer: dict = {},
) -> str:
) -> ResponseWithThought:
client_key = f"{openai_api_key}--{api_base_url}"
client = openai_clients.get(client_key)
if not client:
@@ -117,6 +122,9 @@ def completion_with_backoff(
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
tool_ids = []
tool_calls: list[ToolCall] = []
thoughts = ""
aggregated_response = ""
if stream:
with client.beta.chat.completions.stream(
@@ -130,7 +138,16 @@ def completion_with_backoff(
if chunk.type == "content.delta":
aggregated_response += chunk.delta
elif chunk.type == "thought.delta":
pass
thoughts += chunk.delta
elif chunk.type == "chunk" and chunk.chunk.choices and chunk.chunk.choices[0].delta.tool_calls:
tool_ids += [tool_call.id for tool_call in chunk.chunk.choices[0].delta.tool_calls]
elif chunk.type == "tool_calls.function.arguments.done":
tool_calls += [ToolCall(name=chunk.name, args=json.loads(chunk.arguments), id=None)]
if tool_calls:
tool_calls = [
ToolCall(name=chunk.name, args=chunk.args, id=tool_id) for chunk, tool_id in zip(tool_calls, tool_ids)
]
aggregated_response = json.dumps([tool_call.__dict__ for tool_call in tool_calls])
else:
# Non-streaming chat completion
chunk = client.beta.chat.completions.parse(
@@ -164,7 +181,7 @@ def completion_with_backoff(
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response
return ResponseWithThought(text=aggregated_response, thought=thoughts)
@retry(
@@ -190,6 +207,7 @@ async def chat_completion_with_backoff(
deepthought=False,
model_kwargs: dict = {},
tracer: dict = {},
tools=None,
) -> AsyncGenerator[ResponseWithThought, None]:
client_key = f"{openai_api_key}--{api_base_url}"
client = openai_async_clients.get(client_key)
@@ -258,6 +276,8 @@ async def chat_completion_with_backoff(
read_timeout = 300 if is_local_api(api_base_url) else 60
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
if tools:
model_kwargs["tools"] = tools
aggregated_response = ""
final_chunk = None
@@ -277,7 +297,7 @@ async def chat_completion_with_backoff(
raise ValueError("No response by model.")
aggregated_response = response.choices[0].message.content
final_chunk = response
yield ResponseWithThought(response=aggregated_response)
yield ResponseWithThought(text=aggregated_response)
else:
async for chunk in stream_processor(response):
# Log the time taken to start response
@@ -293,8 +313,8 @@ async def chat_completion_with_backoff(
response_chunk: ResponseWithThought = None
response_delta = chunk.choices[0].delta
if response_delta.content:
response_chunk = ResponseWithThought(response=response_delta.content)
aggregated_response += response_chunk.response
response_chunk = ResponseWithThought(text=response_delta.content)
aggregated_response += response_chunk.text
elif response_delta.thought:
response_chunk = ResponseWithThought(thought=response_delta.thought)
if response_chunk:
@@ -327,16 +347,16 @@ async def chat_completion_with_backoff(
commit_conversation_trace(messages, aggregated_response, tracer)
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
def get_structured_output_support(model_name: str, api_base_url: str = None) -> StructuredOutputSupport:
if model_name.startswith("deepseek-reasoner"):
return JsonSupport.NONE
return StructuredOutputSupport.NONE
if api_base_url:
host = urlparse(api_base_url).hostname
if host and host.endswith(".ai.azure.com"):
return JsonSupport.OBJECT
return StructuredOutputSupport.OBJECT
if host == "api.deepinfra.com":
return JsonSupport.OBJECT
return JsonSupport.SCHEMA
return StructuredOutputSupport.OBJECT
return StructuredOutputSupport.TOOL
def format_message_for_api(messages: List[ChatMessage], api_base_url: str) -> List[dict]:
@@ -345,6 +365,43 @@ def format_message_for_api(messages: List[ChatMessage], api_base_url: str) -> Li
"""
formatted_messages = []
for message in deepcopy(messages):
# Handle tool call and tool result message types
message_type = message.additional_kwargs.get("message_type")
if message_type == "tool_call":
# Convert tool_call to OpenAI function call format
content = []
for part in message.content:
content.append(
{
"type": "function",
"id": part.get("id"),
"function": {
"name": part.get("name"),
"arguments": json.dumps(part.get("input", part.get("args", {}))),
},
}
)
formatted_messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": content,
}
)
continue
if message_type == "tool_result":
# Convert tool_result to OpenAI tool result format
# Each part is a result for a tool call
for part in message.content:
formatted_messages.append(
{
"role": "tool",
"tool_call_id": part.get("id") or part.get("tool_use_id"),
"name": part.get("name"),
"content": part.get("content"),
}
)
continue
if isinstance(message.content, list) and not is_openai_api(api_base_url):
assistant_texts = []
has_images = False
@@ -708,3 +765,47 @@ def add_qwen_no_think_tag(formatted_messages: List[dict]) -> None:
if isinstance(content_part, dict) and content_part.get("type") == "text":
content_part["text"] += " /no_think"
break
def to_openai_tools(tools: List[ToolDefinition]) -> List[Dict] | None:
"Transform tool definitions from standard format to OpenAI format."
openai_tools = [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": clean_response_schema(tool.schema),
},
}
for tool in tools
]
return openai_tools or None
def clean_response_schema(schema: BaseModel | dict) -> dict:
"""
Format response schema to be compatible with OpenAI API.
Clean the response schema by removing unsupported fields.
"""
# Normalize schema to OpenAI compatible JSON schema format
schema_json = schema if isinstance(schema, dict) else schema.model_json_schema()
schema_json = _ensure_strict_json_schema(schema_json, path=(), root=schema_json)
# Recursively drop unsupported fields from schema passed to OpenAI API
# See https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
fields_to_exclude = ["minItems", "maxItems"]
if isinstance(schema_json, dict) and isinstance(schema_json.get("properties"), dict):
for _, prop_value in schema_json["properties"].items():
if isinstance(prop_value, dict):
# Remove specified fields from direct properties
for field in fields_to_exclude:
prop_value.pop(field, None)
# Recursively remove specified fields from child properties
if "items" in prop_value and isinstance(prop_value["items"], dict):
clean_response_schema(prop_value["items"])
# Return cleaned schema
return schema_json
+17 -35
View File
@@ -667,33 +667,37 @@ Here's some additional context about you:
plan_function_execution = PromptTemplate.from_template(
"""
You are Khoj, a smart, creative and methodical researcher. Use the provided tool AIs to investigate information to answer query.
Create a multi-step plan and intelligently iterate on the plan based on the retrieved information to find the requested information.
You are Khoj, a smart, creative and meticulous researcher. Use the provided tool AIs to accomplish the task assigned to you.
Create a multi-step plan and intelligently iterate on the plan to complete the task.
{personality_context}
# Instructions
- Ask highly diverse, detailed queries to the tool AIs, one tool AI at a time, to discover required information or run calculations. Their response will be shown to you in the next iteration.
- Provide highly diverse, detailed requests to the tool AIs, one tool AI at a time, to gather information, perform actions etc. Their response will be shown to you in the next iteration.
- Break down your research process into independent, self-contained steps that can be executed sequentially using the available tool AIs to answer the user's query. Write your step-by-step plan in the scratchpad.
- Always ask a new query that was not asked to the tool AI in a previous iteration. Build on the results of the previous iterations.
- Ensure that all required context is passed to the tool AIs for successful execution. Include any relevant stuff that has previously been attempted. They only know the context provided in your query.
- Think step by step to come up with creative strategies when the previous iteration did not yield useful results.
- You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to answer the user's question.
- Stop when you have the required information by returning a JSON object with the "tool" field set to "text" and "query" field empty. E.g., {{"scratchpad": "I have all I need", "tool": "text", "query": ""}}
- You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to accomplish the task assigned to you. Only stop when you have completed the task.
# Examples
Assuming you can search the user's notes and the internet.
Assuming you can search the user's files and the internet.
- When the user asks for the population of their hometown
1. Try look up their hometown in their notes. Ask the note search AI to search for their birth certificate, childhood memories, school, resume etc.
2. If not found in their notes, try infer their hometown from their online social media profiles. Ask the online search AI to look for {username}'s biography, school, resume on linkedin, facebook, website etc.
3. Only then try find the latest population of their hometown by reading official websites with the help of the online search and web page reading AI.
1. Try look up their hometown in their notes. Ask the semantic search AI to search for their birth certificate, childhood memories, school, resume etc.
2. Use the other document retrieval tools to build on the semantic search results, fill in the gaps, add more details or confirm your hypothesis.
3. If not found in their notes, try infer their hometown from their online social media profiles. Ask the online search AI to look for {username}'s biography, school, resume on linkedin, facebook, website etc.
4. Only then try find the latest population of their hometown by reading official websites with the help of the online search and web page reading AI.
- When the user asks for their computer's specs
1. Try find their computer model in their notes.
1. Try find their computer model in their documents.
2. Now find webpages with their computer model's spec online.
3. Ask the webpage tool AI to extract the required information from the relevant webpages.
- When the user asks what clothes to carry for their upcoming trip
1. Find the itinerary of their upcoming trip in their notes.
1. Use the semantic search tool to find the itinerary of their upcoming trip in their documents.
2. Next find the weather forecast at the destination online.
3. Then find if they mentioned what clothes they own in their notes.
3. Then combine the semantic search, regex search, view file and list files tools to find if all the clothes they own in their files.
- When the user asks you to summarize their expenses in a particular month
1. Combine the semantic search and regex search tool AI to find all transactions in the user's documents for that month.
2. Use the view file tool to read the line ranges in the matched files
3. Finally summarize the expenses
# Background Context
- Current Date: {day_of_week}, {current_date}
@@ -701,31 +705,9 @@ Assuming you can search the user's notes and the internet.
- User Name: {username}
# Available Tool AIs
You decide which of the tool AIs listed below would you use to answer the user's question. You **only** have access to the following tool AIs:
You decide which of the tool AIs listed below would you use to accomplish the user assigned task. You **only** have access to the following tool AIs:
{tools}
Your response should always be a valid JSON object with keys: "scratchpad" (str), "tool" (str) and "query" (str). Do not say anything else.
Response format:
{{"scratchpad": "<your_scratchpad_to_reason_about_which_tool_to_use>", "tool": "<name_of_tool_ai>", "query": "<your_detailed_query_for_the_tool_ai>"}}
""".strip()
)
plan_function_execution_next_tool = PromptTemplate.from_template(
"""
Given the results of your previous iterations, which tool AI will you use next to answer the target query?
# Target Query:
{query}
""".strip()
)
previous_iteration = PromptTemplate.from_template(
"""
# Iteration {index}:
- tool: {tool}
- query: {query}
- result: {result}
""".strip()
)
+128 -57
View File
@@ -10,7 +10,7 @@ from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from io import BytesIO
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import PIL.Image
import pyjson5
@@ -137,60 +137,83 @@ class OperatorRun:
}
class ToolCall:
def __init__(self, name: str, args: dict, id: str):
self.name = name
self.args = args
self.id = id
class ResearchIteration:
def __init__(
self,
tool: str,
query: str,
query: ToolCall | dict | str,
context: list = None,
onlineContext: dict = None,
codeContext: dict = None,
operatorContext: dict | OperatorRun = None,
summarizedResult: str = None,
warning: str = None,
raw_response: list = None,
):
self.tool = tool
self.query = query
self.query = ToolCall(**query) if isinstance(query, dict) else query
self.context = context
self.onlineContext = onlineContext
self.codeContext = codeContext
self.operatorContext = OperatorRun(**operatorContext) if isinstance(operatorContext, dict) else operatorContext
self.summarizedResult = summarizedResult
self.warning = warning
self.raw_response = raw_response
def to_dict(self) -> dict:
data = vars(self).copy()
data["query"] = self.query.__dict__ if isinstance(self.query, ToolCall) else self.query
data["operatorContext"] = self.operatorContext.to_dict() if self.operatorContext else None
return data
def construct_iteration_history(
previous_iterations: List[ResearchIteration],
previous_iteration_prompt: str,
query: str = None,
query_images: List[str] = None,
query_files: str = None,
) -> list[ChatMessageModel]:
iteration_history: list[ChatMessageModel] = []
previous_iteration_messages: list[dict] = []
for idx, iteration in enumerate(previous_iterations):
iteration_data = previous_iteration_prompt.format(
tool=iteration.tool,
query=iteration.query,
result=iteration.summarizedResult,
index=idx + 1,
)
query_message_content = construct_structured_message(query, query_images, attached_file_context=query_files)
if query_message_content:
iteration_history.append(ChatMessageModel(by="you", message=query_message_content))
previous_iteration_messages.append({"type": "text", "text": iteration_data})
if previous_iteration_messages:
if query:
iteration_history.append(ChatMessageModel(by="you", message=query))
iteration_history.append(
for iteration in previous_iterations:
if not iteration.query or isinstance(iteration.query, str):
iteration_history.append(
ChatMessageModel(
by="you",
message=iteration.summarizedResult
or iteration.warning
or "Please specify what you want to do next.",
)
)
continue
iteration_history += [
ChatMessageModel(
by="khoj",
intent=Intent(type="remember", query=query),
message=previous_iteration_messages,
)
)
message=iteration.raw_response or [iteration.query.__dict__],
intent=Intent(type="tool_call", query=query),
),
ChatMessageModel(
by="you",
intent=Intent(type="tool_result"),
message=[
{
"type": "tool_result",
"id": iteration.query.id,
"name": iteration.query.name,
"content": iteration.summarizedResult,
}
],
),
]
return iteration_history
@@ -302,33 +325,44 @@ def construct_tool_chat_history(
ConversationCommand.Notes: (
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
),
ConversationCommand.Online: (
ConversationCommand.SearchWeb: (
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
),
ConversationCommand.Webpage: (
ConversationCommand.ReadWebpage: (
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
),
ConversationCommand.Code: (
ConversationCommand.RunCode: (
lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
),
}
for iteration in previous_iterations:
if not iteration.query or isinstance(iteration.query, str):
chat_history.append(
ChatMessageModel(
by="you",
message=iteration.summarizedResult
or iteration.warning
or "Please specify what you want to do next.",
)
)
continue
# If a tool is provided use the inferred query extractor for that tool if available
# If no tool is provided, use inferred query extractor for the tool used in the iteration
# Fallback to base extractor if the tool does not have an inferred query extractor
inferred_query_extractor = extract_inferred_query_map.get(
tool or ConversationCommand(iteration.tool), base_extractor
tool or ConversationCommand(iteration.query.name), base_extractor
)
chat_history += [
ChatMessageModel(
by="you",
message=iteration.query,
message=yaml.dump(iteration.query.args, default_flow_style=False),
),
ChatMessageModel(
by="khoj",
intent=Intent(
type="remember",
query=iteration.query,
query=yaml.dump(iteration.query.args, default_flow_style=False),
inferred_queries=inferred_query_extractor(iteration),
memory_type="notes",
),
@@ -481,28 +515,32 @@ Khoj: "{chat_response}"
def construct_structured_message(
message: list[dict] | str,
images: list[str],
model_type: str,
vision_enabled: bool,
images: list[str] = None,
model_type: str = None,
vision_enabled: bool = True,
attached_file_context: str = None,
):
"""
Format messages into appropriate multimedia format for supported chat model types
Format messages into appropriate multimedia format for supported chat model types.
Assume vision is enabled and chat model provider supports messages in chatml format, unless specified otherwise.
"""
if model_type in [
if not model_type or model_type in [
ChatModel.ModelType.OPENAI,
ChatModel.ModelType.GOOGLE,
ChatModel.ModelType.ANTHROPIC,
]:
constructed_messages: List[dict[str, Any]] = (
[{"type": "text", "text": message}] if isinstance(message, str) else message
)
constructed_messages: List[dict[str, Any]] = []
if not is_none_or_empty(message):
constructed_messages += [{"type": "text", "text": message}] if isinstance(message, str) else message
# Drop image message passed by caller if chat model does not have vision enabled
if not vision_enabled:
constructed_messages = [m for m in constructed_messages if m.get("type") != "image_url"]
if not is_none_or_empty(attached_file_context):
constructed_messages.append({"type": "text", "text": attached_file_context})
constructed_messages += [{"type": "text", "text": attached_file_context}]
if vision_enabled and images:
for image in images:
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
constructed_messages += [{"type": "image_url", "image_url": {"url": image}}]
return constructed_messages
message = message if isinstance(message, str) else "\n\n".join(m["text"] for m in message)
@@ -638,7 +676,11 @@ def generate_chatml_messages_with_context(
chat_message, chat.images if role == "user" else [], model_type, vision_enabled
)
reconstructed_message = ChatMessage(content=message_content, role=role)
reconstructed_message = ChatMessage(
content=message_content,
role=role,
additional_kwargs={"message_type": chat.intent.type if chat.intent else None},
)
chatml_messages.insert(0, reconstructed_message)
if len(chatml_messages) >= 3 * lookback_turns:
@@ -737,10 +779,21 @@ def count_tokens(
message_content_parts: list[str] = []
# Collate message content into single string to ease token counting
for part in message_content:
if isinstance(part, dict) and part.get("type") == "text":
message_content_parts.append(part["text"])
elif isinstance(part, dict) and part.get("type") == "image_url":
if isinstance(part, dict) and part.get("type") == "image_url":
image_count += 1
elif isinstance(part, dict) and part.get("type") == "text":
message_content_parts.append(part["text"])
elif isinstance(part, dict) and hasattr(part, "model_dump"):
message_content_parts.append(json.dumps(part.model_dump()))
elif isinstance(part, dict) and hasattr(part, "__dict__"):
message_content_parts.append(json.dumps(part.__dict__))
elif isinstance(part, dict):
# If part is a dict but not a recognized type, convert to JSON string
try:
message_content_parts.append(json.dumps(part))
except (TypeError, ValueError) as e:
logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.")
image_count += 1 # Treat as an image/binary if serialization fails
elif isinstance(part, str):
message_content_parts.append(part)
else:
@@ -753,6 +806,15 @@ def count_tokens(
return len(encoder.encode(json.dumps(message_content)))
def count_total_tokens(messages: list[ChatMessage], encoder, system_message: Optional[ChatMessage]) -> Tuple[int, int]:
"""Count total tokens in messages including system message"""
system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0
message_tokens = sum([count_tokens(message.content, encoder) for message in messages])
# Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
total_tokens = message_tokens + system_message_tokens + 4 * len(messages)
return total_tokens, system_message_tokens
def truncate_messages(
messages: list[ChatMessage],
max_prompt_size: int,
@@ -771,23 +833,30 @@ def truncate_messages(
break
# Drop older messages until under max supported prompt size by model
# Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0
tokens = sum([count_tokens(message.content, encoder) for message in messages])
total_tokens = tokens + system_message_tokens + 4 * len(messages)
total_tokens, system_message_tokens = count_total_tokens(messages, encoder, system_message)
while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1):
if len(messages[-1].content) > 1:
# If the last message has more than one content part, pop the oldest content part.
# For tool calls, the whole message should dropped, assistant's tool call content being truncated annoys AI APIs.
if len(messages[-1].content) > 1 and messages[-1].additional_kwargs.get("message_type") != "tool_call":
# The oldest content part is earlier in content list. So pop from the front.
messages[-1].content.pop(0)
# Otherwise, pop the last message if it has only one content part or is a tool call.
else:
# The oldest message is the last one. So pop from the back.
messages.pop()
tokens = sum([count_tokens(message.content, encoder) for message in messages])
total_tokens = tokens + system_message_tokens + 4 * len(messages)
dropped_message = messages.pop()
# Drop tool result pair of tool call, if tool call message has been removed
if (
dropped_message.additional_kwargs.get("message_type") == "tool_call"
and messages
and messages[-1].additional_kwargs.get("message_type") == "tool_result"
):
messages.pop()
total_tokens, _ = count_total_tokens(messages, encoder, system_message)
# Truncate current message if still over max supported prompt size by model
total_tokens = tokens + system_message_tokens + 4 * len(messages)
total_tokens, _ = count_total_tokens(messages, encoder, system_message)
if total_tokens > max_prompt_size:
# At this point, a single message with a single content part of type dict should remain
assert (
@@ -1149,13 +1218,15 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str:
return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages])
class JsonSupport(int, Enum):
class StructuredOutputSupport(int, Enum):
NONE = 0
OBJECT = 1
SCHEMA = 2
TOOL = 3
class ResponseWithThought:
def __init__(self, response: str = None, thought: str = None):
self.response = response
def __init__(self, text: str = None, thought: str = None, raw_content: list = None):
self.text = text
self.thought = thought
self.raw_content = raw_content
@@ -73,7 +73,7 @@ class GroundingAgent:
grounding_user_prompt = self.get_instruction(instruction, self.environment_type)
screenshots = [f"data:image/webp;base64,{current_state.screenshot}"]
grounding_messages_content = construct_structured_message(
grounding_user_prompt, screenshots, self.model.name, vision_enabled=True
grounding_user_prompt, screenshots, self.model.model_type, vision_enabled=True
)
return [{"role": "user", "content": grounding_messages_content}]
@@ -121,7 +121,7 @@ class BinaryOperatorAgent(OperatorAgent):
# Construct input for visual reasoner history
visual_reasoner_history = self._format_message_for_api(self.messages)
try:
natural_language_action = await send_message_to_model_wrapper(
raw_response = await send_message_to_model_wrapper(
query=query_text,
query_images=query_screenshot,
system_message=reasoning_system_prompt,
@@ -129,6 +129,7 @@ class BinaryOperatorAgent(OperatorAgent):
agent_chat_model=self.reasoning_model,
tracer=self.tracer,
)
natural_language_action = raw_response.text
if not isinstance(natural_language_action, str) or not natural_language_action.strip():
raise ValueError(f"Natural language action is empty or not a string. Got {natural_language_action}")
@@ -255,10 +256,10 @@ class BinaryOperatorAgent(OperatorAgent):
# Append summary messages to history
trigger_summary = AgentMessage(role="user", content=summarize_prompt)
summary_message = AgentMessage(role="assistant", content=summary)
summary_message = AgentMessage(role="assistant", content=summary.text)
self.messages.extend([trigger_summary, summary_message])
return summary
return summary.text
def _compile_response(self, response_content: str | List) -> str:
"""Compile response content into a string, handling OpenAI message structures."""
+18
View File
@@ -390,7 +390,25 @@ async def read_webpages(
query_files=query_files,
tracer=tracer,
)
async for result in read_webpages_content(
query,
urls,
user,
send_status_func=send_status_func,
agent=agent,
tracer=tracer,
):
yield result
async def read_webpages_content(
query: str,
urls: List[str],
user: KhojUser,
send_status_func: Optional[Callable] = None,
agent: Agent = None,
tracer: dict = {},
):
logger.info(f"Reading web pages at: {urls}")
if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(urls))
+1 -1
View File
@@ -161,7 +161,7 @@ async def generate_python_code(
)
# Extract python code wrapped in markdown code blocks from the response
code_blocks = re.findall(r"```(?:python)?\n(.*?)```", response, re.DOTALL)
code_blocks = re.findall(r"```(?:python)?\n(.*?)```", response.text, re.DOTALL)
if not code_blocks:
raise ValueError("No Python code blocks found in response")
+1 -1
View File
@@ -1390,7 +1390,7 @@ async def chat(
continue
if cancellation_event.is_set():
break
message = item.response
message = item.text
full_response += message if message else ""
if item.thought:
async for result in send_event(ChatEvent.THOUGHT, item.thought):
+293 -26
View File
@@ -1,6 +1,7 @@
import asyncio
import base64
import concurrent.futures
import fnmatch
import hashlib
import json
import logging
@@ -120,6 +121,7 @@ from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import (
LRU,
ConversationCommand,
ToolDefinition,
get_file_type,
in_debug_mode,
is_none_or_empty,
@@ -303,7 +305,7 @@ async def acreate_title_from_history(
with timer("Chat actor: Generate title from conversation history", logger):
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
return response.strip()
return response.text.strip()
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
@@ -315,7 +317,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
with timer("Chat actor: Generate title from query", logger):
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
return response.strip()
return response.text.strip()
async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: bool = False) -> Tuple[bool, str]:
@@ -339,7 +341,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck
)
response = response.strip()
response = response.text.strip()
try:
response = json.loads(clean_json(response))
is_safe = str(response.get("safe", "true")).lower() == "true"
@@ -418,7 +420,7 @@ async def aget_data_sources_and_output_format(
output: str
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
raw_response = await send_message_to_model_wrapper(
relevant_tools_prompt,
response_type="json_object",
response_schema=PickTools,
@@ -429,7 +431,7 @@ async def aget_data_sources_and_output_format(
)
try:
response = clean_json(response)
response = clean_json(raw_response.text)
response = json.loads(response)
chosen_sources = [s.strip() for s in response.get("source", []) if s.strip()]
@@ -506,7 +508,7 @@ async def infer_webpage_urls(
links: List[str] = Field(..., min_items=1, max_items=max_webpages)
with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper(
raw_response = await send_message_to_model_wrapper(
online_queries_prompt,
query_images=query_images,
response_type="json_object",
@@ -519,7 +521,7 @@ async def infer_webpage_urls(
# Validate that the response is a non-empty, JSON-serializable list of URLs
try:
response = clean_json(response)
response = clean_json(raw_response.text)
urls = json.loads(response)
valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
if is_none_or_empty(valid_unique_urls):
@@ -571,7 +573,7 @@ async def generate_online_subqueries(
queries: List[str] = Field(..., min_items=1, max_items=max_queries)
with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper(
raw_response = await send_message_to_model_wrapper(
online_queries_prompt,
query_images=query_images,
response_type="json_object",
@@ -584,7 +586,7 @@ async def generate_online_subqueries(
# Validate that the response is a non-empty, JSON-serializable list
try:
response = clean_json(response)
response = clean_json(raw_response.text)
response = pyjson5.loads(response)
response = {q.strip() for q in response["queries"] if q.strip()}
if not isinstance(response, set) or not response or len(response) == 0:
@@ -645,7 +647,7 @@ async def aschedule_query(
# Validate that the response is a non-empty, JSON-serializable list
try:
raw_response = raw_response.strip()
raw_response = raw_response.text.strip()
response: Dict[str, str] = json.loads(clean_json(raw_response))
if not response or not isinstance(response, Dict) or len(response) != 3:
raise AssertionError(f"Invalid response for scheduling query : {response}")
@@ -683,7 +685,7 @@ async def extract_relevant_info(
agent_chat_model=agent_chat_model,
tracer=tracer,
)
return response.strip()
return response.text.strip()
async def extract_relevant_summary(
@@ -726,7 +728,7 @@ async def extract_relevant_summary(
agent_chat_model=agent_chat_model,
tracer=tracer,
)
return response.strip()
return response.text.strip()
async def generate_summary_from_files(
@@ -897,7 +899,7 @@ async def generate_better_diagram_description(
agent_chat_model=agent_chat_model,
tracer=tracer,
)
response = response.strip()
response = response.text.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
@@ -925,10 +927,10 @@ async def generate_excalidraw_diagram_from_description(
raw_response = await send_message_to_model_wrapper(
query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
)
raw_response = clean_json(raw_response)
raw_response_text = clean_json(raw_response.text)
try:
# Expect response to have `elements` and `scratchpad` keys
response: Dict[str, str] = json.loads(raw_response)
response: Dict[str, str] = json.loads(raw_response_text)
if (
not response
or not isinstance(response, Dict)
@@ -937,7 +939,7 @@ async def generate_excalidraw_diagram_from_description(
):
raise AssertionError(f"Invalid response for generating Excalidraw diagram: {response}")
except Exception:
raise AssertionError(f"Invalid response for generating Excalidraw diagram: {raw_response}")
raise AssertionError(f"Invalid response for generating Excalidraw diagram: {raw_response_text}")
if not response or not isinstance(response["elements"], List) or not isinstance(response["elements"][0], Dict):
# TODO Some additional validation here that it's a valid Excalidraw diagram
raise AssertionError(f"Invalid response for improving diagram description: {response}")
@@ -1048,11 +1050,11 @@ async def generate_better_mermaidjs_diagram_description(
agent_chat_model=agent_chat_model,
tracer=tracer,
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
response_text = response.text.strip()
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
response_text = response_text[1:-1]
return response
return response_text
async def generate_mermaidjs_diagram_from_description(
@@ -1076,7 +1078,7 @@ async def generate_mermaidjs_diagram_from_description(
raw_response = await send_message_to_model_wrapper(
query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
)
return clean_mermaidjs(raw_response.strip())
return clean_mermaidjs(raw_response.text.strip())
async def generate_better_image_prompt(
@@ -1151,11 +1153,11 @@ async def generate_better_image_prompt(
agent_chat_model=agent_chat_model,
tracer=tracer,
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
response_text = response.text.strip()
if response_text.startswith(('"', "'")) and response_text.endswith(('"', "'")):
response_text = response_text[1:-1]
return response
return response_text
async def search_documents(
@@ -1329,7 +1331,7 @@ async def extract_questions(
# Extract questions from the response
try:
response = clean_json(raw_response)
response = clean_json(raw_response.text)
response = pyjson5.loads(response)
queries = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(queries, list) or not queries:
@@ -1439,6 +1441,7 @@ async def send_message_to_model_wrapper(
system_message: str = "",
response_type: str = "text",
response_schema: BaseModel = None,
tools: List[ToolDefinition] = None,
deepthought: bool = False,
user: KhojUser = None,
query_images: List[str] = None,
@@ -1506,6 +1509,7 @@ async def send_message_to_model_wrapper(
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tools=tools,
deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer,
@@ -1517,6 +1521,7 @@ async def send_message_to_model_wrapper(
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tools=tools,
deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer,
@@ -1528,6 +1533,7 @@ async def send_message_to_model_wrapper(
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
tools=tools,
deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer,
@@ -2796,3 +2802,264 @@ def get_notion_auth_url(user: KhojUser):
if not NOTION_OAUTH_CLIENT_ID or not NOTION_OAUTH_CLIENT_SECRET or not NOTION_REDIRECT_URI:
return None
return f"https://api.notion.com/v1/oauth/authorize?client_id={NOTION_OAUTH_CLIENT_ID}&redirect_uri={NOTION_REDIRECT_URI}&response_type=code&state={user.uuid}"
async def view_file_content(
path: str,
start_line: Optional[int] = None,
end_line: Optional[int] = None,
user: KhojUser = None,
):
"""
View the contents of a file from the user's document database with optional line range specification.
"""
query = f"View file: {path}"
if start_line and end_line:
query += f" (lines {start_line}-{end_line})"
try:
# Get the file object from the database by name
file_objects = await FileObjectAdapters.aget_file_objects_by_name(user, path)
if not file_objects:
error_msg = f"File '{path}' not found in user documents"
logger.warning(error_msg)
yield [{"query": query, "file": path, "compiled": error_msg}]
return
# Use the first file object if multiple exist
file_object = file_objects[0]
raw_text = file_object.raw_text
# Apply line range filtering if specified
if start_line is None and end_line is None:
filtered_text = raw_text
else:
lines = raw_text.split("\n")
start_line = start_line or 1
end_line = end_line or len(lines)
# Validate line range
if start_line < 1 or end_line < 1 or start_line > end_line:
error_msg = f"Invalid line range: {start_line}-{end_line}"
logger.warning(error_msg)
yield [{"query": query, "file": path, "compiled": error_msg}]
return
if start_line > len(lines):
error_msg = f"Start line {start_line} exceeds total number of lines {len(lines)}"
logger.warning(error_msg)
yield [{"query": query, "file": path, "compiled": error_msg}]
return
# Convert from 1-based to 0-based indexing and ensure bounds
start_idx = max(0, start_line - 1)
end_idx = min(len(lines), end_line)
selected_lines = lines[start_idx:end_idx]
filtered_text = "\n".join(selected_lines)
# Truncate the text if it's too long
if len(filtered_text) > 10000:
filtered_text = filtered_text[:10000] + "\n\n[Truncated. Use line numbers to view specific sections.]"
# Format the result as a document reference
document_results = [
{
"query": query,
"file": path,
"compiled": filtered_text,
}
]
yield document_results
except Exception as e:
error_msg = f"Error viewing file {path}: {str(e)}"
logger.error(error_msg, exc_info=True)
# Return an error result in the expected format
yield [{"query": query, "file": path, "compiled": error_msg}]
async def grep_files(
regex_pattern: str,
path_prefix: Optional[str] = None,
lines_before: Optional[int] = None,
lines_after: Optional[int] = None,
user: KhojUser = None,
):
"""
Search for a regex pattern in files with an optional path prefix and context lines.
"""
# Construct the query string based on provided parameters
def _generate_query(line_count, doc_count, path, pattern, lines_before, lines_after, max_results=1000):
query = f"**Found {line_count} matches for '{pattern}' in {doc_count} documents**"
if path:
query += f" in {path}"
if lines_before or lines_after or line_count > max_results:
query += " Showing"
if lines_before or lines_after:
context_info = []
if lines_before:
context_info.append(f"{lines_before} lines before")
if lines_after:
context_info.append(f"{lines_after} lines after")
query += f" {' and '.join(context_info)}"
if line_count > max_results:
if lines_before or lines_after:
query += f" for"
query += f" first {max_results} results"
return query
# Validate regex pattern
path_prefix = path_prefix or ""
lines_before = lines_before or 0
lines_after = lines_after or 0
try:
regex = re.compile(regex_pattern, re.IGNORECASE)
except re.error as e:
yield {
"query": _generate_query(0, 0, path_prefix, regex_pattern, lines_before, lines_after),
"file": path_prefix,
"compiled": f"Invalid regex pattern: {e}",
}
return
try:
file_matches = await FileObjectAdapters.aget_file_objects_by_regex(user, regex_pattern, path_prefix)
line_matches = []
for file_object in file_matches:
lines = file_object.raw_text.split("\n")
matched_line_numbers = []
# Find all matching line numbers first
for i, line in enumerate(lines, 1):
if regex.search(line):
matched_line_numbers.append(i)
# Build context for each match
for line_num in matched_line_numbers:
context_lines = []
# Calculate start and end indices for context (0-based)
start_idx = max(0, line_num - 1 - lines_before)
end_idx = min(len(lines), line_num + lines_after)
# Add context lines with line numbers
for idx in range(start_idx, end_idx):
current_line_num = idx + 1
line_content = lines[idx]
if current_line_num == line_num:
# This is the matching line, mark it
context_lines.append(f"{file_object.file_name}:{current_line_num}:> {line_content}")
else:
# This is a context line
context_lines.append(f"{file_object.file_name}:{current_line_num}: {line_content}")
# Add separator between matches if showing context
if lines_before > 0 or lines_after > 0:
context_lines.append("--")
line_matches.extend(context_lines)
# Remove the last separator if it exists
if line_matches and line_matches[-1] == "--":
line_matches.pop()
# Check if no results found
max_results = 1000
query = _generate_query(
len([m for m in line_matches if ":>" in m]),
len(file_matches),
path_prefix,
regex_pattern,
lines_before,
lines_after,
max_results,
)
if not line_matches:
yield {"query": query, "file": path_prefix, "compiled": "No matches found."}
return
# Truncate matched lines list if too long
if len(line_matches) > max_results:
line_matches = line_matches[:max_results] + [
f"... {len(line_matches) - max_results} more results found. Use stricter regex or path to narrow down results."
]
yield {"query": query, "file": path_prefix or "", "compiled": "\n".join(line_matches)}
except Exception as e:
error_msg = f"Error using grep files tool: {str(e)}"
logger.error(error_msg, exc_info=True)
yield [
{
"query": _generate_query(0, 0, path_prefix or "", regex_pattern, lines_before, lines_after),
"file": path_prefix,
"compiled": error_msg,
}
]
async def list_files(
path: Optional[str] = None,
pattern: Optional[str] = None,
user: KhojUser = None,
):
"""
List files under a given path or glob pattern from the user's document database.
"""
# Construct the query string based on provided parameters
def _generate_query(doc_count, path, pattern):
query = f"**Found {doc_count} files**"
if path:
query += f" in {path}"
if pattern:
query += f" filtered by {pattern}"
return query
try:
# Get user files by path prefix when specified
path = path or ""
if path in ["", "/", ".", "./", "~", "~/"]:
file_objects = await FileObjectAdapters.aget_all_file_objects(user, limit=10000)
else:
file_objects = await FileObjectAdapters.aget_file_objects_by_path_prefix(user, path)
if not file_objects:
yield {"query": _generate_query(0, path, pattern), "file": path, "compiled": "No files found."}
return
# Extract file names from file objects
files = [f.file_name for f in file_objects]
# Convert to relative file path (similar to ls)
if path:
files = [f[len(path) :] for f in files]
# Apply glob pattern filtering if specified
if pattern:
files = [f for f in files if fnmatch.fnmatch(f, pattern)]
query = _generate_query(len(files), path, pattern)
if not files:
yield {"query": query, "file": path, "compiled": "No files found."}
return
# Truncate the list if it's too long
max_files = 100
if len(files) > max_files:
files = files[:max_files] + [
f"... {len(files) - max_files} more files found. Use glob pattern to narrow down results."
]
yield {"query": query, "file": path, "compiled": "\n- ".join(files)}
except Exception as e:
error_msg = f"Error listing files in {path}: {str(e)}"
logger.error(error_msg, exc_info=True)
yield {"query": query, "file": path, "compiled": error_msg}
+169 -155
View File
@@ -3,11 +3,9 @@ import logging
import os
from copy import deepcopy
from datetime import datetime
from enum import Enum
from typing import Callable, Dict, List, Optional, Type
from typing import Callable, Dict, List, Optional
import yaml
from pydantic import BaseModel, Field
from khoj.database.adapters import AgentAdapters, EntryAdapters
from khoj.database.models import Agent, ChatMessageModel, KhojUser
@@ -15,25 +13,31 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
OperatorRun,
ResearchIteration,
ToolCall,
construct_iteration_history,
construct_tool_chat_history,
load_complex_json,
)
from khoj.processor.operator import operate_environment
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.online_search import read_webpages_content, search_online
from khoj.processor.tools.run_code import run_code
from khoj.routers.helpers import (
ChatEvent,
generate_summary_from_files,
grep_files,
list_files,
search_documents,
send_message_to_model_wrapper,
view_file_content,
)
from khoj.utils.helpers import (
ConversationCommand,
ToolDefinition,
dict_to_tuple,
is_none_or_empty,
is_operator_enabled,
timer,
tool_description_for_research_llm,
tools_for_research_llm,
truncate_code_context,
)
from khoj.utils.rawconfig import LocationData
@@ -41,47 +45,6 @@ from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
class PlanningResponse(BaseModel):
"""
Schema for the response from planning agent when deciding the next tool to pick.
"""
scratchpad: str = Field(..., description="Scratchpad to reason about which tool to use next")
class Config:
arbitrary_types_allowed = True
@classmethod
def create_model_with_enum(cls: Type["PlanningResponse"], tool_options: dict) -> Type["PlanningResponse"]:
"""
Factory method that creates a customized PlanningResponse model
with a properly typed tool field based on available tools.
The tool field is dynamically generated based on available tools.
The query field should be filled by the model after the tool field for a more logical reasoning flow.
Args:
tool_options: Dictionary mapping tool names to values
Returns:
A customized PlanningResponse class
"""
# Create dynamic enum from tool options
tool_enum = Enum("ToolEnum", tool_options) # type: ignore
# Create and return a customized response model with the enum
class PlanningResponseWithTool(PlanningResponse):
"""
Use the scratchpad to reason about which tool to use next and the query to send to the tool.
Pick tool from provided options and your query to send to the tool.
"""
tool: tool_enum = Field(..., description="Name of the tool to use")
query: str = Field(..., description="Detailed query for the selected tool")
return PlanningResponseWithTool
async def apick_next_tool(
query: str,
conversation_history: List[ChatMessageModel],
@@ -104,12 +67,13 @@ async def apick_next_tool(
# Continue with previous iteration if a multi-step tool use is in progress
if (
previous_iterations
and previous_iterations[-1].tool == ConversationCommand.Operator
and previous_iterations[-1].query
and isinstance(previous_iterations[-1].query, ToolCall)
and previous_iterations[-1].query.name == ConversationCommand.Operator
and not previous_iterations[-1].summarizedResult
):
previous_iteration = previous_iterations[-1]
yield ResearchIteration(
tool=previous_iteration.tool,
query=query,
context=previous_iteration.context,
onlineContext=previous_iteration.onlineContext,
@@ -120,30 +84,40 @@ async def apick_next_tool(
return
# Construct tool options for the agent to choose from
tool_options = dict()
tools = []
tool_options_str = ""
agent_tools = agent.input_tools if agent else []
user_has_entries = await EntryAdapters.auser_has_entries(user)
for tool, description in tool_description_for_research_llm.items():
for tool, tool_data in tools_for_research_llm.items():
# Skip showing operator tool as an option if not enabled
if tool == ConversationCommand.Operator and not is_operator_enabled():
continue
# Skip showing Notes tool as an option if user has no entries
if tool == ConversationCommand.Notes:
if not user_has_entries:
continue
description = description.format(max_search_queries=max_document_searches)
if tool == ConversationCommand.Webpage:
description = description.format(max_webpages_to_read=max_webpages_to_read)
if tool == ConversationCommand.Online:
description = description.format(max_search_queries=max_online_searches)
# Skip showing document related tools if user has no documents
if (
tool == ConversationCommand.SemanticSearchFiles
or tool == ConversationCommand.RegexSearchFiles
or tool == ConversationCommand.ViewFile
or tool == ConversationCommand.ListFiles
) and not user_has_entries:
continue
if tool == ConversationCommand.SemanticSearchFiles:
description = tool_data.description.format(max_search_queries=max_document_searches)
elif tool == ConversationCommand.Webpage:
description = tool_data.description.format(max_webpages_to_read=max_webpages_to_read)
elif tool == ConversationCommand.Online:
description = tool_data.description.format(max_search_queries=max_online_searches)
else:
description = tool_data.description
# Add tool if agent does not have any tools defined or the tool is supported by the agent.
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options[tool.name] = tool.value
tool_options_str += f'- "{tool.value}": "{description}"\n'
# Create planning reponse model with dynamically populated tool enum class
planning_response_model = PlanningResponse.create_model_with_enum(tool_options)
tools.append(
ToolDefinition(
name=tool.value,
description=description,
schema=tool_data.schema,
)
)
today = datetime.today()
location_data = f"{location}" if location else "Unknown"
@@ -162,24 +136,17 @@ async def apick_next_tool(
max_iterations=max_iterations,
)
if query_images:
query = f"[placeholder for user attached images]\n{query}"
# Construct chat history with user and iteration history with researcher agent for context
iteration_chat_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query)
iteration_chat_history = construct_iteration_history(previous_iterations, query, query_images, query_files)
chat_and_research_history = conversation_history + iteration_chat_history
# Plan function execution for the next tool
query = prompts.plan_function_execution_next_tool.format(query=query) if iteration_chat_history else query
try:
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
query=query,
query="",
system_message=function_planning_prompt,
chat_history=chat_and_research_history,
response_type="json_object",
response_schema=planning_response_model,
tools=tools,
deepthought=True,
user=user,
query_images=query_images,
@@ -190,48 +157,38 @@ async def apick_next_tool(
except Exception as e:
logger.error(f"Failed to infer information sources to refer: {e}", exc_info=True)
yield ResearchIteration(
tool=None,
query=None,
warning="Failed to infer information sources to refer. Skipping iteration. Try again.",
)
return
try:
response = load_complex_json(response)
if not isinstance(response, dict):
raise ValueError(f"Expected dict response, got {type(response).__name__}: {response}")
selected_tool = response.get("tool", None)
generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None)
warning = None
logger.info(f"Response for determining relevant tools: {response}")
# Detect selection of previously used query, tool combination.
previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations if i.warning is None}
if (selected_tool, generated_query) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
# Only send client status updates if we'll execute this iteration
elif send_status_func:
determined_tool_message = "**Determined Tool**: "
determined_tool_message += (
f"{selected_tool}({generated_query})." if selected_tool != ConversationCommand.Text else "respond."
)
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
async for event in send_status_func(f"{scratchpad}"):
yield {ChatEvent.STATUS: event}
yield ResearchIteration(
tool=selected_tool,
query=generated_query,
warning=warning,
)
# Try parse the response as function call response to infer next tool to use.
# TODO: Handle multiple tool calls.
response_text = response.text
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0]
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
yield ResearchIteration(
tool=None,
query=None,
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
)
# Otherwise assume the model has decided to end the research run and respond to the user.
parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None)
# If we have a valid response, extract the tool and query.
warning = None
logger.info(f"Response for determining relevant tools: {parsed_response.name}({parsed_response.args})")
# Detect selection of previously used query, tool combination.
previous_tool_query_combinations = {
(i.query.name, dict_to_tuple(i.query.args))
for i in previous_iterations
if i.warning is None and isinstance(i.query, ToolCall)
}
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
# Only send client status updates if we'll execute this iteration and model has thoughts to share.
elif send_status_func and not is_none_or_empty(response.thought):
async for event in send_status_func(response.thought):
yield {ChatEvent.STATUS: event}
yield ResearchIteration(query=parsed_response, warning=warning, raw_response=response.raw_content)
async def research(
@@ -257,10 +214,10 @@ async def research(
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
# Incorporate previous partial research into current research chat history
research_conversation_history = deepcopy(conversation_history)
research_conversation_history = [chat for chat in deepcopy(conversation_history) if chat.message]
if current_iteration := len(previous_iterations) > 0:
logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.")
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
previous_iterations_history = construct_iteration_history(previous_iterations)
research_conversation_history += previous_iterations_history
while current_iteration < MAX_ITERATIONS:
@@ -273,7 +230,7 @@ async def research(
code_results: Dict = dict()
document_results: List[Dict[str, str]] = []
operator_results: OperatorRun = None
this_iteration = ResearchIteration(tool=None, query=query)
this_iteration = ResearchIteration(query=query)
async for result in apick_next_tool(
query,
@@ -303,26 +260,30 @@ async def research(
logger.warning(f"Research mode: {this_iteration.warning}.")
# Terminate research if selected text tool or query, tool not set for next iteration
elif not this_iteration.query or not this_iteration.tool or this_iteration.tool == ConversationCommand.Text:
elif (
not this_iteration.query
or isinstance(this_iteration.query, str)
or this_iteration.query.name == ConversationCommand.Text
):
current_iteration = MAX_ITERATIONS
elif this_iteration.tool == ConversationCommand.Notes:
elif this_iteration.query.name == ConversationCommand.SemanticSearchFiles:
this_iteration.context = []
document_results = []
previous_inferred_queries = {
c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context
}
async for result in search_documents(
this_iteration.query,
max_document_searches,
None,
user,
construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
conversation_id,
[ConversationCommand.Default],
location,
send_status_func,
query_images,
**this_iteration.query.args,
n=max_document_searches,
d=None,
user=user,
chat_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SemanticSearchFiles),
conversation_id=conversation_id,
conversation_commands=[ConversationCommand.Default],
location_data=location,
send_status_func=send_status_func,
query_images=query_images,
previous_inferred_queries=previous_inferred_queries,
agent=agent,
tracer=tracer,
@@ -350,7 +311,7 @@ async def research(
else:
this_iteration.warning = "No matching document references found"
elif this_iteration.tool == ConversationCommand.Online:
elif this_iteration.query.name == ConversationCommand.SearchWeb:
previous_subqueries = {
subquery
for iteration in previous_iterations
@@ -359,12 +320,12 @@ async def research(
}
try:
async for result in search_online(
this_iteration.query,
construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
location,
user,
send_status_func,
[],
**this_iteration.query.args,
conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
location=location,
user=user,
send_status_func=send_status_func,
custom_filters=[],
max_online_searches=max_online_searches,
max_webpages_to_read=0,
query_images=query_images,
@@ -383,19 +344,15 @@ async def research(
this_iteration.warning = f"Error searching online: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.tool == ConversationCommand.Webpage:
elif this_iteration.query.name == ConversationCommand.ReadWebpage:
try:
async for result in read_webpages(
this_iteration.query,
construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage),
location,
user,
send_status_func,
max_webpages_to_read=max_webpages_to_read,
query_images=query_images,
async for result in read_webpages_content(
**this_iteration.query.args,
user=user,
send_status_func=send_status_func,
# max_webpages_to_read=max_webpages_to_read,
agent=agent,
tracer=tracer,
query_files=query_files,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@@ -416,15 +373,15 @@ async def research(
this_iteration.warning = f"Error reading webpages: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.tool == ConversationCommand.Code:
elif this_iteration.query.name == ConversationCommand.RunCode:
try:
async for result in run_code(
this_iteration.query,
construct_tool_chat_history(previous_iterations, ConversationCommand.Code),
"",
location,
user,
send_status_func,
**this_iteration.query.args,
conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.Code),
context="",
location_data=location,
user=user,
send_status_func=send_status_func,
query_images=query_images,
agent=agent,
query_files=query_files,
@@ -441,14 +398,14 @@ async def research(
this_iteration.warning = f"Error running code: {e}"
logger.warning(this_iteration.warning, exc_info=True)
elif this_iteration.tool == ConversationCommand.Operator:
elif this_iteration.query.name == ConversationCommand.OperateComputer:
try:
async for result in operate_environment(
this_iteration.query,
user,
construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
location,
previous_iterations[-1].operatorContext if previous_iterations else None,
**this_iteration.query.args,
user=user,
conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
location_data=location,
previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
send_status_func=send_status_func,
query_images=query_images,
agent=agent,
@@ -474,6 +431,63 @@ async def research(
this_iteration.warning = f"Error operating browser: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ViewFile:
try:
async for result in view_file_content(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = result # type: ignore
this_iteration.context += document_results
async for result in send_status_func(f"**Viewed file**: {this_iteration.query.args['path']}"):
yield result
except Exception as e:
this_iteration.warning = f"Error viewing file: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ListFiles:
try:
async for result in list_files(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = [result] # type: ignore
this_iteration.context += document_results
async for result in send_status_func(result["query"]):
yield result
except Exception as e:
this_iteration.warning = f"Error listing files: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.RegexSearchFiles:
try:
async for result in grep_files(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = [result] # type: ignore
this_iteration.context += document_results
async for result in send_status_func(result["query"]):
yield result
except Exception as e:
this_iteration.warning = f"Error searching with regex: {e}"
logger.error(this_iteration.warning, exc_info=True)
else:
# No valid tools. This is our exit condition.
current_iteration = MAX_ITERATIONS
@@ -481,7 +495,7 @@ async def research(
current_iteration += 1
if document_results or online_results or code_results or operator_results or this_iteration.warning:
results_data = f"\n<iteration>{current_iteration}\n<tool>{this_iteration.tool}</tool>\n<query>{this_iteration.query}</query>\n<results>"
results_data = f"\n<iteration_{current_iteration}_results>"
if document_results:
results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
if online_results:
@@ -494,7 +508,7 @@ async def research(
)
if this_iteration.warning:
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
results_data += "\n</results>\n</iteration>"
results_data += f"\n</results>\n</iteration_{current_iteration}_results>"
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = results_data
+284 -8
View File
@@ -12,6 +12,7 @@ import random
import urllib.parse
import uuid
from collections import OrderedDict
from copy import deepcopy
from enum import Enum
from functools import lru_cache
from importlib import import_module
@@ -19,8 +20,9 @@ from importlib.metadata import version
from itertools import islice
from os import path
from pathlib import Path
from textwrap import dedent
from time import perf_counter
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Type, Union
from urllib.parse import ParseResult, urlparse
import anthropic
@@ -36,6 +38,7 @@ from google.auth.credentials import Credentials
from google.oauth2 import service_account
from magika import Magika
from PIL import Image
from pydantic import BaseModel
from pytz import country_names, country_timezones
from khoj.utils import constants
@@ -334,6 +337,85 @@ def is_e2b_code_sandbox_enabled():
return not is_none_or_empty(os.getenv("E2B_API_KEY"))
class ToolDefinition:
def __init__(self, name: str, description: str, schema: dict):
self.name = name
self.description = description
self.schema = schema
def create_tool_definition(
schema: Type[BaseModel],
name: str = None,
description: Optional[str] = None,
) -> ToolDefinition:
"""
Converts a response schema BaseModel class into a normalized tool definition.
A standard AI provider agnostic tool format to specify tools the model can use.
Common logic used across models is kept here. AI provider specific adaptations
should be handled in provider code.
Args:
response_schema: The Pydantic BaseModel class to convert.
This class defines the response schema for the tool.
tool_name: The name for the AI model tool (e.g., "get_weather", "plan_next_step").
tool_description: Optional description for the AI model tool.
If None, it attempts to use the Pydantic model's docstring.
If that's also missing, a fallback description is generated.
Returns:
A normalized tool definition for AI model APIs.
"""
raw_schema_dict = schema.model_json_schema()
name = name or schema.__name__.lower()
description = description
if description is None:
docstring = schema.__doc__
if docstring:
description = dedent(docstring).strip()
else:
# Fallback description if no explicit one or docstring is provided
description = f"Tool named '{name}' accepts specified parameters."
# Process properties to inline enums and remove $defs dependency
processed_properties = {}
original_properties = raw_schema_dict.get("properties", {})
defs = raw_schema_dict.get("$defs", {})
for prop_name, prop_schema in original_properties.items():
current_prop_schema = deepcopy(prop_schema) # Work on a copy
# Check for enums defined directly in the property for simpler direct enum definitions.
if "$ref" in current_prop_schema:
ref_path = current_prop_schema["$ref"]
if ref_path.startswith("#/$defs/"):
def_name = ref_path.split("/")[-1]
if def_name in defs and "enum" in defs[def_name]:
enum_def = defs[def_name]
current_prop_schema["enum"] = enum_def["enum"]
current_prop_schema["type"] = enum_def.get("type", "string")
if "description" not in current_prop_schema and "description" in enum_def:
current_prop_schema["description"] = enum_def["description"]
del current_prop_schema["$ref"] # Remove the $ref as it's been inlined
processed_properties[prop_name] = current_prop_schema
# Generate the compiled schema dictionary for the tool definition.
compiled_schema = {
"type": "object",
"properties": processed_properties,
# Generate content in the order in which the schema properties were defined
"property_ordering": list(schema.model_fields.keys()),
}
# Include 'required' fields if specified in the Pydantic model
if "required" in raw_schema_dict and raw_schema_dict["required"]:
compiled_schema["required"] = raw_schema_dict["required"]
return ToolDefinition(name=name, description=description, schema=compiled_schema)
class ConversationCommand(str, Enum):
Default = "default"
General = "general"
@@ -347,6 +429,14 @@ class ConversationCommand(str, Enum):
Diagram = "diagram"
Research = "research"
Operator = "operator"
ViewFile = "view_file"
ListFiles = "list_files"
RegexSearchFiles = "regex_search_files"
SemanticSearchFiles = "semantic_search_files"
SearchWeb = "search_web"
ReadWebpage = "read_webpage"
RunCode = "run_code"
OperateComputer = "operate_computer"
command_descriptions = {
@@ -360,6 +450,9 @@ command_descriptions = {
ConversationCommand.Diagram: "Draw a flowchart, diagram, or any other visual representation best expressed with primitives like lines, rectangles, and text.",
ConversationCommand.Research: "Do deep research on a topic. This will take longer than usual, but give a more detailed, comprehensive answer.",
ConversationCommand.Operator: "Operate and perform tasks using a computer.",
ConversationCommand.ViewFile: "View the contents of a file with optional line range specification.",
ConversationCommand.ListFiles: "List files under a given path with optional glob pattern.",
ConversationCommand.RegexSearchFiles: "Search for lines in files matching regex pattern with an optional path prefix.",
}
command_descriptions_for_agent = {
@@ -385,13 +478,186 @@ tool_descriptions_for_llm = {
ConversationCommand.Operator: "To use when you need to operate a computer to complete the task.",
}
tool_description_for_research_llm = {
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents. Max {max_search_queries} search queries allowed per iteration.",
ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed. Max {max_search_queries} search queries allowed per iteration.",
ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share upto {max_webpages_to_read} webpage links and what information to extract from them in your query.",
ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description,
ConversationCommand.Text: "To respond to the user once you've completed your research and have the required information.",
ConversationCommand.Operator: "To operate a computer to complete the task.",
tools_for_research_llm = {
ConversationCommand.SearchWeb: ToolDefinition(
name="search_web",
description="To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed. Max {max_search_queries} search queries allowed per iteration.",
schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search on the internet.",
},
},
"required": ["query"],
},
),
ConversationCommand.ReadWebpage: ToolDefinition(
name="read_webpage",
description="To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share upto {max_webpages_to_read} webpage links and what information to extract from them in your query.",
schema={
"type": "object",
"properties": {
"urls": {
"type": "array",
"items": {
"type": "string",
},
"description": "The webpage URLs to extract information from.",
},
"query": {
"type": "string",
"description": "The query to extract information from the webpages.",
},
},
"required": ["urls", "query"],
},
),
ConversationCommand.RunCode: ToolDefinition(
name="run_code",
description=e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description,
schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Detailed query and all input data required to generate, execute code in the sandbox.",
},
},
"required": ["query"],
},
),
ConversationCommand.OperateComputer: ToolDefinition(
name="operate_computer",
description="To operate a computer to complete the task.",
schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The task to perform on the computer.",
},
},
"required": ["query"],
},
),
ConversationCommand.ViewFile: ToolDefinition(
name="view_file",
description=dedent(
"""
To view the contents of specific note or document in the user's personal knowledge base.
Especially helpful if the question expects context from the user's notes or documents.
It can be used after finding the document path with the document search tool.
Optionally specify a line range to view only specific sections of large files.
"""
).strip(),
schema={
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The file path to view (can be absolute or relative).",
},
"start_line": {
"type": "integer",
"description": "Optional starting line number for viewing a specific range (1-indexed).",
},
"end_line": {
"type": "integer",
"description": "Optional ending line number for viewing a specific range (1-indexed).",
},
},
"required": ["path"],
},
),
ConversationCommand.ListFiles: ToolDefinition(
name="list_files",
description=dedent(
"""
To list files in the user's knowledge base.
Use the path parameter to only show files under the specified path.
"""
).strip(),
schema={
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The directory path to list files from.",
},
"pattern": {
"type": "string",
"description": "Optional glob pattern to filter files (e.g., '*.md').",
},
},
},
),
ConversationCommand.SemanticSearchFiles: ToolDefinition(
name="semantic_search_files",
description=dedent(
"""
To have the tool AI semantic search through the user's knowledge base.
Helpful to answer questions for which finding some relevant notes or documents can complete the search. Example: "When was Tom born?"
This tool AI cannot find all relevant notes or documents, only a subset of them.
It is a good starting point to find keywords, discover similar topics or related concepts and some relevant notes or documents.
The tool AI can perform a maximum of {max_search_queries} semantic search queries per iteration.
"""
).strip(),
schema={
"type": "object",
"properties": {
"q": {
"type": "string",
"description": "Your natural language query for the tool to search in the user's knowledge base.",
},
},
"required": ["q"],
},
),
ConversationCommand.RegexSearchFiles: ToolDefinition(
name="regex_search_files",
description=dedent(
"""
To search through the user's knowledge base using regex patterns. Returns all lines matching the pattern.
Helpful to answer questions for which all relevant notes or documents are needed to complete the search. Example: "Notes that mention Tom".
You need to know all the correct keywords or regex patterns for this tool to be useful.
REMEMBER:
- The regex pattern will ONLY match content on a single line. Multi-line matches are NOT supported (even if you use \\n).
An optional path prefix can restrict search to specific files/directories.
Use lines_before, lines_after to show context around matches.
"""
).strip(),
schema={
"type": "object",
"properties": {
"regex_pattern": {
"type": "string",
"description": "The regex pattern to search for content in the user's files.",
},
"path_prefix": {
"type": "string",
"description": "Optional path prefix to limit the search to files under a specified path.",
},
"lines_before": {
"type": "integer",
"description": "Optional number of lines to show before each line match for context.",
"minimum": 0,
"maximum": 20,
},
"lines_after": {
"type": "integer",
"description": "Optional number of lines to show after each line match for context.",
"minimum": 0,
"maximum": 20,
},
},
"required": ["regex_pattern"],
},
),
}
mode_descriptions_for_llm = {
@@ -850,3 +1116,13 @@ def clean_object_for_db(data):
return [clean_object_for_db(item) for item in data]
else:
return data
def dict_to_tuple(d):
# Recursively convert dicts to sorted tuples for hashability
if isinstance(d, dict):
return tuple(sorted((k, dict_to_tuple(v)) for k, v in d.items()))
elif isinstance(d, list):
return tuple(dict_to_tuple(i) for i in d)
else:
return d
+12 -9
View File
@@ -48,17 +48,18 @@ class TestTruncateMessage:
big_chat_message = ChatMessage(role="user", content=content_list)
copy_big_chat_message = deepcopy(big_chat_message)
chat_history = [big_chat_message]
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
def test_truncate_message_with_content_list(self):
# Arrange
@@ -68,11 +69,11 @@ class TestTruncateMessage:
big_chat_message = ChatMessage(role="user", content=content_list)
copy_big_chat_message = deepcopy(big_chat_message)
chat_history.insert(0, big_chat_message)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
@@ -83,7 +84,8 @@ class TestTruncateMessage:
copy_big_chat_message.content
), "message content list should be modified"
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
def test_truncate_message_first_large(self):
# Arrange
@@ -91,11 +93,11 @@ class TestTruncateMessage:
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.copy()
chat_history.insert(0, big_chat_message)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
@@ -104,7 +106,8 @@ class TestTruncateMessage:
), "Only most recent message should be present as it itself is larger than context size"
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
def test_truncate_message_large_system_message_first(self):
# Arrange
+11 -11
View File
@@ -189,7 +189,7 @@ async def test_chat_with_no_chat_history_or_retrieved_content():
user_query="Hello, my name is Testatron. Who are you?",
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
expected_responses = ["Khoj", "khoj"]
@@ -217,7 +217,7 @@ async def test_answer_from_chat_history_and_no_content():
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
expected_responses = ["Testatron", "testatron"]
@@ -250,7 +250,7 @@ async def test_answer_from_chat_history_and_previously_retrieved_content():
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
assert len(response) > 0
@@ -279,7 +279,7 @@ async def test_answer_from_chat_history_and_currently_retrieved_content():
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
assert len(response) > 0
@@ -305,7 +305,7 @@ async def test_refuse_answering_unanswerable_question():
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
expected_responses = [
@@ -359,7 +359,7 @@ Expenses:Food:Dining 10.00 USD""",
user_query="What did I have for Dinner today?",
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
expected_responses = ["tacos", "Tacos"]
@@ -405,7 +405,7 @@ Expenses:Food:Dining 10.00 USD""",
user_query="How much did I spend on dining this year?",
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
assert len(response) > 0
@@ -432,7 +432,7 @@ async def test_answer_general_question_not_in_chat_history_or_retrieved_content(
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
expected_responses = ["test", "bug", "code"]
@@ -473,7 +473,7 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
user_query="How many kids does my older sister have?",
api_key=api_key,
)
response = "".join([response_chunk.response async for response_chunk in response_gen])
response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert
expected_responses = [
@@ -508,14 +508,14 @@ async def test_agent_prompt_should_be_used(openai_agent):
user_query="What did I buy?",
api_key=api_key,
)
no_agent_response = "".join([response_chunk.response async for response_chunk in response_gen])
no_agent_response = "".join([response_chunk.text async for response_chunk in response_gen])
response_gen = converse_openai(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I buy?",
api_key=api_key,
agent=openai_agent,
)
agent_response = "".join([response_chunk.response async for response_chunk in response_gen])
agent_response = "".join([response_chunk.text async for response_chunk in response_gen])
# Assert that the model without the agent prompt does not include the summary of purchases
assert all([expected_response not in no_agent_response for expected_response in expected_responses]), (