diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 3080d0f0..39c94f8a 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -37,6 +37,7 @@ from torch import Tensor from khoj.database.models import ( Agent, AiModelApi, + ChatMessageModel, ChatModel, ClientApplication, Conversation, @@ -1419,7 +1420,7 @@ class ConversationAdapters: @require_valid_user async def save_conversation( user: KhojUser, - conversation_log: dict, + chat_history: List[ChatMessageModel], client_application: ClientApplication = None, conversation_id: str = None, user_message: str = None, @@ -1434,6 +1435,7 @@ class ConversationAdapters: await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() ) + conversation_log = {"chat": [msg.model_dump() for msg in chat_history]} cleaned_conversation_log = clean_object_for_db(conversation_log) if conversation: conversation.conversation_log = cleaned_conversation_log diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 84948015..2dded5ad 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -91,7 +91,7 @@ class OnlineContext(PydanticBaseModel): class Intent(PydanticBaseModel): type: str query: str - memory_type: str = Field(alias="memory-type") + memory_type: Optional[str] = Field(alias="memory-type", default=None) inferred_queries: Optional[List[str]] = Field(default=None, alias="inferred-queries") @@ -100,20 +100,20 @@ class TrainOfThought(PydanticBaseModel): data: str -class ChatMessage(PydanticBaseModel): - message: str +class ChatMessageModel(PydanticBaseModel): + by: str + message: str | list[dict] trainOfThought: List[TrainOfThought] = [] context: List[Context] = [] onlineContext: Dict[str, OnlineContext] = {} codeContext: Dict[str, CodeContextData] = {} researchContext: Optional[List] = None operatorContext: Optional[List] = None - created: str + created: Optional[str] = None images: Optional[List[str]] = None queryFiles: Optional[List[Dict]] = None excalidrawDiagram: Optional[List[Dict]] = None - mermaidjsDiagram: str = None - by: str + mermaidjsDiagram: Optional[str] = None turnId: Optional[str] = None intent: Optional[Intent] = None automationId: Optional[str] = None @@ -634,7 +634,7 @@ class Conversation(DbBaseModel): try: messages = self.conversation_log.get("chat", []) for msg in messages: - ChatMessage.model_validate(msg) + ChatMessageModel.model_validate(msg) except Exception as e: raise ValidationError(f"Invalid conversation_log format: {str(e)}") @@ -643,7 +643,7 @@ class Conversation(DbBaseModel): super().save(*args, **kwargs) @property - def messages(self) -> List[ChatMessage]: + def messages(self) -> List[ChatMessageModel]: """Type-hinted accessor for conversation messages""" validated_messages = [] for msg in self.conversation_log.get("chat", []): @@ -654,7 +654,7 @@ class Conversation(DbBaseModel): q for q in msg["intent"]["inferred-queries"] if q is not None and isinstance(q, str) ] msg["message"] = str(msg.get("message", "")) - validated_messages.append(ChatMessage.model_validate(msg)) + validated_messages.append(ChatMessageModel.model_validate(msg)) except ValidationError as e: logger.warning(f"Skipping invalid message in conversation: {e}") continue diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 3d34573e..775b8b99 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -6,7 +6,7 @@ from typing import AsyncGenerator, Dict, List, Optional import pyjson5 from langchain_core.messages.chat import ChatMessage -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.utils import ( anthropic_chat_completion_with_backoff, @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) def extract_questions_anthropic( text, model: Optional[str] = "claude-3-7-sonnet-latest", - conversation_log={}, + chat_history: List[ChatMessageModel] = [], api_key=None, api_base_url=None, location_data: LocationData = None, @@ -54,8 +54,8 @@ def extract_questions_anthropic( location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" - # Extract Past User Message and Inferred Questions from Conversation Log - chat_history = construct_question_history(conversation_log, query_prefix="User", agent_name="Assistant") + # Extract Past User Message and Inferred Questions from Chat History + chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant") # Get dates relative to today for prompt creation today = datetime.today() @@ -76,7 +76,7 @@ def extract_questions_anthropic( ) prompt = prompts.extract_questions_anthropic_user_message.format( - chat_history=chat_history, + chat_history=chat_history_str, text=text, ) @@ -142,7 +142,7 @@ async def converse_anthropic( online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, operator_results: Optional[List[OperatorRun]] = None, - conversation_log={}, + chat_history: List[ChatMessageModel] = [], model: Optional[str] = "claude-3-7-sonnet-latest", api_key: Optional[str] = None, api_base_url: Optional[str] = None, @@ -225,7 +225,7 @@ async def converse_anthropic( messages = generate_chatml_messages_with_context( user_query, context_message=context_message, - conversation_log=conversation_log, + chat_history=chat_history, model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index f9219c28..b2f48c81 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -7,7 +7,7 @@ import pyjson5 from langchain_core.messages.chat import ChatMessage from pydantic import BaseModel, Field -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.google.utils import ( gemini_chat_completion_with_backoff, @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) def extract_questions_gemini( text, model: Optional[str] = "gemini-2.0-flash", - conversation_log={}, + chat_history: List[ChatMessageModel] = [], api_key=None, api_base_url=None, max_tokens=None, @@ -54,8 +54,8 @@ def extract_questions_gemini( location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" - # Extract Past User Message and Inferred Questions from Conversation Log - chat_history = construct_question_history(conversation_log, query_prefix="User", agent_name="Assistant") + # Extract Past User Message and Inferred Questions from Chat History + chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant") # Get dates relative to today for prompt creation today = datetime.today() @@ -76,7 +76,7 @@ def extract_questions_gemini( ) prompt = prompts.extract_questions_anthropic_user_message.format( - chat_history=chat_history, + chat_history=chat_history_str, text=text, ) @@ -163,7 +163,7 @@ async def converse_gemini( online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, operator_results: Optional[List[OperatorRun]] = None, - conversation_log={}, + chat_history: List[ChatMessageModel] = [], model: Optional[str] = "gemini-2.0-flash", api_key: Optional[str] = None, api_base_url: Optional[str] = None, @@ -248,7 +248,7 @@ async def converse_gemini( messages = generate_chatml_messages_with_context( user_query, context_message=context_message, - conversation_log=conversation_log, + chat_history=chat_history, model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 0ecf62fd..27cd9a9e 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -10,7 +10,7 @@ import pyjson5 from langchain_core.messages.chat import ChatMessage from llama_cpp import Llama -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( @@ -38,7 +38,7 @@ def extract_questions_offline( text: str, model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", loaded_model: Union[Any, None] = None, - conversation_log={}, + chat_history: List[ChatMessageModel] = [], use_history: bool = True, should_extract_questions: bool = True, location_data: LocationData = None, @@ -65,7 +65,7 @@ def extract_questions_offline( username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log - chat_history = construct_question_history(conversation_log, include_query=False) if use_history else "" + chat_history_str = construct_question_history(chat_history, include_query=False) if use_history else "" # Get dates relative to today for prompt creation today = datetime.today() @@ -73,7 +73,7 @@ def extract_questions_offline( last_year = today.year - 1 example_questions = prompts.extract_questions_offline.format( query=text, - chat_history=chat_history, + chat_history=chat_history_str, current_date=today.strftime("%Y-%m-%d"), day_of_week=today.strftime("%A"), current_month=today.strftime("%Y-%m"), @@ -147,7 +147,7 @@ async def converse_offline( references: list[dict] = [], online_results={}, code_results={}, - conversation_log={}, + chat_history: list[ChatMessageModel] = [], model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", loaded_model: Union[Any, None] = None, completion_func=None, @@ -227,7 +227,7 @@ async def converse_offline( messages = generate_chatml_messages_with_context( user_query, system_prompt, - conversation_log, + chat_history, context_message=context_message, model_name=model_name, loaded_model=offline_chat_model, diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index b3c440ff..65cbbfca 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -8,7 +8,7 @@ from langchain_core.messages.chat import ChatMessage from openai.lib._pydantic import _ensure_strict_json_schema from pydantic import BaseModel -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, @@ -39,7 +39,7 @@ logger = logging.getLogger(__name__) def extract_questions( text, model: Optional[str] = "gpt-4o-mini", - conversation_log={}, + chat_history: list[ChatMessageModel] = [], api_key=None, api_base_url=None, location_data: LocationData = None, @@ -56,8 +56,8 @@ def extract_questions( location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" - # Extract Past User Message and Inferred Questions from Conversation Log - chat_history = construct_question_history(conversation_log) + # Extract Past User Message and Inferred Questions from Chat History + chat_history_str = construct_question_history(chat_history) # Get dates relative to today for prompt creation today = datetime.today() @@ -73,7 +73,7 @@ def extract_questions( current_new_year_date=current_new_year.strftime("%Y-%m-%d"), bob_tom_age_difference={current_new_year.year - 1984 - 30}, bob_age={current_new_year.year - 1984}, - chat_history=chat_history, + chat_history=chat_history_str, text=text, yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), location=location, @@ -166,7 +166,7 @@ async def converse_openai( online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, operator_results: Optional[List[OperatorRun]] = None, - conversation_log={}, + chat_history: list[ChatMessageModel] = [], model: str = "gpt-4o-mini", api_key: Optional[str] = None, api_base_url: Optional[str] = None, @@ -251,7 +251,7 @@ async def converse_openai( messages = generate_chatml_messages_with_context( user_query, system_prompt, - conversation_log, + chat_history, context_message=context_message, model_name=model, max_prompt_size=max_prompt_size, diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 71ca6526..90b85e47 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -24,7 +24,13 @@ from pydantic import BaseModel from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from khoj.database.adapters import ConversationAdapters -from khoj.database.models import ChatModel, ClientApplication, KhojUser +from khoj.database.models import ( + ChatMessageModel, + ChatModel, + ClientApplication, + Intent, + KhojUser, +) from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.search_filter.base_filter import BaseFilter @@ -161,8 +167,8 @@ def construct_iteration_history( previous_iterations: List[ResearchIteration], previous_iteration_prompt: str, query: str = None, -) -> list[dict]: - iteration_history: list[dict] = [] +) -> list[ChatMessageModel]: + iteration_history: list[ChatMessageModel] = [] previous_iteration_messages: list[dict] = [] for idx, iteration in enumerate(previous_iterations): iteration_data = previous_iteration_prompt.format( @@ -176,46 +182,46 @@ def construct_iteration_history( if previous_iteration_messages: if query: - iteration_history.append({"by": "you", "message": query}) + iteration_history.append(ChatMessageModel(by="you", message=query)) iteration_history.append( - { - "by": "khoj", - "intent": {"type": "remember", "query": query}, - "message": previous_iteration_messages, - } + ChatMessageModel( + by="khoj", + intent={"type": "remember", "query": query}, + message=previous_iteration_messages, + ) ) return iteration_history -def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: - chat_history = "" - for chat in conversation_history.get("chat", [])[-n:]: - if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]: - if chat["intent"].get("inferred-queries"): - chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n' - chat_history += f"{agent_name}: {chat['message']}\n\n" - elif chat["by"] == "khoj" and chat.get("images"): - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: [generated image redacted for space]\n" - elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")): - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n" - elif chat["by"] == "you": - chat_history += f"User: {chat['message']}\n" - raw_query_files = chat.get("queryFiles") +def construct_chat_history(chat_history: list[ChatMessageModel], n: int = 4, agent_name="AI") -> str: + chat_history_str = "" + for chat in chat_history[-n:]: + if chat.by == "khoj" and chat.intent.type in ["remember", "reminder", "summarize"]: + if chat.intent.inferred_queries: + chat_history_str += f'{agent_name}: {{"queries": {chat.intent.inferred_queries}}}\n' + chat_history_str += f"{agent_name}: {chat.message}\n\n" + elif chat.by == "khoj" and chat.images: + chat_history_str += f"User: {chat.intent.query}\n" + chat_history_str += f"{agent_name}: [generated image redacted for space]\n" + elif chat.by == "khoj" and ("excalidraw" in chat.intent.type): + chat_history_str += f"User: {chat.intent.query}\n" + chat_history_str += f"{agent_name}: {chat.intent.inferred_queries[0]}\n" + elif chat.by == "you": + chat_history_str += f"User: {chat.message}\n" + raw_query_files = chat.queryFiles if raw_query_files: query_files: Dict[str, str] = {} for file in raw_query_files: query_files[file["name"]] = file["content"] query_file_context = gather_raw_query_files(query_files) - chat_history += f"User: {query_file_context}\n" + chat_history_str += f"User: {query_file_context}\n" - return chat_history + return chat_history_str def construct_question_history( - conversation_log: dict, + conversation_log: list[ChatMessageModel], include_query: bool = True, lookback: int = 6, query_prefix: str = "Q", @@ -226,16 +232,16 @@ def construct_question_history( """ history_parts = "" original_query = None - for chat in conversation_log.get("chat", [])[-lookback:]: - if chat["by"] == "you": - original_query = chat.get("message") + for chat in conversation_log[-lookback:]: + if chat.by == "you": + original_query = json.dumps(chat.message) history_parts += f"{query_prefix}: {original_query}\n" - if chat["by"] == "khoj": + if chat.by == "khoj": if original_query is None: continue - message = chat.get("message", "") - inferred_queries_list = chat.get("intent", {}).get("inferred-queries") + message = chat.message + inferred_queries_list = chat.intent.inferred_queries or [] # Ensure inferred_queries_list is a list, defaulting to the original query in a list if not inferred_queries_list: @@ -246,7 +252,7 @@ def construct_question_history( if include_query: # Ensure 'type' exists and is a string before checking 'to-image' - intent_type = chat.get("intent", {}).get("type", "") + intent_type = chat.intent.type if chat.intent and chat.intent.type else "" if "to-image" not in intent_type: history_parts += f'{agent_name}: {{"queries": {inferred_queries_list}}}\n' history_parts += f"A: {message}\n\n" @@ -259,7 +265,7 @@ def construct_question_history( return history_parts -def construct_chat_history_for_operator(conversation_history: dict, n: int = 6) -> list[AgentMessage]: +def construct_chat_history_for_operator(conversation_history: List[ChatMessageModel], n: int = 6) -> list[AgentMessage]: """ Construct chat history for operator agent in conversation log. Only include last n completed turns (i.e with user and khoj message). @@ -267,22 +273,22 @@ def construct_chat_history_for_operator(conversation_history: dict, n: int = 6) chat_history: list[AgentMessage] = [] user_message: Optional[AgentMessage] = None - for chat in conversation_history.get("chat", []): + for chat in conversation_history: if len(chat_history) >= n: break - if chat["by"] == "you" and chat.get("message"): - content = [{"type": "text", "text": chat["message"]}] - for file in chat.get("queryFiles", []): + if chat.by == "you" and chat.message: + content = [{"type": "text", "text": chat.message}] + for file in chat.queryFiles or []: content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}] user_message = AgentMessage(role="user", content=content) - elif chat["by"] == "khoj" and chat.get("message"): - chat_history += [user_message, AgentMessage(role="assistant", content=chat["message"])] + elif chat.by == "khoj" and chat.message: + chat_history += [user_message, AgentMessage(role="assistant", content=chat.message)] return chat_history def construct_tool_chat_history( previous_iterations: List[ResearchIteration], tool: ConversationCommand = None -) -> Dict[str, list]: +) -> List[ChatMessageModel]: """ Construct chat history from previous iterations for a specific tool @@ -313,22 +319,23 @@ def construct_tool_chat_history( tool or ConversationCommand(iteration.tool), base_extractor ) chat_history += [ - { - "by": "you", - "message": iteration.query, - }, - { - "by": "khoj", - "intent": { - "type": "remember", - "inferred-queries": inferred_query_extractor(iteration), - "query": iteration.query, - }, - "message": iteration.summarizedResult, - }, + ChatMessageModel( + by="you", + message=iteration.query, + ), + ChatMessageModel( + by="khoj", + intent=Intent( + type="remember", + query=iteration.query, + inferred_queries=inferred_query_extractor(iteration), + memory_type="notes", + ), + message=iteration.summarizedResult, + ), ] - return {"chat": chat_history} + return chat_history class ChatEvent(Enum): @@ -349,8 +356,8 @@ def message_to_log( chat_response, user_message_metadata={}, khoj_message_metadata={}, - conversation_log=[], -): + chat_history: List[ChatMessageModel] = [], +) -> List[ChatMessageModel]: """Create json logs from messages, metadata for conversation log""" default_khoj_message_metadata = { "intent": {"type": "remember", "memory-type": "notes", "query": user_message}, @@ -369,15 +376,17 @@ def message_to_log( khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata) khoj_log = merge_dicts({"message": chat_response, "by": "khoj", "created": khoj_response_time}, khoj_log) - conversation_log.extend([human_log, khoj_log]) - return conversation_log + human_message = ChatMessageModel(**human_log) + khoj_message = ChatMessageModel(**khoj_log) + chat_history.extend([human_message, khoj_message]) + return chat_history async def save_to_conversation_log( q: str, chat_response: str, user: KhojUser, - meta_log: Dict, + chat_history: List[ChatMessageModel], user_message_time: str = None, compiled_references: List[Dict[str, Any]] = [], online_results: Dict[str, Any] = {}, @@ -427,11 +436,11 @@ async def save_to_conversation_log( chat_response=chat_response, user_message_metadata=user_message_metadata, khoj_message_metadata=khoj_message_metadata, - conversation_log=meta_log.get("chat", []), + chat_history=chat_history, ) await ConversationAdapters.save_conversation( user, - {"chat": updated_conversation}, + updated_conversation, client_application=client_application, conversation_id=conversation_id, user_message=q, @@ -502,7 +511,7 @@ def gather_raw_query_files( def generate_chatml_messages_with_context( user_message: str, system_message: str = None, - conversation_log={}, + chat_history: list[ChatMessageModel] = [], model_name="gpt-4o-mini", loaded_model: Optional[Llama] = None, max_prompt_size=None, @@ -529,21 +538,21 @@ def generate_chatml_messages_with_context( # Extract Chat History for Context chatml_messages: List[ChatMessage] = [] - for chat in conversation_log.get("chat", []): + for chat in chat_history: message_context = [] message_attached_files = "" generated_assets = {} - chat_message = chat.get("message") - role = "user" if chat["by"] == "you" else "assistant" + chat_message = chat.message + role = "user" if chat.by == "you" else "assistant" # Legacy code to handle excalidraw diagrams prior to Dec 2024 - if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""): - chat_message = chat["intent"].get("inferred-queries")[0] + if chat.by == "khoj" and "excalidraw" in chat.intent.type or "": + chat_message = (chat.intent.inferred_queries or [])[0] - if chat.get("queryFiles"): - raw_query_files = chat.get("queryFiles") + if chat.queryFiles: + raw_query_files = chat.queryFiles query_files_dict = dict() for file in raw_query_files: query_files_dict[file["name"]] = file["content"] @@ -551,24 +560,24 @@ def generate_chatml_messages_with_context( message_attached_files = gather_raw_query_files(query_files_dict) chatml_messages.append(ChatMessage(content=message_attached_files, role=role)) - if not is_none_or_empty(chat.get("onlineContext")): + if not is_none_or_empty(chat.onlineContext): message_context += [ { "type": "text", - "text": f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}", + "text": f"{prompts.online_search_conversation.format(online_results=chat.onlineContext)}", } ] - if not is_none_or_empty(chat.get("codeContext")): + if not is_none_or_empty(chat.codeContext): message_context += [ { "type": "text", - "text": f"{prompts.code_executed_context.format(code_results=chat.get('codeContext'))}", + "text": f"{prompts.code_executed_context.format(code_results=chat.codeContext)}", } ] - if not is_none_or_empty(chat.get("operatorContext")): - operator_context = chat.get("operatorContext") + if not is_none_or_empty(chat.operatorContext): + operator_context = chat.operatorContext operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context]) message_context += [ { @@ -577,13 +586,9 @@ def generate_chatml_messages_with_context( } ] - if not is_none_or_empty(chat.get("context")): + if not is_none_or_empty(chat.context): references = "\n\n".join( - { - f"# File: {item['file']}\n## {item['compiled']}\n" - for item in chat.get("context") or [] - if isinstance(item, dict) - } + {f"# File: {item.file}\n## {item.compiled}\n" for item in chat.context or [] if isinstance(item, dict)} ) message_context += [{"type": "text", "text": f"{prompts.notes_conversation.format(references=references)}"}] @@ -591,14 +596,14 @@ def generate_chatml_messages_with_context( reconstructed_context_message = ChatMessage(content=message_context, role="user") chatml_messages.insert(0, reconstructed_context_message) - if not is_none_or_empty(chat.get("images")) and role == "assistant": + if not is_none_or_empty(chat.images) and role == "assistant": generated_assets["image"] = { - "query": chat.get("intent", {}).get("inferred-queries", [user_message])[0], + "query": (chat.intent.inferred_queries or [user_message])[0], } - if not is_none_or_empty(chat.get("mermaidjsDiagram")) and role == "assistant": + if not is_none_or_empty(chat.mermaidjsDiagram) and role == "assistant": generated_assets["diagram"] = { - "query": chat.get("intent", {}).get("inferred-queries", [user_message])[0], + "query": (chat.intent.inferred_queries or [user_message])[0], } if not is_none_or_empty(generated_assets): @@ -610,7 +615,7 @@ def generate_chatml_messages_with_context( ) message_content = construct_structured_message( - chat_message, chat.get("images") if role == "user" else [], model_type, vision_enabled + chat_message, chat.images if role == "user" else [], model_type, vision_enabled ) reconstructed_message = ChatMessage(content=message_content, role=role) diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index f1b84431..46b989d6 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -10,7 +10,12 @@ from google import genai from google.genai import types as gtypes from khoj.database.adapters import ConversationAdapters -from khoj.database.models import Agent, KhojUser, TextToImageModelConfig +from khoj.database.models import ( + Agent, + ChatMessageModel, + KhojUser, + TextToImageModelConfig, +) from khoj.routers.helpers import ChatEvent, generate_better_image_prompt from khoj.routers.storage import upload_generated_image_to_bucket from khoj.utils import state @@ -23,7 +28,7 @@ logger = logging.getLogger(__name__) async def text_to_image( message: str, user: KhojUser, - conversation_log: dict, + chat_history: List[ChatMessageModel], location_data: LocationData, references: List[Dict[str, Any]], online_results: Dict[str, Any], @@ -46,14 +51,14 @@ async def text_to_image( return text2image_model = text_to_image_config.model_name - chat_history = "" - for chat in conversation_log.get("chat", [])[-4:]: - if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]: - chat_history += f"Q: {chat['intent']['query']}\n" - chat_history += f"A: {chat['message']}\n" - elif chat["by"] == "khoj" and chat.get("images"): - chat_history += f"Q: {chat['intent']['query']}\n" - chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" + chat_history_str = "" + for chat in chat_history[-4:]: + if chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]: + chat_history_str += f"Q: {chat.intent.query or ''}\n" + chat_history_str += f"A: {chat.message}\n" + elif chat.by == "khoj" and chat.images: + chat_history_str += f"Q: {chat.intent.query}\n" + chat_history_str += f"A: Improved Prompt: {chat.intent.inferred_queries[0]}\n" if send_status_func: async for event in send_status_func("**Enhancing the Painting Prompt**"): @@ -63,7 +68,7 @@ async def text_to_image( # Use the user's message, chat history, and other context image_prompt = await generate_better_image_prompt( message, - chat_history, + chat_history_str, location_data=location_data, note_references=references, online_results=online_results, diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index 138a2696..9b4ad80f 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -5,10 +5,9 @@ import os from typing import Callable, List, Optional from khoj.database.adapters import AgentAdapters, ConversationAdapters -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser from khoj.processor.conversation.utils import ( OperatorRun, - construct_chat_history, construct_chat_history_for_operator, ) from khoj.processor.operator.operator_actions import * @@ -34,7 +33,7 @@ logger = logging.getLogger(__name__) async def operate_environment( query: str, user: KhojUser, - conversation_log: dict, + conversation_log: List[ChatMessageModel], location_data: LocationData, previous_trajectory: Optional[OperatorRun] = None, environment_type: EnvironmentType = EnvironmentType.COMPUTER, diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index f4df22fe..8106e9cc 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -4,7 +4,7 @@ from datetime import datetime from textwrap import dedent from typing import List, Optional -from khoj.database.models import ChatModel +from khoj.database.models import ChatMessageModel, ChatModel from khoj.processor.conversation.utils import ( AgentMessage, OperatorRun, @@ -119,13 +119,13 @@ class BinaryOperatorAgent(OperatorAgent): query_screenshot = self._get_message_images(current_message) # Construct input for visual reasoner history - visual_reasoner_history = {"chat": self._format_message_for_api(self.messages)} + visual_reasoner_history = self._format_message_for_api(self.messages) try: natural_language_action = await send_message_to_model_wrapper( query=query_text, query_images=query_screenshot, system_message=reasoning_system_prompt, - conversation_log=visual_reasoner_history, + chat_history=visual_reasoner_history, agent_chat_model=self.reasoning_model, tracer=self.tracer, ) @@ -238,11 +238,11 @@ class BinaryOperatorAgent(OperatorAgent): async def summarize(self, env_state: EnvState, summarize_prompt: str = None) -> str: summarize_prompt = summarize_prompt or self.summarize_prompt - conversation_history = {"chat": self._format_message_for_api(self.messages)} + conversation_history = self._format_message_for_api(self.messages) try: summary = await send_message_to_model_wrapper( query=summarize_prompt, - conversation_log=conversation_history, + chat_history=conversation_history, agent_chat_model=self.reasoning_model, tracer=self.tracer, ) @@ -296,14 +296,14 @@ class BinaryOperatorAgent(OperatorAgent): images = [item["image_url"]["url"] for item in message.content if item["type"] == "image_url"] return images - def _format_message_for_api(self, messages: list[AgentMessage]) -> List[dict]: + def _format_message_for_api(self, messages: list[AgentMessage]) -> List[ChatMessageModel]: """Format operator agent messages into the Khoj conversation history format.""" formatted_messages = [ - { - "message": self._get_message_text(message), - "images": self._get_message_images(message), - "by": "you" if message.role in ["user", "environment"] else message.role, - } + ChatMessageModel( + message=self._get_message_text(message), + images=self._get_message_images(message), + by="you" if message.role in ["user", "environment"] else message.role, + ) for message in messages ] return formatted_messages diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 8b39cc18..8951bc46 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -10,7 +10,13 @@ from bs4 import BeautifulSoup from markdownify import markdownify from khoj.database.adapters import ConversationAdapters -from khoj.database.models import Agent, KhojUser, ServerChatSettings, WebScraper +from khoj.database.models import ( + Agent, + ChatMessageModel, + KhojUser, + ServerChatSettings, + WebScraper, +) from khoj.processor.conversation import prompts from khoj.routers.helpers import ( ChatEvent, @@ -59,7 +65,7 @@ OLOSTEP_QUERY_PARAMS = { async def search_online( query: str, - conversation_history: dict, + conversation_history: List[ChatMessageModel], location: LocationData, user: KhojUser, send_status_func: Optional[Callable] = None, @@ -361,7 +367,7 @@ async def search_with_serper(query: str, location: LocationData) -> Tuple[str, D async def read_webpages( query: str, - conversation_history: dict, + conversation_history: List[ChatMessageModel], location: LocationData, user: KhojUser, send_status_func: Optional[Callable] = None, diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 535edd0e..5b9c2d04 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -20,7 +20,7 @@ from tenacity import ( ) from khoj.database.adapters import FileObjectAdapters -from khoj.database.models import Agent, FileObject, KhojUser +from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( ChatEvent, @@ -50,7 +50,7 @@ class GeneratedCode(NamedTuple): async def run_code( query: str, - conversation_history: dict, + conversation_history: List[ChatMessageModel], context: str, location_data: LocationData, user: KhojUser, @@ -116,7 +116,7 @@ async def run_code( async def generate_python_code( q: str, - conversation_history: dict, + chat_history: List[ChatMessageModel], context: str, location_data: LocationData, user: KhojUser, @@ -127,7 +127,7 @@ async def generate_python_code( ) -> GeneratedCode: location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d") personality_context = ( @@ -143,7 +143,7 @@ async def generate_python_code( code_generation_prompt = prompts.python_code_generation_prompt.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, context=context, has_network_access=network_access_context, current_date=utc_date, diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index efba17ec..89a95642 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -29,7 +29,13 @@ from khoj.database.adapters import ( get_default_search_model, get_user_photo, ) -from khoj.database.models import Agent, ChatModel, KhojUser, SpeechToTextModelOptions +from khoj.database.models import ( + Agent, + ChatMessageModel, + ChatModel, + KhojUser, + SpeechToTextModelOptions, +) from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.anthropic_chat import ( extract_questions_anthropic, @@ -353,7 +359,7 @@ def set_user_name( async def extract_references_and_questions( user: KhojUser, - meta_log: dict, + chat_history: list[ChatMessageModel], q: str, n: int, d: float, @@ -432,7 +438,7 @@ async def extract_references_and_questions( defiltered_query, model=chat_model, loaded_model=loaded_model, - conversation_log=meta_log, + chat_history=chat_history, should_extract_questions=True, location_data=location_data, user=user, @@ -450,7 +456,7 @@ async def extract_references_and_questions( model=chat_model_name, api_key=api_key, api_base_url=base_url, - conversation_log=meta_log, + chat_history=chat_history, location_data=location_data, user=user, query_images=query_images, @@ -469,7 +475,7 @@ async def extract_references_and_questions( model=chat_model_name, api_key=api_key, api_base_url=api_base_url, - conversation_log=meta_log, + chat_history=chat_history, location_data=location_data, user=user, vision_enabled=vision_enabled, @@ -487,7 +493,7 @@ async def extract_references_and_questions( model=chat_model_name, api_key=api_key, api_base_url=api_base_url, - conversation_log=meta_log, + chat_history=chat_history, location_data=location_data, max_tokens=chat_model.max_prompt_size, user=user, @@ -606,7 +612,7 @@ def post_automation( return Response(content="Invalid crontime", status_code=400) # Infer subject, query to run - _, query_to_run, generated_subject = schedule_query(q, conversation_history={}, user=user) + _, query_to_run, generated_subject = schedule_query(q, chat_history=[], user=user) subject = subject or generated_subject # Normalize query parameters @@ -712,7 +718,7 @@ def edit_job( return Response(content="Invalid automation", status_code=403) # Infer subject, query to run - _, query_to_run, _ = schedule_query(q, conversation_history={}, user=user) + _, query_to_run, _ = schedule_query(q, chat_history=[], user=user) subject = subject # Normalize query parameters diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index cf3d1207..df3245fe 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -752,7 +752,7 @@ async def chat( q, chat_response="", user=user, - meta_log=meta_log, + chat_history=chat_history, compiled_references=compiled_references, online_results=online_results, code_results=code_results, @@ -918,7 +918,7 @@ async def chat( if city or region or country or country_code: location = LocationData(city=city, region=region, country=country, country_code=country_code) user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - meta_log = conversation.conversation_log + chat_history = conversation.messages # If interrupt flag is set, wait for the previous turn to be saved before proceeding if interrupt_flag: @@ -964,14 +964,14 @@ async def chat( operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] # Drop the interrupted message from conversation history - meta_log["chat"].pop() + chat_history.pop() logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.") if conversation_commands == [ConversationCommand.Default]: try: chosen_io = await aget_data_sources_and_output_format( q, - meta_log, + chat_history, is_automated_task, user=user, query_images=uploaded_images, @@ -1011,7 +1011,7 @@ async def chat( user=user, query=defiltered_query, conversation_id=conversation_id, - conversation_history=meta_log, + conversation_history=conversation.messages, previous_iterations=list(research_results), query_images=uploaded_images, agent=agent, @@ -1078,7 +1078,7 @@ async def chat( q=q, user=user, file_filters=file_filters, - meta_log=meta_log, + chat_history=conversation.messages, query_images=uploaded_images, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), @@ -1123,7 +1123,7 @@ async def chat( if ConversationCommand.Automation in conversation_commands: try: automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log, tracer=tracer + q, timezone, user, request.url, chat_history, tracer=tracer ) except Exception as e: logger.error(f"Error scheduling task {q} for {user.email}: {e}") @@ -1139,7 +1139,7 @@ async def chat( q, llm_response, user, - meta_log, + chat_history, user_message_time, intent_type="automation", client_application=request.user.client_app, @@ -1163,7 +1163,7 @@ async def chat( try: async for result in extract_references_and_questions( user, - meta_log, + chat_history, q, (n or 7), d, @@ -1212,7 +1212,7 @@ async def chat( try: async for result in search_online( defiltered_query, - meta_log, + chat_history, location, user, partial(send_event, ChatEvent.STATUS), @@ -1240,7 +1240,7 @@ async def chat( try: async for result in read_webpages( defiltered_query, - meta_log, + chat_history, location, user, partial(send_event, ChatEvent.STATUS), @@ -1281,7 +1281,7 @@ async def chat( context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}" async for result in run_code( defiltered_query, - meta_log, + chat_history, context, location, user, @@ -1306,7 +1306,7 @@ async def chat( async for result in operate_environment( defiltered_query, user, - meta_log, + chat_history, location, list(operator_results)[-1] if operator_results else None, query_images=uploaded_images, @@ -1356,7 +1356,7 @@ async def chat( async for result in text_to_image( defiltered_query, user, - meta_log, + chat_history, location_data=location, references=compiled_references, online_results=online_results, @@ -1400,7 +1400,7 @@ async def chat( async for result in generate_mermaidjs_diagram( q=defiltered_query, - conversation_history=meta_log, + chat_history=chat_history, location_data=location, note_references=compiled_references, online_results=online_results, @@ -1456,7 +1456,7 @@ async def chat( llm_response, chat_metadata = await agenerate_chat_response( defiltered_query, - meta_log, + chat_history, conversation, compiled_references, online_results, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 392f5025..3eaefd8c 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -55,6 +55,7 @@ from khoj.database.adapters import ( ) from khoj.database.models import ( Agent, + ChatMessageModel, ChatModel, ClientApplication, Conversation, @@ -285,7 +286,7 @@ async def acreate_title_from_history( """ Create a title from the given conversation history """ - chat_history = construct_chat_history(conversation.conversation_log) + chat_history = construct_chat_history(conversation.messages) title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history) @@ -345,7 +346,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: async def aget_data_sources_and_output_format( query: str, - conversation_history: dict, + chat_history: list[ChatMessageModel], is_task: bool, user: KhojUser, query_images: List[str] = None, @@ -386,7 +387,7 @@ async def aget_data_sources_and_output_format( if len(agent_outputs) == 0 or output.value in agent_outputs: output_options_str += f'- "{output.value}": "{description}"\n' - chat_history = construct_chat_history(conversation_history, n=6) + chat_history_str = construct_chat_history(chat_history, n=6) if query_images: query = f"[placeholder for {len(query_images)} user attached images]\n{query}" @@ -399,7 +400,7 @@ async def aget_data_sources_and_output_format( query=query, sources=source_options_str, outputs=output_options_str, - chat_history=chat_history, + chat_history=chat_history_str, personality_context=personality_context, ) @@ -462,7 +463,7 @@ async def aget_data_sources_and_output_format( async def infer_webpage_urls( q: str, max_webpages: int, - conversation_history: dict, + chat_history: List[ChatMessageModel], location_data: LocationData, user: KhojUser, query_images: List[str] = None, @@ -475,7 +476,7 @@ async def infer_webpage_urls( """ location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d") personality_context = ( @@ -485,7 +486,7 @@ async def infer_webpage_urls( online_queries_prompt = prompts.infer_webpages_to_read.format( query=q, max_webpages=max_webpages, - chat_history=chat_history, + chat_history=chat_history_str, current_date=utc_date, location=location, username=username, @@ -526,7 +527,7 @@ async def infer_webpage_urls( async def generate_online_subqueries( q: str, - conversation_history: dict, + chat_history: List[ChatMessageModel], location_data: LocationData, user: KhojUser, query_images: List[str] = None, @@ -540,7 +541,7 @@ async def generate_online_subqueries( """ location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d") personality_context = ( @@ -549,7 +550,7 @@ async def generate_online_subqueries( online_queries_prompt = prompts.online_search_conversation_subqueries.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, max_queries=max_queries, current_date=utc_date, location=location, @@ -591,16 +592,16 @@ async def generate_online_subqueries( def schedule_query( - q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {} + q: str, chat_history: List[ChatMessageModel], user: KhojUser, query_images: List[str] = None, tracer: dict = {} ) -> Tuple[str, str, str]: """ Schedule the date, time to run the query. Assume the server timezone is UTC. """ - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) crontime_prompt = prompts.crontime_prompt.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, ) raw_response = send_message_to_model_wrapper_sync( @@ -619,16 +620,16 @@ def schedule_query( async def aschedule_query( - q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {} + q: str, chat_history: List[ChatMessageModel], user: KhojUser, query_images: List[str] = None, tracer: dict = {} ) -> Tuple[str, str, str]: """ Schedule the date, time to run the query. Assume the server timezone is UTC. """ - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) crontime_prompt = prompts.crontime_prompt.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, ) raw_response = await send_message_to_model_wrapper( @@ -681,7 +682,7 @@ async def extract_relevant_info( async def extract_relevant_summary( q: str, corpus: str, - conversation_history: dict, + chat_history: List[ChatMessageModel] = [], query_images: List[str] = None, user: KhojUser = None, agent: Agent = None, @@ -698,11 +699,11 @@ async def extract_relevant_summary( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" ) - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) extract_relevant_information = prompts.extract_relevant_summary.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, corpus=corpus.strip(), personality_context=personality_context, ) @@ -725,7 +726,7 @@ async def generate_summary_from_files( q: str, user: KhojUser, file_filters: List[str], - meta_log: dict, + chat_history: List[ChatMessageModel] = [], query_images: List[str] = None, agent: Agent = None, send_status_func: Optional[Callable] = None, @@ -766,7 +767,7 @@ async def generate_summary_from_files( response = await extract_relevant_summary( q, contextual_data, - conversation_history=meta_log, + chat_history=chat_history, query_images=query_images, user=user, agent=agent, @@ -782,7 +783,7 @@ async def generate_summary_from_files( async def generate_excalidraw_diagram( q: str, - conversation_history: Dict[str, Any], + chat_history: List[ChatMessageModel], location_data: LocationData, note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, @@ -799,7 +800,7 @@ async def generate_excalidraw_diagram( better_diagram_description_prompt = await generate_better_diagram_description( q=q, - conversation_history=conversation_history, + chat_history=chat_history, location_data=location_data, note_references=note_references, online_results=online_results, @@ -834,7 +835,7 @@ async def generate_excalidraw_diagram( async def generate_better_diagram_description( q: str, - conversation_history: Dict[str, Any], + chat_history: List[ChatMessageModel], location_data: LocationData, note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, @@ -857,7 +858,7 @@ async def generate_better_diagram_description( user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references]) - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) simplified_online_results = {} @@ -870,7 +871,7 @@ async def generate_better_diagram_description( improve_diagram_description_prompt = prompts.improve_excalidraw_diagram_description_prompt.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, location=location, current_date=today_date, references=user_references, @@ -939,7 +940,7 @@ async def generate_excalidraw_diagram_from_description( async def generate_mermaidjs_diagram( q: str, - conversation_history: Dict[str, Any], + chat_history: List[ChatMessageModel], location_data: LocationData, note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, @@ -956,7 +957,7 @@ async def generate_mermaidjs_diagram( better_diagram_description_prompt = await generate_better_mermaidjs_diagram_description( q=q, - conversation_history=conversation_history, + chat_history=chat_history, location_data=location_data, note_references=note_references, online_results=online_results, @@ -985,7 +986,7 @@ async def generate_mermaidjs_diagram( async def generate_better_mermaidjs_diagram_description( q: str, - conversation_history: Dict[str, Any], + chat_history: List[ChatMessageModel], location_data: LocationData, note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, @@ -1008,7 +1009,7 @@ async def generate_better_mermaidjs_diagram_description( user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references]) - chat_history = construct_chat_history(conversation_history) + chat_history_str = construct_chat_history(chat_history) simplified_online_results = {} @@ -1021,7 +1022,7 @@ async def generate_better_mermaidjs_diagram_description( improve_diagram_description_prompt = prompts.improve_mermaid_js_diagram_description_prompt.format( query=q, - chat_history=chat_history, + chat_history=chat_history_str, location=location, current_date=today_date, references=user_references, @@ -1160,7 +1161,7 @@ async def send_message_to_model_wrapper( query_images: List[str] = None, context: str = "", query_files: str = None, - conversation_log: dict = {}, + chat_history: list[ChatMessageModel] = [], agent_chat_model: ChatModel = None, tracer: dict = {}, ): @@ -1193,7 +1194,7 @@ async def send_message_to_model_wrapper( user_message=query, context_message=context, system_message=system_message, - conversation_log=conversation_log, + chat_history=chat_history, model_name=chat_model_name, loaded_model=loaded_model, tokenizer_name=tokenizer, @@ -1260,7 +1261,7 @@ def send_message_to_model_wrapper_sync( user: KhojUser = None, query_images: List[str] = None, query_files: str = "", - conversation_log: dict = {}, + chat_history: List[ChatMessageModel] = [], tracer: dict = {}, ): chat_model: ChatModel = ConversationAdapters.get_default_chat_model(user) @@ -1284,7 +1285,7 @@ def send_message_to_model_wrapper_sync( truncated_messages = generate_chatml_messages_with_context( user_message=message, system_message=system_message, - conversation_log=conversation_log, + chat_history=chat_history, model_name=chat_model_name, loaded_model=loaded_model, max_prompt_size=max_tokens, @@ -1342,7 +1343,7 @@ def send_message_to_model_wrapper_sync( async def agenerate_chat_response( q: str, - meta_log: dict, + chat_history: List[ChatMessageModel], conversation: Conversation, compiled_references: List[Dict] = [], online_results: Dict[str, Dict] = {}, @@ -1379,7 +1380,7 @@ async def agenerate_chat_response( save_to_conversation_log, q, user=user, - meta_log=meta_log, + chat_history=chat_history, compiled_references=compiled_references, online_results=online_results, code_results=code_results, @@ -1424,7 +1425,7 @@ async def agenerate_chat_response( references=compiled_references, online_results=online_results, loaded_model=loaded_model, - conversation_log=meta_log, + chat_history=chat_history, completion_func=partial_completion, conversation_commands=conversation_commands, model_name=chat_model.name, @@ -1450,7 +1451,7 @@ async def agenerate_chat_response( online_results=online_results, code_results=code_results, operator_results=operator_results, - conversation_log=meta_log, + chat_history=chat_history, model=chat_model_name, api_key=api_key, api_base_url=openai_chat_config.api_base_url, @@ -1480,7 +1481,7 @@ async def agenerate_chat_response( online_results=online_results, code_results=code_results, operator_results=operator_results, - conversation_log=meta_log, + chat_history=chat_history, model=chat_model.name, api_key=api_key, api_base_url=api_base_url, @@ -1508,7 +1509,7 @@ async def agenerate_chat_response( online_results=online_results, code_results=code_results, operator_results=operator_results, - conversation_log=meta_log, + chat_history=chat_history, model=chat_model.name, api_key=api_key, api_base_url=api_base_url, @@ -2005,11 +2006,11 @@ async def create_automation( timezone: str, user: KhojUser, calling_url: URL, - meta_log: dict = {}, + chat_history: List[ChatMessageModel] = [], conversation_id: str = None, tracer: dict = {}, ): - crontime, query_to_run, subject = await aschedule_query(q, meta_log, user, tracer=tracer) + crontime, query_to_run, subject = await aschedule_query(q, chat_history, user, tracer=tracer) job = await aschedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id) return job, crontime, query_to_run, subject diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index e72b25a5..2dc63efb 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -10,7 +10,7 @@ import yaml from pydantic import BaseModel, Field from khoj.database.adapters import AgentAdapters, EntryAdapters -from khoj.database.models import Agent, KhojUser +from khoj.database.models import Agent, ChatMessageModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( OperatorRun, @@ -84,7 +84,7 @@ class PlanningResponse(BaseModel): async def apick_next_tool( query: str, - conversation_history: dict, + conversation_history: List[ChatMessageModel], user: KhojUser = None, location: LocationData = None, user_name: str = None, @@ -166,18 +166,18 @@ async def apick_next_tool( query = f"[placeholder for user attached images]\n{query}" # Construct chat history with user and iteration history with researcher agent for context - previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query) - iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history} + iteration_chat_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query) + chat_and_research_history = conversation_history + iteration_chat_history # Plan function execution for the next tool - query = prompts.plan_function_execution_next_tool.format(query=query) if previous_iterations_history else query + query = prompts.plan_function_execution_next_tool.format(query=query) if iteration_chat_history else query try: with timer("Chat actor: Infer information sources to refer", logger): response = await send_message_to_model_wrapper( query=query, system_message=function_planning_prompt, - conversation_log=iteration_chat_log, + chat_history=chat_and_research_history, response_type="json_object", response_schema=planning_response_model, deepthought=True, @@ -238,7 +238,7 @@ async def research( user: KhojUser, query: str, conversation_id: str, - conversation_history: dict, + conversation_history: List[ChatMessageModel], previous_iterations: List[ResearchIteration], query_images: List[str], agent: Agent = None, @@ -261,9 +261,7 @@ async def research( if current_iteration := len(previous_iterations) > 0: logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.") previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) - research_conversation_history["chat"] = ( - research_conversation_history.get("chat", []) + previous_iterations_history - ) + research_conversation_history += previous_iterations_history while current_iteration < MAX_ITERATIONS: # Check for cancellation at the start of each iteration diff --git a/tests/helpers.py b/tests/helpers.py index b2c6a3b1..fb738015 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,6 +6,7 @@ from django.utils.timezone import make_aware from khoj.database.models import ( AiModelApi, + ChatMessageModel, ChatModel, Conversation, KhojApiUser, @@ -46,15 +47,15 @@ def get_chat_api_key(provider: ChatModel.ModelType = None): def generate_chat_history(message_list): # Generate conversation logs - conversation_log = {"chat": []} + chat_history: list[ChatMessageModel] = [] for user_message, chat_response, context in message_list: message_to_log( user_message, chat_response, {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, - conversation_log=conversation_log.get("chat", []), + chat_history=chat_history, ) - return conversation_log + return chat_history class UserFactory(factory.django.DjangoModelFactory): diff --git a/tests/test_offline_chat_actors.py b/tests/test_offline_chat_actors.py index 6f18f658..7c2571dc 100644 --- a/tests/test_offline_chat_actors.py +++ b/tests/test_offline_chat_actors.py @@ -135,7 +135,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model): # Act response = extract_questions_offline( query, - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, use_history=True, ) @@ -181,7 +181,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model): # Act response = extract_questions_offline( "Is she a Doctor?", - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, use_history=True, ) @@ -210,7 +210,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo # Act response = extract_questions_offline( "What was the Pizza place we ate at over there?", - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, ) @@ -336,7 +336,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model) response_gen = converse_offline( references=[], # Assume no context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, ) response = "".join([response_chunk for response_chunk in response_gen]) @@ -363,7 +363,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model): {"compiled": "Testatron was born on 1st April 1984 in Testville."} ], # Assume context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, ) response = "".join([response_chunk for response_chunk in response_gen]) @@ -388,7 +388,7 @@ def test_refuse_answering_unanswerable_question(loaded_model): response_gen = converse_offline( references=[], # Assume no context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, ) response = "".join([response_chunk for response_chunk in response_gen]) @@ -501,7 +501,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded response_gen = converse_offline( references=[], # Assume no context retrieved from notes for the user_query user_query="Write a haiku about unit testing in 3 lines", - conversation_log=generate_chat_history(message_list), + chat_history=generate_chat_history(message_list), loaded_model=loaded_model, ) response = "".join([response_chunk for response_chunk in response_gen]) diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py index b681d38c..0eb4a0dc 100644 --- a/tests/test_offline_chat_director.py +++ b/tests/test_offline_chat_director.py @@ -28,7 +28,7 @@ def generate_history(message_list): user_message, gpt_message, {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, - conversation_log=conversation_log.get("chat", []), + chat_history=conversation_log.get("chat", []), ) return conversation_log diff --git a/tests/test_online_chat_actors.py b/tests/test_online_chat_actors.py index 2b08fc74..fc69b149 100644 --- a/tests/test_online_chat_actors.py +++ b/tests/test_online_chat_actors.py @@ -708,6 +708,6 @@ def populate_chat_history(message_list): "context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}, }, - conversation_log=[], + chat_history=[], ) return conversation_log