diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index e6b60282..34879b86 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1,7 +1,6 @@ import asyncio import json import logging -import math from datetime import datetime from functools import partial from typing import Any, Dict, List, Optional @@ -529,29 +528,47 @@ async def set_conversation_title( @api_chat.get("/stream") async def stream_chat( request: Request, + common: CommonQueryParams, q: str, - conversation_id: int, + n: int = 7, + d: float = 0.18, + title: Optional[str] = None, + conversation_id: Optional[int] = None, city: Optional[str] = None, region: Optional[str] = None, country: Optional[str] = None, timezone: Optional[str] = None, + rate_limiter_per_minute=Depends( + ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") + ), + rate_limiter_per_day=Depends( + ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") + ), ): async def event_generator(q: str): connection_alive = True + user: KhojUser = request.user.object + q = unquote(q) async def send_event(event_type: str, data: str): nonlocal connection_alive if not connection_alive or await request.is_disconnected(): connection_alive = False + logger.warn(f"User {user} disconnected from {common.client} client") return try: if event_type == "message": yield data else: yield json.dumps({"type": event_type, "data": data}) + except asyncio.CancelledError: + connection_alive = False + logger.warn(f"User {user} disconnected from {common.client} client") + return except Exception as e: connection_alive = False - logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) + return async def send_llm_response(response: str): async for result in send_event("start_llm_response", ""): @@ -561,393 +578,358 @@ async def stream_chat( async for result in send_event("end_llm_response", ""): yield result - user: KhojUser = request.user.object conversation = await ConversationAdapters.aget_conversation_by_user( - user, client_application=request.user.client_app, conversation_id=conversation_id + user, client_application=request.user.client_app, conversation_id=conversation_id, title=title ) - - hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") - - daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") + if not conversation: + async for result in send_llm_response(f"No Conversation id: {conversation_id} not found"): + yield result await is_ready_to_chat(user) user_name = await aget_user_name(user) - location = None - if city or region or country: location = LocationData(city=city, region=region, country=country) - while connection_alive: - try: - if conversation: - await sync_to_async(conversation.refresh_from_db)(fields=["conversation_log"]) + if is_query_empty(q): + async for result in send_llm_response("Please ask your query to get started."): + yield result + return - # Refresh these because the connection to the database might have been closed - await conversation.arefresh_from_db() + user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + conversation_commands = [get_conversation_command(query=q, any_references=True)] - try: - await sync_to_async(hourly_limiter)(request) - await sync_to_async(daily_limiter)(request) - except HTTPException as e: - async for result in send_event("rate_limit", e.detail): - yield result - return + async for result in send_event("status", f"**👀 Understanding Query**: {q}"): + yield result - if is_query_empty(q): - async for event in send_llm_response("Please ask your query to get started."): - yield event - return + meta_log = conversation.conversation_log + is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - conversation_commands = [get_conversation_command(query=q, any_references=True)] + if conversation_commands == [ConversationCommand.Default] or is_automated_task: + conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) + conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) + async for result in send_event( + "status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}" + ): + yield result - async for result in send_event("status", f"**👀 Understanding Query**: {q}"): + mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) + async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"): + yield result + if mode not in conversation_commands: + conversation_commands.append(mode) + + for cmd in conversation_commands: + await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) + q = q.replace(f"/{cmd.value}", "").strip() + + used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] + file_filters = conversation.file_filters if conversation else [] + # Skip trying to summarize if + if ( + # summarization intent was inferred + ConversationCommand.Summarize in conversation_commands + # and not triggered via slash command + and not used_slash_summarize + # but we can't actually summarize + and len(file_filters) != 1 + ): + conversation_commands.remove(ConversationCommand.Summarize) + elif ConversationCommand.Summarize in conversation_commands: + response_log = "" + if len(file_filters) == 0: + response_log = "No files selected for summarization. Please add files using the section on the left." + async for result in send_llm_response(response_log): yield result - - meta_log = conversation.conversation_log - is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - - used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] - - if conversation_commands == [ConversationCommand.Default] or is_automated_task: - conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) - conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) + elif len(file_filters) > 1: + response_log = "Only one file can be selected for summarization." + async for result in send_llm_response(response_log): + yield result + else: + try: + file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) + if len(file_object) == 0: + response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." + async for result in send_llm_response(response_log): + yield result + return + contextual_data = " ".join([file.raw_text for file in file_object]) + if not q: + q = "Create a general summary of the file" async for result in send_event( - "status", f"**🗃️ Chose Data Sources to Search:** {conversation_commands_str}" + "status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}" ): yield result - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task) - async for result in send_event("status", f"**🧑🏾‍💻 Decided Response Mode:** {mode.value}"): + response = await extract_relevant_summary(q, contextual_data) + response_log = str(response) + async for result in send_llm_response(response_log): yield result - if mode not in conversation_commands: - conversation_commands.append(mode) - - for cmd in conversation_commands: - await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) - q = q.replace(f"/{cmd.value}", "").strip() - - file_filters = conversation.file_filters if conversation else [] - # Skip trying to summarize if - if ( - # summarization intent was inferred - ConversationCommand.Summarize in conversation_commands - # and not triggered via slash command - and not used_slash_summarize - # but we can't actually summarize - and len(file_filters) != 1 - ): - conversation_commands.remove(ConversationCommand.Summarize) - elif ConversationCommand.Summarize in conversation_commands: - response_log = "" - if len(file_filters) == 0: - response_log = ( - "No files selected for summarization. Please add files using the section on the left." - ) - async for result in send_llm_response(response_log): - yield result - elif len(file_filters) > 1: - response_log = "Only one file can be selected for summarization." - async for result in send_llm_response(response_log): - yield result - else: - try: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - if len(file_object) == 0: - response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - async for result in send_llm_response(response_log): - yield result - return - contextual_data = " ".join([file.raw_text for file in file_object]) - if not q: - q = "Create a general summary of the file" - async for result in send_event( - "status", f"**🧑🏾‍💻 Constructing Summary Using:** {file_object[0].file_name}" - ): - yield result - - response = await extract_relevant_summary(q, contextual_data) - response_log = str(response) - async for result in send_llm_response(response_log): - yield result - except Exception as e: - response_log = "Error summarizing file." - logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - async for result in send_llm_response(response_log): - yield result - await sync_to_async(save_to_conversation_log)( - q, - response_log, - user, - meta_log, - user_message_time, - intent_type="summarize", - client_application=request.user.client_app, - conversation_id=conversation_id, - ) - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) - return - - custom_filters = [] - if conversation_commands == [ConversationCommand.Help]: - if not q: - conversation_config = await ConversationAdapters.aget_user_conversation_config(user) - if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config() - model_type = conversation_config.model_type - formatted_help = help_message.format( - model=model_type, version=state.khoj_version, device=get_device() - ) - async for result in send_llm_response(formatted_help): - yield result - return - custom_filters.append("site:khoj.dev") - conversation_commands.append(ConversationCommand.Online) - - if ConversationCommand.Automation in conversation_commands: - try: - automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log - ) - except Exception as e: - logger.error(f"Error scheduling task {q} for {user.email}: {e}") - error_message = f"Unable to create automation. Ensure the automation doesn't already exist." - async for result in send_llm_response(error_message): - yield result - return - - llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - meta_log, - user_message_time, - intent_type="automation", - client_application=request.user.client_app, - conversation_id=conversation_id, - inferred_queries=[query_to_run], - automation_id=automation.id, - ) - common = CommonQueryParamsClass( - client=request.user.client_app, - user_agent=request.headers.get("user-agent"), - host=request.headers.get("host"), - ) - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - **common.__dict__, - ) - async for result in send_llm_response(llm_response): + except Exception as e: + response_log = "Error summarizing file." + logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) + async for result in send_llm_response(response_log): yield result - return + await sync_to_async(save_to_conversation_log)( + q, + response_log, + user, + meta_log, + user_message_time, + intent_type="summarize", + client_application=request.user.client_app, + conversation_id=conversation_id, + ) + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + metadata={"conversation_command": conversation_commands[0].value}, + ) + return - compiled_references, inferred_queries, defiltered_query = [], [], None - async for result in extract_references_and_questions( - request, - meta_log, - q, - 7, - 0.18, - conversation_id, - conversation_commands, - location, - partial(send_event, "status"), + custom_filters = [] + if conversation_commands == [ConversationCommand.Help]: + if not q: + conversation_config = await ConversationAdapters.aget_user_conversation_config(user) + if conversation_config == None: + conversation_config = await ConversationAdapters.aget_default_conversation_config() + model_type = conversation_config.model_type + formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) + async for result in send_llm_response(formatted_help): + yield result + return + # Adding specification to search online specifically on khoj.dev pages. + custom_filters.append("site:khoj.dev") + conversation_commands.append(ConversationCommand.Online) + + if ConversationCommand.Automation in conversation_commands: + try: + automation, crontime, query_to_run, subject = await create_automation( + q, timezone, user, request.url, meta_log + ) + except Exception as e: + logger.error(f"Error scheduling task {q} for {user.email}: {e}") + error_message = f"Unable to create automation. Ensure the automation doesn't already exist." + async for result in send_llm_response(error_message): + yield result + return + + llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) + await sync_to_async(save_to_conversation_log)( + q, + llm_response, + user, + meta_log, + user_message_time, + intent_type="automation", + client_application=request.user.client_app, + conversation_id=conversation_id, + inferred_queries=[query_to_run], + automation_id=automation.id, + ) + common = CommonQueryParamsClass( + client=request.user.client_app, + user_agent=request.headers.get("user-agent"), + host=request.headers.get("host"), + ) + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + **common.__dict__, + ) + async for result in send_llm_response(llm_response): + yield result + return + + compiled_references, inferred_queries, defiltered_query = [], [], None + async for result in extract_references_and_questions( + request, + meta_log, + q, + (n or 7), + (d or 0.18), + conversation_id, + conversation_commands, + location, + partial(send_event, "status"), + ): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + + if not is_none_or_empty(compiled_references): + headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) + async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"): + yield result + + online_results: Dict = dict() + + if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): + async for result in send_llm_response(f"{no_entries_found.format()}"): + yield result + return + + if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): + conversation_commands.remove(ConversationCommand.Notes) + + if ConversationCommand.Online in conversation_commands: + try: + async for result in search_online( + defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters ): if isinstance(result, dict) and "status" in result: yield result["status"] else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] - - if not is_none_or_empty(compiled_references): - headings = "\n- " + "\n- ".join( - set([c.get("compiled", c).split("\n")[0] for c in compiled_references]) - ) - async for result in send_event("status", f"**📜 Found Relevant Notes**: {headings}"): - yield result - - online_results: Dict = dict() - - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries( - user - ): - async for result in send_llm_response(f"{no_entries_found.format()}"): - yield event - return - - if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): - conversation_commands.remove(ConversationCommand.Notes) - - if ConversationCommand.Online in conversation_commands: - try: - async for result in search_online( - defiltered_query, meta_log, location, partial(send_event, "status"), custom_filters - ): - if isinstance(result, dict) and "status" in result: - yield result["status"] - else: - online_results = result - except ValueError as e: - error_message = f"Error searching online: {e}. Attempting to respond without online results" - logger.warning(error_message) - async for result in send_llm_response(error_message): - yield result - return - - if ConversationCommand.Webpage in conversation_commands: - try: - async for result in read_webpages( - defiltered_query, meta_log, location, partial(send_event, "status") - ): - if isinstance(result, dict) and "status" in result: - yield result["status"] - else: - direct_web_pages = result - webpages = [] - for query in direct_web_pages: - if online_results.get(query): - online_results[query]["webpages"] = direct_web_pages[query]["webpages"] - else: - online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} - - for webpage in direct_web_pages[query]["webpages"]: - webpages.append(webpage["link"]) - async for result in send_event("status", f"**📚 Read web pages**: {webpages}"): - yield result - except ValueError as e: - logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", - exc_info=True, - ) - - if ConversationCommand.Image in conversation_commands: - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata={"conversation_command": conversation_commands[0].value}, - ) - async for result in text_to_image( - q, - user, - meta_log, - location_data=location, - references=compiled_references, - online_results=online_results, - send_status_func=partial(send_event, "status"), - ): - if isinstance(result, dict) and "status" in result: - yield result["status"] - else: - image, status_code, improved_image_prompt, intent_type = result - - if image is None or status_code != 200: - content_obj = { - "content-type": "application/json", - "intentType": intent_type, - "detail": improved_image_prompt, - "image": image, - } - async for result in send_llm_response(json.dumps(content_obj)): - yield result - return - - await sync_to_async(save_to_conversation_log)( - q, - image, - user, - meta_log, - user_message_time, - intent_type=intent_type, - inferred_queries=[improved_image_prompt], - client_application=request.user.client_app, - conversation_id=conversation_id, - compiled_references=compiled_references, - online_results=online_results, - ) - content_obj = { - "content-type": "application/json", - "intentType": intent_type, - "context": compiled_references, - "online_results": online_results, - "inferredQueries": [improved_image_prompt], - "image": image, - } - async for result in send_llm_response(json.dumps(content_obj)): - yield result - return - - async for result in send_event( - "references", json.dumps({"context": compiled_references, "online_results": online_results}) - ): + online_results = result + except ValueError as e: + error_message = f"Error searching online: {e}. Attempting to respond without online results" + logger.warning(error_message) + async for result in send_llm_response(error_message): yield result - - async for result in send_event("status", f"**💭 Generating a well-informed response**"): - yield result - llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, - meta_log, - conversation, - compiled_references, - online_results, - inferred_queries, - conversation_commands, - user, - request.user.client_app, - conversation_id, - location, - user_name, - ) - - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None - - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - metadata=chat_metadata, - ) - iterator = AsyncIteratorWrapper(llm_response) - - async for result in send_event("start_llm_response", ""): - yield result - - continue_stream = True - async for item in iterator: - if item is None: - async for result in send_event("end_llm_response", ""): - yield result - logger.debug("Finished streaming response") - return - if not connection_alive or not continue_stream: - continue - try: - async for result in send_event("message", f"{item}"): - yield result - except Exception as e: - continue_stream = False - logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") - # Stop streaming after compiled references section of response starts - # References are being processed via the references event rather than the message event - if "### compiled references:" in item: - continue_stream = False - except asyncio.CancelledError: - logger.error(f"Cancelled Error in API endpoint: {e}", exc_info=True) return + + if ConversationCommand.Webpage in conversation_commands: + try: + async for result in read_webpages(defiltered_query, meta_log, location, partial(send_event, "status")): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + direct_web_pages = result + webpages = [] + for query in direct_web_pages: + if online_results.get(query): + online_results[query]["webpages"] = direct_web_pages[query]["webpages"] + else: + online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + + for webpage in direct_web_pages[query]["webpages"]: + webpages.append(webpage["link"]) + async for result in send_event("status", f"**📚 Read web pages**: {webpages}"): + yield result + except ValueError as e: + logger.warning( + f"Error directly reading webpages: {e}. Attempting to respond without online results", + exc_info=True, + ) + + if ConversationCommand.Image in conversation_commands: + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + metadata={"conversation_command": conversation_commands[0].value}, + ) + async for result in text_to_image( + q, + user, + meta_log, + location_data=location, + references=compiled_references, + online_results=online_results, + send_status_func=partial(send_event, "status"), + ): + if isinstance(result, dict) and "status" in result: + yield result["status"] + else: + image, status_code, improved_image_prompt, intent_type = result + + if image is None or status_code != 200: + content_obj = { + "content-type": "application/json", + "intentType": intent_type, + "detail": improved_image_prompt, + "image": image, + } + async for result in send_llm_response(json.dumps(content_obj)): + yield result + return + + await sync_to_async(save_to_conversation_log)( + q, + image, + user, + meta_log, + user_message_time, + intent_type=intent_type, + inferred_queries=[improved_image_prompt], + client_application=request.user.client_app, + conversation_id=conversation_id, + compiled_references=compiled_references, + online_results=online_results, + ) + content_obj = { + "content-type": "application/json", + "intentType": intent_type, + "context": compiled_references, + "online_results": online_results, + "inferredQueries": [improved_image_prompt], + "image": image, + } + async for result in send_llm_response(json.dumps(content_obj)): + yield result + return + + async for result in send_event( + "references", json.dumps({"context": compiled_references, "online_results": online_results}) + ): + yield result + + async for result in send_event("status", f"**💭 Generating a well-informed response**"): + yield result + llm_response, chat_metadata = await agenerate_chat_response( + defiltered_query, + meta_log, + conversation, + compiled_references, + online_results, + inferred_queries, + conversation_commands, + user, + request.user.client_app, + conversation_id, + location, + user_name, + ) + + chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None + + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + metadata=chat_metadata, + ) + iterator = AsyncIteratorWrapper(llm_response) + + async for result in send_event("start_llm_response", ""): + yield result + + continue_stream = True + async for item in iterator: + if item is None: + async for result in send_event("end_llm_response", ""): + yield result + logger.debug("Finished streaming response") + return + if not connection_alive or not continue_stream: + continue + # Stop streaming after compiled references section of response starts + # References are being processed via the references event rather than the message event + if "### compiled references:" in item: + continue_stream = False + item = item.split("### compiled references:")[0] + try: + async for result in send_event("message", f"{item}"): + yield result except Exception as e: - logger.error(f"General Error in API endpoint: {e}", exc_info=True) - return + continue_stream = False + logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") return StreamingResponse(event_generator(q), media_type="text/plain")