From 99a230524645b1e43a8a93d36118f048657b36fe Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 10 May 2025 16:15:12 -0600 Subject: [PATCH 1/8] Improve tool chat history constructor and fix its usage during research. Code tool should see code context and webpage tool should see online context during research runs Fix to include code context from past conversations to answer queries. Add all queries to tool chat history when no specific tool to limit extracting inferred queries for provided. --- src/khoj/processor/conversation/utils.py | 37 ++++++++++++++++++------ src/khoj/routers/research.py | 4 +-- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e86834f9..7901f29c 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -152,19 +152,35 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A def construct_tool_chat_history( previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None ) -> Dict[str, list]: + """ + Construct chat history from previous iterations for a specific tool + + If a tool is provided, only the inferred queries for that tool is added. + If no tool is provided inferred query for all tools used are added. + """ chat_history: list = [] - inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: [] - if tool == ConversationCommand.Notes: - inferred_query_extractor = ( + base_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: [] + extract_inferred_query_map: Dict[ConversationCommand, Callable[[InformationCollectionIteration], List[str]]] = { + ConversationCommand.Notes: ( lambda iteration: [c["query"] for c in iteration.context] if iteration.context else [] - ) - elif tool == ConversationCommand.Online: - inferred_query_extractor = ( + ), + ConversationCommand.Online: ( lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else [] - ) - elif tool == ConversationCommand.Code: - inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else [] + ), + ConversationCommand.Webpage: ( + lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else [] + ), + ConversationCommand.Code: ( + lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else [] + ), + } for iteration in previous_iterations: + # If a tool is provided use the inferred query extractor for that tool if available + # If no tool is provided, use inferred query extractor for the tool used in the iteration + # Fallback to base extractor if the tool does not have an inferred query extractor + inferred_query_extractor = extract_inferred_query_map.get( + tool or ConversationCommand(iteration.tool), base_extractor + ) chat_history += [ { "by": "you", @@ -409,6 +425,9 @@ def generate_chatml_messages_with_context( if not is_none_or_empty(chat.get("onlineContext")): message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" + if not is_none_or_empty(chat.get("codeContext")): + message_context += f"{prompts.code_executed_context.format(online_results=chat.get('codeContext'))}" + if not is_none_or_empty(message_context): reconstructed_context_message = ChatMessage(content=message_context, role="user") chatml_messages.insert(0, reconstructed_context_message) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 9fb7c229..86f63b2d 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -361,7 +361,7 @@ async def execute_information_collection( try: async for result in run_code( this_iteration.query, - construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage), + construct_tool_chat_history(previous_iterations, ConversationCommand.Code), "", location, user, @@ -388,7 +388,7 @@ async def execute_information_collection( this_iteration.query, user, file_filters, - construct_tool_chat_history(previous_iterations), + construct_tool_chat_history(previous_iterations, ConversationCommand.Summarize), query_images=query_images, agent=agent, send_status_func=send_status_func, From 0f53a67837d4dd9a50b6d88e5b2fc8b1f5562451 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sun, 11 May 2025 15:29:38 -0600 Subject: [PATCH 2/8] Prompt web page reader to extract quantitative data as is from pages Previously the research agent would have a hard time getting quantitative data extracted by the web page reader tool AI. This change aims to encourage the web page reader tool to extract relevant data in verbatim form for higher granularity research and responses. --- src/khoj/processor/conversation/prompts.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index b0cec27b..6ec2376b 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -666,21 +666,25 @@ As a professional analyst, your job is to extract all pertinent information from You will be provided raw text directly from within the document. Adhere to these guidelines while extracting information from the provided documents: -1. Extract all relevant text and links from the document that can assist with further research or answer the user's query. +1. Extract all relevant text and links from the document that can assist with further research or answer the target query. 2. Craft a comprehensive but compact report with all the necessary data from the document to generate an informed response. 3. Rely strictly on the provided text to generate your summary, without including external information. 4. Provide specific, important snippets from the document in your report to establish trust in your summary. +5. Verbatim quote all necessary text, code or data from the provided document to answer the target query. """.strip() extract_relevant_information = PromptTemplate.from_template( """ {personality_context} -Target Query: {query} + +{query} + -Document: + {corpus} + -Collate only relevant information from the document to answer the target query. +Collate all relevant information from the document to answer the target query. """.strip() ) From a337d9e4b828c994a94718e4c09006c9449fbb79 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 10 May 2025 15:44:50 -0600 Subject: [PATCH 3/8] Structure research iteration msgs for more granular context management Previously research iterations and conversation logs were added to a single user message. This prevented truncating each past iteration separately on hitting context limits. So the whole past research context had to be dropped on hitting context limits. This change splits each research iteration into a separate item in a message content list. It uses the ability for message content to be a list, that is supported by all major ai model apis like openai, anthropic and gemini. The change in message format seen by pick next tool chat actor: - New Format - System: System Message - User/Assistant: Chat History - User: Raw Query - Assistant: Iteration History - Iteration 1 - Iteration 2 - User: Query with Pick Next Tool Nudge - Old Format - User: System + Chat History + Previous Iterations Message - User: Query - Collateral Changes The construct_structured_message function has been updated to always return a list[dict[str, Any]]. Previously it'd only use list if attached_file_context or vision model with images for wider compatibility with other openai compatible api --- src/khoj/processor/conversation/prompts.py | 23 +++++++------ src/khoj/processor/conversation/utils.py | 38 +++++++++++++++++----- src/khoj/routers/research.py | 22 +++++++------ 3 files changed, 54 insertions(+), 29 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 6ec2376b..41426294 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -762,29 +762,32 @@ Assuming you can search the user's notes and the internet. - User Name: {username} # Available Tool AIs -Which of the tool AIs listed below would you use to answer the user's question? You **only** have access to the following tool AIs: +You decide which of the tool AIs listed below would you use to answer the user's question. You **only** have access to the following tool AIs: {tools} -# Previous Iterations -{previous_iterations} - -# Chat History: -{chat_history} - -Return the next tool AI to use and the query to ask it. Your response should always be a valid JSON object. Do not say anything else. +Your response should always be a valid JSON object. Do not say anything else. Response format: {{"scratchpad": "", "tool": "", "query": ""}} """.strip() ) +plan_function_execution_next_tool = PromptTemplate.from_template( + """ +Given the results of your previous iterations, which tool AI will you use next to answer the target query? + +# Target Query: +{query} +""".strip() +) + previous_iteration = PromptTemplate.from_template( """ -## Iteration {index}: +# Iteration {index}: - tool: {tool} - query: {query} - result: {result} -""" +""".strip() ) pick_relevant_tools = PromptTemplate.from_template( diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 7901f29c..6e4b62ab 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -105,9 +105,9 @@ class InformationCollectionIteration: def construct_iteration_history( - previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str -) -> str: - previous_iterations_history = "" + query: str, previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str +) -> list[dict]: + previous_iterations_history = [] for idx, iteration in enumerate(previous_iterations): iteration_data = previous_iteration_prompt.format( tool=iteration.tool, @@ -116,8 +116,23 @@ def construct_iteration_history( index=idx + 1, ) - previous_iterations_history += iteration_data - return previous_iterations_history + previous_iterations_history.append(iteration_data) + + return ( + [ + { + "by": "you", + "message": query, + }, + { + "by": "khoj", + "intent": {"type": "remember", "query": query}, + "message": previous_iterations_history, + }, + ] + if previous_iterations_history + else [] + ) def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: @@ -316,7 +331,11 @@ Khoj: "{chat_response}" def construct_structured_message( - message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str = None + message: list[str] | str, + images: list[str], + model_type: str, + vision_enabled: bool, + attached_file_context: str = None, ): """ Format messages into appropriate multimedia format for supported chat model types @@ -326,10 +345,11 @@ def construct_structured_message( ChatModel.ModelType.GOOGLE, ChatModel.ModelType.ANTHROPIC, ]: - if not attached_file_context and not (vision_enabled and images): - return message + message = [message] if isinstance(message, str) else message - constructed_messages: List[Any] = [{"type": "text", "text": message}] + constructed_messages: List[dict[str, Any]] = [ + {"type": "text", "text": message_part} for message_part in message + ] if not is_none_or_empty(attached_file_context): constructed_messages.append({"type": "text", "text": attached_file_context}) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 86f63b2d..62f24282 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -108,13 +108,6 @@ async def apick_next_tool( # Create planning reponse model with dynamically populated tool enum class planning_response_model = PlanningResponse.create_model_with_enum(tool_options) - # Construct chat history with user and iteration history with researcher agent for context - chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj") - previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) - - if query_images: - query = f"[placeholder for user attached images]\n{query}" - today = datetime.today() location_data = f"{location}" if location else "Unknown" agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None @@ -124,21 +117,30 @@ async def apick_next_tool( function_planning_prompt = prompts.plan_function_execution.format( tools=tool_options_str, - chat_history=chat_history, personality_context=personality_context, current_date=today.strftime("%Y-%m-%d"), day_of_week=today.strftime("%A"), username=user_name or "Unknown", location=location_data, - previous_iterations=previous_iterations_history, max_iterations=max_iterations, ) + if query_images: + query = f"[placeholder for user attached images]\n{query}" + + # Construct chat history with user and iteration history with researcher agent for context + previous_iterations_history = construct_iteration_history(query, previous_iterations, prompts.previous_iteration) + iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history} + + # Plan function execution for the next tool + query = prompts.plan_function_execution_next_tool.format(query=query) if previous_iterations_history else query + try: with timer("Chat actor: Infer information sources to refer", logger): response = await send_message_to_model_wrapper( query=query, - context=function_planning_prompt, + system_message=function_planning_prompt, + conversation_log=iteration_chat_log, response_type="json_object", response_schema=planning_response_model, deepthought=True, From 2694734d22c813a8cfd0ba5656e00c83e3a8f10b Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 10 May 2025 18:41:16 -0600 Subject: [PATCH 4/8] Update truncation logic to handle multi-part message content --- .../processor/conversation/anthropic/utils.py | 5 +- .../processor/conversation/google/utils.py | 5 +- src/khoj/processor/conversation/utils.py | 123 ++++++++++++---- src/khoj/utils/state.py | 3 +- tests/test_conversation_utils.py | 137 +++++++++++++----- 5 files changed, 207 insertions(+), 66 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index baf8fade..e436ecda 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -203,7 +203,10 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st system_prompt = system_prompt or "" for message in messages.copy(): if message.role == "system": - system_prompt += message.content + if isinstance(message.content, list): + system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"]) + else: + system_prompt += message.content messages.remove(message) system_prompt = None if is_none_or_empty(system_prompt) else system_prompt diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 9f2be46c..ed37a0b3 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -301,7 +301,10 @@ def format_messages_for_gemini( messages = deepcopy(original_messages) for message in messages.copy(): if message.role == "system": - system_prompt += message.content + if isinstance(message.content, list): + system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"]) + else: + system_prompt += message.content messages.remove(message) system_prompt = None if is_none_or_empty(system_prompt) else system_prompt diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 6e4b62ab..6aaa48e9 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -4,14 +4,12 @@ import logging import math import mimetypes import os -import queue import re import uuid from dataclasses import dataclass from datetime import datetime from enum import Enum from io import BytesIO -from time import perf_counter from typing import Any, Callable, Dict, List, Optional import PIL.Image @@ -20,8 +18,9 @@ import requests import tiktoken import yaml from langchain.schema import ChatMessage +from llama_cpp import LlamaTokenizer from llama_cpp.llama import Llama -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from khoj.database.adapters import ConversationAdapters from khoj.database.models import ChatModel, ClientApplication, KhojUser @@ -382,7 +381,7 @@ def gather_raw_query_files( def generate_chatml_messages_with_context( user_message, - system_message=None, + system_message: str = None, conversation_log={}, model_name="gpt-4o-mini", loaded_model: Optional[Llama] = None, @@ -480,7 +479,7 @@ def generate_chatml_messages_with_context( if len(chatml_messages) >= 3 * lookback_turns: break - messages = [] + messages: list[ChatMessage] = [] if not is_none_or_empty(generated_asset_results): messages.append( @@ -517,6 +516,11 @@ def generate_chatml_messages_with_context( if not is_none_or_empty(system_message): messages.append(ChatMessage(content=system_message, role="system")) + # Normalize message content to list of chatml dictionaries + for message in messages: + if isinstance(message.content, str): + message.content = [{"type": "text", "text": message.content}] + # Truncate oldest messages from conversation history until under max supported prompt size by model messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name) @@ -524,14 +528,11 @@ def generate_chatml_messages_with_context( return messages[::-1] -def truncate_messages( - messages: list[ChatMessage], - max_prompt_size: int, +def get_encoder( model_name: str, loaded_model: Optional[Llama] = None, tokenizer_name=None, -) -> list[ChatMessage]: - """Truncate messages to fit within max prompt size supported by model""" +) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer: default_tokenizer = "gpt-4o" try: @@ -554,6 +555,48 @@ def truncate_messages( logger.debug( f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing." ) + return encoder + + +def count_tokens( + message_content: str | list[str | dict], + encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer | tiktoken.Encoding, +) -> int: + """ + Count the total number of tokens in a list of messages. + + Assumes each images takes 500 tokens for approximation. + """ + if isinstance(message_content, list): + image_count = 0 + message_content_parts: list[str] = [] + # Collate message content into single string to ease token counting + for part in message_content: + if isinstance(part, dict) and part.get("type") == "text": + message_content_parts.append(part["text"]) + elif isinstance(part, dict) and part.get("type") == "image_url": + image_count += 1 + elif isinstance(part, str): + message_content_parts.append(part) + else: + logger.warning(f"Unknown message type: {part}. Skipping.") + message_content = "\n".join(message_content_parts).rstrip() + return len(encoder.encode(message_content)) + image_count * 500 + elif isinstance(message_content, str): + return len(encoder.encode(message_content)) + else: + return len(encoder.encode(json.dumps(message_content))) + + +def truncate_messages( + messages: list[ChatMessage], + max_prompt_size: int, + model_name: str, + loaded_model: Optional[Llama] = None, + tokenizer_name=None, +) -> list[ChatMessage]: + """Truncate messages to fit within max prompt size supported by model""" + encoder = get_encoder(model_name, loaded_model, tokenizer_name) # Extract system message from messages system_message = None @@ -562,35 +605,55 @@ def truncate_messages( system_message = messages.pop(idx) break - # TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string - system_message_tokens = ( - len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0 - ) - - tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) - # Drop older messages until under max supported prompt size by model # Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.) - while (tokens + system_message_tokens + 4 * len(messages)) > max_prompt_size and len(messages) > 1: - messages.pop() - tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) + system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0 + tokens = sum([count_tokens(message.content, encoder) for message in messages]) + total_tokens = tokens + system_message_tokens + 4 * len(messages) + + while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1): + if len(messages[-1].content) > 1: + # The oldest content part is earlier in content list. So pop from the front. + messages[-1].content.pop(0) + else: + # The oldest message is the last one. So pop from the back. + messages.pop() + tokens = sum([count_tokens(message.content, encoder) for message in messages]) + total_tokens = tokens + system_message_tokens + 4 * len(messages) # Truncate current message if still over max supported prompt size by model - if (tokens + system_message_tokens) > max_prompt_size: - current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else "" - original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else "" - original_question = f"\n{original_question}" - original_question_tokens = len(encoder.encode(original_question)) + total_tokens = tokens + system_message_tokens + 4 * len(messages) + if total_tokens > max_prompt_size: + # At this point, a single message with a single content part of type dict should remain + assert ( + len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict) + ), "Expected a single message with a single content part remaining at this point in truncation" + + # Collate message content into single string to ease truncation + part = messages[0].content[0] + message_content: str = part["text"] if part["type"] == "text" else json.dumps(part) + message_role = messages[0].role + + remaining_context = "\n".join(message_content.split("\n")[:-1]) + original_question = "\n" + "\n".join(message_content.split("\n")[-1:]) + + original_question_tokens = count_tokens(original_question, encoder) remaining_tokens = max_prompt_size - system_message_tokens if remaining_tokens > original_question_tokens: remaining_tokens -= original_question_tokens - truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() - messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)] + truncated_context = encoder.decode(encoder.encode(remaining_context)[:remaining_tokens]).strip() + truncated_content = truncated_context + original_question else: - truncated_message = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip() - messages = [ChatMessage(content=truncated_message, role=messages[0].role)] + truncated_content = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip() + messages = [ChatMessage(content=[{"type": "text", "text": truncated_content}], role=message_role)] + + truncated_snippet = ( + f"{truncated_content[:1000]}\n...\n{truncated_content[-1000:]}" + if len(truncated_content) > 2000 + else truncated_content + ) logger.debug( - f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message[:1000]}..." + f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_snippet}" ) if system_message: diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 1673dbe3..f96409c2 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List from apscheduler.schedulers.background import BackgroundScheduler from openai import OpenAI +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from whisper import Whisper from khoj.database.models import ProcessLock @@ -40,7 +41,7 @@ khoj_version: str = None device = get_device() chat_on_gpu: bool = True anonymous_mode: bool = False -pretrained_tokenizers: Dict[str, Any] = dict() +pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict() billing_enabled: bool = ( os.getenv("STRIPE_API_KEY") is not None and os.getenv("STRIPE_SIGNING_SECRET") is not None diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index 54fe2a7f..43f805b2 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -1,3 +1,5 @@ +from copy import deepcopy + import tiktoken from langchain.schema import ChatMessage @@ -5,7 +7,7 @@ from khoj.processor.conversation import utils class TestTruncateMessage: - max_prompt_size = 10 + max_prompt_size = 40 model_name = "gpt-4o-mini" encoder = tiktoken.encoding_for_model(model_name) @@ -15,45 +17,108 @@ class TestTruncateMessage: # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) - tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties assert len(chat_history) < 50 - assert len(chat_history) > 1 + assert len(chat_history) > 5 assert tokens <= self.max_prompt_size + def test_truncate_message_only_oldest_big(self): + # Arrange + chat_history = generate_chat_history(5) + big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?")) + chat_history.append(big_chat_message) + + # Act + truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) + + # Assert + # The original object has been modified. Verify certain properties + assert len(chat_history) == 5 + assert tokens <= self.max_prompt_size + + def test_truncate_message_with_image(self): + # Arrange + image_content_item = {"type": "image_url", "image_url": {"url": "placeholder"}} + content_list = [{"type": "text", "text": f"{index}"} for index in range(100)] + content_list += [image_content_item, {"type": "text", "text": "Question?"}] + big_chat_message = ChatMessage(role="user", content=content_list) + copy_big_chat_message = deepcopy(big_chat_message) + chat_history = [big_chat_message] + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) + + # Act + truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) + + # Assert + # The original object has been modified. Verify certain properties + assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" + assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved" + assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" + + def test_truncate_message_with_content_list(self): + # Arrange + chat_history = generate_chat_history(5) + content_list = [{"type": "text", "text": f"{index}"} for index in range(100)] + content_list += [{"type": "text", "text": "Question?"}] + big_chat_message = ChatMessage(role="user", content=content_list) + copy_big_chat_message = deepcopy(big_chat_message) + chat_history.insert(0, big_chat_message) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) + + # Act + truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) + + # Assert + # The original object has been modified. Verify certain properties + assert ( + len(chat_history) == 1 + ), "Only most recent message should be present as it itself is larger than context size" + assert len(truncated_chat_history[0].content) < len( + copy_big_chat_message.content + ), "message content list should be modified" + assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved" + assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" + def test_truncate_message_first_large(self): # Arrange chat_history = generate_chat_history(5) - big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?") + big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?")) copy_big_chat_message = big_chat_message.copy() chat_history.insert(0, big_chat_message) - tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) - tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties - assert len(chat_history) == 1 - assert truncated_chat_history[0] != copy_big_chat_message - assert tokens <= self.max_prompt_size + assert ( + len(chat_history) == 1 + ), "Only most recent message should be present as it itself is larger than context size" + assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" + assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" + assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" - def test_truncate_message_last_large(self): + def test_truncate_message_large_system_message_first(self): # Arrange chat_history = generate_chat_history(5) chat_history[0].role = "system" # Mark the first message as system message - big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?") + big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?")) copy_big_chat_message = big_chat_message.copy() chat_history.insert(0, big_chat_message) - initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) + initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) - final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties. @@ -62,46 +127,52 @@ class TestTruncateMessage: ) # Because the system_prompt is popped off from the chat_messages list assert len(truncated_chat_history) < 10 assert len(truncated_chat_history) > 1 - assert truncated_chat_history[0] != copy_big_chat_message - assert initial_tokens > self.max_prompt_size - assert final_tokens <= self.max_prompt_size + assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" + assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" + assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" + assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size" def test_truncate_single_large_non_system_message(self): # Arrange - big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?") + big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?")) copy_big_chat_message = big_chat_message.copy() chat_messages = [big_chat_message] - initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages]) # Act truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) - final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties - assert initial_tokens > self.max_prompt_size - assert final_tokens <= self.max_prompt_size - assert len(chat_messages) == 1 - assert truncated_chat_history[0] != copy_big_chat_message + assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" + assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size" + assert ( + len(chat_messages) == 1 + ), "Only most recent message should be present as it itself is larger than context size" + assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" + assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" def test_truncate_single_large_question(self): # Arrange - big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1)) + big_chat_message_content = [{"type": "text", "text": " ".join(["hi"] * (self.max_prompt_size + 1))}] big_chat_message = ChatMessage(role="user", content=big_chat_message_content) copy_big_chat_message = big_chat_message.copy() chat_messages = [big_chat_message] - initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages]) # Act truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) - final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties - assert initial_tokens > self.max_prompt_size - assert final_tokens <= self.max_prompt_size - assert len(chat_messages) == 1 - assert truncated_chat_history[0] != copy_big_chat_message + assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" + assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size" + assert ( + len(chat_messages) == 1 + ), "Only most recent message should be present as it itself is larger than context size" + assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" def test_load_complex_raw_json_string(): @@ -116,12 +187,12 @@ def test_load_complex_raw_json_string(): assert parsed_json == expeced_json -def generate_content(count): - return " ".join([f"{index}" for index, _ in enumerate(range(count))]) +def generate_content(count, suffix=""): + return [{"type": "text", "text": " ".join([f"{index}" for index, _ in enumerate(range(count))]) + "\n" + suffix}] def generate_chat_history(count): return [ - ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}") + ChatMessage(role="user" if index % 2 == 0 else "assistant", content=[{"type": "text", "text": f"{index}"}]) for index, _ in enumerate(range(count)) ] From e125e299a783c7c0b876851f53e2cadbad356eb7 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 13 May 2025 12:14:19 -0600 Subject: [PATCH 5/8] Ensure time to first token logged only once per chat response Time to first token Log lines were shown multiple times if new chunk bein streamed was empty for some reason. This change makes the logic robust to empty chunks being recieved. --- src/khoj/processor/conversation/anthropic/utils.py | 4 +++- src/khoj/processor/conversation/google/utils.py | 4 +++- src/khoj/processor/conversation/openai/utils.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index e436ecda..c2db6a72 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -144,6 +144,7 @@ async def anthropic_chat_completion_with_backoff( formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) aggregated_response = "" + response_started = False final_message = None start_time = perf_counter() async with client.messages.stream( @@ -157,7 +158,8 @@ async def anthropic_chat_completion_with_backoff( ) as stream: async for chunk in stream: # Log the time taken to start response - if aggregated_response == "": + if not response_started: + response_started = True logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") # Skip empty chunks if chunk.type != "content_block_delta": diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index ed37a0b3..d465cbda 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -195,13 +195,15 @@ async def gemini_chat_completion_with_backoff( aggregated_response = "" final_chunk = None + response_started = False start_time = perf_counter() chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream( model=model_name, config=config, contents=formatted_messages ) async for chunk in chat_stream: # Log the time taken to start response - if final_chunk is None: + if not response_started: + response_started = True logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") # Keep track of the last chunk for usage data final_chunk = chunk diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 7b1c11db..77dee0c4 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -226,6 +226,7 @@ async def chat_completion_with_backoff( aggregated_response = "" final_chunk = None + response_started = False start_time = perf_counter() chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( messages=formatted_messages, # type: ignore @@ -237,7 +238,8 @@ async def chat_completion_with_backoff( ) async for chunk in stream_processor(chat_stream): # Log the time taken to start response - if final_chunk is None: + if not response_started: + response_started = True logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") # Keep track of the last chunk for usage data final_chunk = chunk From 417ab42206b49eb7e59a43c3cd27ded0011528ec Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 13 May 2025 13:03:35 -0600 Subject: [PATCH 6/8] Track gemini 2.0 flash lite cost. Reduce max prompt size for 4o-mini --- src/khoj/processor/conversation/utils.py | 2 +- src/khoj/utils/constants.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 6aaa48e9..da871097 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -51,7 +51,7 @@ except ImportError: model_to_prompt_size = { # OpenAI Models "gpt-4o": 60000, - "gpt-4o-mini": 120000, + "gpt-4o-mini": 60000, "gpt-4.1": 60000, "gpt-4.1-mini": 120000, "gpt-4.1-nano": 120000, diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index 68ab00f1..af67e0a1 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -52,6 +52,7 @@ model_to_cost: Dict[str, Dict[str, float]] = { "gemini-1.5-pro": {"input": 1.25, "output": 5.00}, "gemini-1.5-pro-002": {"input": 1.25, "output": 5.00}, "gemini-2.0-flash": {"input": 0.10, "output": 0.40}, + "gemini-2.0-flash-lite": {"input": 0.0075, "output": 0.30}, "gemini-2.5-flash-preview-04-17": {"input": 0.15, "output": 0.60, "thought": 3.50}, "gemini-2.5-pro-preview-03-25": {"input": 1.25, "output": 10.0}, # Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api From 988bde651ce9d7932ede18a4ce62f84209004dae Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 13 May 2025 12:00:06 -0600 Subject: [PATCH 7/8] Make researcher aware of no. of web, doc queries allowed per iteration - Construct tool description dynamically based on configurable query count - Inform the researcher how many webpage reads, online searches and document searches it can perform per iteration when it has to decide which next tool to use and the query to send to the tool AI. - Pass the query counts to perform from the research AI down to the tool AIs --- src/khoj/processor/conversation/prompts.py | 3 +- src/khoj/processor/tools/online_search.py | 6 ++-- src/khoj/routers/api_chat.py | 3 +- src/khoj/routers/helpers.py | 4 +-- src/khoj/routers/research.py | 40 +++++++++++++++------- src/khoj/utils/helpers.py | 8 ++--- 6 files changed, 40 insertions(+), 24 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 41426294..d3935faa 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -865,8 +865,7 @@ infer_webpages_to_read = PromptTemplate.from_template( You are Khoj, an advanced web page reading assistant. You are to construct **up to {max_webpages}, valid** webpage urls to read before answering the user's question. - You will receive the conversation history as context. - Add as much context from the previous questions and answers as required to construct the webpage urls. -- Use multiple web page urls if required to retrieve the relevant information. -- You have access to the the whole internet to retrieve information. +- You have access to the whole internet to retrieve information. {personality_context} Which webpages will you need to read to answer the user's question? Provide web page links as a list of strings in a JSON object. diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index a99ac811..8b39cc18 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -64,11 +64,12 @@ async def search_online( user: KhojUser, send_status_func: Optional[Callable] = None, custom_filters: List[str] = [], + max_online_searches: int = 3, max_webpages_to_read: int = 1, query_images: List[str] = None, + query_files: str = None, previous_subqueries: Set = set(), agent: Agent = None, - query_files: str = None, tracer: dict = {}, ): query += " ".join(custom_filters) @@ -84,9 +85,10 @@ async def search_online( location, user, query_images=query_images, + query_files=query_files, + max_queries=max_online_searches, agent=agent, tracer=tracer, - query_files=query_files, ) subqueries = list(new_subqueries - previous_subqueries) response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {} diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 27db70f5..f7980721 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1129,9 +1129,10 @@ async def chat( user, partial(send_event, ChatEvent.STATUS), custom_filters, + max_online_searches=3, query_images=uploaded_images, - agent=agent, query_files=attached_file_context, + agent=agent, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a290c85b..e4d46739 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -523,8 +523,9 @@ async def generate_online_subqueries( location_data: LocationData, user: KhojUser, query_images: List[str] = None, - agent: Agent = None, query_files: str = None, + max_queries: int = 3, + agent: Agent = None, tracer: dict = {}, ) -> Set[str]: """ @@ -534,7 +535,6 @@ async def generate_online_subqueries( username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) - max_queries = 3 utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d") personality_context = ( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 62f24282..4f3252b4 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -6,7 +6,6 @@ from enum import Enum from typing import Callable, Dict, List, Optional, Type import yaml -from fastapi import Request from pydantic import BaseModel, Field from khoj.database.adapters import AgentAdapters, EntryAdapters @@ -14,7 +13,6 @@ from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( InformationCollectionIteration, - construct_chat_history, construct_iteration_history, construct_tool_chat_history, load_complex_json, @@ -29,9 +27,9 @@ from khoj.routers.helpers import ( ) from khoj.utils.helpers import ( ConversationCommand, - function_calling_description_for_llm, is_none_or_empty, timer, + tool_description_for_research_llm, truncate_code_context, ) from khoj.utils.rawconfig import LocationData @@ -79,15 +77,18 @@ async def apick_next_tool( query: str, conversation_history: dict, user: KhojUser = None, - query_images: List[str] = [], location: LocationData = None, user_name: str = None, agent: Agent = None, previous_iterations: List[InformationCollectionIteration] = [], max_iterations: int = 5, + query_images: List[str] = [], + query_files: str = None, + max_document_searches: int = 7, + max_online_searches: int = 3, + max_webpages_to_read: int = 1, send_status_func: Optional[Callable] = None, tracer: dict = {}, - query_files: str = None, ): """Given a query, determine which of the available tools the agent should use in order to answer appropriately.""" @@ -96,10 +97,16 @@ async def apick_next_tool( tool_options_str = "" agent_tools = agent.input_tools if agent else [] user_has_entries = await EntryAdapters.auser_has_entries(user) - for tool, description in function_calling_description_for_llm.items(): + for tool, description in tool_description_for_research_llm.items(): # Skip showing Notes tool as an option if user has no entries - if tool == ConversationCommand.Notes and not user_has_entries: - continue + if tool == ConversationCommand.Notes: + if not user_has_entries: + continue + description = description.format(max_search_queries=max_document_searches) + if tool == ConversationCommand.Webpage: + description = description.format(max_webpages_to_read=max_webpages_to_read) + if tool == ConversationCommand.Online: + description = description.format(max_search_queries=max_online_searches) # Add tool if agent does not have any tools defined or the tool is supported by the agent. if len(agent_tools) == 0 or tool.value in agent_tools: tool_options[tool.name] = tool.value @@ -210,6 +217,9 @@ async def execute_information_collection( query_files: str = None, cancellation_event: Optional[asyncio.Event] = None, ): + max_document_searches = 7 + max_online_searches = 3 + max_webpages_to_read = 1 current_iteration = 0 MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5)) previous_iterations: List[InformationCollectionIteration] = [] @@ -229,15 +239,18 @@ async def execute_information_collection( query, conversation_history, user, - query_images, location, user_name, agent, previous_iterations, MAX_ITERATIONS, - send_status_func, - tracer=tracer, + query_images=query_images, query_files=query_files, + max_document_searches=max_document_searches, + max_online_searches=max_online_searches, + max_webpages_to_read=max_webpages_to_read, + send_status_func=send_status_func, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -262,7 +275,7 @@ async def execute_information_collection( user, construct_tool_chat_history(previous_iterations, ConversationCommand.Notes), this_iteration.query, - 7, + max_document_searches, None, conversation_id, [ConversationCommand.Default], @@ -309,6 +322,7 @@ async def execute_information_collection( user, send_status_func, [], + max_online_searches=max_online_searches, max_webpages_to_read=0, query_images=query_images, previous_subqueries=previous_subqueries, @@ -334,7 +348,7 @@ async def execute_information_collection( location, user, send_status_func, - max_webpages_to_read=1, + max_webpages_to_read=max_webpages_to_read, query_images=query_images, agent=agent, tracer=tracer, diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 4a756dcb..2530b4fb 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -386,10 +386,10 @@ tool_descriptions_for_llm = { ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description, } -function_calling_description_for_llm = { - ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.", - ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed.", - ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage links and information to extract in your query.", +tool_description_for_research_llm = { + ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents. Max {max_search_queries} search queries allowed per iteration.", + ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed. Max {max_search_queries} search queries allowed per iteration.", + ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share upto {max_webpages_to_read} webpage links and what information to extract from them in your query.", ConversationCommand.Code: e2b_tool_description if is_e2b_code_sandbox_enabled() else terrarium_tool_description, ConversationCommand.Text: "To respond to the user once you've completed your research and have the required information.", } From fd591c6e6c41ecb66e27e0815117fe2a97a524f0 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 17 May 2025 14:21:47 -0700 Subject: [PATCH 8/8] Upgrade tenacity to respect min time for exponential backoff Fix for issue is in tenacity 9.0.0. But older langchain required tenacity <0.9.0. Explicitly pin version of langchain sub packages to avoid indexing and doc parsing breakage. --- pyproject.toml | 7 +++---- src/khoj/processor/content/text_to_entries.py | 2 +- .../processor/conversation/anthropic/anthropic_chat.py | 2 +- src/khoj/processor/conversation/anthropic/utils.py | 2 +- src/khoj/processor/conversation/google/gemini_chat.py | 2 +- src/khoj/processor/conversation/google/utils.py | 2 +- src/khoj/processor/conversation/offline/chat_model.py | 2 +- src/khoj/processor/conversation/openai/gpt.py | 2 +- src/khoj/processor/conversation/prompts.py | 2 +- src/khoj/processor/conversation/utils.py | 2 +- tests/test_conversation_utils.py | 2 +- 11 files changed, 13 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 127cce9c..01d9400a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "jinja2 == 3.1.6", "openai >= 1.0.0", "tiktoken >= 0.3.2", - "tenacity >= 8.2.2", + "tenacity >= 9.0.0", "magika ~= 0.5.1", "pillow ~= 10.0.0", "pydantic[email] >= 2.0.0", @@ -57,10 +57,9 @@ dependencies = [ "torch == 2.6.0", "uvicorn == 0.30.6", "aiohttp ~= 3.9.0", - "langchain == 0.2.5", - "langchain-community == 0.2.5", + "langchain-text-splitters == 0.3.1", + "langchain-community == 0.3.3", "requests >= 2.26.0", - "tenacity == 8.3.0", "anyio ~= 4.8.0", "pymupdf == 1.24.11", "django == 5.1.8", diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index 2c27c5a3..8e0b3322 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from itertools import repeat from typing import Any, Callable, List, Set, Tuple -from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_text_splitters import RecursiveCharacterTextSplitter from tqdm import tqdm from khoj.database.adapters import ( diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index aba69dfb..bfd74a53 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta from typing import AsyncGenerator, Dict, List, Optional import pyjson5 -from langchain.schema import ChatMessage +from langchain_core.messages.chat import ChatMessage from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index c2db6a72..915c082b 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -3,7 +3,7 @@ from time import perf_counter from typing import Dict, List import anthropic -from langchain.schema import ChatMessage +from langchain_core.messages.chat import ChatMessage from tenacity import ( before_sleep_log, retry, diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index e9993a39..5f45f69e 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta from typing import AsyncGenerator, Dict, List, Optional import pyjson5 -from langchain.schema import ChatMessage +from langchain_core.messages.chat import ChatMessage from pydantic import BaseModel, Field from khoj.database.models import Agent, ChatModel, KhojUser diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index d465cbda..c527bf72 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -9,7 +9,7 @@ import httpx from google import genai from google.genai import errors as gerrors from google.genai import types as gtypes -from langchain.schema import ChatMessage +from langchain_core.messages.chat import ChatMessage from pydantic import BaseModel from tenacity import ( before_sleep_log, diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index e2da460e..2a0512f9 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -7,7 +7,7 @@ from time import perf_counter from typing import Any, AsyncGenerator, Dict, List, Optional, Union import pyjson5 -from langchain.schema import ChatMessage +from langchain_core.messages.chat import ChatMessage from llama_cpp import Llama from khoj.database.models import Agent, ChatModel, KhojUser diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 65b2d83f..913bd90c 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta from typing import AsyncGenerator, Dict, List, Optional import pyjson5 -from langchain.schema import ChatMessage +from langchain_core.messages.chat import ChatMessage from openai.lib._pydantic import _ensure_strict_json_schema from pydantic import BaseModel diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index d3935faa..15477c83 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -1,4 +1,4 @@ -from langchain.prompts import PromptTemplate +from langchain_core.prompts import PromptTemplate ## Personality ## -- diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index da871097..72c0fd1e 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -17,7 +17,7 @@ import pyjson5 import requests import tiktoken import yaml -from langchain.schema import ChatMessage +from langchain_core.messages.chat import ChatMessage from llama_cpp import LlamaTokenizer from llama_cpp.llama import Llama from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index 43f805b2..b1fdad30 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -1,7 +1,7 @@ from copy import deepcopy import tiktoken -from langchain.schema import ChatMessage +from langchain_core.messages.chat import ChatMessage from khoj.processor.conversation import utils