diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index e07e3881..be1e1dcc 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -5,7 +5,7 @@ import math import os import time import uuid -from typing import Any, List, Optional, Union +from typing import Any, Callable, List, Optional, Union from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile @@ -274,6 +274,7 @@ async def extract_references_and_questions( d: float, conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], location_data: LocationData = None, + send_status_func: Optional[Callable] = None, ): user = request.user.object if request.user.is_authenticated else None @@ -345,6 +346,8 @@ async def extract_references_and_questions( with timer("Searching knowledge base took", logger): result_list = [] logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}") + if send_status_func: + await send_status_func(f"**🔍 Searching Documents for:** {'\n- ' + '\n- '.join(inferred_queries)}") for query in inferred_queries: n_items = min(n, 3) if using_offline_chat else n result_list.extend( diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 223a32b1..4e7a8cc9 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -371,7 +371,7 @@ async def websocket_endpoint( q = q.replace(f"/{cmd.value}", "").strip() compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( - websocket, None, meta_log, q, 7, 0.18, conversation_commands, location + websocket, None, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update ) if compiled_references: