diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 2dded5ad..8a5beaf0 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -90,7 +90,7 @@ class OnlineContext(PydanticBaseModel): class Intent(PydanticBaseModel): type: str - query: str + query: Optional[str] = None memory_type: Optional[str] = Field(alias="memory-type", default=None) inferred_queries: Optional[List[str]] = Field(default=None, alias="inferred-queries") diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 90b85e47..a9b9ec02 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -186,7 +186,7 @@ def construct_iteration_history( iteration_history.append( ChatMessageModel( by="khoj", - intent={"type": "remember", "query": query}, + intent=Intent(type="remember", query=query), message=previous_iteration_messages, ) ) @@ -196,16 +196,16 @@ def construct_iteration_history( def construct_chat_history(chat_history: list[ChatMessageModel], n: int = 4, agent_name="AI") -> str: chat_history_str = "" for chat in chat_history[-n:]: - if chat.by == "khoj" and chat.intent.type in ["remember", "reminder", "summarize"]: - if chat.intent.inferred_queries: - chat_history_str += f'{agent_name}: {{"queries": {chat.intent.inferred_queries}}}\n' + intent_type = chat.intent.type if chat.intent and chat.intent.type else "" + inferred_queries = chat.intent.inferred_queries if chat.intent else None + if chat.by == "khoj" and intent_type in ["remember", "reminder", "summarize"]: + if inferred_queries: + chat_history_str += f'{agent_name}: {{"queries": {inferred_queries}}}\n' chat_history_str += f"{agent_name}: {chat.message}\n\n" elif chat.by == "khoj" and chat.images: - chat_history_str += f"User: {chat.intent.query}\n" chat_history_str += f"{agent_name}: [generated image redacted for space]\n" - elif chat.by == "khoj" and ("excalidraw" in chat.intent.type): - chat_history_str += f"User: {chat.intent.query}\n" - chat_history_str += f"{agent_name}: {chat.intent.inferred_queries[0]}\n" + elif chat.by == "khoj" and ("excalidraw" in intent_type): + chat_history_str += f"{agent_name}: {inferred_queries[0]}\n" elif chat.by == "you": chat_history_str += f"User: {chat.message}\n" raw_query_files = chat.queryFiles diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 46b989d6..86f1cde1 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -53,11 +53,11 @@ async def text_to_image( text2image_model = text_to_image_config.model_name chat_history_str = "" for chat in chat_history[-4:]: - if chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]: - chat_history_str += f"Q: {chat.intent.query or ''}\n" + if chat.by == "you": + chat_history_str += f"Q: {chat.message}\n" + elif chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]: chat_history_str += f"A: {chat.message}\n" elif chat.by == "khoj" and chat.images: - chat_history_str += f"Q: {chat.intent.query}\n" chat_history_str += f"A: Improved Prompt: {chat.intent.inferred_queries[0]}\n" if send_status_func: diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 3564474e..d971bf4b 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -960,7 +960,11 @@ async def chat( online_results = {key: val.model_dump() for key, val in last_message.onlineContext.items() or []} code_results = {key: val.model_dump() for key, val in last_message.codeContext.items() or []} compiled_references = [ref.model_dump() for ref in last_message.context or []] - research_results = [ResearchIteration(**iter_dict) for iter_dict in last_message.researchContext or []] + research_results = [ + ResearchIteration(**iter_dict) + for iter_dict in last_message.researchContext or [] + if iter_dict.get("summarizedResult") + ] operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []] train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []] # Drop the interrupted message from conversation history @@ -1011,7 +1015,7 @@ async def chat( user=user, query=defiltered_query, conversation_id=conversation_id, - conversation_history=conversation.messages, + conversation_history=chat_history, previous_iterations=list(research_results), query_images=uploaded_images, agent=agent,