From daec439d5250f4440ddf6006eb2804ef08b185a3 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 22 Jul 2024 20:29:45 +0530 Subject: [PATCH] Replace old chat router with new chat router with advanced streaming - Details Only return notes refs, online refs, inferred queries and generated response in non-streaming mode. Do not return train of throught and other status messages Incorporate missing logic from old chat API router into new one. - Motivation So we can halve chat API code by getting rid of the duplicate logic for the websocket router The deduplicated code: - Avoids inadvertant logic drift between the 2 routers - Improves dev velocity --- src/khoj/interface/web/chat.html | 47 ++--- src/khoj/routers/api_chat.py | 333 +++++-------------------------- 2 files changed, 61 insertions(+), 319 deletions(-) diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 00139232..6855c196 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -709,27 +709,11 @@ To get started, just start typing below. You can also type / to see a list of co rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; } } - let references = {}; - if (imageJson.context && imageJson.context.length > 0) { - const rawReferenceAsJson = imageJson.context; - if (rawReferenceAsJson instanceof Array) { - references["notes"] = rawReferenceAsJson; - } else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) { - references["online"] = rawReferenceAsJson; - } - } - if (imageJson.detail) { - // If response has detail field, response is an error message. - rawResponse += imageJson.detail; - } - return { rawResponse, references }; - } - function addMessageToChatBody(rawResponse, newResponseElement, references) { - newResponseElement.innerHTML = ""; - newResponseElement.appendChild(formatHTMLMessage(rawResponse)); + // If response has detail field, response is an error message. + if (imageJson.detail) rawResponse += imageJson.detail; - finalizeChatBodyResponse(references, newResponseElement); + return rawResponse; } function finalizeChatBodyResponse(references, newResponseElement) { @@ -743,7 +727,6 @@ To get started, just start typing below. You can also type / to see a list of co function collectJsonsInBufferedMessageChunk(chunk) { // Collect list of JSON objects and raw strings in the chunk // Return the list of objects and the remaining raw string - console.log("Raw Chunk:", chunk); let startIndex = chunk.indexOf('{'); if (startIndex === -1) return { objects: [chunk], remainder: '' }; const objects = [chunk.slice(0, startIndex)]; @@ -819,11 +802,13 @@ To get started, just start typing below. You can also type / to see a list of co isVoice: false, } } else if (chunk.type === "references") { - const rawReferenceAsJson = JSON.parse(chunk.data); - chatMessageState.references = {"notes": rawReferenceAsJson.context, "online": rawReferenceAsJson.online_results}; + chatMessageState.references = {"notes": chunk.data.context, "online": chunk.data.online_results}; } else if (chunk.type === 'message') { const chunkData = chunk.data; - if (chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) { + if (typeof chunkData === 'object' && chunkData !== null) { + // If chunkData is already a JSON object + handleJsonResponse(chunkData); + } else if (typeof chunkData === 'string' && chunkData.trim()?.startsWith("{") && chunkData.trim()?.endsWith("}")) { // Try process chunk data as if it is a JSON object try { const jsonData = JSON.parse(chunkData.trim()); @@ -841,17 +826,15 @@ To get started, just start typing below. You can also type / to see a list of co function handleJsonResponse(jsonData) { if (jsonData.image || jsonData.detail) { - let { rawResponse, references } = handleImageResponse(jsonData, chatMessageState.rawResponse); - chatMessageState.rawResponse = rawResponse; - chatMessageState.references = references; + chatMessageState.rawResponse = handleImageResponse(jsonData, chatMessageState.rawResponse); } else if (jsonData.response) { chatMessageState.rawResponse = jsonData.response; - chatMessageState.references = { - notes: jsonData.context || {}, - online: jsonData.online_results || {} - }; } - addMessageToChatBody(chatMessageState.rawResponse, chatMessageState.newResponseTextEl, chatMessageState.references); + + if (chatMessageState.newResponseTextEl) { + chatMessageState.newResponseTextEl.innerHTML = ""; + chatMessageState.newResponseTextEl.appendChild(formatHTMLMessage(chatMessageState.rawResponse)); + } } async function sendMessageStream(query) { @@ -866,7 +849,7 @@ To get started, just start typing below. You can also type / to see a list of co refreshChatSessionsPanel(); } - let chatStreamUrl = `/api/chat/stream?q=${encodeURIComponent(query)}&conversation_id=${conversationId}&client=web`; + let chatStreamUrl = `/api/chat?q=${encodeURIComponent(query)}&conversation_id=${conversationId}&stream=true&client=web`; chatStreamUrl += (!!region && !!city && !!countryName && !!timezone) ? `®ion=${region}&city=${city}&country=${countryName}&timezone=${timezone}` : ''; diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 34879b86..d8826264 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -525,13 +525,14 @@ async def set_conversation_title( ) -@api_chat.get("/stream") -async def stream_chat( +@api_chat.get("") +async def chat( request: Request, common: CommonQueryParams, q: str, n: int = 7, d: float = 0.18, + stream: Optional[bool] = False, title: Optional[str] = None, conversation_id: Optional[int] = None, city: Optional[str] = None, @@ -550,7 +551,7 @@ async def stream_chat( user: KhojUser = request.user.object q = unquote(q) - async def send_event(event_type: str, data: str): + async def send_event(event_type: str, data: str | dict): nonlocal connection_alive if not connection_alive or await request.is_disconnected(): connection_alive = False @@ -559,7 +560,9 @@ async def stream_chat( try: if event_type == "message": yield data - else: + elif event_type == "references": + yield json.dumps({"type": event_type, "data": data}) + elif stream: yield json.dumps({"type": event_type, "data": data}) except asyncio.CancelledError: connection_alive = False @@ -744,6 +747,8 @@ async def stream_chat( yield result return + # Gather Context + ## Extract Document References compiled_references, inferred_queries, defiltered_query = [], [], None async for result in extract_references_and_questions( request, @@ -778,6 +783,7 @@ async def stream_chat( if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): conversation_commands.remove(ConversationCommand.Notes) + ## Gather Online References if ConversationCommand.Online in conversation_commands: try: async for result in search_online( @@ -794,6 +800,7 @@ async def stream_chat( yield result return + ## Gather Webpage References if ConversationCommand.Webpage in conversation_commands: try: async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")): @@ -818,6 +825,19 @@ async def stream_chat( exc_info=True, ) + ## Send Gathered References + async for result in send_event( + "references", + { + "inferredQueries": inferred_queries, + "context": compiled_references, + "online_results": online_results, + }, + ): + yield result + + # Generate Output + ## Generate Image Output if ConversationCommand.Image in conversation_commands: update_telemetry_state( request=request, @@ -875,11 +895,7 @@ async def stream_chat( yield result return - async for result in send_event( - "references", json.dumps({"context": compiled_references, "online_results": online_results}) - ): - yield result - + ## Generate Text Output async for result in send_event("status", f"**💭 Generating a well-informed response**"): yield result llm_response, chat_metadata = await agenerate_chat_response( @@ -897,6 +913,8 @@ async def stream_chat( user_name, ) + cmd_set = set([cmd.value for cmd in conversation_commands]) + chat_metadata["conversation_command"] = cmd_set chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None update_telemetry_state( @@ -905,12 +923,13 @@ async def stream_chat( api="chat", metadata=chat_metadata, ) - iterator = AsyncIteratorWrapper(llm_response) + # Send Response async for result in send_event("start_llm_response", ""): yield result continue_stream = True + iterator = AsyncIteratorWrapper(llm_response) async for item in iterator: if item is None: async for result in send_event("end_llm_response", ""): @@ -931,282 +950,22 @@ async def stream_chat( continue_stream = False logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") - return StreamingResponse(event_generator(q), media_type="text/plain") - - -@api_chat.get("", response_class=Response) -@requires(["authenticated"]) -async def chat( - request: Request, - common: CommonQueryParams, - q: str, - n: Optional[int] = 5, - d: Optional[float] = 0.22, - stream: Optional[bool] = False, - title: Optional[str] = None, - conversation_id: Optional[int] = None, - city: Optional[str] = None, - region: Optional[str] = None, - country: Optional[str] = None, - timezone: Optional[str] = None, - rate_limiter_per_minute=Depends( - ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") - ), - rate_limiter_per_day=Depends( - ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") - ), -) -> Response: - user: KhojUser = request.user.object - q = unquote(q) - if is_query_empty(q): - return Response( - content="It seems like your query is incomplete. Could you please provide more details or specify what you need help with?", - media_type="text/plain", - status_code=400, - ) - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - logger.info(f"Chat request by {user.username}: {q}") - - await is_ready_to_chat(user) - conversation_commands = [get_conversation_command(query=q, any_references=True)] - - _custom_filters = [] - if conversation_commands == [ConversationCommand.Help]: - help_str = "/" + ConversationCommand.Help - if q.strip() == help_str: - conversation_config = await ConversationAdapters.aget_user_conversation_config(user) - if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config() - model_type = conversation_config.model_type - formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) - return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200) - # Adding specification to search online specifically on khoj.dev pages. - _custom_filters.append("site:khoj.dev") - conversation_commands.append(ConversationCommand.Online) - - conversation = await ConversationAdapters.aget_conversation_by_user( - user, request.user.client_app, conversation_id, title - ) - conversation_id = conversation.id if conversation else None - - if not conversation: - return Response( - content=f"No conversation found with requested id, title", media_type="text/plain", status_code=400 - ) - else: - meta_log = conversation.conversation_log - - if ConversationCommand.Summarize in conversation_commands: - file_filters = conversation.file_filters - llm_response = "" - if len(file_filters) == 0: - llm_response = "No files selected for summarization. Please add files using the section on the left." - elif len(file_filters) > 1: - llm_response = "Only one file can be selected for summarization." - else: - try: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - if len(file_object) == 0: - llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200) - contextual_data = " ".join([file.raw_text for file in file_object]) - summarizeStr = "/" + ConversationCommand.Summarize - if q.strip() == summarizeStr: - q = "Create a general summary of the file" - response = await extract_relevant_summary(q, contextual_data) - llm_response = str(response) - except Exception as e: - logger.error(f"Error summarizing file for {user.email}: {e}") - llm_response = "Error summarizing file." - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - conversation.conversation_log, - user_message_time, - intent_type="summarize", - client_application=request.user.client_app, - conversation_id=conversation_id, - ) - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - **common.__dict__, - ) - return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200) - - is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - - if conversation_commands == [ConversationCommand.Default] or is_automated_task: - conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) - if mode not in conversation_commands: - conversation_commands.append(mode) - - for cmd in conversation_commands: - await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) - q = q.replace(f"/{cmd.value}", "").strip() - - location = None - - if city or region or country: - location = LocationData(city=city, region=region, country=country) - - user_name = await aget_user_name(user) - - if ConversationCommand.Automation in conversation_commands: - try: - automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log - ) - except Exception as e: - logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True) - return Response( - content=f"Unable to create automation. Ensure the automation doesn't already exist.", - media_type="text/plain", - status_code=500, - ) - - llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - meta_log, - user_message_time, - intent_type="automation", - client_application=request.user.client_app, - conversation_id=conversation_id, - inferred_queries=[query_to_run], - automation_id=automation.id, - ) - - if stream: - return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200) - else: - return Response(content=llm_response, media_type="text/plain", status_code=200) - - compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( - request, meta_log, q, (n or 5), (d or math.inf), conversation_id, conversation_commands, location - ) - online_results: Dict[str, Dict] = {} - - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - no_entries_found_format = no_entries_found.format() - if stream: - return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200) - else: - response_obj = {"response": no_entries_found_format} - return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200) - - if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references): - no_notes_found_format = no_notes_found.format() - if stream: - return StreamingResponse(iter([no_notes_found_format]), media_type="text/event-stream", status_code=200) - else: - response_obj = {"response": no_notes_found_format} - return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200) - - if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): - conversation_commands.remove(ConversationCommand.Notes) - - if ConversationCommand.Online in conversation_commands: - try: - online_results = await search_online(defiltered_query, meta_log, location, custom_filters=_custom_filters) - except ValueError as e: - logger.warning(f"Error searching online: {e}. Attempting to respond without online results") - - if ConversationCommand.Webpage in conversation_commands: - try: - online_results = await read_webpages(defiltered_query, meta_log, location) - except ValueError as e: - logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True - ) - - if ConversationCommand.Image in conversation_commands: - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - **common.__dict__, - ) - image, status_code, improved_image_prompt, intent_type = await text_to_image( - q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results - ) - if image is None: - content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt} - return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) - - await sync_to_async(save_to_conversation_log)( - q, - image, - user, - meta_log, - user_message_time, - intent_type=intent_type, - inferred_queries=[improved_image_prompt], - client_application=request.user.client_app, - conversation_id=conversation.id, - compiled_references=compiled_references, - online_results=online_results, - ) - content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore - return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) - - # Get the (streamed) chat response from the LLM of choice. - llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, - meta_log, - conversation, - compiled_references, - online_results, - inferred_queries, - conversation_commands, - user, - request.user.client_app, - conversation.id, - location, - user_name, - ) - - cmd_set = set([cmd.value for cmd in conversation_commands]) - chat_metadata["conversation_command"] = cmd_set - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None - - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata=chat_metadata, - **common.__dict__, - ) - - if llm_response is None: - return Response(content=llm_response, media_type="text/plain", status_code=500) - + ## Stream Text Response if stream: - return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200) + return StreamingResponse(event_generator(q), media_type="text/plain") + ## Non-Streaming Text Response + else: + # Get the full response from the generator if the stream is not requested. + response_obj = {} + actual_response = "" + iterator = event_generator(q) + async for item in iterator: + try: + item_json = json.loads(item) + if "type" in item_json and item_json["type"] == "references": + response_obj = item_json["data"] + except: + actual_response += item + response_obj["response"] = actual_response - iterator = AsyncIteratorWrapper(llm_response) - - # Get the full response from the generator if the stream is not requested. - aggregated_gpt_response = "" - async for item in iterator: - if item is None: - break - aggregated_gpt_response += item - - actual_response = aggregated_gpt_response.split("### compiled references:")[0] - - response_obj = { - "response": actual_response, - "inferredQueries": inferred_queries, - "context": compiled_references, - "online_results": online_results, - } - - return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200) + return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)