Log total, ttft chat response time on start, end llm_response events

- Deduplicate code to collect chat telemetry by relying on
  end_llm_response event
- Log time to first token and total chat response time for latency
  analysis of Khoj as an agent. Not just the latency of the LLM
- Remove duplicate timer in the image generation path
This commit is contained in:
Debanjum Singh Solanky
2024-07-23 22:02:45 +05:30
parent b36a7833a6
commit 70201e8db8
2 changed files with 47 additions and 57 deletions

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import json import json
import logging import logging
import time
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@@ -22,11 +23,7 @@ from khoj.database.adapters import (
aget_user_name, aget_user_name,
) )
from khoj.database.models import KhojUser from khoj.database.models import KhojUser
from khoj.processor.conversation.prompts import ( from khoj.processor.conversation.prompts import help_message, no_entries_found
help_message,
no_entries_found,
no_notes_found,
)
from khoj.processor.conversation.utils import save_to_conversation_log from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online from khoj.processor.tools.online_search import read_webpages, search_online
@@ -34,7 +31,6 @@ from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ApiUserRateLimiter, ApiUserRateLimiter,
CommonQueryParams, CommonQueryParams,
CommonQueryParamsClass,
ConversationCommandRateLimiter, ConversationCommandRateLimiter,
agenerate_chat_response, agenerate_chat_response,
aget_relevant_information_sources, aget_relevant_information_sources,
@@ -547,22 +543,27 @@ async def chat(
), ),
): ):
async def event_generator(q: str): async def event_generator(q: str):
start_time = time.perf_counter()
ttft = None
chat_metadata: dict = {}
connection_alive = True connection_alive = True
user: KhojUser = request.user.object user: KhojUser = request.user.object
q = unquote(q) q = unquote(q)
async def send_event(event_type: str, data: str | dict): async def send_event(event_type: str, data: str | dict):
nonlocal connection_alive nonlocal connection_alive, ttft
if not connection_alive or await request.is_disconnected(): if not connection_alive or await request.is_disconnected():
connection_alive = False connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client") logger.warn(f"User {user} disconnected from {common.client} client")
return return
try: try:
if event_type == "end_llm_response":
collect_telemetry()
if event_type == "start_llm_response":
ttft = time.perf_counter() - start_time
if event_type == "message": if event_type == "message":
yield data yield data
elif event_type == "references": elif event_type == "references" or stream:
yield json.dumps({"type": event_type, "data": data})
elif stream:
yield json.dumps({"type": event_type, "data": data}) yield json.dumps({"type": event_type, "data": data})
except asyncio.CancelledError: except asyncio.CancelledError:
connection_alive = False connection_alive = False
@@ -581,12 +582,36 @@ async def chat(
async for result in send_event("end_llm_response", ""): async for result in send_event("end_llm_response", ""):
yield result yield result
def collect_telemetry():
# Gather chat response telemetry
nonlocal chat_metadata
latency = time.perf_counter() - start_time
cmd_set = set([cmd.value for cmd in conversation_commands])
chat_metadata = chat_metadata or {}
chat_metadata["conversation_command"] = cmd_set
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
chat_metadata["latency"] = f"{latency:.3f}"
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
logger.info(f"Chat response total time: {latency:.3f} seconds")
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
client=request.user.client_app,
user_agent=request.headers.get("user-agent"),
host=request.headers.get("host"),
metadata=chat_metadata,
)
conversation = await ConversationAdapters.aget_conversation_by_user( conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=request.user.client_app, conversation_id=conversation_id, title=title user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
) )
if not conversation: if not conversation:
async for result in send_llm_response(f"No Conversation id: {conversation_id} not found"): async for result in send_llm_response(f"Conversation {conversation_id} not found"):
yield result yield result
return
await is_ready_to_chat(user) await is_ready_to_chat(user)
@@ -684,12 +709,6 @@ async def chat(
client_application=request.user.client_app, client_application=request.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
) )
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
return return
custom_filters = [] custom_filters = []
@@ -732,17 +751,6 @@ async def chat(
inferred_queries=[query_to_run], inferred_queries=[query_to_run],
automation_id=automation.id, automation_id=automation.id,
) )
common = CommonQueryParamsClass(
client=request.user.client_app,
user_agent=request.headers.get("user-agent"),
host=request.headers.get("host"),
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
**common.__dict__,
)
async for result in send_llm_response(llm_response): async for result in send_llm_response(llm_response):
yield result yield result
return return
@@ -839,12 +847,6 @@ async def chat(
# Generate Output # Generate Output
## Generate Image Output ## Generate Image Output
if ConversationCommand.Image in conversation_commands: if ConversationCommand.Image in conversation_commands:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
async for result in text_to_image( async for result in text_to_image(
q, q,
user, user,
@@ -913,17 +915,6 @@ async def chat(
user_name, user_name,
) )
cmd_set = set([cmd.value for cmd in conversation_commands])
chat_metadata["conversation_command"] = cmd_set
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
)
# Send Response # Send Response
async for result in send_event("start_llm_response", ""): async for result in send_event("start_llm_response", ""):
yield result yield result

View File

@@ -780,7 +780,6 @@ async def text_to_image(
chat_history += f"Q: Prompt: {chat['intent']['query']}\n" chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
with timer("Improve the original user query", logger):
if send_status_func: if send_status_func:
async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"): async for event in send_status_func("**✍🏽 Enhancing the Painting Prompt**"):
yield {"status": event} yield {"status": event}