diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 92846020..248a78e8 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -763,9 +763,9 @@ class AgentAdapters: return False @staticmethod - def get_conversation_agent_by_id(agent_id: int): - agent = Agent.objects.filter(id=agent_id).first() - if agent == AgentAdapters.get_default_agent(): + async def aget_conversation_agent_by_id(agent_id: int): + agent = await Agent.objects.filter(id=agent_id).afirst() + if agent == await AgentAdapters.aget_default_agent(): # If the agent is set to the default agent, then return None and let the default application code be used return None return agent @@ -1109,14 +1109,6 @@ class ConversationAdapters: async def aget_all_chat_models(): return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all()) - @staticmethod - def get_vision_enabled_config(): - chat_models = ConversationAdapters.get_all_chat_models() - for config in chat_models: - if config.vision_enabled: - return config - return None - @staticmethod async def aget_vision_enabled_config(): chat_models = await ConversationAdapters.aget_all_chat_models() @@ -1171,7 +1163,11 @@ class ConversationAdapters: @staticmethod async def aget_chat_model(user: KhojUser): subscribed = await ais_user_subscribed(user) - config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() + config = ( + await UserConversationConfig.objects.filter(user=user) + .prefetch_related("setting", "setting__ai_model_api") + .afirst() + ) if subscribed: # Subscibed users can use any available chat model if config: @@ -1387,7 +1383,7 @@ class ConversationAdapters: @staticmethod @require_valid_user - def save_conversation( + async def save_conversation( user: KhojUser, conversation_log: dict, client_application: ClientApplication = None, @@ -1396,19 +1392,21 @@ class ConversationAdapters: ): slug = user_message.strip()[:200] if user_message else None if conversation_id: - conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first() + conversation = await Conversation.objects.filter( + user=user, client=client_application, id=conversation_id + ).afirst() else: conversation = ( - Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first() + await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() ) if conversation: conversation.conversation_log = conversation_log conversation.slug = slug conversation.updated_at = datetime.now(tz=timezone.utc) - conversation.save() + await conversation.asave() else: - Conversation.objects.create( + await Conversation.objects.acreate( user=user, conversation_log=conversation_log, client=client_application, slug=slug ) @@ -1455,17 +1453,21 @@ class ConversationAdapters: return random.sample(all_questions, max_results) @staticmethod - def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool): + async def aget_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool): agent: Agent = ( - conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None + conversation.agent + if is_subscribed and await AgentAdapters.aget_default_agent() != conversation.agent + else None ) if agent and agent.chat_model: - chat_model = conversation.agent.chat_model + chat_model = await ChatModel.objects.select_related("ai_model_api").aget( + pk=conversation.agent.chat_model.pk + ) else: - chat_model = ConversationAdapters.get_chat_model(user) + chat_model = await ConversationAdapters.aget_chat_model(user) if chat_model is None: - chat_model = ConversationAdapters.get_default_chat_model() + chat_model = await ConversationAdapters.aget_default_chat_model() if chat_model.model_type == ChatModel.ModelType.OFFLINE: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 11e6a03d..b5fbdcf2 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -1,6 +1,6 @@ import logging from datetime import datetime, timedelta -from typing import Dict, List, Optional +from typing import AsyncGenerator, Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage @@ -162,7 +162,7 @@ def send_message_to_model( ) -def converse_openai( +async def converse_openai( references, user_query, online_results: Optional[Dict[str, Dict]] = None, @@ -187,7 +187,7 @@ def converse_openai( program_execution_context: List[str] = None, deepthought: Optional[bool] = False, tracer: dict = {}, -): +) -> AsyncGenerator[str, None]: """ Converse with user using OpenAI's ChatGPT """ @@ -217,11 +217,17 @@ def converse_openai( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - completion_func(chat_response=prompts.no_notes_found.format()) - return iter([prompts.no_notes_found.format()]) + response = prompts.no_notes_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - completion_func(chat_response=prompts.no_online_results_found.format()) - return iter([prompts.no_online_results_found.format()]) + response = prompts.no_online_results_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return context_message = "" if not is_none_or_empty(references): @@ -255,19 +261,23 @@ def converse_openai( logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") # Get Response from GPT - return chat_completion_with_backoff( + full_response = "" + async for chunk in chat_completion_with_backoff( messages=messages, - compiled_references=references, - online_results=online_results, model_name=model, temperature=temperature, openai_api_key=api_key, api_base_url=api_base_url, - completion_func=completion_func, deepthought=deepthought, model_kwargs={"stop": ["Notes:\n["]}, tracer=tracer, - ) + ): + full_response += chunk + yield chunk + + # Call completion_func once finish streaming and we have the full response + if completion_func: + await completion_func(chat_response=full_response) def clean_response_schema(schema: BaseModel | dict) -> dict: diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index b73903ae..7fab44aa 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -1,7 +1,7 @@ import logging import os -from threading import Thread -from typing import Dict, List +from time import perf_counter +from typing import AsyncGenerator, Dict, List from urllib.parse import urlparse import openai @@ -16,13 +16,10 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ( - JsonSupport, - ThreadedGenerator, - commit_conversation_trace, -) +from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace from khoj.utils.helpers import ( get_chat_usage_metrics, + get_openai_async_client, get_openai_client, is_promptrace_enabled, ) @@ -30,6 +27,7 @@ from khoj.utils.helpers import ( logger = logging.getLogger(__name__) openai_clients: Dict[str, openai.OpenAI] = {} +openai_async_clients: Dict[str, openai.AsyncOpenAI] = {} @retry( @@ -124,45 +122,22 @@ def completion_with_backoff( before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def chat_completion_with_backoff( +async def chat_completion_with_backoff( messages, - compiled_references, - online_results, model_name, temperature, openai_api_key=None, api_base_url=None, - completion_func=None, - deepthought=False, - model_kwargs=None, - tracer: dict = {}, -): - g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) - t = Thread( - target=llm_thread, - args=(g, messages, model_name, temperature, openai_api_key, api_base_url, deepthought, model_kwargs, tracer), - ) - t.start() - return g - - -def llm_thread( - g, - messages, - model_name: str, - temperature, - openai_api_key=None, - api_base_url=None, deepthought=False, model_kwargs: dict = {}, tracer: dict = {}, -): +) -> AsyncGenerator[str, None]: try: client_key = f"{openai_api_key}--{api_base_url}" - client = openai_clients.get(client_key) + client = openai_async_clients.get(client_key) if not client: - client = get_openai_client(openai_api_key, api_base_url) - openai_clients[client_key] = client + client = get_openai_async_client(openai_api_key, api_base_url) + openai_async_clients[client_key] = client formatted_messages = [{"role": message.role, "content": message.content} for message in messages] @@ -207,53 +182,58 @@ def llm_thread( if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) - chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( - messages=formatted_messages, - model=model_name, # type: ignore + aggregated_response = "" + final_chunk = None + start_time = perf_counter() + chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( + messages=formatted_messages, # type: ignore + model=model_name, stream=stream, temperature=temperature, timeout=20, **model_kwargs, ) + async for chunk in chat_stream: + # Log the time taken to start response + if final_chunk is None: + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Keep track of the last chunk for usage data + final_chunk = chunk + # Handle streamed response chunk + if len(chunk.choices) == 0: + continue + delta_chunk = chunk.choices[0].delta + text_chunk = "" + if isinstance(delta_chunk, str): + text_chunk = delta_chunk + elif delta_chunk and delta_chunk.content: + text_chunk = delta_chunk.content + if text_chunk: + aggregated_response += text_chunk + yield text_chunk - aggregated_response = "" - if not stream: - chunk = chat - aggregated_response = chunk.choices[0].message.content - g.send(aggregated_response) - else: - for chunk in chat: - if len(chunk.choices) == 0: - continue - delta_chunk = chunk.choices[0].delta - text_chunk = "" - if isinstance(delta_chunk, str): - text_chunk = delta_chunk - elif delta_chunk.content: - text_chunk = delta_chunk.content - if text_chunk: - aggregated_response += text_chunk - g.send(text_chunk) + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") - # Calculate cost of chat - input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 - output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0 - cost = ( - chunk.usage.model_extra.get("estimated_cost", 0) if hasattr(chunk, "usage") and chunk.usage else 0 - ) # Estimated costs returned by DeepInfra API - tracer["usage"] = get_chat_usage_metrics( - model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost - ) + # Calculate cost of chat after stream finishes + input_tokens, output_tokens, cost = 0, 0, 0 + if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage: + input_tokens = final_chunk.usage.prompt_tokens + output_tokens = final_chunk.usage.completion_tokens + # Estimated costs returned by DeepInfra API + if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra: + cost = final_chunk.usage.model_extra.get("estimated_cost", 0) # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost + ) if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: - logger.error(f"Error in llm_thread: {e}", exc_info=True) - finally: - g.close() + logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True) def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 01c25cf4..2bb6265e 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -254,7 +254,7 @@ def message_to_log( return conversation_log -def save_to_conversation_log( +async def save_to_conversation_log( q: str, chat_response: str, user: KhojUser, @@ -306,7 +306,7 @@ def save_to_conversation_log( khoj_message_metadata=khoj_message_metadata, conversation_log=meta_log.get("chat", []), ) - ConversationAdapters.save_conversation( + await ConversationAdapters.save_conversation( user, {"chat": updated_conversation}, client_application=client_application, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 03f367dd..dd951238 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -67,7 +67,6 @@ from khoj.routers.research import ( from khoj.routers.storage import upload_user_image_to_bucket from khoj.utils import state from khoj.utils.helpers import ( - AsyncIteratorWrapper, ConversationCommand, command_descriptions, convert_image_to_webp, @@ -999,7 +998,7 @@ async def chat( return llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( + await save_to_conversation_log( q, llm_response, user, @@ -1308,26 +1307,31 @@ async def chat( yield result continue_stream = True - iterator = AsyncIteratorWrapper(llm_response) - async for item in iterator: + async for item in llm_response: + # Should not happen with async generator, end is signaled by loop exit. Skip. if item is None: - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - # Send Usage Metadata once llm interactions are complete - async for event in send_event(ChatEvent.USAGE, tracer.get("usage")): - yield event - async for result in send_event(ChatEvent.END_RESPONSE, ""): - yield result - logger.debug("Finished streaming response") - return + continue if not connection_alive or not continue_stream: + # Drain the generator if disconnected but keep processing internally continue try: async for result in send_event(ChatEvent.MESSAGE, f"{item}"): yield result except Exception as e: continue_stream = False - logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}") + + # Signal end of LLM response after the loop finishes + if connection_alive: + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + # Send Usage Metadata once llm interactions are complete + if tracer.get("usage"): + async for event in send_event(ChatEvent.USAGE, tracer.get("usage")): + yield event + async for result in send_event(ChatEvent.END_RESPONSE, ""): + yield result + logger.debug("Finished streaming response") ## Stream Text Response if stream: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 742a8708..6042469a 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,4 +1,3 @@ -import asyncio import base64 import hashlib import json @@ -6,9 +5,7 @@ import logging import math import os import re -from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone -from enum import Enum from functools import partial from random import random from typing import ( @@ -17,7 +14,6 @@ from typing import ( AsyncGenerator, Callable, Dict, - Iterator, List, Optional, Set, @@ -126,8 +122,6 @@ from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, Loca logger = logging.getLogger(__name__) -executor = ThreadPoolExecutor(max_workers=1) - NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID") NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET") @@ -262,11 +256,6 @@ def get_conversation_command(query: str) -> ConversationCommand: return ConversationCommand.Default -async def agenerate_chat_response(*args): - loop = asyncio.get_event_loop() - return await loop.run_in_executor(executor, generate_chat_response, *args) - - def gather_raw_query_files( query_files: Dict[str, str], ): @@ -1418,7 +1407,7 @@ def send_message_to_model_wrapper_sync( raise HTTPException(status_code=500, detail="Invalid conversation config") -def generate_chat_response( +async def agenerate_chat_response( q: str, meta_log: dict, conversation: Conversation, @@ -1444,13 +1433,14 @@ def generate_chat_response( generated_asset_results: Dict[str, Dict] = {}, is_subscribed: bool = False, tracer: dict = {}, -) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: +) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]: # Initialize Variables - chat_response = None + chat_response_generator = None logger.debug(f"Conversation Types: {conversation_commands}") metadata = {} - agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None + agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None + try: partial_completion = partial( save_to_conversation_log, @@ -1481,23 +1471,25 @@ def generate_chat_response( code_results = {} deepthought = True - chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed) + chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed) vision_available = chat_model.vision_enabled if not vision_available and query_images: - vision_enabled_config = ConversationAdapters.get_vision_enabled_config() + vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() if vision_enabled_config: chat_model = vision_enabled_config vision_available = True if chat_model.model_type == "offline": + # Assuming converse_offline remains sync or is refactored separately loaded_model = state.offline_chat_processor_config.loaded_model - chat_response = converse_offline( + # If converse_offline returns an iterator, wrap it if needed, or refactor it to async generator + chat_response_generator = converse_offline( # Needs adaptation if it becomes async user_query=query_to_run, references=compiled_references, online_results=online_results, loaded_model=loaded_model, conversation_log=meta_log, - completion_func=partial_completion, + completion_func=partial_completion, # Pass the async wrapper conversation_commands=conversation_commands, model_name=chat_model.name, max_prompt_size=chat_model.max_prompt_size, @@ -1515,7 +1507,7 @@ def generate_chat_response( openai_chat_config = chat_model.ai_model_api api_key = openai_chat_config.api_key chat_model_name = chat_model.name - chat_response = converse_openai( + chat_response_generator = converse_openai( compiled_references, query_to_run, query_images=query_images, @@ -1542,9 +1534,10 @@ def generate_chat_response( ) elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: + # Assuming converse_anthropic remains sync or is refactored separately api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - chat_response = converse_anthropic( + chat_response_generator = converse_anthropic( # Needs adaptation if it becomes async compiled_references, query_to_run, query_images=query_images, @@ -1570,9 +1563,10 @@ def generate_chat_response( tracer=tracer, ) elif chat_model.model_type == ChatModel.ModelType.GOOGLE: + # Assuming converse_gemini remains sync or is refactored separately api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - chat_response = converse_gemini( + chat_response_generator = converse_gemini( # Needs adaptation if it becomes async compiled_references, query_to_run, online_results, @@ -1604,7 +1598,8 @@ def generate_chat_response( logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) - return chat_response, metadata + # Return the generator directly + return chat_response_generator, metadata class DeleteMessageRequestBody(BaseModel):