mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Pass deep typed chat history for more ergonomic, readable, safe code
The chat dictionary is an artifact from earlier non-db chat history storage. We've been ensuring new chat messages have valid type before being written to DB for more than 6 months now. Move to using the deeply typed chat history helps avoids null refs, makes code more readable and easier to reason about. Next Steps: The current update entangles chat_history written to DB with any virtual chat history message generated for intermediate steps. The chat message type written to DB should be decoupled from type that can be passed to AI model APIs (maybe?). For now we've made the ChatMessage.message type looser to allow for list[dict] type (apart from string). But later maybe a good idea to decouple the chat_history recieved by send_message_to_model from the chat_history saved to DB (which can then have its stricter type check)
This commit is contained in:
@@ -37,6 +37,7 @@ from torch import Tensor
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
AiModelApi,
|
||||
ChatMessageModel,
|
||||
ChatModel,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
@@ -1419,7 +1420,7 @@ class ConversationAdapters:
|
||||
@require_valid_user
|
||||
async def save_conversation(
|
||||
user: KhojUser,
|
||||
conversation_log: dict,
|
||||
chat_history: List[ChatMessageModel],
|
||||
client_application: ClientApplication = None,
|
||||
conversation_id: str = None,
|
||||
user_message: str = None,
|
||||
@@ -1434,6 +1435,7 @@ class ConversationAdapters:
|
||||
await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
|
||||
)
|
||||
|
||||
conversation_log = {"chat": [msg.model_dump() for msg in chat_history]}
|
||||
cleaned_conversation_log = clean_object_for_db(conversation_log)
|
||||
if conversation:
|
||||
conversation.conversation_log = cleaned_conversation_log
|
||||
|
||||
@@ -91,7 +91,7 @@ class OnlineContext(PydanticBaseModel):
|
||||
class Intent(PydanticBaseModel):
|
||||
type: str
|
||||
query: str
|
||||
memory_type: str = Field(alias="memory-type")
|
||||
memory_type: Optional[str] = Field(alias="memory-type", default=None)
|
||||
inferred_queries: Optional[List[str]] = Field(default=None, alias="inferred-queries")
|
||||
|
||||
|
||||
@@ -100,20 +100,20 @@ class TrainOfThought(PydanticBaseModel):
|
||||
data: str
|
||||
|
||||
|
||||
class ChatMessage(PydanticBaseModel):
|
||||
message: str
|
||||
class ChatMessageModel(PydanticBaseModel):
|
||||
by: str
|
||||
message: str | list[dict]
|
||||
trainOfThought: List[TrainOfThought] = []
|
||||
context: List[Context] = []
|
||||
onlineContext: Dict[str, OnlineContext] = {}
|
||||
codeContext: Dict[str, CodeContextData] = {}
|
||||
researchContext: Optional[List] = None
|
||||
operatorContext: Optional[List] = None
|
||||
created: str
|
||||
created: Optional[str] = None
|
||||
images: Optional[List[str]] = None
|
||||
queryFiles: Optional[List[Dict]] = None
|
||||
excalidrawDiagram: Optional[List[Dict]] = None
|
||||
mermaidjsDiagram: str = None
|
||||
by: str
|
||||
mermaidjsDiagram: Optional[str] = None
|
||||
turnId: Optional[str] = None
|
||||
intent: Optional[Intent] = None
|
||||
automationId: Optional[str] = None
|
||||
@@ -634,7 +634,7 @@ class Conversation(DbBaseModel):
|
||||
try:
|
||||
messages = self.conversation_log.get("chat", [])
|
||||
for msg in messages:
|
||||
ChatMessage.model_validate(msg)
|
||||
ChatMessageModel.model_validate(msg)
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Invalid conversation_log format: {str(e)}")
|
||||
|
||||
@@ -643,7 +643,7 @@ class Conversation(DbBaseModel):
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def messages(self) -> List[ChatMessage]:
|
||||
def messages(self) -> List[ChatMessageModel]:
|
||||
"""Type-hinted accessor for conversation messages"""
|
||||
validated_messages = []
|
||||
for msg in self.conversation_log.get("chat", []):
|
||||
@@ -654,7 +654,7 @@ class Conversation(DbBaseModel):
|
||||
q for q in msg["intent"]["inferred-queries"] if q is not None and isinstance(q, str)
|
||||
]
|
||||
msg["message"] = str(msg.get("message", ""))
|
||||
validated_messages.append(ChatMessage.model_validate(msg))
|
||||
validated_messages.append(ChatMessageModel.model_validate(msg))
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Skipping invalid message in conversation: {e}")
|
||||
continue
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import AsyncGenerator, Dict, List, Optional
|
||||
import pyjson5
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.anthropic.utils import (
|
||||
anthropic_chat_completion_with_backoff,
|
||||
@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
|
||||
def extract_questions_anthropic(
|
||||
text,
|
||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||
conversation_log={},
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
api_key=None,
|
||||
api_base_url=None,
|
||||
location_data: LocationData = None,
|
||||
@@ -54,8 +54,8 @@ def extract_questions_anthropic(
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = construct_question_history(conversation_log, query_prefix="User", agent_name="Assistant")
|
||||
# Extract Past User Message and Inferred Questions from Chat History
|
||||
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
|
||||
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
@@ -76,7 +76,7 @@ def extract_questions_anthropic(
|
||||
)
|
||||
|
||||
prompt = prompts.extract_questions_anthropic_user_message.format(
|
||||
chat_history=chat_history,
|
||||
chat_history=chat_history_str,
|
||||
text=text,
|
||||
)
|
||||
|
||||
@@ -142,7 +142,7 @@ async def converse_anthropic(
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
operator_results: Optional[List[OperatorRun]] = None,
|
||||
conversation_log={},
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
@@ -225,7 +225,7 @@ async def converse_anthropic(
|
||||
messages = generate_chatml_messages_with_context(
|
||||
user_query,
|
||||
context_message=context_message,
|
||||
conversation_log=conversation_log,
|
||||
chat_history=chat_history,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
|
||||
@@ -7,7 +7,7 @@ import pyjson5
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.google.utils import (
|
||||
gemini_chat_completion_with_backoff,
|
||||
@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
def extract_questions_gemini(
|
||||
text,
|
||||
model: Optional[str] = "gemini-2.0-flash",
|
||||
conversation_log={},
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
api_key=None,
|
||||
api_base_url=None,
|
||||
max_tokens=None,
|
||||
@@ -54,8 +54,8 @@ def extract_questions_gemini(
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = construct_question_history(conversation_log, query_prefix="User", agent_name="Assistant")
|
||||
# Extract Past User Message and Inferred Questions from Chat History
|
||||
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
|
||||
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
@@ -76,7 +76,7 @@ def extract_questions_gemini(
|
||||
)
|
||||
|
||||
prompt = prompts.extract_questions_anthropic_user_message.format(
|
||||
chat_history=chat_history,
|
||||
chat_history=chat_history_str,
|
||||
text=text,
|
||||
)
|
||||
|
||||
@@ -163,7 +163,7 @@ async def converse_gemini(
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
operator_results: Optional[List[OperatorRun]] = None,
|
||||
conversation_log={},
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
model: Optional[str] = "gemini-2.0-flash",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
@@ -248,7 +248,7 @@ async def converse_gemini(
|
||||
messages = generate_chatml_messages_with_context(
|
||||
user_query,
|
||||
context_message=context_message,
|
||||
conversation_log=conversation_log,
|
||||
chat_history=chat_history,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
|
||||
@@ -10,7 +10,7 @@ import pyjson5
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from llama_cpp import Llama
|
||||
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
@@ -38,7 +38,7 @@ def extract_questions_offline(
|
||||
text: str,
|
||||
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
conversation_log={},
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
use_history: bool = True,
|
||||
should_extract_questions: bool = True,
|
||||
location_data: LocationData = None,
|
||||
@@ -65,7 +65,7 @@ def extract_questions_offline(
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = construct_question_history(conversation_log, include_query=False) if use_history else ""
|
||||
chat_history_str = construct_question_history(chat_history, include_query=False) if use_history else ""
|
||||
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
@@ -73,7 +73,7 @@ def extract_questions_offline(
|
||||
last_year = today.year - 1
|
||||
example_questions = prompts.extract_questions_offline.format(
|
||||
query=text,
|
||||
chat_history=chat_history,
|
||||
chat_history=chat_history_str,
|
||||
current_date=today.strftime("%Y-%m-%d"),
|
||||
day_of_week=today.strftime("%A"),
|
||||
current_month=today.strftime("%Y-%m"),
|
||||
@@ -147,7 +147,7 @@ async def converse_offline(
|
||||
references: list[dict] = [],
|
||||
online_results={},
|
||||
code_results={},
|
||||
conversation_log={},
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
completion_func=None,
|
||||
@@ -227,7 +227,7 @@ async def converse_offline(
|
||||
messages = generate_chatml_messages_with_context(
|
||||
user_query,
|
||||
system_prompt,
|
||||
conversation_log,
|
||||
chat_history,
|
||||
context_message=context_message,
|
||||
model_name=model_name,
|
||||
loaded_model=offline_chat_model,
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain_core.messages.chat import ChatMessage
|
||||
from openai.lib._pydantic import _ensure_strict_json_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.openai.utils import (
|
||||
chat_completion_with_backoff,
|
||||
@@ -39,7 +39,7 @@ logger = logging.getLogger(__name__)
|
||||
def extract_questions(
|
||||
text,
|
||||
model: Optional[str] = "gpt-4o-mini",
|
||||
conversation_log={},
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
api_key=None,
|
||||
api_base_url=None,
|
||||
location_data: LocationData = None,
|
||||
@@ -56,8 +56,8 @@ def extract_questions(
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = construct_question_history(conversation_log)
|
||||
# Extract Past User Message and Inferred Questions from Chat History
|
||||
chat_history_str = construct_question_history(chat_history)
|
||||
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
@@ -73,7 +73,7 @@ def extract_questions(
|
||||
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
||||
bob_tom_age_difference={current_new_year.year - 1984 - 30},
|
||||
bob_age={current_new_year.year - 1984},
|
||||
chat_history=chat_history,
|
||||
chat_history=chat_history_str,
|
||||
text=text,
|
||||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
@@ -166,7 +166,7 @@ async def converse_openai(
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
operator_results: Optional[List[OperatorRun]] = None,
|
||||
conversation_log={},
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
model: str = "gpt-4o-mini",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
@@ -251,7 +251,7 @@ async def converse_openai(
|
||||
messages = generate_chatml_messages_with_context(
|
||||
user_query,
|
||||
system_prompt,
|
||||
conversation_log,
|
||||
chat_history,
|
||||
context_message=context_message,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
|
||||
@@ -24,7 +24,13 @@ from pydantic import BaseModel
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import ChatModel, ClientApplication, KhojUser
|
||||
from khoj.database.models import (
|
||||
ChatMessageModel,
|
||||
ChatModel,
|
||||
ClientApplication,
|
||||
Intent,
|
||||
KhojUser,
|
||||
)
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||
from khoj.search_filter.base_filter import BaseFilter
|
||||
@@ -161,8 +167,8 @@ def construct_iteration_history(
|
||||
previous_iterations: List[ResearchIteration],
|
||||
previous_iteration_prompt: str,
|
||||
query: str = None,
|
||||
) -> list[dict]:
|
||||
iteration_history: list[dict] = []
|
||||
) -> list[ChatMessageModel]:
|
||||
iteration_history: list[ChatMessageModel] = []
|
||||
previous_iteration_messages: list[dict] = []
|
||||
for idx, iteration in enumerate(previous_iterations):
|
||||
iteration_data = previous_iteration_prompt.format(
|
||||
@@ -176,46 +182,46 @@ def construct_iteration_history(
|
||||
|
||||
if previous_iteration_messages:
|
||||
if query:
|
||||
iteration_history.append({"by": "you", "message": query})
|
||||
iteration_history.append(ChatMessageModel(by="you", message=query))
|
||||
iteration_history.append(
|
||||
{
|
||||
"by": "khoj",
|
||||
"intent": {"type": "remember", "query": query},
|
||||
"message": previous_iteration_messages,
|
||||
}
|
||||
ChatMessageModel(
|
||||
by="khoj",
|
||||
intent={"type": "remember", "query": query},
|
||||
message=previous_iteration_messages,
|
||||
)
|
||||
)
|
||||
return iteration_history
|
||||
|
||||
|
||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||
chat_history = ""
|
||||
for chat in conversation_history.get("chat", [])[-n:]:
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
||||
if chat["intent"].get("inferred-queries"):
|
||||
chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
|
||||
chat_history += f"{agent_name}: {chat['message']}\n\n"
|
||||
elif chat["by"] == "khoj" and chat.get("images"):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
||||
elif chat["by"] == "you":
|
||||
chat_history += f"User: {chat['message']}\n"
|
||||
raw_query_files = chat.get("queryFiles")
|
||||
def construct_chat_history(chat_history: list[ChatMessageModel], n: int = 4, agent_name="AI") -> str:
|
||||
chat_history_str = ""
|
||||
for chat in chat_history[-n:]:
|
||||
if chat.by == "khoj" and chat.intent.type in ["remember", "reminder", "summarize"]:
|
||||
if chat.intent.inferred_queries:
|
||||
chat_history_str += f'{agent_name}: {{"queries": {chat.intent.inferred_queries}}}\n'
|
||||
chat_history_str += f"{agent_name}: {chat.message}\n\n"
|
||||
elif chat.by == "khoj" and chat.images:
|
||||
chat_history_str += f"User: {chat.intent.query}\n"
|
||||
chat_history_str += f"{agent_name}: [generated image redacted for space]\n"
|
||||
elif chat.by == "khoj" and ("excalidraw" in chat.intent.type):
|
||||
chat_history_str += f"User: {chat.intent.query}\n"
|
||||
chat_history_str += f"{agent_name}: {chat.intent.inferred_queries[0]}\n"
|
||||
elif chat.by == "you":
|
||||
chat_history_str += f"User: {chat.message}\n"
|
||||
raw_query_files = chat.queryFiles
|
||||
if raw_query_files:
|
||||
query_files: Dict[str, str] = {}
|
||||
for file in raw_query_files:
|
||||
query_files[file["name"]] = file["content"]
|
||||
|
||||
query_file_context = gather_raw_query_files(query_files)
|
||||
chat_history += f"User: {query_file_context}\n"
|
||||
chat_history_str += f"User: {query_file_context}\n"
|
||||
|
||||
return chat_history
|
||||
return chat_history_str
|
||||
|
||||
|
||||
def construct_question_history(
|
||||
conversation_log: dict,
|
||||
conversation_log: list[ChatMessageModel],
|
||||
include_query: bool = True,
|
||||
lookback: int = 6,
|
||||
query_prefix: str = "Q",
|
||||
@@ -226,16 +232,16 @@ def construct_question_history(
|
||||
"""
|
||||
history_parts = ""
|
||||
original_query = None
|
||||
for chat in conversation_log.get("chat", [])[-lookback:]:
|
||||
if chat["by"] == "you":
|
||||
original_query = chat.get("message")
|
||||
for chat in conversation_log[-lookback:]:
|
||||
if chat.by == "you":
|
||||
original_query = json.dumps(chat.message)
|
||||
history_parts += f"{query_prefix}: {original_query}\n"
|
||||
if chat["by"] == "khoj":
|
||||
if chat.by == "khoj":
|
||||
if original_query is None:
|
||||
continue
|
||||
|
||||
message = chat.get("message", "")
|
||||
inferred_queries_list = chat.get("intent", {}).get("inferred-queries")
|
||||
message = chat.message
|
||||
inferred_queries_list = chat.intent.inferred_queries or []
|
||||
|
||||
# Ensure inferred_queries_list is a list, defaulting to the original query in a list
|
||||
if not inferred_queries_list:
|
||||
@@ -246,7 +252,7 @@ def construct_question_history(
|
||||
|
||||
if include_query:
|
||||
# Ensure 'type' exists and is a string before checking 'to-image'
|
||||
intent_type = chat.get("intent", {}).get("type", "")
|
||||
intent_type = chat.intent.type if chat.intent and chat.intent.type else ""
|
||||
if "to-image" not in intent_type:
|
||||
history_parts += f'{agent_name}: {{"queries": {inferred_queries_list}}}\n'
|
||||
history_parts += f"A: {message}\n\n"
|
||||
@@ -259,7 +265,7 @@ def construct_question_history(
|
||||
return history_parts
|
||||
|
||||
|
||||
def construct_chat_history_for_operator(conversation_history: dict, n: int = 6) -> list[AgentMessage]:
|
||||
def construct_chat_history_for_operator(conversation_history: List[ChatMessageModel], n: int = 6) -> list[AgentMessage]:
|
||||
"""
|
||||
Construct chat history for operator agent in conversation log.
|
||||
Only include last n completed turns (i.e with user and khoj message).
|
||||
@@ -267,22 +273,22 @@ def construct_chat_history_for_operator(conversation_history: dict, n: int = 6)
|
||||
chat_history: list[AgentMessage] = []
|
||||
user_message: Optional[AgentMessage] = None
|
||||
|
||||
for chat in conversation_history.get("chat", []):
|
||||
for chat in conversation_history:
|
||||
if len(chat_history) >= n:
|
||||
break
|
||||
if chat["by"] == "you" and chat.get("message"):
|
||||
content = [{"type": "text", "text": chat["message"]}]
|
||||
for file in chat.get("queryFiles", []):
|
||||
if chat.by == "you" and chat.message:
|
||||
content = [{"type": "text", "text": chat.message}]
|
||||
for file in chat.queryFiles or []:
|
||||
content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}]
|
||||
user_message = AgentMessage(role="user", content=content)
|
||||
elif chat["by"] == "khoj" and chat.get("message"):
|
||||
chat_history += [user_message, AgentMessage(role="assistant", content=chat["message"])]
|
||||
elif chat.by == "khoj" and chat.message:
|
||||
chat_history += [user_message, AgentMessage(role="assistant", content=chat.message)]
|
||||
return chat_history
|
||||
|
||||
|
||||
def construct_tool_chat_history(
|
||||
previous_iterations: List[ResearchIteration], tool: ConversationCommand = None
|
||||
) -> Dict[str, list]:
|
||||
) -> List[ChatMessageModel]:
|
||||
"""
|
||||
Construct chat history from previous iterations for a specific tool
|
||||
|
||||
@@ -313,22 +319,23 @@ def construct_tool_chat_history(
|
||||
tool or ConversationCommand(iteration.tool), base_extractor
|
||||
)
|
||||
chat_history += [
|
||||
{
|
||||
"by": "you",
|
||||
"message": iteration.query,
|
||||
},
|
||||
{
|
||||
"by": "khoj",
|
||||
"intent": {
|
||||
"type": "remember",
|
||||
"inferred-queries": inferred_query_extractor(iteration),
|
||||
"query": iteration.query,
|
||||
},
|
||||
"message": iteration.summarizedResult,
|
||||
},
|
||||
ChatMessageModel(
|
||||
by="you",
|
||||
message=iteration.query,
|
||||
),
|
||||
ChatMessageModel(
|
||||
by="khoj",
|
||||
intent=Intent(
|
||||
type="remember",
|
||||
query=iteration.query,
|
||||
inferred_queries=inferred_query_extractor(iteration),
|
||||
memory_type="notes",
|
||||
),
|
||||
message=iteration.summarizedResult,
|
||||
),
|
||||
]
|
||||
|
||||
return {"chat": chat_history}
|
||||
return chat_history
|
||||
|
||||
|
||||
class ChatEvent(Enum):
|
||||
@@ -349,8 +356,8 @@ def message_to_log(
|
||||
chat_response,
|
||||
user_message_metadata={},
|
||||
khoj_message_metadata={},
|
||||
conversation_log=[],
|
||||
):
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
) -> List[ChatMessageModel]:
|
||||
"""Create json logs from messages, metadata for conversation log"""
|
||||
default_khoj_message_metadata = {
|
||||
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
|
||||
@@ -369,15 +376,17 @@ def message_to_log(
|
||||
khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
|
||||
khoj_log = merge_dicts({"message": chat_response, "by": "khoj", "created": khoj_response_time}, khoj_log)
|
||||
|
||||
conversation_log.extend([human_log, khoj_log])
|
||||
return conversation_log
|
||||
human_message = ChatMessageModel(**human_log)
|
||||
khoj_message = ChatMessageModel(**khoj_log)
|
||||
chat_history.extend([human_message, khoj_message])
|
||||
return chat_history
|
||||
|
||||
|
||||
async def save_to_conversation_log(
|
||||
q: str,
|
||||
chat_response: str,
|
||||
user: KhojUser,
|
||||
meta_log: Dict,
|
||||
chat_history: List[ChatMessageModel],
|
||||
user_message_time: str = None,
|
||||
compiled_references: List[Dict[str, Any]] = [],
|
||||
online_results: Dict[str, Any] = {},
|
||||
@@ -427,11 +436,11 @@ async def save_to_conversation_log(
|
||||
chat_response=chat_response,
|
||||
user_message_metadata=user_message_metadata,
|
||||
khoj_message_metadata=khoj_message_metadata,
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
chat_history=chat_history,
|
||||
)
|
||||
await ConversationAdapters.save_conversation(
|
||||
user,
|
||||
{"chat": updated_conversation},
|
||||
updated_conversation,
|
||||
client_application=client_application,
|
||||
conversation_id=conversation_id,
|
||||
user_message=q,
|
||||
@@ -502,7 +511,7 @@ def gather_raw_query_files(
|
||||
def generate_chatml_messages_with_context(
|
||||
user_message: str,
|
||||
system_message: str = None,
|
||||
conversation_log={},
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
model_name="gpt-4o-mini",
|
||||
loaded_model: Optional[Llama] = None,
|
||||
max_prompt_size=None,
|
||||
@@ -529,21 +538,21 @@ def generate_chatml_messages_with_context(
|
||||
|
||||
# Extract Chat History for Context
|
||||
chatml_messages: List[ChatMessage] = []
|
||||
for chat in conversation_log.get("chat", []):
|
||||
for chat in chat_history:
|
||||
message_context = []
|
||||
message_attached_files = ""
|
||||
|
||||
generated_assets = {}
|
||||
|
||||
chat_message = chat.get("message")
|
||||
role = "user" if chat["by"] == "you" else "assistant"
|
||||
chat_message = chat.message
|
||||
role = "user" if chat.by == "you" else "assistant"
|
||||
|
||||
# Legacy code to handle excalidraw diagrams prior to Dec 2024
|
||||
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
||||
chat_message = chat["intent"].get("inferred-queries")[0]
|
||||
if chat.by == "khoj" and "excalidraw" in chat.intent.type or "":
|
||||
chat_message = (chat.intent.inferred_queries or [])[0]
|
||||
|
||||
if chat.get("queryFiles"):
|
||||
raw_query_files = chat.get("queryFiles")
|
||||
if chat.queryFiles:
|
||||
raw_query_files = chat.queryFiles
|
||||
query_files_dict = dict()
|
||||
for file in raw_query_files:
|
||||
query_files_dict[file["name"]] = file["content"]
|
||||
@@ -551,24 +560,24 @@ def generate_chatml_messages_with_context(
|
||||
message_attached_files = gather_raw_query_files(query_files_dict)
|
||||
chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
|
||||
|
||||
if not is_none_or_empty(chat.get("onlineContext")):
|
||||
if not is_none_or_empty(chat.onlineContext):
|
||||
message_context += [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}",
|
||||
"text": f"{prompts.online_search_conversation.format(online_results=chat.onlineContext)}",
|
||||
}
|
||||
]
|
||||
|
||||
if not is_none_or_empty(chat.get("codeContext")):
|
||||
if not is_none_or_empty(chat.codeContext):
|
||||
message_context += [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{prompts.code_executed_context.format(code_results=chat.get('codeContext'))}",
|
||||
"text": f"{prompts.code_executed_context.format(code_results=chat.codeContext)}",
|
||||
}
|
||||
]
|
||||
|
||||
if not is_none_or_empty(chat.get("operatorContext")):
|
||||
operator_context = chat.get("operatorContext")
|
||||
if not is_none_or_empty(chat.operatorContext):
|
||||
operator_context = chat.operatorContext
|
||||
operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context])
|
||||
message_context += [
|
||||
{
|
||||
@@ -577,13 +586,9 @@ def generate_chatml_messages_with_context(
|
||||
}
|
||||
]
|
||||
|
||||
if not is_none_or_empty(chat.get("context")):
|
||||
if not is_none_or_empty(chat.context):
|
||||
references = "\n\n".join(
|
||||
{
|
||||
f"# File: {item['file']}\n## {item['compiled']}\n"
|
||||
for item in chat.get("context") or []
|
||||
if isinstance(item, dict)
|
||||
}
|
||||
{f"# File: {item.file}\n## {item.compiled}\n" for item in chat.context or [] if isinstance(item, dict)}
|
||||
)
|
||||
message_context += [{"type": "text", "text": f"{prompts.notes_conversation.format(references=references)}"}]
|
||||
|
||||
@@ -591,14 +596,14 @@ def generate_chatml_messages_with_context(
|
||||
reconstructed_context_message = ChatMessage(content=message_context, role="user")
|
||||
chatml_messages.insert(0, reconstructed_context_message)
|
||||
|
||||
if not is_none_or_empty(chat.get("images")) and role == "assistant":
|
||||
if not is_none_or_empty(chat.images) and role == "assistant":
|
||||
generated_assets["image"] = {
|
||||
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
|
||||
"query": (chat.intent.inferred_queries or [user_message])[0],
|
||||
}
|
||||
|
||||
if not is_none_or_empty(chat.get("mermaidjsDiagram")) and role == "assistant":
|
||||
if not is_none_or_empty(chat.mermaidjsDiagram) and role == "assistant":
|
||||
generated_assets["diagram"] = {
|
||||
"query": chat.get("intent", {}).get("inferred-queries", [user_message])[0],
|
||||
"query": (chat.intent.inferred_queries or [user_message])[0],
|
||||
}
|
||||
|
||||
if not is_none_or_empty(generated_assets):
|
||||
@@ -610,7 +615,7 @@ def generate_chatml_messages_with_context(
|
||||
)
|
||||
|
||||
message_content = construct_structured_message(
|
||||
chat_message, chat.get("images") if role == "user" else [], model_type, vision_enabled
|
||||
chat_message, chat.images if role == "user" else [], model_type, vision_enabled
|
||||
)
|
||||
|
||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||
|
||||
@@ -10,7 +10,12 @@ from google import genai
|
||||
from google.genai import types as gtypes
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatMessageModel,
|
||||
KhojUser,
|
||||
TextToImageModelConfig,
|
||||
)
|
||||
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
|
||||
from khoj.routers.storage import upload_generated_image_to_bucket
|
||||
from khoj.utils import state
|
||||
@@ -23,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||
async def text_to_image(
|
||||
message: str,
|
||||
user: KhojUser,
|
||||
conversation_log: dict,
|
||||
chat_history: List[ChatMessageModel],
|
||||
location_data: LocationData,
|
||||
references: List[Dict[str, Any]],
|
||||
online_results: Dict[str, Any],
|
||||
@@ -46,14 +51,14 @@ async def text_to_image(
|
||||
return
|
||||
|
||||
text2image_model = text_to_image_config.model_name
|
||||
chat_history = ""
|
||||
for chat in conversation_log.get("chat", [])[-4:]:
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
|
||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||
chat_history += f"A: {chat['message']}\n"
|
||||
elif chat["by"] == "khoj" and chat.get("images"):
|
||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
|
||||
chat_history_str = ""
|
||||
for chat in chat_history[-4:]:
|
||||
if chat.by == "khoj" and chat.intent and chat.intent.type in ["remember", "reminder"]:
|
||||
chat_history_str += f"Q: {chat.intent.query or ''}\n"
|
||||
chat_history_str += f"A: {chat.message}\n"
|
||||
elif chat.by == "khoj" and chat.images:
|
||||
chat_history_str += f"Q: {chat.intent.query}\n"
|
||||
chat_history_str += f"A: Improved Prompt: {chat.intent.inferred_queries[0]}\n"
|
||||
|
||||
if send_status_func:
|
||||
async for event in send_status_func("**Enhancing the Painting Prompt**"):
|
||||
@@ -63,7 +68,7 @@ async def text_to_image(
|
||||
# Use the user's message, chat history, and other context
|
||||
image_prompt = await generate_better_image_prompt(
|
||||
message,
|
||||
chat_history,
|
||||
chat_history_str,
|
||||
location_data=location_data,
|
||||
note_references=references,
|
||||
online_results=online_results,
|
||||
|
||||
@@ -5,10 +5,9 @@ import os
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
construct_chat_history,
|
||||
construct_chat_history_for_operator,
|
||||
)
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
@@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
async def operate_environment(
|
||||
query: str,
|
||||
user: KhojUser,
|
||||
conversation_log: dict,
|
||||
conversation_log: List[ChatMessageModel],
|
||||
location_data: LocationData,
|
||||
previous_trajectory: Optional[OperatorRun] = None,
|
||||
environment_type: EnvironmentType = EnvironmentType.COMPUTER,
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
from textwrap import dedent
|
||||
from typing import List, Optional
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.database.models import ChatMessageModel, ChatModel
|
||||
from khoj.processor.conversation.utils import (
|
||||
AgentMessage,
|
||||
OperatorRun,
|
||||
@@ -119,13 +119,13 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
query_screenshot = self._get_message_images(current_message)
|
||||
|
||||
# Construct input for visual reasoner history
|
||||
visual_reasoner_history = {"chat": self._format_message_for_api(self.messages)}
|
||||
visual_reasoner_history = self._format_message_for_api(self.messages)
|
||||
try:
|
||||
natural_language_action = await send_message_to_model_wrapper(
|
||||
query=query_text,
|
||||
query_images=query_screenshot,
|
||||
system_message=reasoning_system_prompt,
|
||||
conversation_log=visual_reasoner_history,
|
||||
chat_history=visual_reasoner_history,
|
||||
agent_chat_model=self.reasoning_model,
|
||||
tracer=self.tracer,
|
||||
)
|
||||
@@ -238,11 +238,11 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
|
||||
async def summarize(self, env_state: EnvState, summarize_prompt: str = None) -> str:
|
||||
summarize_prompt = summarize_prompt or self.summarize_prompt
|
||||
conversation_history = {"chat": self._format_message_for_api(self.messages)}
|
||||
conversation_history = self._format_message_for_api(self.messages)
|
||||
try:
|
||||
summary = await send_message_to_model_wrapper(
|
||||
query=summarize_prompt,
|
||||
conversation_log=conversation_history,
|
||||
chat_history=conversation_history,
|
||||
agent_chat_model=self.reasoning_model,
|
||||
tracer=self.tracer,
|
||||
)
|
||||
@@ -296,14 +296,14 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
images = [item["image_url"]["url"] for item in message.content if item["type"] == "image_url"]
|
||||
return images
|
||||
|
||||
def _format_message_for_api(self, messages: list[AgentMessage]) -> List[dict]:
|
||||
def _format_message_for_api(self, messages: list[AgentMessage]) -> List[ChatMessageModel]:
|
||||
"""Format operator agent messages into the Khoj conversation history format."""
|
||||
formatted_messages = [
|
||||
{
|
||||
"message": self._get_message_text(message),
|
||||
"images": self._get_message_images(message),
|
||||
"by": "you" if message.role in ["user", "environment"] else message.role,
|
||||
}
|
||||
ChatMessageModel(
|
||||
message=self._get_message_text(message),
|
||||
images=self._get_message_images(message),
|
||||
by="you" if message.role in ["user", "environment"] else message.role,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
return formatted_messages
|
||||
|
||||
@@ -10,7 +10,13 @@ from bs4 import BeautifulSoup
|
||||
from markdownify import markdownify
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import Agent, KhojUser, ServerChatSettings, WebScraper
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatMessageModel,
|
||||
KhojUser,
|
||||
ServerChatSettings,
|
||||
WebScraper,
|
||||
)
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
@@ -59,7 +65,7 @@ OLOSTEP_QUERY_PARAMS = {
|
||||
|
||||
async def search_online(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
conversation_history: List[ChatMessageModel],
|
||||
location: LocationData,
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
@@ -361,7 +367,7 @@ async def search_with_serper(query: str, location: LocationData) -> Tuple[str, D
|
||||
|
||||
async def read_webpages(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
conversation_history: List[ChatMessageModel],
|
||||
location: LocationData,
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
|
||||
@@ -20,7 +20,7 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from khoj.database.adapters import FileObjectAdapters
|
||||
from khoj.database.models import Agent, FileObject, KhojUser
|
||||
from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
@@ -50,7 +50,7 @@ class GeneratedCode(NamedTuple):
|
||||
|
||||
async def run_code(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
conversation_history: List[ChatMessageModel],
|
||||
context: str,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
@@ -116,7 +116,7 @@ async def run_code(
|
||||
|
||||
async def generate_python_code(
|
||||
q: str,
|
||||
conversation_history: dict,
|
||||
chat_history: List[ChatMessageModel],
|
||||
context: str,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
@@ -127,7 +127,7 @@ async def generate_python_code(
|
||||
) -> GeneratedCode:
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
chat_history_str = construct_chat_history(chat_history)
|
||||
|
||||
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
|
||||
personality_context = (
|
||||
@@ -143,7 +143,7 @@ async def generate_python_code(
|
||||
|
||||
code_generation_prompt = prompts.python_code_generation_prompt.format(
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
chat_history=chat_history_str,
|
||||
context=context,
|
||||
has_network_access=network_access_context,
|
||||
current_date=utc_date,
|
||||
|
||||
@@ -29,7 +29,13 @@ from khoj.database.adapters import (
|
||||
get_default_search_model,
|
||||
get_user_photo,
|
||||
)
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser, SpeechToTextModelOptions
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatMessageModel,
|
||||
ChatModel,
|
||||
KhojUser,
|
||||
SpeechToTextModelOptions,
|
||||
)
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
||||
extract_questions_anthropic,
|
||||
@@ -353,7 +359,7 @@ def set_user_name(
|
||||
|
||||
async def extract_references_and_questions(
|
||||
user: KhojUser,
|
||||
meta_log: dict,
|
||||
chat_history: list[ChatMessageModel],
|
||||
q: str,
|
||||
n: int,
|
||||
d: float,
|
||||
@@ -432,7 +438,7 @@ async def extract_references_and_questions(
|
||||
defiltered_query,
|
||||
model=chat_model,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
chat_history=chat_history,
|
||||
should_extract_questions=True,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
@@ -450,7 +456,7 @@ async def extract_references_and_questions(
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=base_url,
|
||||
conversation_log=meta_log,
|
||||
chat_history=chat_history,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
@@ -469,7 +475,7 @@ async def extract_references_and_questions(
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
conversation_log=meta_log,
|
||||
chat_history=chat_history,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
@@ -487,7 +493,7 @@ async def extract_references_and_questions(
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
conversation_log=meta_log,
|
||||
chat_history=chat_history,
|
||||
location_data=location_data,
|
||||
max_tokens=chat_model.max_prompt_size,
|
||||
user=user,
|
||||
@@ -606,7 +612,7 @@ def post_automation(
|
||||
return Response(content="Invalid crontime", status_code=400)
|
||||
|
||||
# Infer subject, query to run
|
||||
_, query_to_run, generated_subject = schedule_query(q, conversation_history={}, user=user)
|
||||
_, query_to_run, generated_subject = schedule_query(q, chat_history=[], user=user)
|
||||
subject = subject or generated_subject
|
||||
|
||||
# Normalize query parameters
|
||||
@@ -712,7 +718,7 @@ def edit_job(
|
||||
return Response(content="Invalid automation", status_code=403)
|
||||
|
||||
# Infer subject, query to run
|
||||
_, query_to_run, _ = schedule_query(q, conversation_history={}, user=user)
|
||||
_, query_to_run, _ = schedule_query(q, chat_history=[], user=user)
|
||||
subject = subject
|
||||
|
||||
# Normalize query parameters
|
||||
|
||||
@@ -752,7 +752,7 @@ async def chat(
|
||||
q,
|
||||
chat_response="",
|
||||
user=user,
|
||||
meta_log=meta_log,
|
||||
chat_history=chat_history,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
@@ -918,7 +918,7 @@ async def chat(
|
||||
if city or region or country or country_code:
|
||||
location = LocationData(city=city, region=region, country=country, country_code=country_code)
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
meta_log = conversation.conversation_log
|
||||
chat_history = conversation.messages
|
||||
|
||||
# If interrupt flag is set, wait for the previous turn to be saved before proceeding
|
||||
if interrupt_flag:
|
||||
@@ -964,14 +964,14 @@ async def chat(
|
||||
operator_results = [OperatorRun(**iter_dict) for iter_dict in last_message.operatorContext or []]
|
||||
train_of_thought = [thought.model_dump() for thought in last_message.trainOfThought or []]
|
||||
# Drop the interrupted message from conversation history
|
||||
meta_log["chat"].pop()
|
||||
chat_history.pop()
|
||||
logger.info(f"Loaded interrupted partial context from conversation {conversation_id}.")
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default]:
|
||||
try:
|
||||
chosen_io = await aget_data_sources_and_output_format(
|
||||
q,
|
||||
meta_log,
|
||||
chat_history,
|
||||
is_automated_task,
|
||||
user=user,
|
||||
query_images=uploaded_images,
|
||||
@@ -1011,7 +1011,7 @@ async def chat(
|
||||
user=user,
|
||||
query=defiltered_query,
|
||||
conversation_id=conversation_id,
|
||||
conversation_history=meta_log,
|
||||
conversation_history=conversation.messages,
|
||||
previous_iterations=list(research_results),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
@@ -1078,7 +1078,7 @@ async def chat(
|
||||
q=q,
|
||||
user=user,
|
||||
file_filters=file_filters,
|
||||
meta_log=meta_log,
|
||||
chat_history=conversation.messages,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
@@ -1123,7 +1123,7 @@ async def chat(
|
||||
if ConversationCommand.Automation in conversation_commands:
|
||||
try:
|
||||
automation, crontime, query_to_run, subject = await create_automation(
|
||||
q, timezone, user, request.url, meta_log, tracer=tracer
|
||||
q, timezone, user, request.url, chat_history, tracer=tracer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||
@@ -1139,7 +1139,7 @@ async def chat(
|
||||
q,
|
||||
llm_response,
|
||||
user,
|
||||
meta_log,
|
||||
chat_history,
|
||||
user_message_time,
|
||||
intent_type="automation",
|
||||
client_application=request.user.client_app,
|
||||
@@ -1163,7 +1163,7 @@ async def chat(
|
||||
try:
|
||||
async for result in extract_references_and_questions(
|
||||
user,
|
||||
meta_log,
|
||||
chat_history,
|
||||
q,
|
||||
(n or 7),
|
||||
d,
|
||||
@@ -1212,7 +1212,7 @@ async def chat(
|
||||
try:
|
||||
async for result in search_online(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
chat_history,
|
||||
location,
|
||||
user,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
@@ -1240,7 +1240,7 @@ async def chat(
|
||||
try:
|
||||
async for result in read_webpages(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
chat_history,
|
||||
location,
|
||||
user,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
@@ -1281,7 +1281,7 @@ async def chat(
|
||||
context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
|
||||
async for result in run_code(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
chat_history,
|
||||
context,
|
||||
location,
|
||||
user,
|
||||
@@ -1306,7 +1306,7 @@ async def chat(
|
||||
async for result in operate_environment(
|
||||
defiltered_query,
|
||||
user,
|
||||
meta_log,
|
||||
chat_history,
|
||||
location,
|
||||
list(operator_results)[-1] if operator_results else None,
|
||||
query_images=uploaded_images,
|
||||
@@ -1356,7 +1356,7 @@ async def chat(
|
||||
async for result in text_to_image(
|
||||
defiltered_query,
|
||||
user,
|
||||
meta_log,
|
||||
chat_history,
|
||||
location_data=location,
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
@@ -1400,7 +1400,7 @@ async def chat(
|
||||
|
||||
async for result in generate_mermaidjs_diagram(
|
||||
q=defiltered_query,
|
||||
conversation_history=meta_log,
|
||||
chat_history=chat_history,
|
||||
location_data=location,
|
||||
note_references=compiled_references,
|
||||
online_results=online_results,
|
||||
@@ -1456,7 +1456,7 @@ async def chat(
|
||||
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
chat_history,
|
||||
conversation,
|
||||
compiled_references,
|
||||
online_results,
|
||||
|
||||
@@ -55,6 +55,7 @@ from khoj.database.adapters import (
|
||||
)
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatMessageModel,
|
||||
ChatModel,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
@@ -285,7 +286,7 @@ async def acreate_title_from_history(
|
||||
"""
|
||||
Create a title from the given conversation history
|
||||
"""
|
||||
chat_history = construct_chat_history(conversation.conversation_log)
|
||||
chat_history = construct_chat_history(conversation.messages)
|
||||
|
||||
title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history)
|
||||
|
||||
@@ -345,7 +346,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
|
||||
|
||||
async def aget_data_sources_and_output_format(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
chat_history: list[ChatMessageModel],
|
||||
is_task: bool,
|
||||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
@@ -386,7 +387,7 @@ async def aget_data_sources_and_output_format(
|
||||
if len(agent_outputs) == 0 or output.value in agent_outputs:
|
||||
output_options_str += f'- "{output.value}": "{description}"\n'
|
||||
|
||||
chat_history = construct_chat_history(conversation_history, n=6)
|
||||
chat_history_str = construct_chat_history(chat_history, n=6)
|
||||
|
||||
if query_images:
|
||||
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
||||
@@ -399,7 +400,7 @@ async def aget_data_sources_and_output_format(
|
||||
query=query,
|
||||
sources=source_options_str,
|
||||
outputs=output_options_str,
|
||||
chat_history=chat_history,
|
||||
chat_history=chat_history_str,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
@@ -462,7 +463,7 @@ async def aget_data_sources_and_output_format(
|
||||
async def infer_webpage_urls(
|
||||
q: str,
|
||||
max_webpages: int,
|
||||
conversation_history: dict,
|
||||
chat_history: List[ChatMessageModel],
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
@@ -475,7 +476,7 @@ async def infer_webpage_urls(
|
||||
"""
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
chat_history_str = construct_chat_history(chat_history)
|
||||
|
||||
utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
personality_context = (
|
||||
@@ -485,7 +486,7 @@ async def infer_webpage_urls(
|
||||
online_queries_prompt = prompts.infer_webpages_to_read.format(
|
||||
query=q,
|
||||
max_webpages=max_webpages,
|
||||
chat_history=chat_history,
|
||||
chat_history=chat_history_str,
|
||||
current_date=utc_date,
|
||||
location=location,
|
||||
username=username,
|
||||
@@ -526,7 +527,7 @@ async def infer_webpage_urls(
|
||||
|
||||
async def generate_online_subqueries(
|
||||
q: str,
|
||||
conversation_history: dict,
|
||||
chat_history: List[ChatMessageModel],
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
@@ -540,7 +541,7 @@ async def generate_online_subqueries(
|
||||
"""
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
chat_history_str = construct_chat_history(chat_history)
|
||||
|
||||
utc_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
personality_context = (
|
||||
@@ -549,7 +550,7 @@ async def generate_online_subqueries(
|
||||
|
||||
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
chat_history=chat_history_str,
|
||||
max_queries=max_queries,
|
||||
current_date=utc_date,
|
||||
location=location,
|
||||
@@ -591,16 +592,16 @@ async def generate_online_subqueries(
|
||||
|
||||
|
||||
def schedule_query(
|
||||
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
|
||||
q: str, chat_history: List[ChatMessageModel], user: KhojUser, query_images: List[str] = None, tracer: dict = {}
|
||||
) -> Tuple[str, str, str]:
|
||||
"""
|
||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||
"""
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
chat_history_str = construct_chat_history(chat_history)
|
||||
|
||||
crontime_prompt = prompts.crontime_prompt.format(
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
chat_history=chat_history_str,
|
||||
)
|
||||
|
||||
raw_response = send_message_to_model_wrapper_sync(
|
||||
@@ -619,16 +620,16 @@ def schedule_query(
|
||||
|
||||
|
||||
async def aschedule_query(
|
||||
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
|
||||
q: str, chat_history: List[ChatMessageModel], user: KhojUser, query_images: List[str] = None, tracer: dict = {}
|
||||
) -> Tuple[str, str, str]:
|
||||
"""
|
||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||
"""
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
chat_history_str = construct_chat_history(chat_history)
|
||||
|
||||
crontime_prompt = prompts.crontime_prompt.format(
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
chat_history=chat_history_str,
|
||||
)
|
||||
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
@@ -681,7 +682,7 @@ async def extract_relevant_info(
|
||||
async def extract_relevant_summary(
|
||||
q: str,
|
||||
corpus: str,
|
||||
conversation_history: dict,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
@@ -698,11 +699,11 @@ async def extract_relevant_summary(
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
chat_history_str = construct_chat_history(chat_history)
|
||||
|
||||
extract_relevant_information = prompts.extract_relevant_summary.format(
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
chat_history=chat_history_str,
|
||||
corpus=corpus.strip(),
|
||||
personality_context=personality_context,
|
||||
)
|
||||
@@ -725,7 +726,7 @@ async def generate_summary_from_files(
|
||||
q: str,
|
||||
user: KhojUser,
|
||||
file_filters: List[str],
|
||||
meta_log: dict,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
@@ -766,7 +767,7 @@ async def generate_summary_from_files(
|
||||
response = await extract_relevant_summary(
|
||||
q,
|
||||
contextual_data,
|
||||
conversation_history=meta_log,
|
||||
chat_history=chat_history,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
@@ -782,7 +783,7 @@ async def generate_summary_from_files(
|
||||
|
||||
async def generate_excalidraw_diagram(
|
||||
q: str,
|
||||
conversation_history: Dict[str, Any],
|
||||
chat_history: List[ChatMessageModel],
|
||||
location_data: LocationData,
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
@@ -799,7 +800,7 @@ async def generate_excalidraw_diagram(
|
||||
|
||||
better_diagram_description_prompt = await generate_better_diagram_description(
|
||||
q=q,
|
||||
conversation_history=conversation_history,
|
||||
chat_history=chat_history,
|
||||
location_data=location_data,
|
||||
note_references=note_references,
|
||||
online_results=online_results,
|
||||
@@ -834,7 +835,7 @@ async def generate_excalidraw_diagram(
|
||||
|
||||
async def generate_better_diagram_description(
|
||||
q: str,
|
||||
conversation_history: Dict[str, Any],
|
||||
chat_history: List[ChatMessageModel],
|
||||
location_data: LocationData,
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
@@ -857,7 +858,7 @@ async def generate_better_diagram_description(
|
||||
|
||||
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
chat_history_str = construct_chat_history(chat_history)
|
||||
|
||||
simplified_online_results = {}
|
||||
|
||||
@@ -870,7 +871,7 @@ async def generate_better_diagram_description(
|
||||
|
||||
improve_diagram_description_prompt = prompts.improve_excalidraw_diagram_description_prompt.format(
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
chat_history=chat_history_str,
|
||||
location=location,
|
||||
current_date=today_date,
|
||||
references=user_references,
|
||||
@@ -939,7 +940,7 @@ async def generate_excalidraw_diagram_from_description(
|
||||
|
||||
async def generate_mermaidjs_diagram(
|
||||
q: str,
|
||||
conversation_history: Dict[str, Any],
|
||||
chat_history: List[ChatMessageModel],
|
||||
location_data: LocationData,
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
@@ -956,7 +957,7 @@ async def generate_mermaidjs_diagram(
|
||||
|
||||
better_diagram_description_prompt = await generate_better_mermaidjs_diagram_description(
|
||||
q=q,
|
||||
conversation_history=conversation_history,
|
||||
chat_history=chat_history,
|
||||
location_data=location_data,
|
||||
note_references=note_references,
|
||||
online_results=online_results,
|
||||
@@ -985,7 +986,7 @@ async def generate_mermaidjs_diagram(
|
||||
|
||||
async def generate_better_mermaidjs_diagram_description(
|
||||
q: str,
|
||||
conversation_history: Dict[str, Any],
|
||||
chat_history: List[ChatMessageModel],
|
||||
location_data: LocationData,
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
@@ -1008,7 +1009,7 @@ async def generate_better_mermaidjs_diagram_description(
|
||||
|
||||
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
chat_history_str = construct_chat_history(chat_history)
|
||||
|
||||
simplified_online_results = {}
|
||||
|
||||
@@ -1021,7 +1022,7 @@ async def generate_better_mermaidjs_diagram_description(
|
||||
|
||||
improve_diagram_description_prompt = prompts.improve_mermaid_js_diagram_description_prompt.format(
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
chat_history=chat_history_str,
|
||||
location=location,
|
||||
current_date=today_date,
|
||||
references=user_references,
|
||||
@@ -1160,7 +1161,7 @@ async def send_message_to_model_wrapper(
|
||||
query_images: List[str] = None,
|
||||
context: str = "",
|
||||
query_files: str = None,
|
||||
conversation_log: dict = {},
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
agent_chat_model: ChatModel = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
@@ -1193,7 +1194,7 @@ async def send_message_to_model_wrapper(
|
||||
user_message=query,
|
||||
context_message=context,
|
||||
system_message=system_message,
|
||||
conversation_log=conversation_log,
|
||||
chat_history=chat_history,
|
||||
model_name=chat_model_name,
|
||||
loaded_model=loaded_model,
|
||||
tokenizer_name=tokenizer,
|
||||
@@ -1260,7 +1261,7 @@ def send_message_to_model_wrapper_sync(
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
query_files: str = "",
|
||||
conversation_log: dict = {},
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
tracer: dict = {},
|
||||
):
|
||||
chat_model: ChatModel = ConversationAdapters.get_default_chat_model(user)
|
||||
@@ -1284,7 +1285,7 @@ def send_message_to_model_wrapper_sync(
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
conversation_log=conversation_log,
|
||||
chat_history=chat_history,
|
||||
model_name=chat_model_name,
|
||||
loaded_model=loaded_model,
|
||||
max_prompt_size=max_tokens,
|
||||
@@ -1342,7 +1343,7 @@ def send_message_to_model_wrapper_sync(
|
||||
|
||||
async def agenerate_chat_response(
|
||||
q: str,
|
||||
meta_log: dict,
|
||||
chat_history: List[ChatMessageModel],
|
||||
conversation: Conversation,
|
||||
compiled_references: List[Dict] = [],
|
||||
online_results: Dict[str, Dict] = {},
|
||||
@@ -1379,7 +1380,7 @@ async def agenerate_chat_response(
|
||||
save_to_conversation_log,
|
||||
q,
|
||||
user=user,
|
||||
meta_log=meta_log,
|
||||
chat_history=chat_history,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
@@ -1424,7 +1425,7 @@ async def agenerate_chat_response(
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
chat_history=chat_history,
|
||||
completion_func=partial_completion,
|
||||
conversation_commands=conversation_commands,
|
||||
model_name=chat_model.name,
|
||||
@@ -1450,7 +1451,7 @@ async def agenerate_chat_response(
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
operator_results=operator_results,
|
||||
conversation_log=meta_log,
|
||||
chat_history=chat_history,
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=openai_chat_config.api_base_url,
|
||||
@@ -1480,7 +1481,7 @@ async def agenerate_chat_response(
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
operator_results=operator_results,
|
||||
conversation_log=meta_log,
|
||||
chat_history=chat_history,
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
@@ -1508,7 +1509,7 @@ async def agenerate_chat_response(
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
operator_results=operator_results,
|
||||
conversation_log=meta_log,
|
||||
chat_history=chat_history,
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
@@ -2005,11 +2006,11 @@ async def create_automation(
|
||||
timezone: str,
|
||||
user: KhojUser,
|
||||
calling_url: URL,
|
||||
meta_log: dict = {},
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
conversation_id: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
crontime, query_to_run, subject = await aschedule_query(q, meta_log, user, tracer=tracer)
|
||||
crontime, query_to_run, subject = await aschedule_query(q, chat_history, user, tracer=tracer)
|
||||
job = await aschedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
|
||||
return job, crontime, query_to_run, subject
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from khoj.database.adapters import AgentAdapters, EntryAdapters
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.database.models import Agent, ChatMessageModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
@@ -84,7 +84,7 @@ class PlanningResponse(BaseModel):
|
||||
|
||||
async def apick_next_tool(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
conversation_history: List[ChatMessageModel],
|
||||
user: KhojUser = None,
|
||||
location: LocationData = None,
|
||||
user_name: str = None,
|
||||
@@ -166,18 +166,18 @@ async def apick_next_tool(
|
||||
query = f"[placeholder for user attached images]\n{query}"
|
||||
|
||||
# Construct chat history with user and iteration history with researcher agent for context
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query)
|
||||
iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history}
|
||||
iteration_chat_history = construct_iteration_history(previous_iterations, prompts.previous_iteration, query)
|
||||
chat_and_research_history = conversation_history + iteration_chat_history
|
||||
|
||||
# Plan function execution for the next tool
|
||||
query = prompts.plan_function_execution_next_tool.format(query=query) if previous_iterations_history else query
|
||||
query = prompts.plan_function_execution_next_tool.format(query=query) if iteration_chat_history else query
|
||||
|
||||
try:
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
query=query,
|
||||
system_message=function_planning_prompt,
|
||||
conversation_log=iteration_chat_log,
|
||||
chat_history=chat_and_research_history,
|
||||
response_type="json_object",
|
||||
response_schema=planning_response_model,
|
||||
deepthought=True,
|
||||
@@ -238,7 +238,7 @@ async def research(
|
||||
user: KhojUser,
|
||||
query: str,
|
||||
conversation_id: str,
|
||||
conversation_history: dict,
|
||||
conversation_history: List[ChatMessageModel],
|
||||
previous_iterations: List[ResearchIteration],
|
||||
query_images: List[str],
|
||||
agent: Agent = None,
|
||||
@@ -261,9 +261,7 @@ async def research(
|
||||
if current_iteration := len(previous_iterations) > 0:
|
||||
logger.info(f"Continuing research with the previous {len(previous_iterations)} iteration results.")
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
||||
research_conversation_history["chat"] = (
|
||||
research_conversation_history.get("chat", []) + previous_iterations_history
|
||||
)
|
||||
research_conversation_history += previous_iterations_history
|
||||
|
||||
while current_iteration < MAX_ITERATIONS:
|
||||
# Check for cancellation at the start of each iteration
|
||||
|
||||
@@ -6,6 +6,7 @@ from django.utils.timezone import make_aware
|
||||
|
||||
from khoj.database.models import (
|
||||
AiModelApi,
|
||||
ChatMessageModel,
|
||||
ChatModel,
|
||||
Conversation,
|
||||
KhojApiUser,
|
||||
@@ -46,15 +47,15 @@ def get_chat_api_key(provider: ChatModel.ModelType = None):
|
||||
|
||||
def generate_chat_history(message_list):
|
||||
# Generate conversation logs
|
||||
conversation_log = {"chat": []}
|
||||
chat_history: list[ChatMessageModel] = []
|
||||
for user_message, chat_response, context in message_list:
|
||||
message_to_log(
|
||||
user_message,
|
||||
chat_response,
|
||||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
conversation_log=conversation_log.get("chat", []),
|
||||
chat_history=chat_history,
|
||||
)
|
||||
return conversation_log
|
||||
return chat_history
|
||||
|
||||
|
||||
class UserFactory(factory.django.DjangoModelFactory):
|
||||
|
||||
@@ -135,7 +135,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_offline(
|
||||
query,
|
||||
conversation_log=generate_chat_history(message_list),
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
use_history=True,
|
||||
)
|
||||
@@ -181,7 +181,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_offline(
|
||||
"Is she a Doctor?",
|
||||
conversation_log=generate_chat_history(message_list),
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
use_history=True,
|
||||
)
|
||||
@@ -210,7 +210,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
|
||||
# Act
|
||||
response = extract_questions_offline(
|
||||
"What was the Pizza place we ate at over there?",
|
||||
conversation_log=generate_chat_history(message_list),
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
|
||||
@@ -336,7 +336,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model)
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
conversation_log=generate_chat_history(message_list),
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
@@ -363,7 +363,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
|
||||
{"compiled": "Testatron was born on 1st April 1984 in Testville."}
|
||||
], # Assume context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
conversation_log=generate_chat_history(message_list),
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
@@ -388,7 +388,7 @@ def test_refuse_answering_unanswerable_question(loaded_model):
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
conversation_log=generate_chat_history(message_list),
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
@@ -501,7 +501,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Write a haiku about unit testing in 3 lines",
|
||||
conversation_log=generate_chat_history(message_list),
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
@@ -28,7 +28,7 @@ def generate_history(message_list):
|
||||
user_message,
|
||||
gpt_message,
|
||||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
conversation_log=conversation_log.get("chat", []),
|
||||
chat_history=conversation_log.get("chat", []),
|
||||
)
|
||||
return conversation_log
|
||||
|
||||
|
||||
@@ -708,6 +708,6 @@ def populate_chat_history(message_list):
|
||||
"context": context,
|
||||
"intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'},
|
||||
},
|
||||
conversation_log=[],
|
||||
chat_history=[],
|
||||
)
|
||||
return conversation_log
|
||||
|
||||
Reference in New Issue
Block a user