Pass deep typed chat history for more ergonomic, readable, safe code

The chat dictionary is an artifact from earlier non-db chat history
storage. We've been ensuring new chat messages have valid type before
being written to DB for more than 6 months now.

Move to using the deeply typed chat history helps avoids null refs,
makes code more readable and easier to reason about.

Next Steps:
The current update entangles chat_history written to DB
with any virtual chat history message generated for intermediate
steps. The chat message type written to DB should be decoupled from
type that can be passed to AI model APIs (maybe?).

For now we've made the ChatMessage.message type looser to allow
for list[dict] type (apart from string). But later maybe a good idea
to decouple the chat_history recieved by send_message_to_model from
the chat_history saved to DB (which can then have its stricter type check)
This commit is contained in:
Debanjum
2025-06-03 15:28:06 -07:00
parent 430459a338
commit 05d4e19cb8
20 changed files with 271 additions and 248 deletions

View File

@@ -37,6 +37,7 @@ from torch import Tensor
from khoj.database.models import (
Agent,
AiModelApi,
ChatMessageModel,
ChatModel,
ClientApplication,
Conversation,
@@ -1419,7 +1420,7 @@ class ConversationAdapters:
@require_valid_user
async def save_conversation(
user: KhojUser,
conversation_log: dict,
chat_history: List[ChatMessageModel],
client_application: ClientApplication = None,
conversation_id: str = None,
user_message: str = None,
@@ -1434,6 +1435,7 @@ class ConversationAdapters:
await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
)
conversation_log = {"chat": [msg.model_dump() for msg in chat_history]}
cleaned_conversation_log = clean_object_for_db(conversation_log)
if conversation:
conversation.conversation_log = cleaned_conversation_log

View File

@@ -91,7 +91,7 @@ class OnlineContext(PydanticBaseModel):
class Intent(PydanticBaseModel):
type: str
query: str
memory_type: str = Field(alias="memory-type")
memory_type: Optional[str] = Field(alias="memory-type", default=None)
inferred_queries: Optional[List[str]] = Field(default=None, alias="inferred-queries")
@@ -100,20 +100,20 @@ class TrainOfThought(PydanticBaseModel):
data: str
class ChatMessage(PydanticBaseModel):
message: str
class ChatMessageModel(PydanticBaseModel):
by: str
message: str | list[dict]
trainOfThought: List[TrainOfThought] = []
context: List[Context] = []
onlineContext: Dict[str, OnlineContext] = {}
codeContext: Dict[str, CodeContextData] = {}
researchContext: Optional[List] = None
operatorContext: Optional[List] = None
created: str
created: Optional[str] = None
images: Optional[List[str]] = None
queryFiles: Optional[List[Dict]] = None
excalidrawDiagram: Optional[List[Dict]] = None
mermaidjsDiagram: str = None
by: str
mermaidjsDiagram: Optional[str] = None
turnId: Optional[str] = None
intent: Optional[Intent] = None
automationId: Optional[str] = None
@@ -634,7 +634,7 @@ class Conversation(DbBaseModel):
try:
messages = self.conversation_log.get("chat", [])
for msg in messages:
ChatMessage.model_validate(msg)
ChatMessageModel.model_validate(msg)
except Exception as e:
raise ValidationError(f"Invalid conversation_log format: {str(e)}")
@@ -643,7 +643,7 @@ class Conversation(DbBaseModel):
super().save(*args, **kwargs)
@property
def messages(self) -> List[ChatMessage]:
def messages(self) -> List[ChatMessageModel]:
"""Type-hinted accessor for conversation messages"""
validated_messages = []
for msg in self.conversation_log.get("chat", []):
@@ -654,7 +654,7 @@ class Conversation(DbBaseModel):
q for q in msg["intent"]["inferred-queries"] if q is not None and isinstance(q, str)
]
msg["message"] = str(msg.get("message", ""))
validated_messages.append(ChatMessage.model_validate(msg))
validated_messages.append(ChatMessageModel.model_validate(msg))
except ValidationError as e:
logger.warning(f"Skipping invalid message in conversation: {e}")
continue

View File

