diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index fa62dbb2..16a9ff67 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -485,6 +485,47 @@ Khoj: """.strip() ) +plan_function_execution = PromptTemplate.from_template( + """ +You are Khoj, an extremely smart and helpful search assistant. +{personality_context} +- You have access to a variety of data sources to help you answer the user's question +- You can use the data sources listed below to collect more relevant information, one at a time +- You are given multiple iterations to with these data sources to answer the user's question +- You are provided with additional context. If you have enough context to answer the question, then exit execution + +If you already know the answer to the question, return an empty response, e.g., {{}}. + +Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources: + +{tools} + +Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data source and associated query in a JSON object. Do not say anything else. + +Previous Iterations: +{previous_iterations} + +Response format: +{{"data_source": "", "query": ""}} + +Chat History: +{chat_history} + +Q: {query} +Khoj: +""".strip() +) + +previous_iteration = PromptTemplate.from_template( + """ +data_source: {data_source} +query: {query} +context: {context} +onlineContext: {onlineContext} +--- +""".strip() +) + pick_relevant_information_collection_tools = PromptTemplate.from_template( """ You are Khoj, an extremely smart and helpful search assistant. diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 11ab1112..d26b7b5a 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -355,9 +355,10 @@ async def extract_references_and_questions( agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent) if ( - not ConversationCommand.Notes in conversation_commands - and not ConversationCommand.Default in conversation_commands - and not agent_has_entries + # not ConversationCommand.Notes in conversation_commands + # and not ConversationCommand.Default in conversation_commands + # and not agent_has_entries + True ): yield compiled_references, inferred_queries, q return diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 4acefe30..af19a40c 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -41,7 +41,8 @@ from khoj.routers.helpers import ( aget_relevant_output_modes, construct_automation_created_message, create_automation, - extract_relevant_summary, + extract_relevant_info, + generate_summary_from_files, get_conversation_command, is_query_empty, is_ready_to_chat, @@ -49,6 +50,10 @@ from khoj.routers.helpers import ( update_telemetry_state, validate_conversation_config, ) +from khoj.routers.research import ( + InformationCollectionIteration, + execute_information_collection, +) from khoj.routers.storage import upload_image_to_bucket from khoj.utils import state from khoj.utils.helpers import ( @@ -689,6 +694,522 @@ async def chat( meta_log = conversation.conversation_log is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] + pending_research = True + + researched_results = "" + online_results: Dict = dict() + ## Extract Document References + compiled_references, inferred_queries, defiltered_query = [], [], None + + if conversation_commands == [ConversationCommand.Default] or is_automated_task: + async for research_result in execute_information_collection( + request=request, + user=user, + query=q, + conversation_id=conversation_id, + conversation_history=meta_log, + subscribed=subscribed, + uploaded_image_url=uploaded_image_url, + agent=agent, + send_status_func=partial(send_event, ChatEvent.STATUS), + location=location, + file_filters=conversation.file_filters if conversation else [], + ): + if type(research_result) == InformationCollectionIteration: + pending_research = False + if research_result.onlineContext: + researched_results += str(research_result.onlineContext) + online_results.update(research_result.onlineContext) + + if research_result.context: + researched_results += str(research_result.context) + compiled_references.extend(research_result.context) + + else: + yield research_result + + researched_results = await extract_relevant_info(q, researched_results, agent) + + logger.info(f"Researched Results: {researched_results}") + + pending_research = False + + conversation_commands = await aget_relevant_information_sources( + q, + meta_log, + is_automated_task, + subscribed=subscribed, + uploaded_image_url=uploaded_image_url, + agent=agent, + ) + conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) + async for result in send_event( + ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}" + ): + yield result + + mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent) + async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): + yield result + if mode not in conversation_commands: + conversation_commands.append(mode) + + for cmd in conversation_commands: + await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) + q = q.replace(f"/{cmd.value}", "").strip() + + used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] + file_filters = conversation.file_filters if conversation else [] + # Skip trying to summarize if + if ( + # summarization intent was inferred + ConversationCommand.Summarize in conversation_commands + # and not triggered via slash command + and not used_slash_summarize + # but we can't actually summarize + and len(file_filters) != 1 + # not pending research + and not pending_research + ): + conversation_commands.remove(ConversationCommand.Summarize) + elif ConversationCommand.Summarize in conversation_commands and pending_research: + response_log = "" + agent_has_entries = await EntryAdapters.aagent_has_entries(agent) + if len(file_filters) == 0 and not agent_has_entries: + response_log = "No files selected for summarization. Please add files using the section on the left." + async for result in send_llm_response(response_log): + yield result + elif len(file_filters) > 1 and not agent_has_entries: + response_log = "Only one file can be selected for summarization." + async for result in send_llm_response(response_log): + yield result + else: + response_log = await generate_summary_from_files( + q=query, + user=user, + file_filters=file_filters, + meta_log=meta_log, + subscribed=subscribed, + send_status_func=partial(send_event, ChatEvent.STATUS), + send_response_func=partial(send_llm_response), + ) + await sync_to_async(save_to_conversation_log)( + q, + response_log, + user, + meta_log, + user_message_time, + intent_type="summarize", + client_application=request.user.client_app, + conversation_id=conversation_id, + uploaded_image_url=uploaded_image_url, + ) + return + + custom_filters = [] + if conversation_commands == [ConversationCommand.Help]: + if not q: + conversation_config = await ConversationAdapters.aget_user_conversation_config(user) + if conversation_config == None: + conversation_config = await ConversationAdapters.aget_default_conversation_config() + model_type = conversation_config.model_type + formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) + async for result in send_llm_response(formatted_help): + yield result + return + # Adding specification to search online specifically on khoj.dev pages. + custom_filters.append("site:khoj.dev") + conversation_commands.append(ConversationCommand.Online) + + if ConversationCommand.Automation in conversation_commands: + try: + automation, crontime, query_to_run, subject = await create_automation( + q, timezone, user, request.url, meta_log + ) + except Exception as e: + logger.error(f"Error scheduling task {q} for {user.email}: {e}") + error_message = f"Unable to create automation. Ensure the automation doesn't already exist." + async for result in send_llm_response(error_message): + yield result + return + + llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) + await sync_to_async(save_to_conversation_log)( + q, + llm_response, + user, + meta_log, + user_message_time, + intent_type="automation", + client_application=request.user.client_app, + conversation_id=conversation_id, + inferred_queries=[query_to_run], + automation_id=automation.id, + uploaded_image_url=uploaded_image_url, + ) + async for result in send_llm_response(llm_response): + yield result + return + + # Gather Context + async for result in extract_references_and_questions( + request, + meta_log, + q, + (n or 7), + d, + conversation_id, + conversation_commands, + location, + partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + + if not is_none_or_empty(compiled_references): + headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) + # Strip only leading # from headings + headings = headings.replace("#", "") + async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): + yield result + + if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): + async for result in send_llm_response(f"{no_entries_found.format()}"): + yield result + return + + if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): + conversation_commands.remove(ConversationCommand.Notes) + + ## Gather Online References + if ConversationCommand.Online in conversation_commands and pending_research: + try: + async for result in search_online( + defiltered_query, + meta_log, + location, + user, + subscribed, + partial(send_event, ChatEvent.STATUS), + custom_filters, + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + online_results = result + except ValueError as e: + error_message = f"Error searching online: {e}. Attempting to respond without online results" + logger.warning(error_message) + async for result in send_llm_response(error_message): + yield result + return + + ## Gather Webpage References + if ConversationCommand.Webpage in conversation_commands and pending_research: + try: + async for result in read_webpages( + defiltered_query, + meta_log, + location, + user, + subscribed, + partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + direct_web_pages = result + webpages = [] + for query in direct_web_pages: + if online_results.get(query): + online_results[query]["webpages"] = direct_web_pages[query]["webpages"] + else: + online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + + for webpage in direct_web_pages[query]["webpages"]: + webpages.append(webpage["link"]) + async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): + yield result + except ValueError as e: + logger.warning( + f"Error directly reading webpages: {e}. Attempting to respond without online results", + exc_info=True, + ) + + ## Send Gathered References + async for result in send_event( + ChatEvent.REFERENCES, + { + "inferredQueries": inferred_queries, + "context": compiled_references, + "onlineContext": online_results, + }, + ): + yield result + + # Generate Output + ## Generate Image Output + if ConversationCommand.Image in conversation_commands: + async for result in text_to_image( + q, + user, + meta_log, + location_data=location, + references=compiled_references, + online_results=online_results, + subscribed=subscribed, + send_status_func=partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + image, status_code, improved_image_prompt, intent_type = result + + if image is None or status_code != 200: + content_obj = { + "content-type": "application/json", + "intentType": intent_type, + "detail": improved_image_prompt, + "image": image, + } + async for result in send_llm_response(json.dumps(content_obj)): + yield result + return + + await sync_to_async(save_to_conversation_log)( + q, + image, + user, + meta_log, + user_message_time, + intent_type=intent_type, + inferred_queries=[improved_image_prompt], + client_application=request.user.client_app, + conversation_id=conversation_id, + compiled_references=compiled_references, + online_results=online_results, + uploaded_image_url=uploaded_image_url, + ) + content_obj = { + "intentType": intent_type, + "inferredQueries": [improved_image_prompt], + "image": image, + } + async for result in send_llm_response(json.dumps(content_obj)): + yield result + return + + ## Generate Text Output + async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): + yield result + llm_response, chat_metadata = await agenerate_chat_response( + defiltered_query, + meta_log, + conversation, + researched_results, + compiled_references, + online_results, + inferred_queries, + conversation_commands, + user, + request.user.client_app, + conversation_id, + location, + user_name, + uploaded_image_url, + ) + + # Send Response + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result + + continue_stream = True + iterator = AsyncIteratorWrapper(llm_response) + async for item in iterator: + if item is None: + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + logger.debug("Finished streaming response") + return + if not connection_alive or not continue_stream: + continue + try: + async for result in send_event(ChatEvent.MESSAGE, f"{item}"): + yield result + except Exception as e: + continue_stream = False + logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + + ## Stream Text Response + if stream: + return StreamingResponse(event_generator(q, image=image), media_type="text/plain") + ## Non-Streaming Text Response + else: + response_iterator = event_generator(q, image=image) + response_data = await read_chat_stream(response_iterator) + return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) + + +# @api_chat.post("") +@requires(["authenticated"]) +async def old_chat( + request: Request, + common: CommonQueryParams, + body: ChatRequestBody, + rate_limiter_per_minute=Depends( + ApiUserRateLimiter(requests=60, subscribed_requests=200, window=60, slug="chat_minute") + ), + rate_limiter_per_day=Depends( + ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day") + ), +): + # Access the parameters from the body + q = body.q + n = body.n + d = body.d + stream = body.stream + title = body.title + conversation_id = body.conversation_id + city = body.city + region = body.region + country = body.country or get_country_name_from_timezone(body.timezone) + country_code = body.country_code or get_country_code_from_timezone(body.timezone) + timezone = body.timezone + image = body.image + + async def event_generator(q: str, image: str): + start_time = time.perf_counter() + ttft = None + chat_metadata: dict = {} + connection_alive = True + user: KhojUser = request.user.object + subscribed: bool = has_required_scope(request, ["premium"]) + event_delimiter = "␃🔚␗" + q = unquote(q) + nonlocal conversation_id + + uploaded_image_url = None + if image: + decoded_string = unquote(image) + base64_data = decoded_string.split(",", 1)[1] + image_bytes = base64.b64decode(base64_data) + webp_image_bytes = convert_image_to_webp(image_bytes) + try: + uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id) + except: + uploaded_image_url = None + + async def send_event(event_type: ChatEvent, data: str | dict): + nonlocal connection_alive, ttft + if not connection_alive or await request.is_disconnected(): + connection_alive = False + logger.warning(f"User {user} disconnected from {common.client} client") + return + try: + if event_type == ChatEvent.END_LLM_RESPONSE: + collect_telemetry() + if event_type == ChatEvent.START_LLM_RESPONSE: + ttft = time.perf_counter() - start_time + if event_type == ChatEvent.MESSAGE: + yield data + elif event_type == ChatEvent.REFERENCES or stream: + yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) + except asyncio.CancelledError as e: + connection_alive = False + logger.warn(f"User {user} disconnected from {common.client} client: {e}") + return + except Exception as e: + connection_alive = False + logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) + return + finally: + yield event_delimiter + + async def send_llm_response(response: str): + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result + async for result in send_event(ChatEvent.MESSAGE, response): + yield result + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + + def collect_telemetry(): + # Gather chat response telemetry + nonlocal chat_metadata + latency = time.perf_counter() - start_time + cmd_set = set([cmd.value for cmd in conversation_commands]) + chat_metadata = chat_metadata or {} + chat_metadata["conversation_command"] = cmd_set + chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None + chat_metadata["latency"] = f"{latency:.3f}" + chat_metadata["ttft_latency"] = f"{ttft:.3f}" + + logger.info(f"Chat response time to first token: {ttft:.3f} seconds") + logger.info(f"Chat response total time: {latency:.3f} seconds") + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + client=request.user.client_app, + user_agent=request.headers.get("user-agent"), + host=request.headers.get("host"), + metadata=chat_metadata, + ) + + conversation_commands = [get_conversation_command(query=q, any_references=True)] + + conversation = await ConversationAdapters.aget_conversation_by_user( + user, + client_application=request.user.client_app, + conversation_id=conversation_id, + title=title, + create_new=body.create_new, + ) + if not conversation: + async for result in send_llm_response(f"Conversation {conversation_id} not found"): + yield result + return + conversation_id = conversation.id + + agent: Agent | None = None + default_agent = await AgentAdapters.aget_default_agent() + if conversation.agent and conversation.agent != default_agent: + agent = conversation.agent + + if not conversation.agent: + conversation.agent = default_agent + await conversation.asave() + agent = default_agent + + await is_ready_to_chat(user) + + user_name = await aget_user_name(user) + location = None + if city or region or country or country_code: + location = LocationData(city=city, region=region, country=country, country_code=country_code) + + if is_query_empty(q): + async for result in send_llm_response("Please ask your query to get started."): + yield result + return + + user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + meta_log = conversation.conversation_log + is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] + if conversation_commands == [ConversationCommand.Default] or is_automated_task: conversation_commands = await aget_relevant_information_sources( q, @@ -738,47 +1259,15 @@ async def chat( async for result in send_llm_response(response_log): yield result else: - try: - file_object = None - if await EntryAdapters.aagent_has_entries(agent): - file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) - if len(file_names) > 0: - file_object = await FileObjectAdapters.async_get_file_objects_by_name( - None, file_names[0], agent - ) - - if len(file_filters) > 0: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - - if len(file_object) == 0: - response_log = "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again." - async for result in send_llm_response(response_log): - yield result - return - contextual_data = " ".join([file.raw_text for file in file_object]) - if not q: - q = "Create a general summary of the file" - async for result in send_event( - ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}" - ): - yield result - - response = await extract_relevant_summary( - q, - contextual_data, - conversation_history=meta_log, - subscribed=subscribed, - uploaded_image_url=uploaded_image_url, - agent=agent, - ) - response_log = str(response) - async for result in send_llm_response(response_log): - yield result - except Exception as e: - response_log = "Error summarizing file. Please try again, or contact support." - logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - async for result in send_llm_response(response_log): - yield result + response_log = await generate_summary_from_files( + q=query, + user=user, + file_filters=file_filters, + meta_log=meta_log, + subscribed=subscribed, + send_status_func=partial(send_event, ChatEvent.STATUS), + send_response_func=partial(send_llm_response), + ) await sync_to_async(save_to_conversation_log)( q, response_log, @@ -867,8 +1356,6 @@ async def chat( async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): yield result - online_results: Dict = dict() - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): async for result in send_llm_response(f"{no_entries_found.format()}"): yield result @@ -1049,483 +1536,3 @@ async def chat( response_iterator = event_generator(q, image=image) response_data = await read_chat_stream(response_iterator) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) - - -# Deprecated API. Remove by end of September 2024 -@api_chat.get("") -@requires(["authenticated"]) -async def get_chat( - request: Request, - common: CommonQueryParams, - q: str, - n: int = 7, - d: float = None, - stream: Optional[bool] = False, - title: Optional[str] = None, - conversation_id: Optional[str] = None, - city: Optional[str] = None, - region: Optional[str] = None, - country: Optional[str] = None, - timezone: Optional[str] = None, - image: Optional[str] = None, - rate_limiter_per_minute=Depends( - ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute") - ), - rate_limiter_per_day=Depends( - ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") - ), -): - # Issue a deprecation warning - warnings.warn( - "The 'get_chat' API endpoint is deprecated. It will be removed by the end of September 2024.", - DeprecationWarning, - stacklevel=2, - ) - - async def event_generator(q: str, image: str): - start_time = time.perf_counter() - ttft = None - chat_metadata: dict = {} - connection_alive = True - user: KhojUser = request.user.object - subscribed: bool = has_required_scope(request, ["premium"]) - event_delimiter = "␃🔚␗" - q = unquote(q) - nonlocal conversation_id - - uploaded_image_url = None - if image: - decoded_string = unquote(image) - base64_data = decoded_string.split(",", 1)[1] - image_bytes = base64.b64decode(base64_data) - webp_image_bytes = convert_image_to_webp(image_bytes) - try: - uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id) - except: - uploaded_image_url = None - - async def send_event(event_type: ChatEvent, data: str | dict): - nonlocal connection_alive, ttft - if not connection_alive or await request.is_disconnected(): - connection_alive = False - logger.warn(f"User {user} disconnected from {common.client} client") - return - try: - if event_type == ChatEvent.END_LLM_RESPONSE: - collect_telemetry() - if event_type == ChatEvent.START_LLM_RESPONSE: - ttft = time.perf_counter() - start_time - if event_type == ChatEvent.MESSAGE: - yield data - elif event_type == ChatEvent.REFERENCES or stream: - yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) - except asyncio.CancelledError as e: - connection_alive = False - logger.warn(f"User {user} disconnected from {common.client} client: {e}") - return - except Exception as e: - connection_alive = False - logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) - return - finally: - yield event_delimiter - - async def send_llm_response(response: str): - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result - async for result in send_event(ChatEvent.MESSAGE, response): - yield result - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - - def collect_telemetry(): - # Gather chat response telemetry - nonlocal chat_metadata - latency = time.perf_counter() - start_time - cmd_set = set([cmd.value for cmd in conversation_commands]) - chat_metadata = chat_metadata or {} - chat_metadata["conversation_command"] = cmd_set - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None - chat_metadata["latency"] = f"{latency:.3f}" - chat_metadata["ttft_latency"] = f"{ttft:.3f}" - - logger.info(f"Chat response time to first token: {ttft:.3f} seconds") - logger.info(f"Chat response total time: {latency:.3f} seconds") - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - client=request.user.client_app, - user_agent=request.headers.get("user-agent"), - host=request.headers.get("host"), - metadata=chat_metadata, - ) - - conversation_commands = [get_conversation_command(query=q, any_references=True)] - - conversation = await ConversationAdapters.aget_conversation_by_user( - user, client_application=request.user.client_app, conversation_id=conversation_id, title=title - ) - if not conversation: - async for result in send_llm_response(f"Conversation {conversation_id} not found"): - yield result - return - conversation_id = conversation.id - agent = conversation.agent if conversation.agent else None - - await is_ready_to_chat(user) - - user_name = await aget_user_name(user) - location = None - if city or region or country: - location = LocationData(city=city, region=region, country=country) - - if is_query_empty(q): - async for result in send_llm_response("Please ask your query to get started."): - yield result - return - - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - meta_log = conversation.conversation_log - is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - - if conversation_commands == [ConversationCommand.Default] or is_automated_task: - conversation_commands = await aget_relevant_information_sources( - q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url - ) - conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) - async for result in send_event( - ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}" - ): - yield result - - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url) - async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): - yield result - if mode not in conversation_commands: - conversation_commands.append(mode) - - for cmd in conversation_commands: - await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) - q = q.replace(f"/{cmd.value}", "").strip() - - used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] - file_filters = conversation.file_filters if conversation else [] - # Skip trying to summarize if - if ( - # summarization intent was inferred - ConversationCommand.Summarize in conversation_commands - # and not triggered via slash command - and not used_slash_summarize - # but we can't actually summarize - and len(file_filters) != 1 - ): - conversation_commands.remove(ConversationCommand.Summarize) - elif ConversationCommand.Summarize in conversation_commands: - response_log = "" - if len(file_filters) == 0: - response_log = "No files selected for summarization. Please add files using the section on the left." - async for result in send_llm_response(response_log): - yield result - elif len(file_filters) > 1: - response_log = "Only one file can be selected for summarization." - async for result in send_llm_response(response_log): - yield result - else: - try: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - if len(file_object) == 0: - response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - async for result in send_llm_response(response_log): - yield result - return - contextual_data = " ".join([file.raw_text for file in file_object]) - if not q: - q = "Create a general summary of the file" - async for result in send_event( - ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}" - ): - yield result - - response = await extract_relevant_summary( - q, - contextual_data, - conversation_history=meta_log, - subscribed=subscribed, - uploaded_image_url=uploaded_image_url, - ) - response_log = str(response) - async for result in send_llm_response(response_log): - yield result - except Exception as e: - response_log = "Error summarizing file." - logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - async for result in send_llm_response(response_log): - yield result - await sync_to_async(save_to_conversation_log)( - q, - response_log, - user, - meta_log, - user_message_time, - intent_type="summarize", - client_application=request.user.client_app, - conversation_id=conversation_id, - uploaded_image_url=uploaded_image_url, - ) - return - - custom_filters = [] - if conversation_commands == [ConversationCommand.Help]: - if not q: - conversation_config = await ConversationAdapters.aget_user_conversation_config(user) - if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config() - model_type = conversation_config.model_type - formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) - async for result in send_llm_response(formatted_help): - yield result - return - # Adding specification to search online specifically on khoj.dev pages. - custom_filters.append("site:khoj.dev") - conversation_commands.append(ConversationCommand.Online) - - if ConversationCommand.Automation in conversation_commands: - try: - automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log - ) - except Exception as e: - logger.error(f"Error scheduling task {q} for {user.email}: {e}") - error_message = f"Unable to create automation. Ensure the automation doesn't already exist." - async for result in send_llm_response(error_message): - yield result - return - - llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - meta_log, - user_message_time, - intent_type="automation", - client_application=request.user.client_app, - conversation_id=conversation_id, - inferred_queries=[query_to_run], - automation_id=automation.id, - uploaded_image_url=uploaded_image_url, - ) - async for result in send_llm_response(llm_response): - yield result - return - - # Gather Context - ## Extract Document References - compiled_references, inferred_queries, defiltered_query = [], [], None - async for result in extract_references_and_questions( - request, - meta_log, - q, - (n or 7), - d, - conversation_id, - conversation_commands, - location, - partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] - - if not is_none_or_empty(compiled_references): - headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) - # Strip only leading # from headings - headings = headings.replace("#", "") - async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): - yield result - - online_results: Dict = dict() - - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - async for result in send_llm_response(f"{no_entries_found.format()}"): - yield result - return - - if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): - conversation_commands.remove(ConversationCommand.Notes) - - ## Gather Online References - if ConversationCommand.Online in conversation_commands: - try: - async for result in search_online( - defiltered_query, - meta_log, - location, - user, - subscribed, - partial(send_event, ChatEvent.STATUS), - custom_filters, - uploaded_image_url=uploaded_image_url, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - online_results = result - except ValueError as e: - error_message = f"Error searching online: {e}. Attempting to respond without online results" - logger.warning(error_message) - async for result in send_llm_response(error_message): - yield result - return - - ## Gather Webpage References - if ConversationCommand.Webpage in conversation_commands: - try: - async for result in read_webpages( - defiltered_query, - meta_log, - location, - user, - subscribed, - partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - direct_web_pages = result - webpages = [] - for query in direct_web_pages: - if online_results.get(query): - online_results[query]["webpages"] = direct_web_pages[query]["webpages"] - else: - online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} - - for webpage in direct_web_pages[query]["webpages"]: - webpages.append(webpage["link"]) - async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): - yield result - except ValueError as e: - logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", - exc_info=True, - ) - - ## Send Gathered References - async for result in send_event( - ChatEvent.REFERENCES, - { - "inferredQueries": inferred_queries, - "context": compiled_references, - "onlineContext": online_results, - }, - ): - yield result - - # Generate Output - ## Generate Image Output - if ConversationCommand.Image in conversation_commands: - async for result in text_to_image( - q, - user, - meta_log, - location_data=location, - references=compiled_references, - online_results=online_results, - subscribed=subscribed, - send_status_func=partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - image, status_code, improved_image_prompt, intent_type = result - - if image is None or status_code != 200: - content_obj = { - "content-type": "application/json", - "intentType": intent_type, - "detail": improved_image_prompt, - "image": image, - } - async for result in send_llm_response(json.dumps(content_obj)): - yield result - return - - await sync_to_async(save_to_conversation_log)( - q, - image, - user, - meta_log, - user_message_time, - intent_type=intent_type, - inferred_queries=[improved_image_prompt], - client_application=request.user.client_app, - conversation_id=conversation_id, - compiled_references=compiled_references, - online_results=online_results, - uploaded_image_url=uploaded_image_url, - ) - content_obj = { - "intentType": intent_type, - "inferredQueries": [improved_image_prompt], - "image": image, - } - async for result in send_llm_response(json.dumps(content_obj)): - yield result - return - - ## Generate Text Output - async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): - yield result - llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, - meta_log, - conversation, - compiled_references, - online_results, - inferred_queries, - conversation_commands, - user, - request.user.client_app, - conversation_id, - location, - user_name, - uploaded_image_url, - ) - - # Send Response - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result - - continue_stream = True - iterator = AsyncIteratorWrapper(llm_response) - async for item in iterator: - if item is None: - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - logger.debug("Finished streaming response") - return - if not connection_alive or not continue_stream: - continue - try: - async for result in send_event(ChatEvent.MESSAGE, f"{item}"): - yield result - except Exception as e: - continue_stream = False - logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") - - ## Stream Text Response - if stream: - return StreamingResponse(event_generator(q, image=image), media_type="text/plain") - ## Non-Streaming Text Response - else: - response_iterator = event_generator(q, image=image) - response_data = await read_chat_stream(response_iterator) - return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index fdb1aa12..279ad85e 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -14,6 +14,7 @@ from typing import ( Annotated, Any, AsyncGenerator, + Callable, Dict, Iterator, List, @@ -39,6 +40,7 @@ from khoj.database.adapters import ( AutomationAdapters, ConversationAdapters, EntryAdapters, + FileObjectAdapters, create_khoj_token, get_khoj_tokens, get_user_name, @@ -614,6 +616,58 @@ async def extract_relevant_summary( return response.strip() +async def generate_summary_from_files( + q: str, + user: KhojUser, + file_filters: List[str], + meta_log: dict, + subscribed: bool, + uploaded_image_url: str = None, + agent: Agent = None, + send_status_func: Optional[Callable] = None, + send_response_func: Optional[Callable] = None, +): + try: + file_object = None + if await EntryAdapters.aagent_has_entries(agent): + file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) + if len(file_names) > 0: + file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names[0], agent) + + if len(file_filters) > 0: + file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) + + if len(file_object) == 0: + response_log = ( + "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again." + ) + async for result in send_response_func(response_log): + yield result + return + contextual_data = " ".join([file.raw_text for file in file_object]) + if not q: + q = "Create a general summary of the file" + async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"): + yield result + + response = await extract_relevant_summary( + q, + contextual_data, + conversation_history=meta_log, + subscribed=subscribed, + uploaded_image_url=uploaded_image_url, + agent=agent, + ) + response_log = str(response) + async for result in send_response_func(response_log): + yield result + except Exception as e: + response_log = "Error summarizing file. Please try again, or contact support." + logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) + async for result in send_response_func(response_log): + yield result + + async def generate_better_image_prompt( q: str, conversation_history: str, @@ -893,6 +947,7 @@ def generate_chat_response( q: str, meta_log: dict, conversation: Conversation, + meta_research: str = "", compiled_references: List[Dict] = [], online_results: Dict[str, Dict] = {}, inferred_queries: List[str] = [], @@ -910,6 +965,9 @@ def generate_chat_response( metadata = {} agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None + query_to_run = q + if meta_research: + query_to_run = f"AI Research: {meta_research} {q}" try: partial_completion = partial( save_to_conversation_log, @@ -937,7 +995,7 @@ def generate_chat_response( chat_response = converse_offline( references=compiled_references, online_results=online_results, - user_query=q, + user_query=query_to_run, loaded_model=loaded_model, conversation_log=meta_log, completion_func=partial_completion, @@ -956,7 +1014,7 @@ def generate_chat_response( chat_model = conversation_config.chat_model chat_response = converse( compiled_references, - q, + query_to_run, image_url=uploaded_image_url, online_results=online_results, conversation_log=meta_log, @@ -977,7 +1035,7 @@ def generate_chat_response( api_key = conversation_config.openai_config.api_key chat_response = converse_anthropic( compiled_references, - q, + query_to_run, online_results, meta_log, model=conversation_config.chat_model, @@ -994,7 +1052,7 @@ def generate_chat_response( api_key = conversation_config.openai_config.api_key chat_response = converse_gemini( compiled_references, - q, + query_to_run, online_results, meta_log, model=conversation_config.chat_model, diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py new file mode 100644 index 00000000..65c3f42d --- /dev/null +++ b/src/khoj/routers/research.py @@ -0,0 +1,261 @@ +import json +import logging +from typing import Any, Callable, Dict, List, Optional + +from fastapi import Request + +from khoj.database.adapters import EntryAdapters +from khoj.database.models import Agent, KhojUser +from khoj.processor.conversation import prompts +from khoj.processor.conversation.utils import remove_json_codeblock +from khoj.processor.tools.online_search import read_webpages, search_online +from khoj.routers.api import extract_references_and_questions +from khoj.routers.helpers import ( + ChatEvent, + construct_chat_history, + generate_summary_from_files, + send_message_to_model_wrapper, +) +from khoj.utils.helpers import ( + ConversationCommand, + function_calling_description_for_llm, + timer, +) +from khoj.utils.rawconfig import LocationData + +logger = logging.getLogger(__name__) + + +class InformationCollectionIteration: + def __init__( + self, data_source: str, query: str, context: str = None, onlineContext: str = None, result: Any = None + ): + self.data_source = data_source + self.query = query + self.context = context + self.onlineContext = onlineContext + + +async def apick_next_tool( + query: str, + conversation_history: dict, + subscribed: bool, + uploaded_image_url: str = None, + agent: Agent = None, + previous_iterations: List[InformationCollectionIteration] = None, +): + """ + Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer. + """ + + tool_options = dict() + tool_options_str = "" + + agent_tools = agent.input_tools if agent else [] + + for tool, description in function_calling_description_for_llm.items(): + tool_options[tool.value] = description + if len(agent_tools) == 0 or tool.value in agent_tools: + tool_options_str += f'- "{tool.value}": "{description}"\n' + + chat_history = construct_chat_history(conversation_history) + + previous_iterations_history = "" + for iteration in previous_iterations: + iteration_data = prompts.previous_iteration.format( + query=iteration.query, + data_source=iteration.data_source, + context=str(iteration.context), + onlineContext=str(iteration.onlineContext), + ) + + previous_iterations_history += iteration_data + + if uploaded_image_url: + query = f"[placeholder for user attached image]\n{query}" + + personality_context = ( + prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" + ) + + function_planning_prompt = prompts.plan_function_execution.format( + query=query, + tools=tool_options_str, + chat_history=chat_history, + personality_context=personality_context, + previous_iterations=previous_iterations_history, + ) + + with timer("Chat actor: Infer information sources to refer", logger): + response = await send_message_to_model_wrapper( + function_planning_prompt, + response_type="json_object", + subscribed=subscribed, + ) + + try: + response = response.strip() + response = remove_json_codeblock(response) + response = json.loads(response) + suggested_data_source = response.get("data_source", None) + suggested_query = response.get("query", None) + + return InformationCollectionIteration( + data_source=suggested_data_source, + query=suggested_query, + ) + + except Exception as e: + logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True) + return InformationCollectionIteration( + data_source=None, + query=None, + ) + + +async def execute_information_collection( + request: Request, + user: KhojUser, + query: str, + conversation_id: str, + conversation_history: dict, + subscribed: bool, + uploaded_image_url: str = None, + agent: Agent = None, + send_status_func: Optional[Callable] = None, + location: LocationData = None, + file_filters: List[str] = [], +): + iteration = 0 + MAX_ITERATIONS = 2 + previous_iterations = [] + while iteration < MAX_ITERATIONS: + online_results: Dict = dict() + compiled_references, inferred_queries, defiltered_query = [], [], None + this_iteration = await apick_next_tool( + query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations + ) + if this_iteration.data_source == ConversationCommand.Notes: + ## Extract Document References + compiled_references, inferred_queries, defiltered_query = [], [], None + async for result in extract_references_and_questions( + request, + conversation_history, + this_iteration.query, + 7, + None, + conversation_id, + [ConversationCommand.Default], + location, + send_status_func, + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + previous_iterations.append( + InformationCollectionIteration( + data_source=this_iteration.data_source, + query=this_iteration.query, + context=str(compiled_references), + ) + ) + + elif this_iteration.data_source == ConversationCommand.Online: + async for result in search_online( + this_iteration.query, + conversation_history, + location, + user, + subscribed, + send_status_func, + [], + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + online_results = result + previous_iterations.append( + InformationCollectionIteration( + data_source=this_iteration.data_source, + query=this_iteration.query, + onlineContext=online_results, + ) + ) + + elif this_iteration.data_source == ConversationCommand.Webpage: + async for result in read_webpages( + this_iteration.query, + conversation_history, + location, + user, + subscribed, + send_status_func, + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + direct_web_pages = result + + webpages = [] + for query in direct_web_pages: + if online_results.get(query): + online_results[query]["webpages"] = direct_web_pages[query]["webpages"] + else: + online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + + for webpage in direct_web_pages[query]["webpages"]: + webpages.append(webpage["link"]) + yield send_status_func(f"**Read web pages**: {webpages}") + + previous_iterations.append( + InformationCollectionIteration( + data_source=this_iteration.data_source, + query=this_iteration.query, + onlineContext=online_results, + ) + ) + + elif this_iteration.data_source == ConversationCommand.Summarize: + response_log = "" + agent_has_entries = await EntryAdapters.aagent_has_entries(agent) + if len(file_filters) == 0 and not agent_has_entries: + previous_iterations.append( + InformationCollectionIteration( + data_source=this_iteration.data_source, + query=this_iteration.query, + context="No files selected for summarization.", + ) + ) + elif len(file_filters) > 1 and not agent_has_entries: + response_log = "Only one file can be selected for summarization." + previous_iterations.append( + InformationCollectionIteration( + data_source=this_iteration.data_source, + query=this_iteration.query, + context=response_log, + ) + ) + else: + response_log = await generate_summary_from_files( + q=query, + user=user, + file_filters=file_filters, + meta_log=conversation_history, + subscribed=subscribed, + send_status_func=send_status_func, + ) + else: + iteration = MAX_ITERATIONS + + iteration += 1 + for completed_iter in previous_iterations: + yield completed_iter diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 4c7bf985..8538aace 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -345,6 +345,13 @@ tool_descriptions_for_llm = { ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.", } +function_calling_description_for_llm = { + ConversationCommand.Notes: "Use this if you think the user's personal knowledge base contains relevant context.", + ConversationCommand.Online: "Use this if you think the there's important information on the internet related to the query.", + ConversationCommand.Webpage: "Use this if the user has provided a webpage URL or you are share of a webpage URL that will help you directly answer this query", + ConversationCommand.Summarize: "Use this if you want to retrieve an answer that depends on reading an entire corpus.", +} + mode_descriptions_for_llm = { ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.", ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.",