Commit conversation traces using user, chat, message branch hierarchy

- Message train of thought forks and merges from its conversation branch
- Conversation branches from user branch
- User branches from root commit on the main branch

- Weave chat tracer metadata from api endpoint through all chat actors
  and commit it to the prompt trace
This commit is contained in:
Debanjum Singh Solanky
2024-10-23 20:02:28 -07:00
parent a3022b7556
commit ea0712424b
6 changed files with 114 additions and 21 deletions

View File

@@ -23,7 +23,7 @@ from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, merge_dicts
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
logger = logging.getLogger(__name__)
model_to_prompt_size = {
@@ -119,6 +119,7 @@ def save_to_conversation_log(
conversation_id: str = None,
automation_id: str = None,
query_images: List[str] = None,
tracer: Dict[str, Any] = {},
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log(
@@ -144,6 +145,9 @@ def save_to_conversation_log(
user_message=q,
)
if in_debug_mode() or state.verbose > 1:
merge_message_into_conversation_trace(q, chat_response, tracer)
logger.info(
f"""
Saved Conversation Turn

View File

@@ -28,6 +28,7 @@ async def text_to_image(
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
agent: Agent = None,
tracer: dict = {},
):
status_code = 200
image = None
@@ -68,6 +69,7 @@ async def text_to_image(
query_images=query_images,
user=user,
agent=agent,
tracer=tracer,
)
if send_status_func:

View File

@@ -64,6 +64,7 @@ async def search_online(
custom_filters: List[str] = [],
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
query += " ".join(custom_filters)
if not is_internet_connected():
@@ -73,7 +74,7 @@ async def search_online(
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(
query, conversation_history, location, user, query_images=query_images, agent=agent
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
)
response_dict = {}
@@ -111,7 +112,7 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
for link, data in webpages.items()
]
results = await asyncio.gather(*tasks)
@@ -153,6 +154,7 @@ async def read_webpages(
send_status_func: Optional[Callable] = None,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
"Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read")
@@ -166,7 +168,7 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict)
@@ -192,7 +194,12 @@ async def read_webpage(
async def read_webpage_and_extract_content(
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
subqueries: set[str],
url: str,
content: str = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Tuple[set[str], str, Union[None, str]]:
# Select the web scrapers to use for reading the web page
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
@@ -214,7 +221,9 @@ async def read_webpage_and_extract_content(
# Extract relevant information from the web page
if is_none_or_empty(extracted_info):
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
extracted_info = await extract_relevant_info(
subqueries, content, user=user, agent=agent, tracer=tracer
)
# If we successfully extracted information, break the loop
if not is_none_or_empty(extracted_info):

View File

@@ -350,6 +350,7 @@ async def extract_references_and_questions(
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
agent: Agent = None,
tracer: dict = {},
):
user = request.user.object if request.user.is_authenticated else None
@@ -425,6 +426,7 @@ async def extract_references_and_questions(
user=user,
max_prompt_size=conversation_config.max_prompt_size,
personality_context=personality_context,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
@@ -442,6 +444,7 @@ async def extract_references_and_questions(
query_images=query_images,
vision_enabled=vision_enabled,
personality_context=personality_context,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
@@ -456,6 +459,7 @@ async def extract_references_and_questions(
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
@@ -471,6 +475,7 @@ async def extract_references_and_questions(
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
tracer=tracer,
)
# Collate search results as context for GPT

View File

@@ -3,6 +3,7 @@ import base64
import json
import logging
import time
import uuid
from datetime import datetime
from functools import partial
from typing import Dict, Optional
@@ -563,6 +564,12 @@ async def chat(
event_delimiter = "␃🔚␗"
q = unquote(q)
nonlocal conversation_id
tracer: dict = {
"mid": f"{uuid.uuid4()}",
"cid": conversation_id,
"uid": user.id,
"khoj_version": state.khoj_version,
}
uploaded_images: list[str] = []
if images:
@@ -682,6 +689,7 @@ async def chat(
user=user,
query_images=uploaded_images,
agent=agent,
tracer=tracer,
)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
@@ -689,7 +697,9 @@ async def chat(
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
mode = await aget_relevant_output_modes(
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
@@ -755,6 +765,7 @@ async def chat(
query_images=uploaded_images,
user=user,
agent=agent,
tracer=tracer,
)
response_log = str(response)
async for result in send_llm_response(response_log):
@@ -774,6 +785,7 @@ async def chat(
client_application=request.user.client_app,
conversation_id=conversation_id,
query_images=uploaded_images,
tracer=tracer,
)
return
@@ -795,7 +807,7 @@ async def chat(
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log
q, timezone, user, request.url, meta_log, tracer=tracer
)
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
@@ -817,6 +829,7 @@ async def chat(
inferred_queries=[query_to_run],
automation_id=automation.id,
query_images=uploaded_images,
tracer=tracer,
)
async for result in send_llm_response(llm_response):
yield result
@@ -838,6 +851,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@@ -882,6 +896,7 @@ async def chat(
custom_filters,
query_images=uploaded_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@@ -906,6 +921,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@@ -956,6 +972,7 @@ async def chat(
send_status_func=partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@@ -986,6 +1003,7 @@ async def chat(
compiled_references=compiled_references,
online_results=online_results,
query_images=uploaded_images,
tracer=tracer,
)
content_obj = {
"intentType": intent_type,
@@ -1014,6 +1032,7 @@ async def chat(
user=user,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@@ -1041,6 +1060,7 @@ async def chat(
compiled_references=compiled_references,
online_results=online_results,
query_images=uploaded_images,
tracer=tracer,
)
async for result in send_llm_response(json.dumps(content_obj)):
@@ -1064,6 +1084,7 @@ async def chat(
location,
user_name,
uploaded_images,
tracer,
)
# Send Response

View File

@@ -301,6 +301,7 @@ async def aget_relevant_information_sources(
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@@ -337,6 +338,7 @@ async def aget_relevant_information_sources(
relevant_tools_prompt,
response_type="json_object",
user=user,
tracer=tracer,
)
try:
@@ -378,6 +380,7 @@ async def aget_relevant_output_modes(
user: KhojUser = None,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@@ -413,7 +416,9 @@ async def aget_relevant_output_modes(
)
with timer("Chat actor: Infer output mode for chat response", logger):
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user)
response = await send_message_to_model_wrapper(
relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer
)
try:
response = response.strip()
@@ -444,6 +449,7 @@ async def infer_webpage_urls(
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
) -> List[str]:
"""
Infer webpage links from the given query
@@ -468,7 +474,11 @@ async def infer_webpage_urls(
with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
online_queries_prompt,
query_images=query_images,
response_type="json_object",
user=user,
tracer=tracer,
)
# Validate that the response is a non-empty, JSON-serializable list of URLs
@@ -490,6 +500,7 @@ async def generate_online_subqueries(
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
) -> List[str]:
"""
Generate subqueries from the given query
@@ -514,7 +525,11 @@ async def generate_online_subqueries(
with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
online_queries_prompt,
query_images=query_images,
response_type="json_object",
user=user,
tracer=tracer,
)
# Validate that the response is a non-empty, JSON-serializable list
@@ -533,7 +548,7 @@ async def generate_online_subqueries(
async def schedule_query(
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
) -> Tuple[str, ...]:
"""
Schedule the date, time to run the query. Assume the server timezone is UTC.
@@ -546,7 +561,7 @@ async def schedule_query(
)
raw_response = await send_message_to_model_wrapper(
crontime_prompt, query_images=query_images, response_type="json_object", user=user
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
)
# Validate that the response is a non-empty, JSON-serializable list
@@ -561,7 +576,7 @@ async def schedule_query(
async def extract_relevant_info(
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
@@ -584,6 +599,7 @@ async def extract_relevant_info(
extract_relevant_information,
prompts.system_prompt_extract_relevant_information,
user=user,
tracer=tracer,
)
return response.strip()
@@ -595,6 +611,7 @@ async def extract_relevant_summary(
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
@@ -622,6 +639,7 @@ async def extract_relevant_summary(
prompts.system_prompt_extract_relevant_summary,
user=user,
query_images=query_images,
tracer=tracer,
)
return response.strip()
@@ -636,6 +654,7 @@ async def generate_excalidraw_diagram(
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
tracer: dict = {},
):
if send_status_func:
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
@@ -650,6 +669,7 @@ async def generate_excalidraw_diagram(
query_images=query_images,
user=user,
agent=agent,
tracer=tracer,
)
if send_status_func:
@@ -660,6 +680,7 @@ async def generate_excalidraw_diagram(
q=better_diagram_description_prompt,
user=user,
agent=agent,
tracer=tracer,
)
yield better_diagram_description_prompt, excalidraw_diagram_description
@@ -674,6 +695,7 @@ async def generate_better_diagram_description(
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> str:
"""
Generate a diagram description from the given query and context
@@ -711,7 +733,7 @@ async def generate_better_diagram_description(
with timer("Chat actor: Generate better diagram description", logger):
response = await send_message_to_model_wrapper(
improve_diagram_description_prompt, query_images=query_images, user=user
improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
@@ -724,6 +746,7 @@ async def generate_excalidraw_diagram_from_description(
q: str,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> str:
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
@@ -735,7 +758,9 @@ async def generate_excalidraw_diagram_from_description(
)
with timer("Chat actor: Generate excalidraw diagram", logger):
raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user)
raw_response = await send_message_to_model_wrapper(
message=excalidraw_diagram_generation, user=user, tracer=tracer
)
raw_response = raw_response.strip()
raw_response = remove_json_codeblock(raw_response)
response: Dict[str, str] = json.loads(raw_response)
@@ -756,6 +781,7 @@ async def generate_better_image_prompt(
query_images: Optional[List[str]] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> str:
"""
Generate a better image prompt from the given query
@@ -802,7 +828,9 @@ async def generate_better_image_prompt(
)
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
response = await send_message_to_model_wrapper(
image_prompt, query_images=query_images, user=user, tracer=tracer
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
@@ -816,6 +844,7 @@ async def send_message_to_model_wrapper(
response_type: str = "text",
user: KhojUser = None,
query_images: List[str] = None,
tracer: dict = {},
):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled
@@ -862,6 +891,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.OPENAI:
@@ -885,6 +915,7 @@ async def send_message_to_model_wrapper(
model=chat_model,
response_type=response_type,
api_base_url=api_base_url,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
@@ -903,6 +934,7 @@ async def send_message_to_model_wrapper(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
@@ -918,7 +950,7 @@ async def send_message_to_model_wrapper(
)
return gemini_send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@@ -929,6 +961,7 @@ def send_message_to_model_wrapper_sync(
system_message: str = "",
response_type: str = "text",
user: KhojUser = None,
tracer: dict = {},
):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
@@ -961,6 +994,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@@ -975,7 +1009,11 @@ def send_message_to_model_wrapper_sync(
)
openai_response = send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)
return openai_response
@@ -995,6 +1033,7 @@ def send_message_to_model_wrapper_sync(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
@@ -1013,6 +1052,7 @@ def send_message_to_model_wrapper_sync(
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@@ -1032,6 +1072,7 @@ def generate_chat_response(
location_data: LocationData = None,
user_name: Optional[str] = None,
query_images: Optional[List[str]] = None,
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
chat_response = None
@@ -1051,6 +1092,7 @@ def generate_chat_response(
client_application=client_application,
conversation_id=conversation_id,
query_images=query_images,
tracer=tracer,
)
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
@@ -1077,6 +1119,7 @@ def generate_chat_response(
location_data=location_data,
user_name=user_name,
agent=agent,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@@ -1100,6 +1143,7 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
vision_available=vision_available,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
@@ -1120,6 +1164,7 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
vision_available=vision_available,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
@@ -1139,6 +1184,7 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
vision_available=vision_available,
tracer=tracer,
)
metadata.update({"chat_model": conversation_config.chat_model})
@@ -1495,9 +1541,15 @@ def scheduled_chat(
async def create_automation(
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
q: str,
timezone: str,
user: KhojUser,
calling_url: URL,
meta_log: dict = {},
conversation_id: str = None,
tracer: dict = {},
):
crontime, query_to_run, subject = await schedule_query(q, meta_log, user)
crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer)
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
return job, crontime, query_to_run, subject