diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 8f010fc2..d91dd596 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -61,6 +61,36 @@ async def search( dedupe: Optional[bool] = True, ): user = request.user.object + + results = await execute_search( + user=user, + q=q, + n=n, + t=t, + r=r, + max_distance=max_distance, + dedupe=dedupe, + ) + + update_telemetry_state( + request=request, + telemetry_type="api", + api="search", + **common.__dict__, + ) + + return results + + +async def execute_search( + user: KhojUser, + q: str, + n: Optional[int] = 5, + t: Optional[SearchType] = SearchType.All, + r: Optional[bool] = False, + max_distance: Optional[Union[float, None]] = None, + dedupe: Optional[bool] = True, +): start_time = time.time() # Run validation checks @@ -155,13 +185,6 @@ async def search( if user: state.query_cache[user.uuid][query_cache_key] = results - update_telemetry_state( - request=request, - telemetry_type="api", - api="search", - **common.__dict__, - ) - end_time = time.time() logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds") @@ -349,14 +372,14 @@ async def extract_references_and_questions( for query in inferred_queries: n_items = min(n, 3) if using_offline_chat else n result_list.extend( - await search( + await execute_search( + user, f"{query} {filters_in_query}", - request=request, n=n_items, + t=SearchType.All, r=True, max_distance=d, dedupe=False, - common=common, ) ) result_list = text_search.deduplicated_search_responses(result_list) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index be0ac6c1..dc240a78 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -5,10 +5,12 @@ from typing import Dict, Optional from urllib.parse import unquote from asgiref.sync import sync_to_async -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends, Request, WebSocket from fastapi.requests import Request from fastapi.responses import Response, StreamingResponse from starlette.authentication import requires +from starlette.websockets import WebSocketDisconnect +from websockets import ConnectionClosedOK from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name from khoj.database.models import KhojUser @@ -229,6 +231,220 @@ async def set_conversation_title( ) +@api_chat.websocket("/ws") +async def websocket_endpoint( + websocket: WebSocket, + conversation_id: int, + city: Optional[str] = None, + region: Optional[str] = None, + country: Optional[str] = None, +): + connection_alive = True + + async def send_status_update(message: str): + nonlocal connection_alive + if not connection_alive: + return + + status_packet = { + "type": "status", + "message": message, + "content-type": "application/json", + } + try: + await websocket.send_text(json.dumps(status_packet)) + except ConnectionClosedOK: + connection_alive = False + logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") + + async def send_complete_llm_response(llm_response: str): + nonlocal connection_alive + if not connection_alive: + return + try: + await websocket.send_text("start_llm_response") + await websocket.send_text(llm_response) + await websocket.send_text("end_llm_response") + except ConnectionClosedOK: + connection_alive = False + logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") + + user: KhojUser = websocket.user.object + conversation = await ConversationAdapters.aget_conversation_by_user( + user, client_application=websocket.user.client_app, conversation_id=conversation_id + ) + + hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") + + daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") + + await is_ready_to_chat(user) + + user_name = await aget_user_name(user) + + location = None + + if city or region or country: + location = LocationData(city=city, region=region, country=country) + + await websocket.accept() + while connection_alive: + try: + q = await websocket.receive_text() + except WebSocketDisconnect: + logger.debug(f"User {user} disconnected web socket") + break + + await sync_to_async(hourly_limiter)(websocket) + await sync_to_async(daily_limiter)(websocket) + + conversation_commands = [get_conversation_command(query=q, any_references=True)] + + await send_status_update(f"**Processing query**: {q}") + + if conversation_commands == [ConversationCommand.Help]: + conversation_config = await ConversationAdapters.aget_user_conversation_config(user) + if conversation_config == None: + conversation_config = await ConversationAdapters.aget_default_conversation_config() + model_type = conversation_config.model_type + formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) + await send_complete_llm_response(formatted_help) + continue + + meta_log = conversation.conversation_log + + if conversation_commands == [ConversationCommand.Default]: + conversation_commands = await aget_relevant_information_sources(q, meta_log) + mode = await aget_relevant_output_modes(q, meta_log) + if mode not in conversation_commands: + conversation_commands.append(mode) + + for cmd in conversation_commands: + await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd) + q = q.replace(f"/{cmd.value}", "").strip() + + await send_status_update( + f"**Using conversation commands:** {', '.join([cmd.value for cmd in conversation_commands])}" + ) + + compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( + websocket, None, meta_log, q, 7, 0.18, conversation_commands, location + ) + + if compiled_references: + headings = set([c.split("\n")[0] for c in compiled_references]) + await send_status_update(f"**Searching references**: {headings}") + + online_results: Dict = dict() + + if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): + await send_complete_llm_response(f"{no_entries_found.format()}") + continue + + if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): + conversation_commands.remove(ConversationCommand.Notes) + + if ConversationCommand.Online in conversation_commands: + try: + await send_status_update("Searching the web for relevant information 🌐") + online_results = await search_online(defiltered_query, meta_log, location) + online_searches = "".join([f"{query}" for query in online_results.keys()]) + await send_status_update(f"**Online searches**: {online_searches}") + except ValueError as e: + await send_complete_llm_response( + "Please set your SERPER_DEV_API_KEY to get started with online searches 🌐" + ) + continue + + if ConversationCommand.Image in conversation_commands: + update_telemetry_state( + request=websocket, + telemetry_type="api", + api="chat", + metadata={"conversation_command": conversation_commands[0].value}, + ) + intent_type = "text-to-image" + image, status_code, improved_image_prompt, image_url = await text_to_image( + q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results + ) + if image is None or status_code != 200: + content_obj = { + "image": image, + "intentType": intent_type, + "detail": improved_image_prompt, + "content-type": "application/json", + } + await send_complete_llm_response(json.dumps(content_obj)) + continue + + if image_url: + intent_type = "text-to-image2" + image = image_url + await sync_to_async(save_to_conversation_log)( + q, + image, + user, + meta_log, + intent_type=intent_type, + inferred_queries=[improved_image_prompt], + client_application=websocket.user.client_app, + conversation_id=conversation_id, + compiled_references=compiled_references, + online_results=online_results, + ) + content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "content-type": "application/json", "online_results": online_results} # type: ignore + + await send_complete_llm_response(json.dumps(content_obj)) + continue + + llm_response, chat_metadata = await agenerate_chat_response( + defiltered_query, + meta_log, + conversation, + compiled_references, + online_results, + inferred_queries, + conversation_commands, + user, + websocket.user.client_app, + conversation_id, + location, + user_name, + ) + + update_telemetry_state( + request=websocket, + telemetry_type="api", + api="chat", + metadata=chat_metadata, + ) + iterator = AsyncIteratorWrapper(llm_response) + + if connection_alive: + try: + await websocket.send_text("start_llm_response") + except ConnectionClosedOK: + connection_alive = False + logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") + + async for item in iterator: + if item is None: + break + if connection_alive: + try: + await websocket.send_text(f"{item}") + except ConnectionClosedOK: + connection_alive = False + logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") + + if connection_alive: + try: + await websocket.send_text("end_llm_response") + except ConnectionClosedOK: + connection_alive = False + logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") + + @api_chat.get("", response_class=Response) @requires(["authenticated"]) async def chat(