diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index a6c4cd57..22fb4f03 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import time from datetime import datetime from functools import partial from typing import Any, Dict, List, Optional @@ -22,11 +23,7 @@ from khoj.database.adapters import ( aget_user_name, ) from khoj.database.models import KhojUser -from khoj.processor.conversation.prompts import ( - help_message, - no_entries_found, - no_notes_found, -) +from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.utils import save_to_conversation_log from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.tools.online_search import read_webpages, search_online @@ -34,7 +31,6 @@ from khoj.routers.api import extract_references_and_questions from khoj.routers.helpers import ( ApiUserRateLimiter, CommonQueryParams, - CommonQueryParamsClass, ConversationCommandRateLimiter, agenerate_chat_response, aget_relevant_information_sources, @@ -547,22 +543,27 @@ async def chat( ), ): async def event_generator(q: str): + start_time = time.perf_counter() + ttft = None + chat_metadata: dict = {} connection_alive = True user: KhojUser = request.user.object q = unquote(q) async def send_event(event_type: str, data: str | dict): - nonlocal connection_alive + nonlocal connection_alive, ttft if not connection_alive or await request.is_disconnected(): connection_alive = False logger.warn(f"User {user} disconnected from {common.client} client") return try: + if event_type == "end_llm_response": + collect_telemetry() + if event_type == "start_llm_response": + ttft = time.perf_counter() - start_time if event_type == "message": yield data - elif event_type == "references": - yield json.dumps({"type": event_type, "data": data}) - elif stream: + elif event_type == "references" or stream: yield json.dumps({"type": event_type, "data": data}) except asyncio.CancelledError: connection_alive = False @@ -581,12 +582,36 @@ async def chat( async for result in send_event("end_llm_response", ""): yield result + def collect_telemetry(): + # Gather chat response telemetry + nonlocal chat_metadata + latency = time.perf_counter() - start_time + cmd_set = set([cmd.value for cmd in conversation_commands]) + chat_metadata = chat_metadata or {} + chat_metadata["conversation_command"] = cmd_set + chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None + chat_metadata["latency"] = f"{latency:.3f}" + chat_metadata["ttft_latency"] = f"{ttft:.3f}" + + logger.info(f"Chat response time to first token: {ttft:.3f} seconds") + logger.info(f"Chat response total time: {latency:.3f} seconds") + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + client=request.user.client_app, + user_agent=request.headers.get("user-agent"), + host=request.headers.get("host"), + metadata=chat_metadata, + ) + conversation = await ConversationAdapters.aget_conversation_by_user( user, client_application=request.user.client_app, conversation_id=conversation_id, title=title ) if not conversation: - async for result in send_llm_response(f"No Conversation id: {conversation_id} not found"): + async for result in send_llm_response(f"Conversation {conversation_id} not found"): yield result + return await is_ready_to_chat(user) @@ -684,12 +709,6 @@ async def chat( client_application=request.user.client_app, conversation_id=conversation_id, ) - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) return custom_filters = [] @@ -732,17 +751,6 @@ async def chat( inferred_queries=[query_to_run], automation_id=automation.id, ) - common = CommonQueryParamsClass( - client=request.user.client_app, - user_agent=request.headers.get("user-agent"), - host=request.headers.get("host"), - ) - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - **common.__dict__, - ) async for result in send_llm_response(llm_response): yield result return @@ -839,12 +847,6 @@ async def chat( # Generate Output ## Generate Image Output if ConversationCommand.Image in conversation_commands: - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) async for result in text_to_image( q, user, @@ -913,17 +915,6 @@ async def chat( user_name, ) - cmd_set = set([cmd.value for cmd in conversation_commands]) - chat_metadata["conversation_command"] = cmd_set - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None - - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata=chat_metadata, - ) - # Send Response async for result in send_event("start_llm_response", ""): yield result diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d23df6f0..7b8af5d9 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -780,18 +780,17 @@ async def text_to_image( chat_history += f"Q: Prompt: {chat['intent']['query']}\n" chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" - with timer("Improve the original user query", logger): - if send_status_func: - async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"): - yield {"status": event} - improved_image_prompt = await generate_better_image_prompt( - message, - chat_history, - location_data=location_data, - note_references=references, - online_results=online_results, - model_type=text_to_image_config.model_type, - ) + if send_status_func: + async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"): + yield {"status": event} + improved_image_prompt = await generate_better_image_prompt( + message, + chat_history, + location_data=location_data, + note_references=references, + online_results=online_results, + model_type=text_to_image_config.model_type, + ) if send_status_func: async for event in send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}"):