Improve Research Mode Context Management (#1179)

### Major
* Do more granular truncation on hitting context limits
* Pack research iterations as list of message content instead of
separate messages
* Update message truncation logic to truncate items in message content
list
* Make researcher aware of number of web, doc queries allowed per
iteration

### Minor
* Prompt web page reader to extract quantitative data as is from pages
* Track gemini 2.0 flash lite cost. Reduce max prompt size for 4o-mini
* Ensure time to first token logged only once per chat response
* Upgrade tenacity to respect min_time passed to exponential backoff
with jitter function
This commit is contained in:
Debanjum
2025-05-17 17:38:31 -07:00
committed by GitHub
19 changed files with 363 additions and 152 deletions

View File

@@ -44,7 +44,7 @@ dependencies = [
"jinja2 == 3.1.6", "jinja2 == 3.1.6",
"openai >= 1.0.0", "openai >= 1.0.0",
"tiktoken >= 0.3.2", "tiktoken >= 0.3.2",
"tenacity >= 8.2.2", "tenacity >= 9.0.0",
"magika ~= 0.5.1", "magika ~= 0.5.1",
"pillow ~= 10.0.0", "pillow ~= 10.0.0",
"pydantic[email] >= 2.0.0", "pydantic[email] >= 2.0.0",
@@ -57,10 +57,9 @@ dependencies = [
"torch == 2.6.0", "torch == 2.6.0",
"uvicorn == 0.30.6", "uvicorn == 0.30.6",
"aiohttp ~= 3.9.0", "aiohttp ~= 3.9.0",
"langchain == 0.2.5", "langchain-text-splitters == 0.3.1",
"langchain-community == 0.2.5", "langchain-community == 0.3.3",
"requests >= 2.26.0", "requests >= 2.26.0",
"tenacity == 8.3.0",
"anyio ~= 4.8.0", "anyio ~= 4.8.0",
"pymupdf == 1.24.11", "pymupdf == 1.24.11",
"django == 5.1.8", "django == 5.1.8",

View File

@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from itertools import repeat from itertools import repeat
from typing import Any, Callable, List, Set, Tuple from typing import Any, Callable, List, Set, Tuple
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter
from tqdm import tqdm from tqdm import tqdm
from khoj.database.adapters import ( from khoj.database.adapters import (

View File

@@ -4,7 +4,7 @@ from datetime import datetime, timedelta
from typing import AsyncGenerator, Dict, List, Optional from typing import AsyncGenerator, Dict, List, Optional
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain_core.messages.chat import ChatMessage
from khoj.database.models import Agent, ChatModel, KhojUser from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts

View File

@@ -3,7 +3,7 @@ from time import perf_counter
from typing import Dict, List from typing import Dict, List
import anthropic import anthropic
from langchain.schema import ChatMessage from langchain_core.messages.chat import ChatMessage
from tenacity import ( from tenacity import (
before_sleep_log, before_sleep_log,
retry, retry,
@@ -144,6 +144,7 @@ async def anthropic_chat_completion_with_backoff(
formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
aggregated_response = "" aggregated_response = ""
response_started = False
final_message = None final_message = None
start_time = perf_counter() start_time = perf_counter()
async with client.messages.stream( async with client.messages.stream(
@@ -157,7 +158,8 @@ async def anthropic_chat_completion_with_backoff(
) as stream: ) as stream:
async for chunk in stream: async for chunk in stream:
# Log the time taken to start response # Log the time taken to start response
if aggregated_response == "": if not response_started:
response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Skip empty chunks # Skip empty chunks
if chunk.type != "content_block_delta": if chunk.type != "content_block_delta":
@@ -203,6 +205,9 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
system_prompt = system_prompt or "" system_prompt = system_prompt or ""
for message in messages.copy(): for message in messages.copy():
if message.role == "system": if message.role == "system":
if isinstance(message.content, list):
system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"])
else:
system_prompt += message.content system_prompt += message.content
messages.remove(message) messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt system_prompt = None if is_none_or_empty(system_prompt) else system_prompt

View File

@@ -4,7 +4,7 @@ from datetime import datetime, timedelta
from typing import AsyncGenerator, Dict, List, Optional from typing import AsyncGenerator, Dict, List, Optional
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain_core.messages.chat import ChatMessage
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from khoj.database.models import Agent, ChatModel, KhojUser from khoj.database.models import Agent, ChatModel, KhojUser

View File

@@ -9,7 +9,7 @@ import httpx
from google import genai from google import genai
from google.genai import errors as gerrors from google.genai import errors as gerrors
from google.genai import types as gtypes from google.genai import types as gtypes
from langchain.schema import ChatMessage from langchain_core.messages.chat import ChatMessage
from pydantic import BaseModel from pydantic import BaseModel
from tenacity import ( from tenacity import (
before_sleep_log, before_sleep_log,
@@ -195,13 +195,15 @@ async def gemini_chat_completion_with_backoff(
aggregated_response = "" aggregated_response = ""
final_chunk = None final_chunk = None
response_started = False
start_time = perf_counter() start_time = perf_counter()
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream( chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
model=model_name, config=config, contents=formatted_messages model=model_name, config=config, contents=formatted_messages
) )
async for chunk in chat_stream: async for chunk in chat_stream:
# Log the time taken to start response # Log the time taken to start response
if final_chunk is None: if not response_started:
response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data # Keep track of the last chunk for usage data
final_chunk = chunk final_chunk = chunk
@@ -301,6 +303,9 @@ def format_messages_for_gemini(
messages = deepcopy(original_messages) messages = deepcopy(original_messages)
for message in messages.copy(): for message in messages.copy():
if message.role == "system": if message.role == "system":
if isinstance(message.content, list):
system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"])
else:
system_prompt += message.content system_prompt += message.content
messages.remove(message) messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt system_prompt = None if is_none_or_empty(system_prompt) else system_prompt

View File

@@ -7,7 +7,7 @@ from time import perf_counter
from typing import Any, AsyncGenerator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain_core.messages.chat import ChatMessage
from llama_cpp import Llama from llama_cpp import Llama
from khoj.database.models import Agent, ChatModel, KhojUser from khoj.database.models import Agent, ChatModel, KhojUser

View File

@@ -4,7 +4,7 @@ from datetime import datetime, timedelta
from typing import AsyncGenerator, Dict, List, Optional from typing import AsyncGenerator, Dict, List, Optional
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain_core.messages.chat import ChatMessage
from openai.lib._pydantic import _ensure_strict_json_schema from openai.lib._pydantic import _ensure_strict_json_schema
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -226,6 +226,7 @@ async def chat_completion_with_backoff(
aggregated_response = "" aggregated_response = ""
final_chunk = None final_chunk = None
response_started = False
start_time = perf_counter() start_time = perf_counter()
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
messages=formatted_messages, # type: ignore messages=formatted_messages, # type: ignore
@@ -237,7 +238,8 @@ async def chat_completion_with_backoff(
) )
async for chunk in stream_processor(chat_stream): async for chunk in stream_processor(chat_stream):
# Log the time taken to start response # Log the time taken to start response
if final_chunk is None: if not response_started:
response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data # Keep track of the last chunk for usage data
final_chunk = chunk final_chunk = chunk

View File

@@ -1,4 +1,4 @@
from langchain.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
## Personality ## Personality
## -- ## --
@@ -666,21 +666,25 @@ As a professional analyst, your job is to extract all pertinent information from
You will be provided raw text directly from within the document. You will be provided raw text directly from within the document.
Adhere to these guidelines while extracting information from the provided documents: Adhere to these guidelines while extracting information from the provided documents:
1. Extract all relevant text and links from the document that can assist with further research or answer the user's query. 1. Extract all relevant text and links from the document that can assist with further research or answer the target query.
2. Craft a comprehensive but compact report with all the necessary data from the document to generate an informed response. 2. Craft a comprehensive but compact report with all the necessary data from the document to generate an informed response.
3. Rely strictly on the provided text to generate your summary, without including external information. 3. Rely strictly on the provided text to generate your summary, without including external information.
4. Provide specific, important snippets from the document in your report to establish trust in your summary. 4. Provide specific, important snippets from the document in your report to establish trust in your summary.
5. Verbatim quote all necessary text, code or data from the provided document to answer the target query.
""".strip() """.strip()
extract_relevant_information = PromptTemplate.from_template( extract_relevant_information = PromptTemplate.from_template(
""" """
{personality_context} {personality_context}
Target Query: {query} <target_query>
{query}
</target_query>
Document: <document>
{corpus} {corpus}
</document>
Collate only relevant information from the document to answer the target query. Collate all relevant information from the document to answer the target query.
""".strip() """.strip()
) )
@@ -758,29 +762,32 @@ Assuming you can search the user's notes and the internet.
- User Name: {username} - User Name: {username}
# Available Tool AIs # Available Tool AIs
Which of the tool AIs listed below would you use to answer the user's question? You **only** have access to the following tool AIs: You decide which of the tool AIs listed below would you use to answer the user's question. You **only** have access to the following tool AIs:
{tools} {tools}
# Previous Iterations Your response should always be a valid JSON object. Do not say anything else.
{previous_iterations}
# Chat History:
{chat_history}
Return the next tool AI to use and the query to ask it. Your response should always be a valid JSON object. Do not say anything else.
Response format: Response format:
{{"scratchpad": "<your_scratchpad_to_reason_about_which_tool_to_use>", "tool": "<name_of_tool_ai>", "query": "<your_detailed_query_for_the_tool_ai>"}} {{"scratchpad": "<your_scratchpad_to_reason_about_which_tool_to_use>", "tool": "<name_of_tool_ai>", "query": "<your_detailed_query_for_the_tool_ai>"}}
""".strip() """.strip()
) )
plan_function_execution_next_tool = PromptTemplate.from_template(
"""
Given the results of your previous iterations, which tool AI will you use next to answer the target query?
# Target Query:
{query}
""".strip()
)
previous_iteration = PromptTemplate.from_template( previous_iteration = PromptTemplate.from_template(
""" """
## Iteration {index}: # Iteration {index}:
- tool: {tool} - tool: {tool}
- query: {query} - query: {query}
- result: {result} - result: {result}
""" """.strip()
) )
pick_relevant_tools = PromptTemplate.from_template( pick_relevant_tools = PromptTemplate.from_template(
@@ -858,8 +865,7 @@ infer_webpages_to_read = PromptTemplate.from_template(
You are Khoj, an advanced web page reading assistant. You are to construct **up to {max_webpages}, valid** webpage urls to read before answering the user's question. You are Khoj, an advanced web page reading assistant. You are to construct **up to {max_webpages}, valid** webpage urls to read before answering the user's question.
- You will receive the conversation history as context. - You will receive the conversation history as context.
- Add as much context from the previous questions and answers as required to construct the webpage urls. - Add as much context from the previous questions and answers as required to construct the webpage urls.
- Use multiple web page urls if required to retrieve the relevant information. - You have access to the whole internet to retrieve information.
- You have access to the the whole internet to retrieve information.
{personality_context} {personality_context}
Which webpages will you need to read to answer the user's question? Which webpages will you need to read to answer the user's question?
Provide web page links as a list of strings in a JSON object. Provide web page links as a list of strings in a JSON object.

View File

@@ -4,14 +4,12 @@ import logging
import math import math
import mimetypes import mimetypes
import os import os
import queue
import re import re
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from time import perf_counter
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import PIL.Image import PIL.Image
@@ -19,9 +17,10 @@ import pyjson5
import requests import requests
import tiktoken import tiktoken
import yaml import yaml
from langchain.schema import ChatMessage from langchain_core.messages.chat import ChatMessage
from llama_cpp import LlamaTokenizer
from llama_cpp.llama import Llama from llama_cpp.llama import Llama
from transformers import AutoTokenizer from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from khoj.database.adapters import ConversationAdapters from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ChatModel, ClientApplication, KhojUser from khoj.database.models import ChatModel, ClientApplication, KhojUser
@@ -52,7 +51,7 @@ except ImportError:
model_to_prompt_size = { model_to_prompt_size = {
# OpenAI Models # OpenAI Models
"gpt-4o": 60000, "gpt-4o": 60000,
"gpt-4o-mini": 120000, "gpt-4o-mini": 60000,
"gpt-4.1": 60000, "gpt-4.1": 60000,
"gpt-4.1-mini": 120000, "gpt-4.1-mini": 120000,
"gpt-4.1-nano": 120000, "gpt-4.1-nano": 120000,
@@ -105,9 +104,9 @@ class InformationCollectionIteration:
def construct_iteration_history( def construct_iteration_history(
previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str query: str, previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
) -> str: ) -> list[dict]:
previous_iterations_history = "" previous_iterations_history = []
for idx, iteration in enumerate(previous_iterations): for idx, iteration in enumerate(previous_iterations):
iteration_data = previous_iteration_prompt.format( iteration_data = previous_iteration_prompt.format(
tool=iteration.tool, tool=iteration.tool,
@@ -116,8 +115,23 @@ def construct_iteration_history(
index=idx + 1, index=idx + 1,
) )
previous_iterations_history += iteration_data previous_iterations_history.append(iteration_data)
return previous_iterations_history
return (
[
{
"by": "you",
"message": query,
},
{
"by": "khoj",
"intent": {"type": "remember", "query": query},
"message": previous_iterations_history,
},
]
if previous_iterations_history
else []
)
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
@@ -152,19 +166,35 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
def construct_tool_chat_history( def construct_tool_chat_history(
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
) -> Dict[str, list]: ) -> Dict[str, list]:
"""
Construct chat history from previous iterations for a specific tool
If a tool is provided, only the inferred queries for that tool is added.
If no tool is provided inferred query for all tools used are added.
"""
chat_history: list = [] chat_history: list = []
inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: [] base_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
if tool == ConversationCommand.Notes: extract_inferred_query_map: Dict[ConversationCommand, Callable[[InformationCollectionIteration], List[str]]] = {
inferred_query_extractor = ( ConversationCommand.Notes: (
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else [] lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
) ),
elif tool == ConversationCommand.Online: ConversationCommand.Online: (
inferred_query_extractor = (
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else [] lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
) ),
elif tool == ConversationCommand.Code: ConversationCommand.Webpage: (
inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else [] lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
),
ConversationCommand.Code: (
lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
),
}
for iteration in previous_iterations: for iteration in previous_iterations:
# If a tool is provided use the inferred query extractor for that tool if available
# If no tool is provided, use inferred query extractor for the tool used in the iteration
# Fallback to base extractor if the tool does not have an inferred query extractor
inferred_query_extractor = extract_inferred_query_map.get(
tool or ConversationCommand(iteration.tool), base_extractor
)
chat_history += [ chat_history += [
{ {
"by": "you", "by": "you",
@@ -300,7 +330,11 @@ Khoj: "{chat_response}"
def construct_structured_message( def construct_structured_message(
message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str = None message: list[str] | str,
images: list[str],
model_type: str,
vision_enabled: bool,
attached_file_context: str = None,
): ):
""" """
Format messages into appropriate multimedia format for supported chat model types Format messages into appropriate multimedia format for supported chat model types
@@ -310,10 +344,11 @@ def construct_structured_message(
ChatModel.ModelType.GOOGLE, ChatModel.ModelType.GOOGLE,
ChatModel.ModelType.ANTHROPIC, ChatModel.ModelType.ANTHROPIC,
]: ]:
if not attached_file_context and not (vision_enabled and images): message = [message] if isinstance(message, str) else message
return message
constructed_messages: List[Any] = [{"type": "text", "text": message}] constructed_messages: List[dict[str, Any]] = [
{"type": "text", "text": message_part} for message_part in message
]
if not is_none_or_empty(attached_file_context): if not is_none_or_empty(attached_file_context):
constructed_messages.append({"type": "text", "text": attached_file_context}) constructed_messages.append({"type": "text", "text": attached_file_context})
@@ -346,7 +381,7 @@ def gather_raw_query_files(
def generate_chatml_messages_with_context( def generate_chatml_messages_with_context(
user_message, user_message,
system_message=None, system_message: str = None,
conversation_log={}, conversation_log={},
model_name="gpt-4o-mini", model_name="gpt-4o-mini",
loaded_model: Optional[Llama] = None, loaded_model: Optional[Llama] = None,
@@ -409,6 +444,9 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(chat.get("onlineContext")): if not is_none_or_empty(chat.get("onlineContext")):
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
if not is_none_or_empty(chat.get("codeContext")):
message_context += f"{prompts.code_executed_context.format(online_results=chat.get('codeContext'))}"
if not is_none_or_empty(message_context): if not is_none_or_empty(message_context):
reconstructed_context_message = ChatMessage(content=message_context, role="user") reconstructed_context_message = ChatMessage(content=message_context, role="user")
chatml_messages.insert(0, reconstructed_context_message) chatml_messages.insert(0, reconstructed_context_message)
@@ -441,7 +479,7 @@ def generate_chatml_messages_with_context(
if len(chatml_messages) >= 3 * lookback_turns: if len(chatml_messages) >= 3 * lookback_turns:
break break
messages = [] messages: list[ChatMessage] = []
if not is_none_or_empty(generated_asset_results): if not is_none_or_empty(generated_asset_results):
messages.append( messages.append(
@@ -478,6 +516,11 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(system_message): if not is_none_or_empty(system_message):
messages.append(ChatMessage(content=system_message, role="system")) messages.append(ChatMessage(content=system_message, role="system"))
# Normalize message content to list of chatml dictionaries
for message in messages:
if isinstance(message.content, str):
message.content = [{"type": "text", "text": message.content}]
# Truncate oldest messages from conversation history until under max supported prompt size by model # Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name) messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
@@ -485,14 +528,11 @@ def generate_chatml_messages_with_context(
return messages[::-1] return messages[::-1]
def truncate_messages( def get_encoder(
messages: list[ChatMessage],
max_prompt_size: int,
model_name: str, model_name: str,
loaded_model: Optional[Llama] = None, loaded_model: Optional[Llama] = None,
tokenizer_name=None, tokenizer_name=None,
) -> list[ChatMessage]: ) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer:
"""Truncate messages to fit within max prompt size supported by model"""
default_tokenizer = "gpt-4o" default_tokenizer = "gpt-4o"
try: try:
@@ -515,6 +555,48 @@ def truncate_messages(
logger.debug( logger.debug(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing." f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing."
) )
return encoder
def count_tokens(
message_content: str | list[str | dict],
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer | tiktoken.Encoding,
) -> int:
"""
Count the total number of tokens in a list of messages.
Assumes each images takes 500 tokens for approximation.
"""
if isinstance(message_content, list):
image_count = 0
message_content_parts: list[str] = []
# Collate message content into single string to ease token counting
for part in message_content:
if isinstance(part, dict) and part.get("type") == "text":
message_content_parts.append(part["text"])
elif isinstance(part, dict) and part.get("type") == "image_url":
image_count += 1
elif isinstance(part, str):
message_content_parts.append(part)
else:
logger.warning(f"Unknown message type: {part}. Skipping.")
message_content = "\n".join(message_content_parts).rstrip()
return len(encoder.encode(message_content)) + image_count * 500
elif isinstance(message_content, str):
return len(encoder.encode(message_content))
else:
return len(encoder.encode(json.dumps(message_content)))
def truncate_messages(
messages: list[ChatMessage],
max_prompt_size: int,
model_name: str,
loaded_model: Optional[Llama] = None,
tokenizer_name=None,
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
encoder = get_encoder(model_name, loaded_model, tokenizer_name)
# Extract system message from messages # Extract system message from messages
system_message = None system_message = None
@@ -523,35 +605,55 @@ def truncate_messages(
system_message = messages.pop(idx) system_message = messages.pop(idx)
break break
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
system_message_tokens = (
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
)
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Drop older messages until under max supported prompt size by model # Drop older messages until under max supported prompt size by model
# Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.) # Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
while (tokens + system_message_tokens + 4 * len(messages)) > max_prompt_size and len(messages) > 1: system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0
tokens = sum([count_tokens(message.content, encoder) for message in messages])
total_tokens = tokens + system_message_tokens + 4 * len(messages)
while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1):
if len(messages[-1].content) > 1:
# The oldest content part is earlier in content list. So pop from the front.
messages[-1].content.pop(0)
else:
# The oldest message is the last one. So pop from the back.
messages.pop() messages.pop()
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) tokens = sum([count_tokens(message.content, encoder) for message in messages])
total_tokens = tokens + system_message_tokens + 4 * len(messages)
# Truncate current message if still over max supported prompt size by model # Truncate current message if still over max supported prompt size by model
if (tokens + system_message_tokens) > max_prompt_size: total_tokens = tokens + system_message_tokens + 4 * len(messages)
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else "" if total_tokens > max_prompt_size:
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else "" # At this point, a single message with a single content part of type dict should remain
original_question = f"\n{original_question}" assert (
original_question_tokens = len(encoder.encode(original_question)) len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict)
), "Expected a single message with a single content part remaining at this point in truncation"
# Collate message content into single string to ease truncation
part = messages[0].content[0]
message_content: str = part["text"] if part["type"] == "text" else json.dumps(part)
message_role = messages[0].role
remaining_context = "\n".join(message_content.split("\n")[:-1])
original_question = "\n" + "\n".join(message_content.split("\n")[-1:])
original_question_tokens = count_tokens(original_question, encoder)
remaining_tokens = max_prompt_size - system_message_tokens remaining_tokens = max_prompt_size - system_message_tokens
if remaining_tokens > original_question_tokens: if remaining_tokens > original_question_tokens:
remaining_tokens -= original_question_tokens remaining_tokens -= original_question_tokens
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() truncated_context = encoder.decode(encoder.encode(remaining_context)[:remaining_tokens]).strip()
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)] truncated_content = truncated_context + original_question
else: else:
truncated_message = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip() truncated_content = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip()
messages = [ChatMessage(content=truncated_message, role=messages[0].role)] messages = [ChatMessage(content=[{"type": "text", "text": truncated_content}], role=message_role)]
truncated_snippet = (
f"{truncated_content[:1000]}\n...\n{truncated_content[-1000:]}"
if len(truncated_content) > 2000
else truncated_content
)
logger.debug( logger.debug(
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message[:1000]}..." f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_snippet}"
) )
if system_message: if system_message:

View File

@@ -64,11 +64,12 @@ async def search_online(
user: KhojUser, user: KhojUser,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [], custom_filters: List[str] = [],
max_online_searches: int = 3,
max_webpages_to_read: int = 1, max_webpages_to_read: int = 1,
query_images: List[str] = None, query_images: List[str] = None,
query_files: str = None,
previous_subqueries: Set = set(), previous_subqueries: Set = set(),
agent: Agent = None, agent: Agent = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
query += " ".join(custom_filters) query += " ".join(custom_filters)
@@ -84,9 +85,10 @@ async def search_online(
location, location,
user, user,
query_images=query_images, query_images=query_images,
query_files=query_files,
max_queries=max_online_searches,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
query_files=query_files,
) )
subqueries = list(new_subqueries - previous_subqueries) subqueries = list(new_subqueries - previous_subqueries)
response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {} response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {}

View File

@@ -1129,9 +1129,10 @@ async def chat(
user, user,
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
custom_filters, custom_filters,
max_online_searches=3,
query_images=uploaded_images, query_images=uploaded_images,
agent=agent,
query_files=attached_file_context, query_files=attached_file_context,
agent=agent,
tracer=tracer, tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:

View File

@@ -523,8 +523,9 @@ async def generate_online_subqueries(
location_data: LocationData, location_data: LocationData,
user: KhojUser, user: KhojUser,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None,
query_files: str = None, query_files: str = None,
max_queries: int = 3,
agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
) -> Set[str]: ) -> Set[str]:
""" """
@@ -534,7 +535,6 @@ async def generate_online_subqueries(
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
max_queries = 3
utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d") utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
personality_context = ( personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""

View File

@@ -6,7 +6,6 @@ from enum import Enum
from typing import Callable, Dict, List, Optional, Type from typing import Callable, Dict, List, Optional, Type
import yaml import yaml
from fastapi import Request
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from khoj.database.adapters import AgentAdapters, EntryAdapters from khoj.database.adapters import AgentAdapters, EntryAdapters
@@ -14,7 +13,6 @@ from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
InformationCollectionIteration, InformationCollectionIteration,
construct_chat_history,
construct_iteration_history, construct_iteration_history,
construct_tool_chat_history, construct_tool_chat_history,
load_complex_json, load_complex_json,
@@ -29,9 +27,9 @@ from khoj.routers.helpers import (
) )
from khoj.utils.helpers import ( from khoj.utils.helpers import (
ConversationCommand, ConversationCommand,
function_calling_description_for_llm,
is_none_or_empty, is_none_or_empty,
timer, timer,
tool_description_for_research_llm,
truncate_code_context, truncate_code_context,
) )
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@@ -79,15 +77,18 @@ async def apick_next_tool(
query: str, query: str,
conversation_history: dict, conversation_history: dict,
user: KhojUser = None, user: KhojUser = None,
query_images: List[str] = [],
location: LocationData = None, location: LocationData = None,
user_name: str = None, user_name: str = None,
agent: Agent = None, agent: Agent = None,
previous_iterations: List[InformationCollectionIteration] = [], previous_iterations: List[InformationCollectionIteration] = [],
max_iterations: int = 5, max_iterations: int = 5,
query_images: List[str] = [],
query_files: str = None,
max_document_searches: int = 7,
max_online_searches: int = 3,
max_webpages_to_read: int = 1,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
tracer: dict = {}, tracer: dict = {},
query_files: str = None,
): ):
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately.""" """Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
@@ -96,10 +97,16 @@ async def apick_next_tool(
tool_options_str = "" tool_options_str = ""
agent_tools = agent.input_tools if agent else [] agent_tools = agent.input_tools if agent else []
user_has_entries = await EntryAdapters.auser_has_entries(user) user_has_entries = await EntryAdapters.auser_has_entries(user)
for tool, description in function_calling_description_for_llm.items(): for tool, description in tool_description_for_research_llm.items():
# Skip showing Notes tool as an option if user has no entries # Skip showing Notes tool as an option if user has no entries
if tool == ConversationCommand.Notes and not user_has_entries: if tool == ConversationCommand.Notes:
if not user_has_entries:
continue continue
description = description.format(max_search_queries=max_document_searches)
if tool == ConversationCommand.Webpage:
description = description.format(max_webpages_to_read=max_webpages_to_read)
if tool == ConversationCommand.Online:
description = description.format(max_search_queries=max_online_searches)
# Add tool if agent does not have any tools defined or the tool is supported by the agent. # Add tool if agent does not have any tools defined or the tool is supported by the agent.
if len(agent_tools) == 0 or tool.value in agent_tools: if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options[tool.name] = tool.value tool_options[tool.name] = tool.value
@@ -108,13 +115,6 @@ async def apick_next_tool(
# Create planning reponse model with dynamically populated tool enum class # Create planning reponse model with dynamically populated tool enum class
planning_response_model = PlanningResponse.create_model_with_enum(tool_options) planning_response_model = PlanningResponse.create_model_with_enum(tool_options)
# Construct chat history with user and iteration history with researcher agent for context
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
if query_images:
query = f"[placeholder for user attached images]\n{query}"
today = datetime.today() today = datetime.today()
location_data = f"{location}" if location else "Unknown" location_data = f"{location}" if location else "Unknown"
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
@@ -124,21 +124,30 @@ async def apick_next_tool(
function_planning_prompt = prompts.plan_function_execution.format( function_planning_prompt = prompts.plan_function_execution.format(
tools=tool_options_str, tools=tool_options_str,
chat_history=chat_history,
personality_context=personality_context, personality_context=personality_context,
current_date=today.strftime("%Y-%m-%d"), current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"), day_of_week=today.strftime("%A"),
username=user_name or "Unknown", username=user_name or "Unknown",
location=location_data, location=location_data,
previous_iterations=previous_iterations_history,
max_iterations=max_iterations, max_iterations=max_iterations,
) )
if query_images:
query = f"[placeholder for user attached images]\n{query}"
# Construct chat history with user and iteration history with researcher agent for context
previous_iterations_history = construct_iteration_history(query, previous_iterations, prompts.previous_iteration)
iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history}
# Plan function execution for the next tool
query = prompts.plan_function_execution_next_tool.format(query=query) if previous_iterations_history else query
try: try:
with timer("Chat actor: Infer information sources to refer", logger): with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
query=query, query=query,
context=function_planning_prompt, system_message=function_planning_prompt,
conversation_log=iteration_chat_log,
response_type="json_object", response_type="json_object",
response_schema=planning_response_model, response_schema=planning_response_model,
deepthought=True, deepthought=True,
@@ -208,6 +217,9 @@ async def execute_information_collection(
query_files: str = None, query_files: str = None,
cancellation_event: Optional[asyncio.Event] = None, cancellation_event: Optional[asyncio.Event] = None,
): ):
max_document_searches = 7
max_online_searches = 3
max_webpages_to_read = 1
current_iteration = 0 current_iteration = 0
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5)) MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
previous_iterations: List[InformationCollectionIteration] = [] previous_iterations: List[InformationCollectionIteration] = []
@@ -227,15 +239,18 @@ async def execute_information_collection(
query, query,
conversation_history, conversation_history,
user, user,
query_images,
location, location,
user_name, user_name,
agent, agent,
previous_iterations, previous_iterations,
MAX_ITERATIONS, MAX_ITERATIONS,
send_status_func, query_images=query_images,
tracer=tracer,
query_files=query_files, query_files=query_files,
max_document_searches=max_document_searches,
max_online_searches=max_online_searches,
max_webpages_to_read=max_webpages_to_read,
send_status_func=send_status_func,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@@ -260,7 +275,7 @@ async def execute_information_collection(
user, user,
construct_tool_chat_history(previous_iterations, ConversationCommand.Notes), construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
this_iteration.query, this_iteration.query,
7, max_document_searches,
None, None,
conversation_id, conversation_id,
[ConversationCommand.Default], [ConversationCommand.Default],
@@ -307,6 +322,7 @@ async def execute_information_collection(
user, user,
send_status_func, send_status_func,
[], [],
max_online_searches=max_online_searches,
max_webpages_to_read=0, max_webpages_to_read=0,
query_images=query_images, query_images=query_images,
previous_subqueries=previous_subqueries, previous_subqueries=previous_subqueries,
@@ -332,7 +348,7 @@ async def execute_information_collection(
location, location,
user, user,
send_status_func, send_status_func,
max_webpages_to_read=1, max_webpages_to_read=max_webpages_to_read,
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
@@ -361,7 +377,7 @@ async def execute_information_collection(
try: try:
async for result in run_code( async for result in run_code(
this_iteration.query, this_iteration.query,
construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage), construct_tool_chat_history(previous_iterations, ConversationCommand.Code),
"", "",
location, location,
user, user,
@@ -388,7 +404,7 @@ async def execute_information_collection(
this_iteration.query, this_iteration.query,
user, user,
file_filters, file_filters,
construct_tool_chat_history(previous_iterations), construct_tool_chat_history(previous_iterations, ConversationCommand.Summarize),
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,
send_status_func=send_status_func, send_status_func=send_status_func,

View File

@@ -52,6 +52,7 @@ model_to_cost: Dict[str, Dict[str, float]] = {
"gemini-1.5-pro": {"input": 1.25, "output": 5.00}, "gemini-1.5-pro": {"input": 1.25, "output": 5.00},
"gemini-1.5-pro-002": {"input": 1.25, "output": 5.00}, "gemini-1.5-pro-002": {"input": 1.25, "output": 5.00},
"gemini-2.0-flash": {"input": 0.10, "output": 0.40}, "gemini-2.0-flash": {"input": 0.10, "output": 0.40},
"gemini-2.0-flash-lite": {"input": 0.0075, "output": 0.30},
"gemini-2.5-flash-preview-04-17": {"input": 0.15, "output": 0.60, "thought": 3.50}, "gemini-2.5-flash-preview-04-17": {"input": 0.15, "output": 0.60, "thought": 3.50},
"gemini-2.5-pro-preview-03-25": {"input": 1.25, "output": 10.0}, "gemini-2.5-pro-preview-03-25": {"input": 1.25, "output": 10.0},
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api # Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api

View File

@@ -386,10 +386,10 @@ tool_descriptions_for_llm = {
ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description, ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description,
} }
function_calling_description_for_llm = { tool_description_for_research_llm = {
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.", ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents. Max {max_search_queries} search queries allowed per iteration.",
ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed.", ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed. Max {max_search_queries} search queries allowed per iteration.",
ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage links and information to extract in your query.", ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share upto {max_webpages_to_read} webpage links and what information to extract from them in your query.",
ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description, ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description,
ConversationCommand.Text: "To respond to the user once you've completed your research and have the required information.", ConversationCommand.Text: "To respond to the user once you've completed your research and have the required information.",
} }

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, List
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
from openai import OpenAI from openai import OpenAI
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from whisper import Whisper from whisper import Whisper
from khoj.database.models import ProcessLock from khoj.database.models import ProcessLock
@@ -40,7 +41,7 @@ khoj_version: str = None
device = get_device() device = get_device()
chat_on_gpu: bool = True chat_on_gpu: bool = True
anonymous_mode: bool = False anonymous_mode: bool = False
pretrained_tokenizers: Dict[str, Any] = dict() pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict()
billing_enabled: bool = ( billing_enabled: bool = (
os.getenv("STRIPE_API_KEY") is not None os.getenv("STRIPE_API_KEY") is not None
and os.getenv("STRIPE_SIGNING_SECRET") is not None and os.getenv("STRIPE_SIGNING_SECRET") is not None

View File

@@ -1,11 +1,13 @@
from copy import deepcopy
import tiktoken import tiktoken
from langchain.schema import ChatMessage from langchain_core.messages.chat import ChatMessage
from khoj.processor.conversation import utils from khoj.processor.conversation import utils
class TestTruncateMessage: class TestTruncateMessage:
max_prompt_size = 10 max_prompt_size = 40
model_name = "gpt-4o-mini" model_name = "gpt-4o-mini"
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
@@ -15,45 +17,108 @@ class TestTruncateMessage:
# Act # Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert # Assert
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert len(chat_history) < 50 assert len(chat_history) < 50
assert len(chat_history) > 1 assert len(chat_history) > 5
assert tokens <= self.max_prompt_size assert tokens <= self.max_prompt_size
def test_truncate_message_only_oldest_big(self):
# Arrange
chat_history = generate_chat_history(5)
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
chat_history.append(big_chat_message)
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_history) == 5
assert tokens <= self.max_prompt_size
def test_truncate_message_with_image(self):
# Arrange
image_content_item = {"type": "image_url", "image_url": {"url": "placeholder"}}
content_list = [{"type": "text", "text": f"{index}"} for index in range(100)]
content_list += [image_content_item, {"type": "text", "text": "Question?"}]
big_chat_message = ChatMessage(role="user", content=content_list)
copy_big_chat_message = deepcopy(big_chat_message)
chat_history = [big_chat_message]
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
def test_truncate_message_with_content_list(self):
# Arrange
chat_history = generate_chat_history(5)
content_list = [{"type": "text", "text": f"{index}"} for index in range(100)]
content_list += [{"type": "text", "text": "Question?"}]
big_chat_message = ChatMessage(role="user", content=content_list)
copy_big_chat_message = deepcopy(big_chat_message)
chat_history.insert(0, big_chat_message)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert (
len(chat_history) == 1
), "Only most recent message should be present as it itself is larger than context size"
assert len(truncated_chat_history[0].content) < len(
copy_big_chat_message.content
), "message content list should be modified"
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
def test_truncate_message_first_large(self): def test_truncate_message_first_large(self):
# Arrange # Arrange
chat_history = generate_chat_history(5) chat_history = generate_chat_history(5)
big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?") big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_history.insert(0, big_chat_message) chat_history.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act # Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert # Assert
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert len(chat_history) == 1 assert (
assert truncated_chat_history[0] != copy_big_chat_message len(chat_history) == 1
assert tokens <= self.max_prompt_size ), "Only most recent message should be present as it itself is larger than context size"
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
def test_truncate_message_last_large(self): def test_truncate_message_large_system_message_first(self):
# Arrange # Arrange
chat_history = generate_chat_history(5) chat_history = generate_chat_history(5)
chat_history[0].role = "system" # Mark the first message as system message chat_history[0].role = "system" # Mark the first message as system message
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?") big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_history.insert(0, big_chat_message) chat_history.insert(0, big_chat_message)
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act # Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert # Assert
# The original object has been modified. Verify certain properties. # The original object has been modified. Verify certain properties.
@@ -62,46 +127,52 @@ class TestTruncateMessage:
) # Because the system_prompt is popped off from the chat_messages list ) # Because the system_prompt is popped off from the chat_messages list
assert len(truncated_chat_history) < 10 assert len(truncated_chat_history) < 10
assert len(truncated_chat_history) > 1 assert len(truncated_chat_history) > 1
assert truncated_chat_history[0] != copy_big_chat_message assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert initial_tokens > self.max_prompt_size assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert final_tokens <= self.max_prompt_size assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
def test_truncate_single_large_non_system_message(self): def test_truncate_single_large_non_system_message(self):
# Arrange # Arrange
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?") big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message] chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages])
# Act # Act
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert # Assert
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
assert len(chat_messages) == 1 assert (
assert truncated_chat_history[0] != copy_big_chat_message len(chat_messages) == 1
), "Only most recent message should be present as it itself is larger than context size"
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
def test_truncate_single_large_question(self): def test_truncate_single_large_question(self):
# Arrange # Arrange
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1)) big_chat_message_content = [{"type": "text", "text": " ".join(["hi"] * (self.max_prompt_size + 1))}]
big_chat_message = ChatMessage(role="user", content=big_chat_message_content) big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message] chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages])
# Act # Act
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert # Assert
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
assert len(chat_messages) == 1 assert (
assert truncated_chat_history[0] != copy_big_chat_message len(chat_messages) == 1
), "Only most recent message should be present as it itself is larger than context size"
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
def test_load_complex_raw_json_string(): def test_load_complex_raw_json_string():
@@ -116,12 +187,12 @@ def test_load_complex_raw_json_string():
assert parsed_json == expeced_json assert parsed_json == expeced_json
def generate_content(count): def generate_content(count, suffix=""):
return " ".join([f"{index}" for index, _ in enumerate(range(count))]) return [{"type": "text", "text": " ".join([f"{index}" for index, _ in enumerate(range(count))]) + "\n" + suffix}]
def generate_chat_history(count): def generate_chat_history(count):
return [ return [
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}") ChatMessage(role="user" if index % 2 == 0 else "assistant", content=[{"type": "text", "text": f"{index}"}])
for index, _ in enumerate(range(count)) for index, _ in enumerate(range(count))
] ]