From 05d4e19cb820d54043967e3a100681922c76cbe8 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 3 Jun 2025 15:28:06 -0700 Subject: [PATCH] Pass deep typed chat history for more ergonomic, readable, safe code The chat dictionary is an artifact from earlier non-db chat history storage. We've been ensuring new chat messages have valid type before being written to DB for more than 6 months now. Move to using the deeply typed chat history helps avoids null refs, makes code more readable and easier to reason about. Next Steps: The current update entangles chat_history written to DB with any virtual chat history message generated for intermediate steps. The chat message type written to DB should be decoupled from type that can be passed to AI model APIs (maybe?). For now we've made the ChatMessage.message type looser to allow for list[dict] type (apart from string). But later maybe a good idea to decouple the chat_history recieved by send_message_to_model from the chat_history saved to DB (which can then have its stricter type check) --- src/khoj/database/adapters/__init__.py | 4 +- src/khoj/database/models/__init__.py | 18 +- .../conversation/anthropic/anthropic_chat.py | 14 +- .../conversation/google/gemini_chat.py | 14 +- .../conversation/offline/chat_model.py | 12 +- src/khoj/processor/conversation/openai/gpt.py | 14 +- src/khoj/processor/conversation/utils.py | 183 +++++++++--------- src/khoj/processor/image/generate.py | 27 +-- src/khoj/processor/operator/__init__.py | 5 +- .../operator/operator_agent_binary.py | 22 +-- src/khoj/processor/tools/online_search.py | 12 +- src/khoj/processor/tools/run_code.py | 10 +- src/khoj/routers/api.py | 22 ++- src/khoj/routers/api_chat.py | 32 +-- src/khoj/routers/helpers.py | 87 +++++---- src/khoj/routers/research.py | 18 +- tests/helpers.py | 7 +- tests/test_offline_chat_actors.py | 14 +- tests/test_offline_chat_director.py | 2 +- tests/test_online_chat_actors.py | 2 +- 20 files changed, 271 insertions(+), 248 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 3080d0f0..39c94f8a 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -37,6 +37,7 @@ from torch import Tensor from khoj.database.models import ( Agent, AiModelApi, + ChatMessageModel, ChatModel, ClientApplication, Conversation, @@ -1419,7 +1420,7 @@ class ConversationAdapters: @require_valid_user async def save_conversation( user: KhojUser, - conversation_log: dict, + chat_history: List[ChatMessageModel], client_application: ClientApplication = None, conversation_id: str = None, user_message: str = None, @@ -1434,6 +1435,7 @@ class ConversationAdapters: await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() ) + conversation_log = {"chat": [msg.model_dump() for msg in chat_history]} cleaned_conversation_log = clean_object_for_db(conversation_log) if conversation: conversation.conversation_log = cleaned_conversation_log diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 84948015..2dded5ad 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -91,7 +91,7 @@ class OnlineContext(PydanticBaseModel): class Intent(PydanticBaseModel): type: str query: str - memory_type: str = Field(alias="memory-type") + memory_type: Optional[str] = Field(alias="memory-type", default=None) inferred_queries: Optional[List[str]] = Field(default=None, alias="inferred-queries") @@ -100,20 +100,20 @@ class TrainOfThought(PydanticBaseModel): data: str -class ChatMessage(PydanticBaseModel): - message: str +class ChatMessageModel(PydanticBaseModel): + by: str + message: str | list[dict] trainOfThought: List[TrainOfThought] = [] context: List[Context] = [] onlineContext: Dict[str, OnlineContext] = {} codeContext: Dict[str, CodeContextData] = {} researchContext: Optional[List] = None operatorContext: Optional[List] = None - created: str + created: Optional[str] = None images: Optional[List[str]] = None queryFiles: Optional[List[Dict]] = None excalidrawDiagram: Optional[List[Dict]] = None - mermaidjsDiagram: str = None - by: str + mermaidjsDiagram: Optional[str] = None turnId: Optional[str] = None intent: Optional[Intent] = None automationId: Optional[str] = None @@ -634,7 +634,7 @@ class Conversation(DbBaseModel): try: messages = self.conversation_log.get("chat", []) for msg in messages: - ChatMessage.model_validate(msg) + ChatMessageModel.model_validate(msg) except Exception as e: raise ValidationError(f"Invalid conversation_log format: {str(e)}") @@ -643,7 +643,7 @@ class Conversation(DbBaseModel): super().save(*args, **kwargs) @property - def messages(self) -> List[ChatMessage]: + def messages(self) -> List[ChatMessageModel]: """Type-hinted accessor for conversation messages""" validated_messages = [] for msg in self.conversation_log.get("chat", []): @@ -654,7 +654,7 @@ class Conversation(DbBaseModel): q for q in msg["intent"]["inferred-queries"] if q is not None and isinstance(q, str) ] msg["message"] = str(msg.get("message", "")) - validated_messages.append(ChatMessage.model_validate(msg)) + validated_messages.append(ChatMessageModel.model_validate(msg)) except ValidationError as e: logger.warning(f"Skipping invalid message in conversation: {e}") continue diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 3d34573e..775b8b99 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -6,7 +6,7 @@ from typing import AsyncGenerator, Dict, List, Optional import pyjson5 from langchain_core.messages.chat import ChatMessage -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.utils import ( anthropic_chat_completion_with_backoff, @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) def extract_questions_anthropic( text, model: Optional[str] = "claude-3-7-sonnet-latest", - conversation_log={}, + chat_history: List[ChatMessageModel] = [], api_key=None, api_base_url=None, location_data: LocationData = None, @@ -54,8 +54,8 @@ def extract_questions_anthropic( location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" - # Extract Past User Message and Inferred Questions from Conversation Log - chat_history = construct_question_history(conversation_log, query_prefix="User", agent_name="Assistant") + # Extract Past User Message and Inferred Questions from Chat History + chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant") # Get dates relative to today for prompt creation today = datetime.today() @@ -76,7 +76,7 @@ def extract_questions_anthropic( ) prompt = prompts.extract_questions_anthropic_user_message.format( - chat_history=chat_history, + chat_history=chat_history_str, text=text, ) @@ -142,7 +142,7 @@ async def converse_anthropic( online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, operator_results: Optional[List[OperatorRun]] = None, - conversation_log={}, + chat_history: List[ChatMessageModel] = [], model: Optional[str] = "claude-3-7-sonnet-latest", api_key: Optional[str] = None, api_base_url: Optional[str] = None, @@ -225,7 +225,7 @@ async def converse_anthropic( messages = generate_chatml_messages_with_context( user_query, context_message=context_message, - conversation_log=conversation_log, + chat_history=chat_history, model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index f9219c28..b2f48c81 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -7,7 +7,7 @@ import pyjson5 from langchain_core.messages.chat import ChatMessage from pydantic import BaseModel, Field -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.google.utils import ( gemini_chat_completion_with_backoff, @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) def extract_questions_gemini( text, model: Optional[str] = "gemini-2.0-flash", - conversation_log={}, + chat_history: List[ChatMessageModel] = [], api_key=None, api_base_url=None, max_tokens=None, @@ -54,8 +54,8 @@ def extract_questions_gemini( location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" - # Extract Past User Message and Inferred Questions from Conversation Log - chat_history = construct_question_history(conversation_log, query_prefix="User", agent_name="Assistant") + # Extract Past User Message and Inferred Questions from Chat History + chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant") # Get dates relative to today for prompt creation today = datetime.today() @@ -76,7 +76,7 @@ def extract_questions_gemini( ) prompt = prompts.extract_questions_anthropic_user_message.format( - chat_history=chat_history, + chat_history=chat_history_str, text=text, ) @@ -163,7 +163,7 @@ async def converse_gemini( online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, operator_results: Optional[List[OperatorRun]] = None, - conversation_log={}, + chat_history: List[ChatMessageModel] = [], model: Optional[str] = "gemini-2.0-flash", api_key: Optional[str] = None, api_base_url: Optional[str] = None, @@ -248,7 +248,7 @@ async def converse_gemini( messages = generate_chatml_messages_with_context( user_query, context_message=context_message, - conversation_log=conversation_log, + chat_history=chat_history, model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 0ecf62fd..27cd9a9e 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -10,7 +10,7 @@ import pyjson5 from langchain_core.messages.chat import ChatMessage from llama_cpp import Llama -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( @@ -38,7 +38,7 @@ def extract_questions_offline( text: str, model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", loaded_model: Union[Any, None] = None, - conversation_log={}, + chat_history: List[ChatMessageModel] = [], use_history: bool = True, should_extract_questions: bool = True, location_data: LocationData = None, @@ -65,7 +65,7 @@ def extract_questions_offline( username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log - chat_history = construct_question_history(conversation_log, include_query=False) if use_history else "" + chat_history_str = construct_question_history(chat_history, include_query=False) if use_history else "" # Get dates relative to today for prompt creation today = datetime.today() @@ -73,7 +73,7 @@ def extract_questions_offline( last_year = today.year - 1 example_questions = prompts.extract_questions_offline.format( query=text, - chat_history=chat_history, + chat_history=chat_history_str, current_date=today.strftime("%Y-%m-%d"), day_of_week=today.strftime("%A"), current_month=today.strftime("%Y-%m"), @@ -147,7 +147,7 @@ async def converse_offline( references: list[dict] = [], online_results={}, code_results={}, - conversation_log={}, + chat_history: list[ChatMessageModel] = [], model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", loaded_model: Union[Any, None] = None, completion_func=None, @@ -227,7 +227,7 @@ async def converse_offline( messages = generate_chatml_messages_with_context( user_query, system_prompt, - conversation_log, + chat_history, context_message=context_message, model_name=model_name, loaded_model=offline_chat_model, diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index b3c440ff..65cbbfca 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -8,7 +8,7 @@ from langchain_core.messages.chat import ChatMessage from openai.lib._pydantic import _ensure_strict_json_schema from pydantic import BaseModel -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, @@ -39,7 +39,7 @@ logger = logging.getLogger(__name__) def extract_questions( text, model: Optional[str] = "gpt-4o-mini", - conversation_log={}, + chat_history: list[ChatMessageModel] = [], api_key=None, api_base_url=None, location_data: LocationData = None, @@ -56,8 +56,8 @@ def extract_questions( location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" - # Extract Past User Message and Inferred Questions from Conversation Log - chat_history = construct_question_history(conversation_log) + # Extract Past User Message and Inferred Questions from Chat History + chat_history_str = construct_question_history(chat_history) # Get dates relative to today for prompt creation today = datetime.today() @@ -73,7 +73,7 @@ def extract_questions( current_new_year_date=current_new_year.strftime("%Y-%m-%d"), bob_tom_age_difference={current_new_year.year - 1984 - 30}, bob_age={current_new_year.year - 1984}, - chat_history=chat_history, + chat_history=chat_history_str, text=text, yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), location=location, @@ -166,7 +166,7 @@ async def converse_openai( online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, operator_results: Optional[List[OperatorRun]] = None, - conversation_log={}, + chat_history: list[ChatMessageModel] = [], model: str = "gpt-4o-mini", api_key: Optional[str] = None, api_base_url: Optional[str] = None, @@ -251,7 +251,7 @@ async def converse_openai( messages = generate_chatml_messages_with_context( user_query, system_prompt, - conversation_log, + chat_history, context_message=context_message, model_name=model, max_prompt_size=max_prompt_size, diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 71ca6526..90b85e47 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -24,7 +24,13 @@ from pydantic import BaseModel from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from khoj.database.adapters import ConversationAdapters -from khoj.database.models import ChatModel, ClientApplication, KhojUser +from khoj.database.models import ( + ChatMessageModel, + ChatModel, + ClientApplication, + Intent, + KhojUser, +) from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.search_filter.base_filter import BaseFilter @@ -161,8 +167,8 @@ def construct_iteration_history( previous_iterations: List[ResearchIteration], previous_iteration_prompt: str, query: str = None, -) -> list[dict]: - iteration_history: list[dict] = [] +) -> list[ChatMessageModel]: + iteration_history: list[ChatMessageModel] = [] previous_iteration_messages: list[dict] = [] for idx, iteration in enumerate(previous_iterations): iteration_data = previous_iteration_prompt.format( @@ -176,46 +182,46 @@ def construct_iteration_history( if previous_iteration_messages: if query: - iteration_history.append({"by": "you", "message": query}) + iteration_history.append(ChatMessageModel(by="you", message=query)) iteration_history.append( - { - "by": "khoj", - "intent": {"type": "remember", "query": query}, - "message": previous_iteration_messages, - } + ChatMessageModel( + by="khoj", + intent={"type": "remember", "query": query}, + message=previous_iteration_messages, + ) ) return iteration_history -def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: - chat_history = "" - for chat in conversation_history.get("chat", [])[-n:]: - if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]: - if chat["intent"].get("inferred-queries"): - chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n' - chat_history += f"{agent_name}: {chat['message']}\n\n" - elif chat["by"] == "khoj" and chat.get("images"): - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: [generated image redacted for space]\n" - elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")): - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n" - elif chat["by"] == "you": - chat_history += f"User: {chat['message']}\n" - raw_query_files = chat.get("queryFiles") +def construct_chat_history(chat_history: list[ChatMessageModel], n: int = 4, agent_name="AI") -> str: + chat_history_str = "" + for chat in chat_history[-n:]: + if chat.by == "khoj" and chat.intent.type in ["remember", "reminder", "summarize"]: + if chat.intent.inferred_queries: + chat_history_str += f'{agent_name}: {{"queries": {chat.intent.inferred_queries}}}\n' + chat_history_str += f"{agent_name}: {chat.message}\n\n" + elif chat.by == "khoj" and chat.images: + chat_history_str += f"User: {chat.intent.query}\n" + chat_history_str += f"{agent_name}: [generated image redacted for space]\n" + elif chat.by == "khoj" and ("excalidraw" in chat.intent.type): + chat_history_str += f"User: {chat.intent.query}\n" + chat_history_str += f"{agent_name}: {chat.intent.inferred_queries[0]}\n" + elif chat.by == "you": + chat_history_str += f"User: {chat.message}\n" + raw_query_files = chat.queryFiles if raw_query_files: query_files: Dict[str, str] = {} for file in raw_query_files: query_files[file["name"]] = file["content"] query_file_context = gather_raw_query_files(query_files) - chat_history += f"User: {query_file_context}\n" + chat_history_str += f"User: {query_file_context}\n" - return chat_history + return chat_history_str def construct_question_history( - conversation_log: dict, + conversation_log: list[ChatMessageModel], include_query: bool = True, lookback: int = 6, query_prefix: str = "Q", @@ -226,16 +232,16 @@ def construct_question_history( """ history_parts = "" original_query = None - for chat in conversation_log.get("chat", [])[-lookback:]: - if chat["by"] == "you": - original_query = chat.get("message") + for chat in conversation_log[-lookback:]: + if chat.by == "you": + original_query = json.dumps(chat.message) history_parts += f"{query_prefix}: {original_query}\n" - if chat["by"] == "khoj": + if chat.by == "khoj": if original_query is None: continue - message = chat.get("message", "") - inferred_queries_list = chat.get("intent", {}).get("inferred-queries") + message = chat.message + inferred_queries_list = chat.intent.inferred_queries or [] # Ensure inferred_queries_list is a list, defaulting to the original query in a list if not inferred_queries_list: @@ -246,7 +252,7 @@ def construct_question_history( if include_query: # Ensure 'type' exists and is a string before checking 'to-image' - intent_type = chat.get("intent", {}).get("type", "") + intent_type = chat.intent.type if chat.intent and chat.intent.type else "" if "to-image" not in intent_type: history_parts += f'{agent_name}: {{"queries": {inferred_queries_list}}}\n' history_parts += f"A: {message}\n\n" @@ -259,7 +265,7 @@ def construct_question_history( return history_parts -def construct_chat_history_for_operator(conversation_history: dict, n: int = 6) -> list[AgentMessage]: +def construct_chat_history_for_operator(conversation_history: List[ChatMessageModel], n: int = 6) -> list[AgentMessage]: """ Construct chat history for operator agent in conversation log. Only include last n completed turns (i.e with user and khoj message). @@ -267,22 +273,22 @@ def construct_chat_history_for_operator(conversation_history: dict, n: int = 6) chat_history: list[AgentMessage] = [] user_message: Optional[AgentMessage] = None - for chat in conversation_history.get("chat", []): + for chat in conversation_history: if len(chat_history) >= n: break - if chat["by"] == "you" and chat.get("message"): - content = [{"type": "text", "text": chat["message"]}] - for file in chat.get("queryFiles", []): + if chat.by == "you" and chat.message: + content = [{"type": "text", "text": chat.message}] + for file in chat.queryFiles or []: content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}] user_message = AgentMessage(role="user", content=content) - elif chat["by"] == "khoj" and chat.get("message"): - chat_history += [user_message, AgentMessage(role="assistant", content=chat["message"])] + elif chat.by == "khoj" and chat.message: + chat_history += [user_message, AgentMessage(role="assistant", content=chat.message)] return chat_history def construct_tool_chat_history( previous_iterations: List[ResearchIteration], tool: ConversationCommand = None -) -> Dict[str, list]: +) -> List[ChatMessageModel]: """ Construct chat history from previous iterations for a specific tool @@ -313,22 +319,23 @@ def construct_tool_chat_history( tool or ConversationCommand(iteration.tool), base_extractor ) chat_history += [ - { - "by": "you", - "message": iteration.query, - }, - { - "by": "khoj", - "intent": { - "type": "remember", - "inferred-queries": inferred_query_extractor(iteration), - "query": iteration.query, - }, - "message": iteration.summarizedResult, - }, + ChatMessageModel( + by="you", + message=iteration.query, + ), + ChatMessageModel( + by="khoj", + intent=Intent( + type="remember", + query=iteration.query, + inferred_queries=inferred_query_extractor(iteration), + memory_type="notes", + ), + message=iteration.summarizedResult, + ), ] - return {"chat": chat_history} + return chat_history class ChatEvent(Enum): @@ -349,8 +356,8 @@ def message_to_log( chat_response, user_message_metadata={}, khoj_message_metadata={}, - conversation_log=[], -): + chat_history: List[ChatMessageModel] = [], +) -> List[ChatMessageModel]: """Create json logs from messages, metadata for conversation log""" default_khoj_message_metadata = { "intent": {"type": "remember", "memory-type": "notes", "query": user_message}, @@ -369,15 +376,17 @@ def message_to_log( khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata) khoj_log = merge_dicts({"message": chat_response, "by": "khoj", "created": khoj_response_time}, khoj_log) - conversation_log.extend([human_log, khoj_log]) - return conversation_log + human_message = ChatMessageModel(**human_log) + khoj_message = ChatMessageModel(**khoj_log) + chat_history.extend([human_message, khoj_message]) + return chat_history async def save_to_conversation_log( q: str, chat_response: str, user: KhojUser, - meta_log: Dict, + chat_history: List[ChatMessageModel], user_message_time: str = None, compiled_references: List[Dict[str, Any]] = [], online_results: Dict[str, Any] = {}, @@ -427,11 +436,11 @@ async def save_to_conversation_log( chat_response=chat_response, user_message_metadata=user_message_metadata, khoj_message_metadata=khoj_message_metadata, - conversation_log=meta_log.get("chat", []), + chat_history=chat_history, ) await ConversationAdapters.save_conversation( user, - {"chat": updated_conversation}, + updated_conversation, client_application=client_application, conversation_id=conversation_id, user_message=q, @@ -502,7 +511,7 @@ def gather_raw_query_files( def generate_chatml_messages_with_context( user_message: str, system_message: str = None, - conversation_log={}, + chat_history: list[ChatMessageModel] = [], model_name="gpt-4o-mini", loaded_model: Optional[Llama] = None, max_prompt_size=None, @@ -529,21 +538,21 @@ def generate_chatml_messages_with_context( # Extract Chat History for Context chatml_messages: List[ChatMessage] = [] - for chat in conversation_log.get("chat", []): + for chat in chat_history: message_context = [] message_attached_files = "" generated_assets = {} - chat_message = chat.get("message") - role = "user" if chat["by"] == "you" else "assistant" + chat_message = chat.message + role = "user" if chat.by == "you" else "assistant" # Legacy code to handle excalidraw diagrams prior to Dec 2024 - if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""): - chat_message = chat["intent"].get("inferred-queries")[0] + if chat.by == "khoj" and "excalidraw" in chat.intent.type or "": + chat_message = (chat.intent.inferred_queries or [])[0] - if chat.get("queryFiles"): - raw_query_files = chat.get("queryFiles") + if chat.queryFiles: + raw_query_files = chat.queryFiles query_files_dict = dict() for file in raw_query_files: query_files_dict[file["name"]] = file["content"] @@ -551,24 +560,24 @@ def generate_chatml_messages_with_context( message_attached_files = gather_raw_query_files(query_files_dict) chatml_messages.append(ChatMessage(content=message_attached_files, role=role)) - if not is_none_or_empty(chat.get("onlineContext")): + if not is_none_or_empty(chat.onlineContext): message_context += [ { "type": "text", - "text": f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}", + "text": f"{prompts.online_search_conversation.format(online_results=chat.onlineContext)}", } ] - if not is_none_or_empty(chat.get("codeContext")): + if not is_none_or_empty(chat.codeContext): message_context += [ { "type": "text", - "text": f"{prompts.code_executed_context.format(code_results=chat.get('codeContext'))}", + "text": f"{prompts.code_executed_context.format(code_results=chat.codeContext)}", } ] - if not is_none_or_empty(chat.get("operatorContext")): - operator_context = chat.get("operatorContext") + if not is_none_or_empty(chat.operatorContext): + operator_context = chat.operatorContext operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context]) message_context += [ { @@ -577,13 +586,9 @@ def generate_chatml_messages_with_context( } ] - if not is_none_or_empty(chat.get("context")): + if not is_none_or_empty(chat.context): references = "\n\n".join( - { - f"# File: {item['file']}\n## {item['compiled']}\n" - for item in chat.get("context") or [] - if isinstance(item, dict) - } + {f"# File: {item.file}\n## {item.compiled}\n" for item in chat.context or [] if isinstance(item, dict)} ) message_context += [{"type": "text", "text": f"{prompts.notes_conversation.format(references=references)}"}] @@ -591,14 +596,14 @@ def generate_chatml_messages_with_context( reconstructed_context_message = ChatMessage(content=message_context, role="user") chatml_messages.insert(0, reconstructed_context_message) - if not is_none_or_empty(chat.get("images")) and role == "assistant": + if not is_none_or_empty(chat.images) and role == "assistant": generated_assets["image"] = { - "query": chat.get("intent", {}).get("inferred-queries", [user_message])[0], + "query": (chat.intent.inferred_queries or [user_message])[0], } - if not is_none_or_empty(chat.get("mermaidjsDiagram")) and role == "assistant": + if not is_none_or_empty(chat.mermaidjsDiagram) and role == "assistant": generated_assets["diagram"] = { - "query": chat.get("intent", {}).get("inferred-queries", [user_message])[0], + "query": (chat.intent.inferred_queries or [user_message])[0], } if not is_none_or_empty(generated_assets): @@ -610,7 +615,7 @@ def generate_chatml_messages_with_context( ) message_content = construct_structured_message( - chat_message, chat.get("images") if role == "user" else [], model_type, vision_enabled + chat_message, chat.images if role == "user" else [], model_type, vision_enabled ) reconstructed_message = ChatMessage(content=message_content, role=role) diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index f1b84431..46b989d6 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -10,7 +10,12 @@ from google import genai from google.genai import types as gtypes from khoj.database.adapters import ConversationAdapters -from khoj.database.models import Agent, KhojUser, TextToImageModelConfig +from khoj.database.models import ( + Agent, + ChatMessageModel, + KhojUser, + TextToImageModelConfig, +) from khoj.routers.helpers import ChatEvent, generate_better_image_prompt from khoj.routers.storage import upload_generated_image_to_bucket from khoj.utils import state @@ -23,7 +28,7 @@ logger = logging.getLogger(__name__) async def text_to_image( message: str, user: KhojUser, - conversation_log: dict, + chat_history: List[ChatMessageModel], location_data: LocationData, references: List[Dict[str, Any]], online_results: Dict[str, Any], @@ -46,14 +51,14 @@ async def text_to_image( return text2image_model = text_to_image_config.model_name - chat_history = "" - for chat in conversation_log.get("chat", [])[-4:]: - if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]: - chat_history += f"Q: {chat['intent']['query']}\n" - chat_history += f"A: {chat['message']}\n" - elif chat["by"] == "khoj" and chat.get("images"): - chat_history += f"Q: {chat['intent']['query']}\n" - chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" + chat_history_str = "" + for chat in chat_history[-4:]: + if chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]: + chat_history_str += f"Q: {chat.intent.query or ''}\n" + chat_history_str += f"A: {chat.message}\n" + elif chat.by == "khoj" and chat.images: + chat_history_str += f"Q: {chat.intent.query}\n" + chat_history_str += f"A: Improved Prompt: {chat.intent.inferred_queries[0]}\n" if send_status_func: async for event in send_status_func("**Enhancing the Painting Prompt**"): @@ -63,7 +68,7 @@ async def text_to_image( # Use the user's message, chat history, and other context image_prompt = await generate_better_image_prompt( message, - chat_history, + chat_history_str, location_data=location_data, note_references=references, online_results=online_results, diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index 138a2696..9b4ad80f 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -5,10 +5,9 @@ import os from typing import Callable, List, Optional from khoj.database.adapters import AgentAdapters, ConversationAdapters -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.processor.conversation.utils import ( OperatorRun, - construct_chat_history, construct_chat_history_for_operator, ) from khoj.processor.operator.operator_actions import * @@ -34,7 +33,7 @@ logger = logging.getLogger(__name__) async def operate_environment( query: str, user: KhojUser, - conversation_log: dict, + conversation_log: List[ChatMessageModel], location_data: LocationData, previous_trajectory: Optional[OperatorRun] = None, environment_type: EnvironmentType = EnvironmentType.COMPUTER, diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index f4df22fe..8106e9cc 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -4,7 +4,7 @@ from datetime import datetime from textwrap import dedent from typing import List, Optional -from khoj.database.models import ChatModel +from khoj.database.models import ChatMessageModel, ChatModel from khoj.processor.conversation.utils import ( AgentMessage, OperatorRun, @@ -119,13 +119,13 @@ class BinaryOperatorAgent(OperatorAgent): query_screenshot = self._get_message_images(current_message) # Construct input for visual reasoner history - visual_reasoner_history = {"chat": self._format_message_for_api(self.messages)} + visual_reasoner_history = self._format_message_for_api(self.messages) try: natural_language_action = await send_message_to_model_wrapper( query=query_text, query_images=query_screenshot, system_message=reasoning_system_prompt, - conversation_log=visual_reasoner_history, + chat_history=visual_reasoner_history, agent_chat_model=self.reasoning_model, tracer=self.tracer, ) @@ -238,11 +238,11 @@ class BinaryOperatorAgent(OperatorAgent): async def summarize(self, env_state: EnvState, summarize_prompt: str = None) -> str: summarize_prompt = summarize_prompt or self.summarize_prompt - conversation_history = {"chat": self._format_message_for_api(self.messages)} + conversation_history = self._format_message_for_api(self.messages) try: summary = await send_message_to_model_wrapper( query=summarize_prompt, - conversation_log=conversation_history, + chat_history=conversation_history, agent_chat_model=self.reasoning_model, tracer=self.tracer, ) @@ -296,14 +296,14 @@ class BinaryOperatorAgent(OperatorAgent): images = [item["image_url"]["url"] for item in message.content if item["type"] == "image_url"] return images - def _format_message_for_api(self, messages: list[AgentMessage]) -> List[dict]: + def _format_message_for_api(self, messages: list[AgentMessage]) -> List[ChatMessageModel]: """Format operator agent messages into the Khoj conversation history format.""" formatted_messages = [ - { - "message": self._get_message_text(message), - "images": self._get_message_images(message), - "by": "you" if message.role in ["user", "environment"] else message.role, - } + ChatMessageModel( + message=self._get_message_text(message), + images=self._get_message_images(message), + by="you" if message.role in ["user", "environment"] else message.role, + ) for message in messages ] return formatted_messages diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 8b39cc18..8951bc46 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -10,7 +10,13 @@ from bs4 import BeautifulSoup from markdownify import markdownify from khoj.database.adapters import ConversationAdapters -from khoj.database.models import Agent, KhojUser, ServerChatSettings, WebScraper +from khoj.database.models import ( + Agent, + ChatMessageModel, + KhojUser, + ServerChatSettings, + WebScraper, +) from khoj.processor.conversation import prompts from khoj.routers.helpers import ( ChatEvent, @@ -59,7 +65,7 @@ OLOSTEP_QUERY_PARAMS = { async def search_online( query: str, - conversation_history: dict, + conversation_history: List[ChatMessageModel], location: LocationData, user: KhojUser, send_status_func: Optional[Callable] = None, @@ -361,7 +367,7 @@ async def search_with_serper(query: str, location: LocationData) -> Tuple[str, D async def read_webpages( query: str, - conversation_history: dict, + conversation_history: List[ChatMessageModel], location: LocationData, user: KhojUser, send_status_func: Optional[Callable] = None, diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 535edd0e..5b9c2d04 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -20,7 +20,7 @@ from tenacity import ( ) from khoj.database.adapters import FileObjectAdapters -from khoj.database.models import Agent, FileObject, KhojUser +from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( ChatEvent, @@ -50,7 +50,7 @@ class GeneratedCode(NamedTuple): async def run_code( query: str, - conversation_history: dict, + conversation_history: List[ChatMessageModel], context: str, location_data: LocationData, user: KhojUser, @@ -116,7 +116,7 @@ async def run_code( async def generate_python_code( q: str, - conversation_history: dict, + chat_history: List[ChatMessageModel], context: str, location_data: LocationData, user: KhojUser, @@ -127,7 +127,7 @@ async def generate_python_code( ) -> GeneratedCode: location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d") personality_context = ( @@ -143,7 +143,7 @@ async def generate_python_code( code_generation_prompt = prompts.python_code_generation_prompt.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, context=context, has_network_access=network_access_context, current_date=utc_date, diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index efba17ec..89a95642 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -29,7 +29,13 @@ from khoj.database.adapters import ( get_default_search_model, get_user_photo, ) -from khoj.database.models import Agent, ChatModel, KhojUser, SpeechToTextModelOptions +from khoj.database.models import ( + Agent, + ChatMessageModel, + ChatModel, + KhojUser, + SpeechToTextModelOptions, +) from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.anthropic_chat import ( extract_questions_anthropic, @@ -353,7 +359,7 @@ def set_user_name( async def extract_references_and_questions( user: KhojUser, - meta_log: dict, + chat_history: list[ChatMessageModel], q: str, n: int, d: float, @@ -432,7 +438,7 @@ async def extract_references_and_questions( defiltered_query, model=chat_model, loaded_model=loaded_model, - conversation_log=meta_log, + chat_history=chat_history, should_extract_questions=True, location_data=location_data, user=user, @@ -450,7 +456,7 @@ async def extract_references_and_questions( model=chat_model_name, api_key=api_key, api_base_url=base_url, - conversation_log=meta_log, + chat_history=chat_history, location_data=location_data, user=user, query_images=query_images, @@ -469,7 +475,7 @@ async def extract_references_and_questions( model=chat_model_name, api_key=api_key, api_base_url=api_base_url, - conversation_log=meta_log, + chat_history=chat_history, location_data=location_data, user=user, vision_enabled=vision_enabled, @@ -487,7 +493,7 @@ async def extract_references_and_questions( model=chat_model_name, api_key=api_key, api_base_url=api_base_url, - conversation_log=meta_log, + chat_history=chat_history, location_data=location_data, max_tokens=chat_model.max_prompt_size, user=user, @@ -606,7 +612,7 @@ def post_automation( return Response(content="Invalid crontime", status_code=400) # Infer subject, query to run - _, query_to_run, generated_subject = schedule_query(q, conversation_history={}, user=user) + _, query_to_run, generated_subject = schedule_query(q, chat_history=[], user=user) subject = subject or generated_subject # Normalize query parameters @@ -712,7 +718,7 @@ def edit_job( return Response(content="Invalid automation", status_code=403) # Infer subject, query to run - _, query_to_run, _ = schedule_query(q, conversation_history={}, user=user) + _, query_to_run, _ = schedule_query(q, chat_history=[], user=user) subject = subject # Normalize query parameters diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index cf3d1207..df3245fe 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -752,7 +752,7 @@ async def chat( q, chat_response="", user=user, - meta_log=meta_log, + chat_history=chat_history, compiled_references=compiled_references, online_results=online_results, code_results=code_results, @@ -918,7 +918,7 @@ async def chat( if city or region or country or country_code: location = LocationData(city=city, region=region, country=country, country_code=country_code) user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - meta_log = conversation.conversation_log + chat_history = conversation.messages # If interrupt flag is set, wait for the previous turn to be saved before proceeding if interrupt_flag: @@ -964,14 +964,14 @@ async def chat( operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] # Drop the interrupted message from conversation history - meta_log["chat"].pop() + chat_history.pop() logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") if conversation_commands == [ConversationCommand.Default]: try: chosen_io = await aget_data_sources_and_output_format( q, - meta_log, + chat_history, is_automated_task, user=user, query_images=uploaded_images, @@ -1011,7 +1011,7 @@ async def chat( user=user, query=defiltered_query, conversation_id=conversation_id, - conversation_history=meta_log, + conversation_history=conversation.messages, previous_iterations=list(research_results), query_images=uploaded_images, agent=agent, @@ -1078,7 +1078,7 @@ async def chat( q=q, user=user, file_filters=file_filters, - meta_log=meta_log, + chat_history=conversation.messages, query_images=uploaded_images, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), @@ -1123,7 +1123,7 @@ async def chat( if ConversationCommand.Automation in conversation_commands: try: automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log, tracer=tracer + q, timezone, user, request.url, chat_history, tracer=tracer ) except Exception as e: logger.error(f"Error scheduling task {q} for {user.email}: {e}") @@ -1139,7 +1139,7 @@ async def chat( q, llm_response, user, - meta_log, + chat_history, user_message_time, intent_type="automation", client_application=request.user.client_app, @@ -1163,7 +1163,7 @@ async def chat( try: async for result in extract_references_and_questions( user, - meta_log, + chat_history, q, (n or 7), d, @@ -1212,7 +1212,7 @@ async def chat( try: async for result in search_online( defiltered_query, - meta_log, + chat_history, location, user, partial(send_event, ChatEvent.STATUS), @@ -1240,7 +1240,7 @@ async def chat( try: async for result in read_webpages( defiltered_query, - meta_log, + chat_history, location, user, partial(send_event, ChatEvent.STATUS), @@ -1281,7 +1281,7 @@ async def chat( context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}" async for result in run_code( defiltered_query, - meta_log, + chat_history, context, location, user, @@ -1306,7 +1306,7 @@ async def chat( async for result in operate_environment( defiltered_query, user, - meta_log, + chat_history, location, list(operator_results)[-1] if operator_results else None, query_images=uploaded_images, @@ -1356,7 +1356,7 @@ async def chat( async for result in text_to_image( defiltered_query, user, - meta_log, + chat_history, location_data=location, references=compiled_references, online_results=online_results, @@ -1400,7 +1400,7 @@ async def chat( async for result in generate_mermaidjs_diagram( q=defiltered_query, - conversation_history=meta_log, + chat_history=chat_history, location_data=location, note_references=compiled_references, online_results=online_results, @@ -1456,7 +1456,7 @@ async def chat( llm_response, chat_metadata = await agenerate_chat_response( defiltered_query, - meta_log, + chat_history, conversation, compiled_references, online_results, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 392f5025..3eaefd8c 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -55,6 +55,7 @@ from khoj.database.adapters import ( ) from khoj.database.models import ( Agent, + ChatMessageModel, ChatModel, ClientApplication, Conversation, @@ -285,7 +286,7 @@ async def acreate_title_from_history( """ Create a title from the given conversation history """ - chat_history = construct_chat_history(conversation.conversation_log) + chat_history = construct_chat_history(conversation.messages) title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history) @@ -345,7 +346,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: async def aget_data_sources_and_output_format( query: str, - conversation_history: dict, + chat_history: list[ChatMessageModel], is_task: bool, user: KhojUser, query_images: List[str] = None, @@ -386,7 +387,7 @@ async def aget_data_sources_and_output_format( if len(agent_outputs) == 0 or output.value in agent_outputs: output_options_str += f'- "{output.value}": "{description}"\n' - chat_history = construct_chat_history(conversation_history, n=6) + chat_history_str = construct_chat_history(chat_history, n=6) if query_images: query = f"[placeholder for {len(query_images)} user attached images]\n{query}" @@ -399,7 +400,7 @@ async def aget_data_sources_and_output_format( query=query, sources=source_options_str, outputs=output_options_str, - chat_history=chat_history, + chat_history=chat_history_str, personality_context=personality_context, ) @@ -462,7 +463,7 @@ async def aget_data_sources_and_output_format( async def infer_webpage_urls( q: str, max_webpages: int, - conversation_history: dict, + chat_history: List[ChatMessageModel], location_data: LocationData, user: KhojUser, query_images: List[str] = None, @@ -475,7 +476,7 @@ async def infer_webpage_urls( """ location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d") personality_context = ( @@ -485,7 +486,7 @@ async def infer_webpage_urls( online_queries_prompt = prompts.infer_webpages_to_read.format( query=q, max_webpages=max_webpages, - chat_history=chat_history, + chat_history=chat_history_str, current_date=utc_date, location=location, username=username, @@ -526,7 +527,7 @@ async def infer_webpage_urls( async def generate_online_subqueries( q: str, - conversation_history: dict, + chat_history: List[ChatMessageModel], location_data: LocationData, user: KhojUser, query_images: List[str] = None, @@ -540,7 +541,7 @@ async def generate_online_subqueries( """ location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d") personality_context = ( @@ -549,7 +550,7 @@ async def generate_online_subqueries( online_queries_prompt = prompts.online_search_conversation_subqueries.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, max_queries=max_queries, current_date=utc_date, location=location, @@ -591,16 +592,16 @@ async def generate_online_subqueries( def schedule_query( - q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {} + q: str, chat_history: List[ChatMessageModel], user: KhojUser, query_images: List[str] = None, tracer: dict = {} ) -> Tuple[str, str, str]: """ Schedule the date, time to run the query. Assume the server timezone is UTC. """ - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) crontime_prompt = prompts.crontime_prompt.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, ) raw_response = send_message_to_model_wrapper_sync( @@ -619,16 +620,16 @@ def schedule_query( async def aschedule_query( - q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {} + q: str, chat_history: List[ChatMessageModel], user: KhojUser, query_images: List[str] = None, tracer: dict = {} ) -> Tuple[str, str, str]: """ Schedule the date, time to run the query. Assume the server timezone is UTC. """ - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) crontime_prompt = prompts.crontime_prompt.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, ) raw_response = await send_message_to_model_wrapper( @@ -681,7 +682,7 @@ async def extract_relevant_info( async def extract_relevant_summary( q: str, corpus: str, - conversation_history: dict, + chat_history: List[ChatMessageModel] = [], query_images: List[str] = None, user: KhojUser = None, agent: Agent = None, @@ -698,11 +699,11 @@ async def extract_relevant_summary( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" ) - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) extract_relevant_information = prompts.extract_relevant_summary.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, corpus=corpus.strip(), personality_context=personality_context, ) @@ -725,7 +726,7 @@ async def generate_summary_from_files( q: str, user: KhojUser, file_filters: List[str], - meta_log: dict, + chat_history: List[ChatMessageModel] = [], query_images: List[str] = None, agent: Agent = None, send_status_func: Optional[Callable] = None, @@ -766,7 +767,7 @@ async def generate_summary_from_files( response = await extract_relevant_summary( q, contextual_data, - conversation_history=meta_log, + chat_history=chat_history, query_images=query_images, user=user, agent=agent, @@ -782,7 +783,7 @@ async def generate_summary_from_files( async def generate_excalidraw_diagram( q: str, - conversation_history: Dict[str, Any], + chat_history: List[ChatMessageModel], location_data: LocationData, note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, @@ -799,7 +800,7 @@ async def generate_excalidraw_diagram( better_diagram_description_prompt = await generate_better_diagram_description( q=q, - conversation_history=conversation_history, + chat_history=chat_history, location_data=location_data, note_references=note_references, online_results=online_results, @@ -834,7 +835,7 @@ async def generate_excalidraw_diagram( async def generate_better_diagram_description( q: str, - conversation_history: Dict[str, Any], + chat_history: List[ChatMessageModel], location_data: LocationData, note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, @@ -857,7 +858,7 @@ async def generate_better_diagram_description( user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references]) - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) simplified_online_results = {} @@ -870,7 +871,7 @@ async def generate_better_diagram_description( improve_diagram_description_prompt = prompts.improve_excalidraw_diagram_description_prompt.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, location=location, current_date=today_date, references=user_references, @@ -939,7 +940,7 @@ async def generate_excalidraw_diagram_from_description( async def generate_mermaidjs_diagram( q: str, - conversation_history: Dict[str, Any], + chat_history: List[ChatMessageModel], location_data: LocationData, note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, @@ -956,7 +957,7 @@ async def generate_mermaidjs_diagram( better_diagram_description_prompt = await generate_better_mermaidjs_diagram_description( q=q, - conversation_history=conversation_history, + chat_history=chat_history, location_data=location_data, note_references=note_references, online_results=online_results, @@ -985,7 +986,7 @@ async def generate_mermaidjs_diagram( async def generate_better_mermaidjs_diagram_description( q: str, - conversation_history: Dict[str, Any], + chat_history: List[ChatMessageModel], location_data: LocationData, note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, @@ -1008,7 +1009,7 @@ async def generate_better_mermaidjs_diagram_description( user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references]) - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) simplified_online_results = {} @@ -1021,7 +1022,7 @@ async def generate_better_mermaidjs_diagram_description( improve_diagram_description_prompt = prompts.improve_mermaid_js_diagram_description_prompt.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, location=location, current_date=today_date, references=user_references, @@ -1160,7 +1161,7 @@ async def send_message_to_model_wrapper( query_images: List[str] = None, context: str = "", query_files: str = None, - conversation_log: dict = {}, + chat_history: list[ChatMessageModel] = [], agent_chat_model: ChatModel = None, tracer: dict = {}, ): @@ -1193,7 +1194,7 @@ async def send_message_to_model_wrapper( user_message=query, context_message=context, system_message=system_message, - conversation_log=conversation_log, + chat_history=chat_history, model_name=chat_model_name, loaded_model=loaded_model, tokenizer_name=tokenizer, @@ -1260,7 +1261,7 @@ def send_message_to_model_wrapper_sync( user: KhojUser = None, query_images: List[str] = None, query_files: str = "", - conversation_log: dict = {}, + chat_history: List[ChatMessageModel] = [], tracer: dict = {}, ): chat_model: ChatModel = ConversationAdapters.get_default_chat_model(user) @@ -1284,7 +1285,7 @@ def send_message_to_model_wrapper_sync( truncated_messages = generate_chatml_messages_with_context( user_message=message, system_message=system_message, - conversation_log=conversation_log, + chat_history=chat_history, model_name=chat_model_name, loaded_model=loaded_model, max_prompt_size=max_tokens, @@ -1342,7 +1343,7 @@ def send_message_to_model_wrapper_sync( async def agenerate_chat_response( q: str, - meta_log: dict, + chat_history: List[ChatMessageModel], conversation: Conversation, compiled_references: List[Dict] = [], online_results: Dict[str, Dict] = {}, @@ -1379,7 +1380,7 @@ async def agenerate_chat_response( save_to_conversation_log, q, user=user, - meta_log=meta_log, + chat_history=chat_history, compiled_references=compiled_references, online_results=online_results, code_results=code_results, @@ -1424,7 +1425,7 @@ async def agenerate_chat_response( references=compiled_references, online_results=online_results, loaded_model=loaded_model, - conversation_log=meta_log, + chat_history=chat_history, completion_func=partial_completion, conversation_commands=conversation_commands, model_name=chat_model.name, @@ -1450,7 +1451,7 @@ async def agenerate_chat_response( online_results=online_results, code_results=code_results, operator_results=operator_results, - conversation_log=meta_log, + chat_history=chat_history, model=chat_model_name, api_key=api_key, api_base_url=openai_chat_config.api_base_url, @@ -1480,7 +1481,7 @@ async def agenerate_chat_response( online_results=online_results, code_results=code_results, operator_results=operator_results, - conversation_log=meta_log, + chat_history=chat_history, model=chat_model.name, api_key=api_key, api_base_url=api_base_url, @@ -1508,7 +1509,7 @@ async def agenerate_chat_response( online_results=online_results, code_results=code_results, operator_results=operator_results, - conversation_log=meta_log, + chat_history=chat_history, model=chat_model.name, api_key=api_key, api_base_url=api_base_url, @@ -2005,11 +2006,11 @@ async def create_automation( timezone: str, user: KhojUser, calling_url: URL, - meta_log: dict = {}, + chat_history: List[ChatMessageModel] = [], conversation_id: str = None, tracer: dict = {}, ): - crontime, query_to_run, subject = await aschedule_query(q, meta_log, user, tracer=tracer) + crontime, query_to_run, subject = await aschedule_query(q, chat_history, user, tracer=tracer) job = await aschedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id) return job, crontime, query_to_run, subject diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index e72b25a5..2dc63efb 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -10,7 +10,7 @@ import yaml from pydantic import BaseModel, Field from khoj.database.adapters import AgentAdapters, EntryAdapters -from khoj.database.models import Agent, KhojUser +from khoj.database.models import Agent, ChatMessageModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( OperatorRun, @@ -84,7 +84,7 @@ class PlanningResponse(BaseModel): async def apick_next_tool( query: str, - conversation_history: dict, + conversation_history: List[ChatMessageModel], user: KhojUser = None, location: LocationData = None, user_name: str = None, @@ -166,18 +166,18 @@ async def apick_next_tool( query = f"[placeholder for user attached images]\n{query}" # Construct chat history with user and iteration history with researcher agent for context - previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query) - iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history} + iteration_chat_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query) + chat_and_research_history = conversation_history + iteration_chat_history # Plan function execution for the next tool - query = prompts.plan_function_execution_next_tool.format(query=query) if previous_iterations_history else query + query = prompts.plan_function_execution_next_tool.format(query=query) if iteration_chat_history else query try: with timer("Chat actor: Infer information sources to refer", logger): response = await send_message_to_model_wrapper( query=query, system_message=function_planning_prompt, - conversation_log=iteration_chat_log, + chat_history=chat_and_research_history, response_type="json_object", response_schema=planning_response_model, deepthought=True, @@ -238,7 +238,7 @@ async def research( user: KhojUser, query: str, conversation_id: str, - conversation_history: dict, + conversation_history: List[ChatMessageModel], previous_iterations: List[ResearchIteration], query_images: List[str], agent: Agent = None, @@ -261,9 +261,7 @@ async def research( if current_iteration := len(previous_iterations) > 0: logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.") previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) - research_conversation_history["chat"] = ( - research_conversation_history.get("chat", []) + previous_iterations_history - ) + research_conversation_history += previous_iterations_history while current_iteration < MAX_ITERATIONS: # Check for cancellation at the start of each iteration diff --git a/tests/helpers.py b/tests/helpers.py index b2c6a3b1..fb738015 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,6 +6,7 @@ from django.utils.timezone import make_aware from khoj.database.models import ( AiModelApi, + ChatMessageModel, ChatModel, Conversation, KhojApiUser, @@ -46,15 +47,15 @@ def get_chat_api_key(provider: ChatModel.ModelType = None): def generate_chat_history(message_list): # Generate conversation logs - conversation_log = {"chat": []} + chat_history: list[ChatMessageModel] = [] for user_message, chat_response, context in message_list: message_to_log( user_message, chat_response, {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, - conversation_log=conversation_log.get("chat", []), + chat_history=chat_history, ) - return conversation_log + return chat_history class UserFactory(factory.django.DjangoModelFactory): diff --git a/tests/test_offline_chat_actors.py b/tests/test_offline_chat_actors.py index 6f18f658..7c2571dc 100644 --- a/tests/test_offline_chat_actors.py +++ b/tests/test_offline_chat_actors.py @@ -135,7 +135,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model): # Act response = extract_questions_offline( query, - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, use_history=True, ) @@ -181,7 +181,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model): # Act response = extract_questions_offline( "Is she a Doctor?", - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, use_history=True, ) @@ -210,7 +210,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo # Act response = extract_questions_offline( "What was the Pizza place we ate at over there?", - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, ) @@ -336,7 +336,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model) response_gen = converse_offline( references=[], # Assume no context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, ) response = "".join([response_chunk for response_chunk in response_gen]) @@ -363,7 +363,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model): {"compiled": "Testatron was born on 1st April 1984 in Testville."} ], # Assume context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, ) response = "".join([response_chunk for response_chunk in response_gen]) @@ -388,7 +388,7 @@ def test_refuse_answering_unanswerable_question(loaded_model): response_gen = converse_offline( references=[], # Assume no context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, ) response = "".join([response_chunk for response_chunk in response_gen]) @@ -501,7 +501,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded response_gen = converse_offline( references=[], # Assume no context retrieved from notes for the user_query user_query="Write a haiku about unit testing in 3 lines", - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, ) response = "".join([response_chunk for response_chunk in response_gen]) diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py index b681d38c..0eb4a0dc 100644 --- a/tests/test_offline_chat_director.py +++ b/tests/test_offline_chat_director.py @@ -28,7 +28,7 @@ def generate_history(message_list): user_message, gpt_message, {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, - conversation_log=conversation_log.get("chat", []), + chat_history=conversation_log.get("chat", []), ) return conversation_log diff --git a/tests/test_online_chat_actors.py b/tests/test_online_chat_actors.py index 2b08fc74..fc69b149 100644 --- a/tests/test_online_chat_actors.py +++ b/tests/test_online_chat_actors.py @@ -708,6 +708,6 @@ def populate_chat_history(message_list): "context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}, }, - conversation_log=[], + chat_history=[], ) return conversation_log