@@ -6,7 +6,7 @@ from typing import AsyncGenerator, Dict, List, Optional
import pyjson5
from langchain_core.messages.chat import ChatMessage
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
def extract_questions_anthropic(
text,
model: Optional[str] = "claude-3-7-sonnet-latest",
conversation_log={},
chat_history: List[ChatMessageModel] = [],
api_key=None,
api_base_url=None,
location_data: LocationData = None,
@@ -54,8 +54,8 @@ def extract_questions_anthropic(
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = construct_question_history(conversation_log, query_prefix="User", agent_name="Assistant")
# Extract Past User Message and Inferred Questions from Chat History
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
# Get dates relative to today for prompt creation
today = datetime.today()
@@ -76,7 +76,7 @@ def extract_questions_anthropic(
)
prompt = prompts.extract_questions_anthropic_user_message.format(
chat_history=chat_history,
chat_history=chat_history_str,
text=text,
)
@@ -142,7 +142,7 @@ async def converse_anthropic(
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[List[OperatorRun]] = None,
conversation_log={},
chat_history: List[ChatMessageModel] = [],
model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
@@ -225,7 +225,7 @@ async def converse_anthropic(
messages = generate_chatml_messages_with_context(
user_query,
context_message=context_message,
conversation_log=conversation_log,
chat_history=chat_history,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,

View File

@@ -7,7 +7,7 @@ import pyjson5
from langchain_core.messages.chat import ChatMessage
from pydantic import BaseModel, Field
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
gemini_chat_completion_with_backoff,
@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
def extract_questions_gemini(
text,
model: Optional[str] = "gemini-2.0-flash",
conversation_log={},
chat_history: List[ChatMessageModel] = [],
api_key=None,
api_base_url=None,
max_tokens=None,
@@ -54,8 +54,8 @@ def extract_questions_gemini(
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = construct_question_history(conversation_log, query_prefix="User", agent_name="Assistant")
# Extract Past User Message and Inferred Questions from Chat History
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
# Get dates relative to today for prompt creation
today = datetime.today()
@@ -76,7 +76,7 @@ def extract_questions_gemini(
)
prompt = prompts.extract_questions_anthropic_user_message.format(
chat_history=chat_history,
chat_history=chat_history_str,
text=text,
)
@@ -163,7 +163,7 @@ async def converse_gemini(
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[List[OperatorRun]] = None,
conversation_log={},
chat_history: List[ChatMessageModel] = [],
model: Optional[str] = "gemini-2.0-flash",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
@@ -248,7 +248,7 @@ async def converse_gemini(
messages = generate_chatml_messages_with_context(
user_query,
context_message=context_message,
conversation_log=conversation_log,
chat_history=chat_history,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,

View File

@@ -10,7 +10,7 @@ import pyjson5
from langchain_core.messages.chat import ChatMessage
from llama_cpp import Llama
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
@@ -38,7 +38,7 @@ def extract_questions_offline(
text: str,
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
loaded_model: Union[Any, None] = None,
conversation_log={},
chat_history: List[ChatMessageModel] = [],
use_history: bool = True,
should_extract_questions: bool = True,
location_data: LocationData = None,
@@ -65,7 +65,7 @@ def extract_questions_offline(
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = construct_question_history(conversation_log, include_query=False) if use_history else ""
chat_history_str = construct_question_history(chat_history, include_query=False) if use_history else ""
# Get dates relative to today for prompt creation
today = datetime.today()
@@ -73,7 +73,7 @@ def extract_questions_offline(
last_year = today.year - 1
example_questions = prompts.extract_questions_offline.format(
query=text,
chat_history=chat_history,
chat_history=chat_history_str,
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
@@ -147,7 +147,7 @@ async def converse_offline(
references: list[dict] = [],
online_results={},
code_results={},
conversation_log={},
chat_history: list[ChatMessageModel] = [],
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
loaded_model: Union[Any, None] = None,
completion_func=None,
@@ -227,7 +227,7 @@ async def converse_offline(
messages = generate_chatml_messages_with_context(
user_query,
system_prompt,
conversation_log,
chat_history,
context_message=context_message,
model_name=model_name,
loaded_model=offline_chat_model,

View File

@@ -8,7 +8,7 @@ from langchain_core.messages.chat import ChatMessage
from openai.lib._pydantic import _ensure_strict_json_schema
from pydantic import BaseModel
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
@@ -39,7 +39,7 @@ logger = logging.getLogger(__name__)
def extract_questions(
text,
model: Optional[str] = "gpt-4o-mini",
conversation_log={},
chat_history: list[ChatMessageModel] = [],
api_key=None,
api_base_url=None,
location_data: LocationData = None,
@@ -56,8 +56,8 @@ def extract_questions(
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = construct_question_history(conversation_log)
# Extract Past User Message and Inferred Questions from Chat History
chat_history_str = construct_question_history(chat_history)
# Get dates relative to today for prompt creation
today = datetime.today()
@@ -73,7 +73,7 @@ def extract_questions(
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
bob_tom_age_difference={current_new_year.year - 1984 - 30},
bob_age={current_new_year.year - 1984},
chat_history=chat_history,
chat_history=chat_history_str,
text=text,
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
@@ -166,7 +166,7 @@ async def converse_openai(
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[List[OperatorRun]] = None,
conversation_log={},
chat_history: list[ChatMessageModel] = [],
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
@@ -251,7 +251,7 @@ async def converse_openai(
messages = generate_chatml_messages_with_context(
user_query,
system_prompt,
conversation_log,
chat_history,
context_message=context_message,
model_name=model,
max_prompt_size=max_prompt_size,

View File

@@ -24,7 +24,13 @@ from pydantic import BaseModel
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ChatModel, ClientApplication, KhojUser
from khoj.database.models import (
ChatMessageModel,
ChatModel,
ClientApplication,
Intent,
KhojUser,
)
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.search_filter.base_filter import BaseFilter
@@ -161,8 +167,8 @@ def construct_iteration_history(
previous_iterations: List[ResearchIteration],
previous_iteration_prompt: str,
query: str = None,
) -> list[dict]:
iteration_history: list[dict] = []
) -> list[ChatMessageModel]:
iteration_history: list[ChatMessageModel] = []
previous_iteration_messages: list[dict] = []
for idx, iteration in enumerate(previous_iterations):
iteration_data = previous_iteration_prompt.format(
@@ -176,46 +182,46 @@ def construct_iteration_history(
if previous_iteration_messages:
if query:
iteration_history.append({"by": "you", "message": query})
iteration_history.append(ChatMessageModel(by="you", message=query))
iteration_history.append(
{
"by": "khoj",
"intent": {"type": "remember", "query": query},
"message": previous_iteration_messages,
}
ChatMessageModel(
by="khoj",
intent={"type": "remember", "query": query},
message=previous_iteration_messages,
)
)
return iteration_history
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
if chat["intent"].get("inferred-queries"):
chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
chat_history += f"{agent_name}: {chat['message']}\n\n"
elif chat["by"] == "khoj" and chat.get("images"):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: [generated image redacted for space]\n"
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
elif chat["by"] == "you":
chat_history += f"User: {chat['message']}\n"
raw_query_files = chat.get("queryFiles")
def construct_chat_history(chat_history: list[ChatMessageModel], n: int = 4, agent_name="AI") -> str:
chat_history_str = ""
for chat in chat_history[-n:]:
if chat.by == "khoj" and chat.intent.type in ["remember", "reminder", "summarize"]:
if chat.intent.inferred_queries:
chat_history_str += f'{agent_name}: {{"queries": {chat.intent.inferred_queries}}}\n'
chat_history_str += f"{agent_name}: {chat.message}\n\n"
elif chat.by == "khoj" and chat.images:
chat_history_str += f"User: {chat.intent.query}\n"
chat_history_str += f"{agent_name}: [generated image redacted for space]\n"
elif chat.by == "khoj" and ("excalidraw" in chat.intent.type):
chat_history_str += f"User: {chat.intent.query}\n"
chat_history_str += f"{agent_name}: {chat.intent.inferred_queries[0]}\n"
elif chat.by == "you":
chat_history_str += f"User: {chat.message}\n"
raw_query_files = chat.queryFiles
if raw_query_files:
query_files: Dict[str, str] = {}
for file in raw_query_files:
query_files[file["name"]] = file["content"]
query_file_context = gather_raw_query_files(query_files)
chat_history += f"User: {query_file_context}\n"
chat_history_str += f"User: {query_file_context}\n"
return chat_history
return chat_history_str
def construct_question_history(
conversation_log: dict,
conversation_log: list[ChatMessageModel],
include_query: bool = True,
lookback: int = 6,
query_prefix: str = "Q",
@@ -226,16 +232,16 @@ def construct_question_history(
"""
history_parts = ""
original_query = None
for chat in conversation_log.get("chat", [])[-lookback:]:
if chat["by"] == "you":
original_query = chat.get("message")
for chat in conversation_log[-lookback:]:
if chat.by == "you":
original_query = json.dumps(chat.message)
history_parts += f"{query_prefix}: {original_query}\n"
if chat["by"] == "khoj":
if chat.by == "khoj":
if original_query is None:
continue
message = chat.get("message", "")
inferred_queries_list = chat.get("intent", {}).get("inferred-queries")
message = chat.message
inferred_queries_list = chat.intent.inferred_queries or []
# Ensure inferred_queries_list is a list, defaulting to the original query in a list
if not inferred_queries_list:
@@ -246,7 +252,7 @@ def construct_question_history(
if include_query:
# Ensure 'type' exists and is a string before checking 'to-image'
intent_type = chat.get("intent", {}).get("type", "")
intent_type = chat.intent.type if chat.intent and chat.intent.type else ""
if "to-image" not in intent_type:
history_parts += f'{agent_name}: {{"queries": {inferred_queries_list}}}\n'
history_parts += f"A: {message}\n\n"
@@ -259,7 +265,7 @@ def construct_question_history(
return history_parts
def construct_chat_history_for_operator(conversation_history: dict, n: int = 6) -> list[AgentMessage]:
def construct_chat_history_for_operator(conversation_history: List[ChatMessageModel], n: int = 6) -> list[AgentMessage]:
"""
Construct chat history for operator agent in conversation log.
Only include last n completed turns (i.e with user and khoj message).
@@ -267,22 +273,22 @@ def construct_chat_history_for_operator(conversation_history: dict, n: int = 6)
chat_history: list[AgentMessage] = []
user_message: Optional[AgentMessage] = None
for chat in conversation_history.get("chat", []):
for chat in conversation_history:
if len(chat_history) >= n:
break
if chat["by"] == "you" and chat.get("message"):
content = [{"type": "text", "text": chat["message"]}]
for file in chat.get("queryFiles", []):
if chat.by == "you" and chat.message:
content = [{"type": "text", "text": chat.message}]
for file in chat.queryFiles or []:
content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}]
user_message = AgentMessage(role="user", content=content)
elif chat["by"] == "khoj" and chat.get("message"):
chat_history += [user_message, AgentMessage(role="assistant", content=chat["message"])]
elif chat.by == "khoj" and chat.message:
chat_history += [user_message, AgentMessage(role="assistant", content=chat.message)]
return chat_history
def construct_tool_chat_history(
previous_iterations: List[ResearchIteration], tool: ConversationCommand = None
) -> Dict[str, list]:
) -> List[ChatMessageModel]:
"""
Construct chat history from previous iterations for a specific tool
@@ -313,22 +319,23 @@ def construct_tool_chat_history(
tool or ConversationCommand(iteration.tool), base_extractor
)
chat_history += [
{
"by": "you",
"message": iteration.query,
},
{
"by": "khoj",
"intent": {
"type": "remember",
"inferred-queries": inferred_query_extractor(iteration),
"query": iteration.query,
},
"message": iteration.summarizedResult,
},
ChatMessageModel(
by="you",
message=iteration.query,
),
ChatMessageModel(
by="khoj",
intent=Intent(
type="remember",
query=iteration.query,
inferred_queries=inferred_query_extractor(iteration),
memory_type="notes",
),
message=iteration.summarizedResult,
),
]
return {"chat": chat_history}
return chat_history
class ChatEvent(Enum):
@@ -349,8 +356,8 @@ def message_to_log(
chat_response,
user_message_metadata={},
khoj_message_metadata={},
conversation_log=[],
):
chat_history: List[ChatMessageModel] = [],
) -> List[ChatMessageModel]:
"""Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = {
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
@@ -369,15 +376,17 @@ def message_to_log(
khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
khoj_log = merge_dicts({"message": chat_response, "by": "khoj", "created": khoj_response_time}, khoj_log)
conversation_log.extend([human_log, khoj_log])
return conversation_log
human_message = ChatMessageModel(**human_log)
khoj_message = ChatMessageModel(**khoj_log)
chat_history.extend([human_message, khoj_message])
return chat_history
async def save_to_conversation_log(
q: str,
chat_response: str,
user: KhojUser,
meta_log: Dict,
chat_history: List[ChatMessageModel],
user_message_time: str = None,
compiled_references: List[Dict[str, Any]] = [],
online_results: Dict[str, Any] = {},
@@ -427,11 +436,11 @@ async def save_to_conversation_log(
chat_response=chat_response,
user_message_metadata=user_message_metadata,
khoj_message_metadata=khoj_message_metadata,
conversation_log=meta_log.get("chat", []),
chat_history=chat_history,
)
await ConversationAdapters.save_conversation(
user,
{"chat": updated_conversation},
updated_conversation,
client_application=client_application,
conversation_id=conversation_id,
user_message=q,
@@ -502,7 +511,7 @@ def gather_raw_query_files(
def generate_chatml_messages_with_context(
user_message: str,
system_message: str = None,
conversation_log={},
chat_history: list[ChatMessageModel] = [],
model_name="gpt-4o-mini",
loaded_model: Optional[Llama] = None,
max_prompt_size=None,
@@ -529,21 +538,21 @@ def generate_chatml_messages_with_context(
# Extract Chat History for Context
chatml_messages: List[ChatMessage] = []
for chat in conversation_log.get("chat", []):
for chat in chat_history:
message_context = []
message_attached_files = ""
generated_assets = {}
chat_message = chat.get("message")
role = "user" if chat["by"] == "you" else "assistant"
chat_message = chat.message
role = "user" if chat.by == "you" else "assistant"
# Legacy code to handle excalidraw diagrams prior to Dec 2024
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
chat_message = chat["intent"].get("inferred-queries")[0]
if chat.by == "khoj" and "excalidraw" in chat.intent.type or "":
chat_message = (chat.intent.inferred_queries or [])[0]
if chat.get("queryFiles"):
raw_query_files = chat.get("queryFiles")
if chat.queryFiles:
raw_query_files = chat.queryFiles
query_files_dict = dict()
for file in raw_query_files:
query_files_dict[file["name"]] = file["content"]
@@ -551,24 +560,24 @@ def generate_chatml_messages_with_context(
message_attached_files = gather_raw_query_files(query_files_dict)
chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
if not is_none_or_empty(chat.get("onlineContext")):
if not is_none_or_empty(chat.onlineContext):
message_context += [
{
"type": "text",
"text": f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}",
"text": f"{prompts.online_search_conversation.format(online_results=chat.onlineContext)}",
}
]
if not is_none_or_empty(chat.get("codeContext")):
if not is_none_or_empty(chat.codeContext):
message_context += [
{
"type": "text",
"text": f"{prompts.code_executed_context.format(code_results=chat.get('codeContext'))}",
"text": f"{prompts.code_executed_context.format(code_results=chat.codeContext)}",
}
]
if not is_none_or_empty(chat.get("operatorContext")):
operator_context = chat.get("operatorContext")
if not is_none_or_empty(chat.operatorContext):
operator_context = chat.operatorContext
operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context])
message_context += [
{
@@ -577,13 +586,9 @@ def generate_chatml_messages_with_context(
}
]
if not is_none_or_empty(chat.get("context")):
if not is_none_or_empty(chat.context):
references = "\n\n".join(
{
f"# File: {item['file']}\n## {item['compiled']}\n"
for item in chat.get("context") or []
if isinstance(item, dict)
}
{f"# File: {item.file}\n## {item.compiled}\n" for item in chat.context or [] if isinstance(item, dict)}
)
message_context += [{"type": "text", "text": f"{prompts.notes_conversation.format(references=references)}"}]
@@ -591,14 +596,14 @@ def generate_chatml_messages_with_context(
reconstructed_context_message = ChatMessage(content=message_context, role="user")
chatml_messages.insert(0, reconstructed_context_message)
if not is_none_or_empty(chat.get("images")) and role == "assistant":
if not is_none_or_empty(chat.images) and role == "assistant":
generated_assets["image"] = {
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
"query": (chat.intent.inferred_queries or [user_message])[0],
}
if not is_none_or_empty(chat.get("mermaidjsDiagram")) and role == "assistant":
if not is_none_or_empty(chat.mermaidjsDiagram) and role == "assistant":
generated_assets["diagram"] = {
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
"query": (chat.intent.inferred_queries or [user_message])[0],
}
if not is_none_or_empty(generated_assets):
@@ -610,7 +615,7 @@ def generate_chatml_messages_with_context(
)
message_content = construct_structured_message(
chat_message, chat.get("images") if role == "user" else [], model_type, vision_enabled
chat_message, chat.images if role == "user" else [], model_type, vision_enabled
)
reconstructed_message = ChatMessage(content=message_content, role=role)

View File

@@ -10,7 +10,12 @@ from google import genai
from google.genai import types as gtypes
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
from khoj.database.models import (
Agent,
ChatMessageModel,
KhojUser,
TextToImageModelConfig,
)
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.storage import upload_generated_image_to_bucket
from khoj.utils import state
@@ -23,7 +28,7 @@ logger = logging.getLogger(__name__)
async def text_to_image(
message: str,
user: KhojUser,
conversation_log: dict,
chat_history: List[ChatMessageModel],
location_data: LocationData,
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
@@ -46,14 +51,14 @@ async def text_to_image(
return
text2image_model = text_to_image_config.model_name
chat_history = ""
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
elif chat["by"] == "khoj" and chat.get("images"):
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
chat_history_str = ""
for chat in chat_history[-4:]:
if chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]:
chat_history_str += f"Q: {chat.intent.query or ''}\n"
chat_history_str += f"A: {chat.message}\n"
elif chat.by == "khoj" and chat.images:
chat_history_str += f"Q: {chat.intent.query}\n"
chat_history_str += f"A: Improved Prompt: {chat.intent.inferred_queries[0]}\n"
if send_status_func:
async for event in send_status_func("**Enhancing the Painting Prompt**"):
@@ -63,7 +68,7 @@ async def text_to_image(
# Use the user's message, chat history, and other context
image_prompt = await generate_better_image_prompt(
message,
chat_history,
chat_history_str,
location_data=location_data,
note_references=references,
online_results=online_results,

View File

@@ -5,10 +5,9 @@ import os
from typing import Callable, List, Optional
from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.processor.conversation.utils import (
OperatorRun,
construct_chat_history,
construct_chat_history_for_operator,
)
from khoj.processor.operator.operator_actions import *
@@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
async def operate_environment(
query: str,
user: KhojUser,
conversation_log: dict,
conversation_log: List[ChatMessageModel],
location_data: LocationData,
previous_trajectory: Optional[OperatorRun] = None,
environment_type: EnvironmentType = EnvironmentType.COMPUTER,

View File

@@ -4,7 +4,7 @@ from datetime import datetime
from textwrap import dedent
from typing import List, Optional
from khoj.database.models import ChatModel
from khoj.database.models import ChatMessageModel, ChatModel
from khoj.processor.conversation.utils import (
AgentMessage,
OperatorRun,
@@ -119,13 +119,13 @@ class BinaryOperatorAgent(OperatorAgent):
query_screenshot = self._get_message_images(current_message)
# Construct input for visual reasoner history
visual_reasoner_history = {"chat": self._format_message_for_api(self.messages)}
visual_reasoner_history = self._format_message_for_api(self.messages)
try:
natural_language_action = await send_message_to_model_wrapper(
query=query_text,
query_images=query_screenshot,
system_message=reasoning_system_prompt,
conversation_log=visual_reasoner_history,
chat_history=visual_reasoner_history,
agent_chat_model=self.reasoning_model,
tracer=self.tracer,
)
@@ -238,11 +238,11 @@ class BinaryOperatorAgent(OperatorAgent):
async def summarize(self, env_state: EnvState, summarize_prompt: str = None) -> str:
summarize_prompt = summarize_prompt or self.summarize_prompt
conversation_history = {"chat": self._format_message_for_api(self.messages)}
conversation_history = self._format_message_for_api(self.messages)
try:
summary = await send_message_to_model_wrapper(
query=summarize_prompt,
conversation_log=conversation_history,
chat_history=conversation_history,
agent_chat_model=self.reasoning_model,
tracer=self.tracer,
)
@@ -296,14 +296,14 @@ class BinaryOperatorAgent(OperatorAgent):
images = [item["image_url"]["url"] for item in message.content if item["type"] == "image_url"]
return images
def _format_message_for_api(self, messages: list[AgentMessage]) -> List[dict]:
def _format_message_for_api(self, messages: list[AgentMessage]) -> List[ChatMessageModel]:
"""Format operator agent messages into the Khoj conversation history format."""
formatted_messages = [
{
"message": self._get_message_text(message),
"images": self._get_message_images(message),
"by": "you" if message.role in ["user", "environment"] else message.role,
}
ChatMessageModel(
message=self._get_message_text(message),
images=self._get_message_images(message),
by="you" if message.role in ["user", "environment"] else message.role,
)
for message in messages
]
return formatted_messages

View File

@@ -10,7 +10,13 @@ from bs4 import BeautifulSoup
from markdownify import markdownify
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import Agent, KhojUser, ServerChatSettings, WebScraper
from khoj.database.models import (
Agent,
ChatMessageModel,
KhojUser,
ServerChatSettings,
WebScraper,
)
from khoj.processor.conversation import prompts
from khoj.routers.helpers import (
ChatEvent,
@@ -59,7 +65,7 @@ OLOSTEP_QUERY_PARAMS = {
async def search_online(
query: str,
conversation_history: dict,
conversation_history: List[ChatMessageModel],
location: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,
@@ -361,7 +367,7 @@ async def search_with_serper(query: str, location: LocationData) -> Tuple[str, D
async def read_webpages(
query: str,
conversation_history: dict,
conversation_history: List[ChatMessageModel],
location: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,

View File

@@ -20,7 +20,7 @@ from tenacity import (
)
from khoj.database.adapters import FileObjectAdapters
from khoj.database.models import Agent, FileObject, KhojUser
from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
ChatEvent,
@@ -50,7 +50,7 @@ class GeneratedCode(NamedTuple):
async def run_code(
query: str,
conversation_history: dict,
conversation_history: List[ChatMessageModel],
context: str,
location_data: LocationData,
user: KhojUser,
@@ -116,7 +116,7 @@ async def run_code(
async def generate_python_code(
q: str,
conversation_history: dict,
chat_history: List[ChatMessageModel],
context: str,
location_data: LocationData,
user: KhojUser,
@@ -127,7 +127,7 @@ async def generate_python_code(
) -> GeneratedCode:
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
chat_history_str = construct_chat_history(chat_history)
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
personality_context = (
@@ -143,7 +143,7 @@ async def generate_python_code(
code_generation_prompt = prompts.python_code_generation_prompt.format(
query=q,
chat_history=chat_history,
chat_history=chat_history_str,
context=context,
has_network_access=network_access_context,
current_date=utc_date,

View File

@@ -29,7 +29,13 @@ from khoj.database.adapters import (
get_default_search_model,
get_user_photo,
)
from khoj.database.models import Agent, ChatModel, KhojUser, SpeechToTextModelOptions
from khoj.database.models import (
Agent,
ChatMessageModel,
ChatModel,
KhojUser,
SpeechToTextModelOptions,
)
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import (
extract_questions_anthropic,
@@ -353,7 +359,7 @@ def set_user_name(
async def extract_references_and_questions(
user: KhojUser,
meta_log: dict,
chat_history: list[ChatMessageModel],
q: str,
n: int,
d: float,
@@ -432,7 +438,7 @@ async def extract_references_and_questions(
defiltered_query,
model=chat_model,
loaded_model=loaded_model,
conversation_log=meta_log,
chat_history=chat_history,
should_extract_questions=True,
location_data=location_data,
user=user,
@@ -450,7 +456,7 @@ async def extract_references_and_questions(
model=chat_model_name,
api_key=api_key,
api_base_url=base_url,
conversation_log=meta_log,
chat_history=chat_history,
location_data=location_data,
user=user,
query_images=query_images,
@@ -469,7 +475,7 @@ async def extract_references_and_questions(
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
conversation_log=meta_log,
chat_history=chat_history,
location_data=location_data,
user=user,
vision_enabled=vision_enabled,
@@ -487,7 +493,7 @@ async def extract_references_and_questions(
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
conversation_log=meta_log,
chat_history=chat_history,
location_data=location_data,
max_tokens=chat_model.max_prompt_size,
user=user,
@@ -606,7 +612,7 @@ def post_automation(
return Response(content="Invalid crontime", status_code=400)
# Infer subject, query to run
_, query_to_run, generated_subject = schedule_query(q, conversation_history={}, user=user)
_, query_to_run, generated_subject = schedule_query(q, chat_history=[], user=user)
subject = subject or generated_subject
# Normalize query parameters
@@ -712,7 +718,7 @@ def edit_job(
return Response(content="Invalid automation", status_code=403)
# Infer subject, query to run
_, query_to_run, _ = schedule_query(q, conversation_history={}, user=user)
_, query_to_run, _ = schedule_query(q, chat_history=[], user=user)
subject = subject
# Normalize query parameters

View File

@@ -752,7 +752,7 @@ async def chat(
q,
chat_response="",
user=user,
meta_log=meta_log,
chat_history=chat_history,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
@@ -918,7 +918,7 @@ async def chat(
if city or region or country or country_code:
location = LocationData(city=city, region=region, country=country, country_code=country_code)
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
meta_log = conversation.conversation_log
chat_history = conversation.messages
# If interrupt flag is set, wait for the previous turn to be saved before proceeding
if interrupt_flag:
@@ -964,14 +964,14 @@ async def chat(
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
# Drop the interrupted message from conversation history
meta_log["chat"].pop()
chat_history.pop()
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
if conversation_commands == [ConversationCommand.Default]:
try:
chosen_io = await aget_data_sources_and_output_format(
q,
meta_log,
chat_history,
is_automated_task,
user=user,
query_images=uploaded_images,
@@ -1011,7 +1011,7 @@ async def chat(
user=user,
query=defiltered_query,
conversation_id=conversation_id,
conversation_history=meta_log,
conversation_history=conversation.messages,
previous_iterations=list(research_results),
query_images=uploaded_images,
agent=agent,
@@ -1078,7 +1078,7 @@ async def chat(
q=q,
user=user,
file_filters=file_filters,
meta_log=meta_log,
chat_history=conversation.messages,
query_images=uploaded_images,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
@@ -1123,7 +1123,7 @@ async def chat(
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log, tracer=tracer
q, timezone, user, request.url, chat_history, tracer=tracer
)
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
@@ -1139,7 +1139,7 @@ async def chat(
q,
llm_response,
user,
meta_log,
chat_history,
user_message_time,
intent_type="automation",
client_application=request.user.client_app,
@@ -1163,7 +1163,7 @@ async def chat(
try:
async for result in extract_references_and_questions(
user,
meta_log,
chat_history,
q,
(n or 7),
d,
@@ -1212,7 +1212,7 @@ async def chat(
try:
async for result in search_online(
defiltered_query,
meta_log,
chat_history,
location,
user,
partial(send_event, ChatEvent.STATUS),
@@ -1240,7 +1240,7 @@ async def chat(
try:
async for result in read_webpages(
defiltered_query,
meta_log,
chat_history,
location,
user,
partial(send_event, ChatEvent.STATUS),
@@ -1281,7 +1281,7 @@ async def chat(
context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
async for result in run_code(
defiltered_query,
meta_log,
chat_history,
context,
location,
user,
@@ -1306,7 +1306,7 @@ async def chat(
async for result in operate_environment(
defiltered_query,
user,
meta_log,
chat_history,
location,
list(operator_results)[-1] if operator_results else None,
query_images=uploaded_images,
@@ -1356,7 +1356,7 @@ async def chat(
async for result in text_to_image(
defiltered_query,
user,
meta_log,
chat_history,
location_data=location,
references=compiled_references,
online_results=online_results,
@@ -1400,7 +1400,7 @@ async def chat(
async for result in generate_mermaidjs_diagram(
q=defiltered_query,
conversation_history=meta_log,
chat_history=chat_history,
location_data=location,
note_references=compiled_references,
online_results=online_results,
@@ -1456,7 +1456,7 @@ async def chat(
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
chat_history,
conversation,
compiled_references,
online_results,

View File

@@ -55,6 +55,7 @@ from khoj.database.adapters import (
)
from khoj.database.models import (
Agent,
ChatMessageModel,
ChatModel,
ClientApplication,
Conversation,
@@ -285,7 +286,7 @@ async def acreate_title_from_history(
"""
Create a title from the given conversation history
"""
chat_history = construct_chat_history(conversation.conversation_log)
chat_history = construct_chat_history(conversation.messages)
title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history)
@@ -345,7 +346,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
async def aget_data_sources_and_output_format(
query: str,
conversation_history: dict,
chat_history: list[ChatMessageModel],
is_task: bool,
user: KhojUser,
query_images: List[str] = None,
@@ -386,7 +387,7 @@ async def aget_data_sources_and_output_format(
if len(agent_outputs) == 0 or output.value in agent_outputs:
output_options_str += f'- "{output.value}": "{description}"\n'
chat_history = construct_chat_history(conversation_history, n=6)
chat_history_str = construct_chat_history(chat_history, n=6)
if query_images:
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
@@ -399,7 +400,7 @@ async def aget_data_sources_and_output_format(
query=query,
sources=source_options_str,
outputs=output_options_str,
chat_history=chat_history,
chat_history=chat_history_str,
personality_context=personality_context,
)
@@ -462,7 +463,7 @@ async def aget_data_sources_and_output_format(
async def infer_webpage_urls(
q: str,
max_webpages: int,
conversation_history: dict,
chat_history: List[ChatMessageModel],
location_data: LocationData,
user: KhojUser,
query_images: List[str] = None,
@@ -475,7 +476,7 @@ async def infer_webpage_urls(
"""
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
chat_history_str = construct_chat_history(chat_history)
utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
personality_context = (
@@ -485,7 +486,7 @@ async def infer_webpage_urls(
online_queries_prompt = prompts.infer_webpages_to_read.format(
query=q,
max_webpages=max_webpages,
chat_history=chat_history,
chat_history=chat_history_str,
current_date=utc_date,
location=location,
username=username,
@@ -526,7 +527,7 @@ async def infer_webpage_urls(
async def generate_online_subqueries(
q: str,
conversation_history: dict,
chat_history: List[ChatMessageModel],
location_data: LocationData,
user: KhojUser,
query_images: List[str] = None,
@@ -540,7 +541,7 @@ async def generate_online_subqueries(
"""
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
chat_history_str = construct_chat_history(chat_history)
utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
personality_context = (
@@ -549,7 +550,7 @@ async def generate_online_subqueries(
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
query=q,
chat_history=chat_history,
chat_history=chat_history_str,
max_queries=max_queries,
current_date=utc_date,
location=location,
@@ -591,16 +592,16 @@ async def generate_online_subqueries(
def schedule_query(
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
q: str, chat_history: List[ChatMessageModel], user: KhojUser, query_images: List[str] = None, tracer: dict = {}
) -> Tuple[str, str, str]:
"""
Schedule the date, time to run the query. Assume the server timezone is UTC.
"""
chat_history = construct_chat_history(conversation_history)
chat_history_str = construct_chat_history(chat_history)
crontime_prompt = prompts.crontime_prompt.format(
query=q,
chat_history=chat_history,
chat_history=chat_history_str,
)
raw_response = send_message_to_model_wrapper_sync(
@@ -619,16 +620,16 @@ def schedule_query(
async def aschedule_query(
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
q: str, chat_history: List[ChatMessageModel], user: KhojUser, query_images: List[str] = None, tracer: dict = {}
) -> Tuple[str, str, str]:
"""
Schedule the date, time to run the query. Assume the server timezone is UTC.
"""
chat_history = construct_chat_history(conversation_history)
chat_history_str = construct_chat_history(chat_history)
crontime_prompt = prompts.crontime_prompt.format(
query=q,
chat_history=chat_history,
chat_history=chat_history_str,
)
raw_response = await send_message_to_model_wrapper(
@@ -681,7 +682,7 @@ async def extract_relevant_info(
async def extract_relevant_summary(
q: str,
corpus: str,
conversation_history: dict,
chat_history: List[ChatMessageModel] = [],
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
@@ -698,11 +699,11 @@ async def extract_relevant_summary(
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
chat_history = construct_chat_history(conversation_history)
chat_history_str = construct_chat_history(chat_history)
extract_relevant_information = prompts.extract_relevant_summary.format(
query=q,
chat_history=chat_history,
chat_history=chat_history_str,
corpus=corpus.strip(),
personality_context=personality_context,
)
@@ -725,7 +726,7 @@ async def generate_summary_from_files(
q: str,
user: KhojUser,
file_filters: List[str],
meta_log: dict,
chat_history: List[ChatMessageModel] = [],
query_images: List[str] = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
@@ -766,7 +767,7 @@ async def generate_summary_from_files(
response = await extract_relevant_summary(
q,
contextual_data,
conversation_history=meta_log,
chat_history=chat_history,
query_images=query_images,
user=user,
agent=agent,
@@ -782,7 +783,7 @@ async def generate_summary_from_files(
async def generate_excalidraw_diagram(
q: str,
conversation_history: Dict[str, Any],
chat_history: List[ChatMessageModel],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
@@ -799,7 +800,7 @@ async def generate_excalidraw_diagram(
better_diagram_description_prompt = await generate_better_diagram_description(
q=q,
conversation_history=conversation_history,
chat_history=chat_history,
location_data=location_data,
note_references=note_references,
online_results=online_results,
@@ -834,7 +835,7 @@ async def generate_excalidraw_diagram(
async def generate_better_diagram_description(
q: str,
conversation_history: Dict[str, Any],
chat_history: List[ChatMessageModel],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
@@ -857,7 +858,7 @@ async def generate_better_diagram_description(
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
chat_history = construct_chat_history(conversation_history)
chat_history_str = construct_chat_history(chat_history)
simplified_online_results = {}
@@ -870,7 +871,7 @@ async def generate_better_diagram_description(
improve_diagram_description_prompt = prompts.improve_excalidraw_diagram_description_prompt.format(
query=q,
chat_history=chat_history,
chat_history=chat_history_str,
location=location,
current_date=today_date,
references=user_references,
@@ -939,7 +940,7 @@ async def generate_excalidraw_diagram_from_description(
async def generate_mermaidjs_diagram(
q: str,
conversation_history: Dict[str, Any],
chat_history: List[ChatMessageModel],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
@@ -956,7 +957,7 @@ async def generate_mermaidjs_diagram(
better_diagram_description_prompt = await generate_better_mermaidjs_diagram_description(
q=q,
conversation_history=conversation_history,
chat_history=chat_history,
location_data=location_data,
note_references=note_references,
online_results=online_results,
@@ -985,7 +986,7 @@ async def generate_mermaidjs_diagram(
async def generate_better_mermaidjs_diagram_description(
q: str,
conversation_history: Dict[str, Any],
chat_history: List[ChatMessageModel],
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
@@ -1008,7 +1009,7 @@ async def generate_better_mermaidjs_diagram_description(
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
chat_history = construct_chat_history(conversation_history)
chat_history_str = construct_chat_history(chat_history)
simplified_online_results = {}
@@ -1021,7 +1022,7 @@ async def generate_better_mermaidjs_diagram_description(
improve_diagram_description_prompt = prompts.improve_mermaid_js_diagram_description_prompt.format(
query=q,
chat_history=chat_history,
chat_history=chat_history_str,
location=location,
current_date=today_date,
references=user_references,
@@ -1160,7 +1161,7 @@ async def send_message_to_model_wrapper(
query_images: List[str] = None,
context: str = "",
query_files: str = None,
conversation_log: dict = {},
chat_history: list[ChatMessageModel] = [],
agent_chat_model: ChatModel = None,
tracer: dict = {},
):
@@ -1193,7 +1194,7 @@ async def send_message_to_model_wrapper(
user_message=query,
context_message=context,
system_message=system_message,
conversation_log=conversation_log,
chat_history=chat_history,
model_name=chat_model_name,
loaded_model=loaded_model,
tokenizer_name=tokenizer,
@@ -1260,7 +1261,7 @@ def send_message_to_model_wrapper_sync(
user: KhojUser = None,
query_images: List[str] = None,
query_files: str = "",
conversation_log: dict = {},
chat_history: List[ChatMessageModel] = [],
tracer: dict = {},
):
chat_model: ChatModel = ConversationAdapters.get_default_chat_model(user)
@@ -1284,7 +1285,7 @@ def send_message_to_model_wrapper_sync(
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
conversation_log=conversation_log,
chat_history=chat_history,
model_name=chat_model_name,
loaded_model=loaded_model,
max_prompt_size=max_tokens,
@@ -1342,7 +1343,7 @@ def send_message_to_model_wrapper_sync(
async def agenerate_chat_response(
q: str,
meta_log: dict,
chat_history: List[ChatMessageModel],
conversation: Conversation,
compiled_references: List[Dict] = [],
online_results: Dict[str, Dict] = {},
@@ -1379,7 +1380,7 @@ async def agenerate_chat_response(
save_to_conversation_log,
q,
user=user,
meta_log=meta_log,
chat_history=chat_history,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
@@ -1424,7 +1425,7 @@ async def agenerate_chat_response(
references=compiled_references,
online_results=online_results,
loaded_model=loaded_model,
conversation_log=meta_log,
chat_history=chat_history,
completion_func=partial_completion,
conversation_commands=conversation_commands,
model_name=chat_model.name,
@@ -1450,7 +1451,7 @@ async def agenerate_chat_response(
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
conversation_log=meta_log,
chat_history=chat_history,
model=chat_model_name,
api_key=api_key,
api_base_url=openai_chat_config.api_base_url,
@@ -1480,7 +1481,7 @@ async def agenerate_chat_response(
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
conversation_log=meta_log,
chat_history=chat_history,
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
@@ -1508,7 +1509,7 @@ async def agenerate_chat_response(
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
conversation_log=meta_log,
chat_history=chat_history,
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
@@ -2005,11 +2006,11 @@ async def create_automation(
timezone: str,
user: KhojUser,
calling_url: URL,
meta_log: dict = {},
chat_history: List[ChatMessageModel] = [],
conversation_id: str = None,
tracer: dict = {},
):
crontime, query_to_run, subject = await aschedule_query(q, meta_log, user, tracer=tracer)
crontime, query_to_run, subject = await aschedule_query(q, chat_history, user, tracer=tracer)
job = await aschedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
return job, crontime, query_to_run, subject

View File

@@ -10,7 +10,7 @@ import yaml
from pydantic import BaseModel, Field
from khoj.database.adapters import AgentAdapters, EntryAdapters
from khoj.database.models import Agent, KhojUser
from khoj.database.models import Agent, ChatMessageModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
OperatorRun,
@@ -84,7 +84,7 @@ class PlanningResponse(BaseModel):
async def apick_next_tool(
query: str,
conversation_history: dict,
conversation_history: List[ChatMessageModel],
user: KhojUser = None,
location: LocationData = None,
user_name: str = None,
@@ -166,18 +166,18 @@ async def apick_next_tool(
query = f"[placeholder for user attached images]\n{query}"
# Construct chat history with user and iteration history with researcher agent for context
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query)
iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history}
iteration_chat_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query)
chat_and_research_history = conversation_history + iteration_chat_history
# Plan function execution for the next tool
query = prompts.plan_function_execution_next_tool.format(query=query) if previous_iterations_history else query
query = prompts.plan_function_execution_next_tool.format(query=query) if iteration_chat_history else query
try:
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
query=query,
system_message=function_planning_prompt,
conversation_log=iteration_chat_log,
chat_history=chat_and_research_history,
response_type="json_object",
response_schema=planning_response_model,
deepthought=True,
@@ -238,7 +238,7 @@ async def research(
user: KhojUser,
query: str,
conversation_id: str,
conversation_history: dict,
conversation_history: List[ChatMessageModel],
previous_iterations: List[ResearchIteration],
query_images: List[str],
agent: Agent = None,
@@ -261,9 +261,7 @@ async def research(
if current_iteration := len(previous_iterations) > 0:
logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.")
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
research_conversation_history["chat"] = (
research_conversation_history.get("chat", []) + previous_iterations_history
)
research_conversation_history += previous_iterations_history
while current_iteration < MAX_ITERATIONS:
# Check for cancellation at the start of each iteration

View File

@@ -6,6 +6,7 @@ from django.utils.timezone import make_aware
from khoj.database.models import (
AiModelApi,
ChatMessageModel,
ChatModel,
Conversation,
KhojApiUser,
@@ -46,15 +47,15 @@ def get_chat_api_key(provider: ChatModel.ModelType = None):
def generate_chat_history(message_list):
# Generate conversation logs
conversation_log = {"chat": []}
chat_history: list[ChatMessageModel] = []
for user_message, chat_response, context in message_list:
message_to_log(
user_message,
chat_response,
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
conversation_log=conversation_log.get("chat", []),
chat_history=chat_history,
)
return conversation_log
return chat_history
class UserFactory(factory.django.DjangoModelFactory):

View File

@@ -135,7 +135,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
# Act
response = extract_questions_offline(
query,
conversation_log=generate_chat_history(message_list),
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
use_history=True,
)
@@ -181,7 +181,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
# Act
response = extract_questions_offline(
"Is she a Doctor?",
conversation_log=generate_chat_history(message_list),
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
use_history=True,
)
@@ -210,7 +210,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
# Act
response = extract_questions_offline(
"What was the Pizza place we ate at over there?",
conversation_log=generate_chat_history(message_list),
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
)
@@ -336,7 +336,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model)
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=generate_chat_history(message_list),
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
@@ -363,7 +363,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
{"compiled": "Testatron was born on 1st April 1984 in Testville."}
], # Assume context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=generate_chat_history(message_list),
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
@@ -388,7 +388,7 @@ def test_refuse_answering_unanswerable_question(loaded_model):
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=generate_chat_history(message_list),
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
@@ -501,7 +501,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Write a haiku about unit testing in 3 lines",
conversation_log=generate_chat_history(message_list),
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])

View File

@@ -28,7 +28,7 @@ def generate_history(message_list):
user_message,
gpt_message,
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
conversation_log=conversation_log.get("chat", []),
chat_history=conversation_log.get("chat", []),
)
return conversation_log

View File

@@ -708,6 +708,6 @@ def populate_chat_history(message_list):
"context": context,
"intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'},
},
conversation_log=[],
chat_history=[],
)
return conversation_log