diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 92846020..248a78e8 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -763,9 +763,9 @@ class AgentAdapters: return False @staticmethod - def get_conversation_agent_by_id(agent_id: int): - agent = Agent.objects.filter(id=agent_id).first() - if agent == AgentAdapters.get_default_agent(): + async def aget_conversation_agent_by_id(agent_id: int): + agent = await Agent.objects.filter(id=agent_id).afirst() + if agent == await AgentAdapters.aget_default_agent(): # If the agent is set to the default agent, then return None and let the default application code be used return None return agent @@ -1109,14 +1109,6 @@ class ConversationAdapters: async def aget_all_chat_models(): return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all()) - @staticmethod - def get_vision_enabled_config(): - chat_models = ConversationAdapters.get_all_chat_models() - for config in chat_models: - if config.vision_enabled: - return config - return None - @staticmethod async def aget_vision_enabled_config(): chat_models = await ConversationAdapters.aget_all_chat_models() @@ -1171,7 +1163,11 @@ class ConversationAdapters: @staticmethod async def aget_chat_model(user: KhojUser): subscribed = await ais_user_subscribed(user) - config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() + config = ( + await UserConversationConfig.objects.filter(user=user) + .prefetch_related("setting", "setting__ai_model_api") + .afirst() + ) if subscribed: # Subscibed users can use any available chat model if config: @@ -1387,7 +1383,7 @@ class ConversationAdapters: @staticmethod @require_valid_user - def save_conversation( + async def save_conversation( user: KhojUser, conversation_log: dict, client_application: ClientApplication = None, @@ -1396,19 +1392,21 @@ class ConversationAdapters: ): slug = user_message.strip()[:200] if user_message else None if conversation_id: - conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first() + conversation = await Conversation.objects.filter( + user=user, client=client_application, id=conversation_id + ).afirst() else: conversation = ( - Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first() + await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() ) if conversation: conversation.conversation_log = conversation_log conversation.slug = slug conversation.updated_at = datetime.now(tz=timezone.utc) - conversation.save() + await conversation.asave() else: - Conversation.objects.create( + await Conversation.objects.acreate( user=user, conversation_log=conversation_log, client=client_application, slug=slug ) @@ -1455,17 +1453,21 @@ class ConversationAdapters: return random.sample(all_questions, max_results) @staticmethod - def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool): + async def aget_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool): agent: Agent = ( - conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None + conversation.agent + if is_subscribed and await AgentAdapters.aget_default_agent() != conversation.agent + else None ) if agent and agent.chat_model: - chat_model = conversation.agent.chat_model + chat_model = await ChatModel.objects.select_related("ai_model_api").aget( + pk=conversation.agent.chat_model.pk + ) else: - chat_model = ConversationAdapters.get_chat_model(user) + chat_model = await ConversationAdapters.aget_chat_model(user) if chat_model is None: - chat_model = ConversationAdapters.get_default_chat_model() + chat_model = await ConversationAdapters.aget_default_chat_model() if chat_model.model_type == ChatModel.ModelType.OFFLINE: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 977b25de..5bad38ef 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -1,6 +1,6 @@ import logging from datetime import datetime, timedelta -from typing import Dict, List, Optional +from typing import AsyncGenerator, Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage @@ -137,7 +137,7 @@ def anthropic_send_message_to_model( ) -def converse_anthropic( +async def converse_anthropic( references, user_query, online_results: Optional[Dict[str, Dict]] = None, @@ -161,7 +161,7 @@ def converse_anthropic( generated_asset_results: Dict[str, Dict] = {}, deepthought: Optional[bool] = False, tracer: dict = {}, -): +) -> AsyncGenerator[str, None]: """ Converse with user using Anthropic's Claude """ @@ -191,11 +191,17 @@ def converse_anthropic( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - completion_func(chat_response=prompts.no_notes_found.format()) - return iter([prompts.no_notes_found.format()]) + response = prompts.no_notes_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - completion_func(chat_response=prompts.no_online_results_found.format()) - return iter([prompts.no_online_results_found.format()]) + response = prompts.no_online_results_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return context_message = "" if not is_none_or_empty(references): @@ -228,17 +234,21 @@ def converse_anthropic( logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}") # Get Response from Claude - return anthropic_chat_completion_with_backoff( + full_response = "" + async for chunk in anthropic_chat_completion_with_backoff( messages=messages, - compiled_references=references, - online_results=online_results, model_name=model, temperature=0.2, api_key=api_key, api_base_url=api_base_url, system_prompt=system_prompt, - completion_func=completion_func, max_prompt_size=max_prompt_size, deepthought=deepthought, tracer=tracer, - ) + ): + full_response += chunk + yield chunk + + # Call completion_func once finish streaming and we have the full response + if completion_func: + await completion_func(chat_response=full_response) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 6c2ffb8a..442e1cb3 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -1,5 +1,5 @@ import logging -from threading import Thread +from time import perf_counter from typing import Dict, List import anthropic @@ -13,13 +13,13 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ( - ThreadedGenerator, commit_conversation_trace, get_image_from_base64, get_image_from_url, ) from khoj.utils.helpers import ( - get_ai_api_info, + get_anthropic_async_client, + get_anthropic_client, get_chat_usage_metrics, is_none_or_empty, is_promptrace_enabled, @@ -28,24 +28,12 @@ from khoj.utils.helpers import ( logger = logging.getLogger(__name__) anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {} +anthropic_async_clients: Dict[str, anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex] = {} DEFAULT_MAX_TOKENS_ANTHROPIC = 8000 MAX_REASONING_TOKENS_ANTHROPIC = 12000 -def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex: - api_info = get_ai_api_info(api_key, api_base_url) - if api_info.api_key: - client = anthropic.Anthropic(api_key=api_info.api_key) - else: - client = anthropic.AnthropicVertex( - region=api_info.region, - project_id=api_info.project, - credentials=api_info.credentials, - ) - return client - - @retry( wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(2), @@ -126,60 +114,23 @@ def anthropic_completion_with_backoff( before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def anthropic_chat_completion_with_backoff( +async def anthropic_chat_completion_with_backoff( messages: list[ChatMessage], - compiled_references, - online_results, model_name, temperature, api_key, api_base_url, system_prompt: str, max_prompt_size=None, - completion_func=None, - deepthought=False, - model_kwargs=None, - tracer={}, -): - g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) - t = Thread( - target=anthropic_llm_thread, - args=( - g, - messages, - system_prompt, - model_name, - temperature, - api_key, - api_base_url, - max_prompt_size, - deepthought, - model_kwargs, - tracer, - ), - ) - t.start() - return g - - -def anthropic_llm_thread( - g, - messages: list[ChatMessage], - system_prompt: str, - model_name: str, - temperature, - api_key, - api_base_url=None, - max_prompt_size=None, deepthought=False, model_kwargs=None, tracer={}, ): try: - client = anthropic_clients.get(api_key) + client = anthropic_async_clients.get(api_key) if not client: - client = get_anthropic_client(api_key, api_base_url) - anthropic_clients[api_key] = client + client = get_anthropic_async_client(api_key, api_base_url) + anthropic_async_clients[api_key] = client model_kwargs = model_kwargs or dict() max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC @@ -193,7 +144,8 @@ def anthropic_llm_thread( aggregated_response = "" final_message = None - with client.messages.stream( + start_time = perf_counter() + async with client.messages.stream( messages=formatted_messages, model=model_name, # type: ignore temperature=temperature, @@ -202,10 +154,17 @@ def anthropic_llm_thread( max_tokens=max_tokens, **model_kwargs, ) as stream: - for text in stream.text_stream: + async for text in stream.text_stream: + # Log the time taken to start response + if aggregated_response == "": + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Handle streamed response chunk aggregated_response += text - g.send(text) - final_message = stream.get_final_message() + yield text + final_message = await stream.get_final_message() + + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") # Calculate cost of chat input_tokens = final_message.usage.input_tokens @@ -222,9 +181,7 @@ def anthropic_llm_thread( if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: - logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) - finally: - g.close() + logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True) def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None): diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 73167ca2..3c42ef06 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -1,6 +1,6 @@ import logging from datetime import datetime, timedelta -from typing import Dict, List, Optional +from typing import AsyncGenerator, Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage @@ -160,7 +160,7 @@ def gemini_send_message_to_model( ) -def converse_gemini( +async def converse_gemini( references, user_query, online_results: Optional[Dict[str, Dict]] = None, @@ -185,7 +185,7 @@ def converse_gemini( program_execution_context: List[str] = None, deepthought: Optional[bool] = False, tracer={}, -): +) -> AsyncGenerator[str, None]: """ Converse with user using Google's Gemini """ @@ -216,11 +216,17 @@ def converse_gemini( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - completion_func(chat_response=prompts.no_notes_found.format()) - return iter([prompts.no_notes_found.format()]) + response = prompts.no_notes_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - completion_func(chat_response=prompts.no_online_results_found.format()) - return iter([prompts.no_online_results_found.format()]) + response = prompts.no_online_results_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return context_message = "" if not is_none_or_empty(references): @@ -253,16 +259,20 @@ def converse_gemini( logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}") # Get Response from Google AI - return gemini_chat_completion_with_backoff( + full_response = "" + async for chunk in gemini_chat_completion_with_backoff( messages=messages, - compiled_references=references, - online_results=online_results, model_name=model, temperature=temperature, api_key=api_key, api_base_url=api_base_url, system_prompt=system_prompt, - completion_func=completion_func, deepthought=deepthought, tracer=tracer, - ) + ): + full_response += chunk + yield chunk + + # Call completion_func once finish streaming and we have the full response + if completion_func: + await completion_func(chat_response=full_response) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 9a8b4132..d85cd09e 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -2,8 +2,8 @@ import logging import os import random from copy import deepcopy -from threading import Thread -from typing import Dict +from time import perf_counter +from typing import AsyncGenerator, AsyncIterator, Dict from google import genai from google.genai import errors as gerrors @@ -19,14 +19,13 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ( - ThreadedGenerator, commit_conversation_trace, get_image_from_base64, get_image_from_url, ) from khoj.utils.helpers import ( - get_ai_api_info, get_chat_usage_metrics, + get_gemini_client, is_none_or_empty, is_promptrace_enabled, ) @@ -62,17 +61,6 @@ SAFETY_SETTINGS = [ ] -def get_gemini_client(api_key, api_base_url=None) -> genai.Client: - api_info = get_ai_api_info(api_key, api_base_url) - return genai.Client( - location=api_info.region, - project=api_info.project, - credentials=api_info.credentials, - api_key=api_info.api_key, - vertexai=api_info.api_key is None, - ) - - @retry( wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(2), @@ -132,8 +120,8 @@ def gemini_completion_with_backoff( ) # Aggregate cost of chat - input_tokens = response.usage_metadata.prompt_token_count if response else 0 - output_tokens = response.usage_metadata.candidates_token_count if response else 0 + input_tokens = response.usage_metadata.prompt_token_count or 0 if response else 0 + output_tokens = response.usage_metadata.candidates_token_count or 0 if response else 0 thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0 tracer["usage"] = get_chat_usage_metrics( model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") @@ -154,52 +142,17 @@ def gemini_completion_with_backoff( before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def gemini_chat_completion_with_backoff( +async def gemini_chat_completion_with_backoff( messages, - compiled_references, - online_results, model_name, temperature, api_key, api_base_url, system_prompt, - completion_func=None, model_kwargs=None, deepthought=False, tracer: dict = {}, -): - g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) - t = Thread( - target=gemini_llm_thread, - args=( - g, - messages, - system_prompt, - model_name, - temperature, - api_key, - api_base_url, - model_kwargs, - deepthought, - tracer, - ), - ) - t.start() - return g - - -def gemini_llm_thread( - g, - messages, - system_prompt, - model_name, - temperature, - api_key, - api_base_url=None, - model_kwargs=None, - deepthought=False, - tracer: dict = {}, -): +) -> AsyncGenerator[str, None]: try: client = gemini_clients.get(api_key) if not client: @@ -224,21 +177,32 @@ def gemini_llm_thread( ) aggregated_response = "" - - for chunk in client.models.generate_content_stream( + final_chunk = None + start_time = perf_counter() + chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream( model=model_name, config=config, contents=formatted_messages - ): + ) + async for chunk in chat_stream: + # Log the time taken to start response + if final_chunk is None: + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Keep track of the last chunk for usage data + final_chunk = chunk + # Handle streamed response chunk message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) message = message or chunk.text aggregated_response += message - g.send(message) + yield message if stopped: raise ValueError(message) + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") + # Calculate cost of chat - input_tokens = chunk.usage_metadata.prompt_token_count - output_tokens = chunk.usage_metadata.candidates_token_count - thought_tokens = chunk.usage_metadata.thoughts_token_count or 0 + input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0 + output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0 + thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0 tracer["usage"] = get_chat_usage_metrics( model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") ) @@ -254,9 +218,7 @@ def gemini_llm_thread( + f"Last Message by {messages[-1].role}: {messages[-1].content}" ) except Exception as e: - logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True) - finally: - g.close() + logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True) def handle_gemini_response( diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index f727fd1d..b7f89c8d 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -1,9 +1,10 @@ -import json +import asyncio import logging import os from datetime import datetime, timedelta from threading import Thread -from typing import Any, Dict, Iterator, List, Optional, Union +from time import perf_counter +from typing import Any, AsyncGenerator, Dict, List, Optional, Union import pyjson5 from langchain.schema import ChatMessage @@ -13,7 +14,6 @@ from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( - ThreadedGenerator, clean_json, commit_conversation_trace, generate_chatml_messages_with_context, @@ -147,7 +147,7 @@ def filter_questions(questions: List[str]): return list(filtered_questions) -def converse_offline( +async def converse_offline( user_query, references=[], online_results={}, @@ -167,9 +167,9 @@ def converse_offline( additional_context: List[str] = None, generated_asset_results: Dict[str, Dict] = {}, tracer: dict = {}, -) -> Union[ThreadedGenerator, Iterator[str]]: +) -> AsyncGenerator[str, None]: """ - Converse with user using Llama + Converse with user using Llama (Async Version) """ # Initialize Variables assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" @@ -200,10 +200,17 @@ def converse_offline( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - return iter([prompts.no_notes_found.format()]) + response = prompts.no_notes_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - completion_func(chat_response=prompts.no_online_results_found.format()) - return iter([prompts.no_online_results_found.format()]) + response = prompts.no_online_results_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return context_message = "" if not is_none_or_empty(references): @@ -240,33 +247,77 @@ def converse_offline( logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}") - g = ThreadedGenerator(references, online_results, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer)) - t.start() - return g - - -def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}): + # Use asyncio.Queue and a thread to bridge sync iterator + queue: asyncio.Queue = asyncio.Queue() stop_phrases = ["", "INST]", "Notes:"] - aggregated_response = "" + aggregated_response_container = {"response": ""} - state.chat_lock.acquire() - try: - response_iterator = send_message_to_model_offline( - messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True - ) - for response in response_iterator: - response_delta = response["choices"][0]["delta"].get("content", "") - aggregated_response += response_delta - g.send(response_delta) + def _sync_llm_thread(): + """Synchronous function to run in a separate thread.""" + aggregated_response = "" + start_time = perf_counter() + state.chat_lock.acquire() + try: + response_iterator = send_message_to_model_offline( + messages, + loaded_model=offline_chat_model, + stop=stop_phrases, + max_prompt_size=max_prompt_size, + streaming=True, + tracer=tracer, + ) + for response in response_iterator: + response_delta = response["choices"][0]["delta"].get("content", "") + # Log the time taken to start response + if aggregated_response == "" and response_delta != "": + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Handle response chunk + aggregated_response += response_delta + # Put chunk into the asyncio queue (non-blocking) + try: + queue.put_nowait(response_delta) + except asyncio.QueueFull: + # Should not happen with default queue size unless consumer is very slow + logger.warning("Asyncio queue full during offline LLM streaming.") + # Potentially block here or handle differently if needed + asyncio.run(queue.put(response_delta)) - # Save conversation trace - if is_promptrace_enabled(): - commit_conversation_trace(messages, aggregated_response, tracer) + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") - finally: - state.chat_lock.release() - g.close() + # Save conversation trace + tracer["chat_model"] = model_name + if is_promptrace_enabled(): + commit_conversation_trace(messages, aggregated_response, tracer) + + except Exception as e: + logger.error(f"Error in offline LLM thread: {e}", exc_info=True) + finally: + state.chat_lock.release() + # Signal end of stream + queue.put_nowait(None) + aggregated_response_container["response"] = aggregated_response + + # Start the synchronous thread + thread = Thread(target=_sync_llm_thread) + thread.start() + + # Asynchronously consume from the queue + while True: + chunk = await queue.get() + if chunk is None: # End of stream signal + queue.task_done() + break + yield chunk + queue.task_done() + + # Wait for the thread to finish (optional, ensures cleanup) + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, thread.join) + + # Call the completion function after streaming is done + if completion_func: + await completion_func(chat_response=aggregated_response_container["response"]) def send_message_to_model_offline( diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 11e6a03d..b5fbdcf2 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -1,6 +1,6 @@ import logging from datetime import datetime, timedelta -from typing import Dict, List, Optional +from typing import AsyncGenerator, Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage @@ -162,7 +162,7 @@ def send_message_to_model( ) -def converse_openai( +async def converse_openai( references, user_query, online_results: Optional[Dict[str, Dict]] = None, @@ -187,7 +187,7 @@ def converse_openai( program_execution_context: List[str] = None, deepthought: Optional[bool] = False, tracer: dict = {}, -): +) -> AsyncGenerator[str, None]: """ Converse with user using OpenAI's ChatGPT """ @@ -217,11 +217,17 @@ def converse_openai( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - completion_func(chat_response=prompts.no_notes_found.format()) - return iter([prompts.no_notes_found.format()]) + response = prompts.no_notes_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - completion_func(chat_response=prompts.no_online_results_found.format()) - return iter([prompts.no_online_results_found.format()]) + response = prompts.no_online_results_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return context_message = "" if not is_none_or_empty(references): @@ -255,19 +261,23 @@ def converse_openai( logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") # Get Response from GPT - return chat_completion_with_backoff( + full_response = "" + async for chunk in chat_completion_with_backoff( messages=messages, - compiled_references=references, - online_results=online_results, model_name=model, temperature=temperature, openai_api_key=api_key, api_base_url=api_base_url, - completion_func=completion_func, deepthought=deepthought, model_kwargs={"stop": ["Notes:\n["]}, tracer=tracer, - ) + ): + full_response += chunk + yield chunk + + # Call completion_func once finish streaming and we have the full response + if completion_func: + await completion_func(chat_response=full_response) def clean_response_schema(schema: BaseModel | dict) -> dict: diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index b73903ae..7fab44aa 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -1,7 +1,7 @@ import logging import os -from threading import Thread -from typing import Dict, List +from time import perf_counter +from typing import AsyncGenerator, Dict, List from urllib.parse import urlparse import openai @@ -16,13 +16,10 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ( - JsonSupport, - ThreadedGenerator, - commit_conversation_trace, -) +from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace from khoj.utils.helpers import ( get_chat_usage_metrics, + get_openai_async_client, get_openai_client, is_promptrace_enabled, ) @@ -30,6 +27,7 @@ from khoj.utils.helpers import ( logger = logging.getLogger(__name__) openai_clients: Dict[str, openai.OpenAI] = {} +openai_async_clients: Dict[str, openai.AsyncOpenAI] = {} @retry( @@ -124,45 +122,22 @@ def completion_with_backoff( before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def chat_completion_with_backoff( +async def chat_completion_with_backoff( messages, - compiled_references, - online_results, model_name, temperature, openai_api_key=None, api_base_url=None, - completion_func=None, - deepthought=False, - model_kwargs=None, - tracer: dict = {}, -): - g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) - t = Thread( - target=llm_thread, - args=(g, messages, model_name, temperature, openai_api_key, api_base_url, deepthought, model_kwargs, tracer), - ) - t.start() - return g - - -def llm_thread( - g, - messages, - model_name: str, - temperature, - openai_api_key=None, - api_base_url=None, deepthought=False, model_kwargs: dict = {}, tracer: dict = {}, -): +) -> AsyncGenerator[str, None]: try: client_key = f"{openai_api_key}--{api_base_url}" - client = openai_clients.get(client_key) + client = openai_async_clients.get(client_key) if not client: - client = get_openai_client(openai_api_key, api_base_url) - openai_clients[client_key] = client + client = get_openai_async_client(openai_api_key, api_base_url) + openai_async_clients[client_key] = client formatted_messages = [{"role": message.role, "content": message.content} for message in messages] @@ -207,53 +182,58 @@ def llm_thread( if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) - chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( - messages=formatted_messages, - model=model_name, # type: ignore + aggregated_response = "" + final_chunk = None + start_time = perf_counter() + chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( + messages=formatted_messages, # type: ignore + model=model_name, stream=stream, temperature=temperature, timeout=20, **model_kwargs, ) + async for chunk in chat_stream: + # Log the time taken to start response + if final_chunk is None: + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Keep track of the last chunk for usage data + final_chunk = chunk + # Handle streamed response chunk + if len(chunk.choices) == 0: + continue + delta_chunk = chunk.choices[0].delta + text_chunk = "" + if isinstance(delta_chunk, str): + text_chunk = delta_chunk + elif delta_chunk and delta_chunk.content: + text_chunk = delta_chunk.content + if text_chunk: + aggregated_response += text_chunk + yield text_chunk - aggregated_response = "" - if not stream: - chunk = chat - aggregated_response = chunk.choices[0].message.content - g.send(aggregated_response) - else: - for chunk in chat: - if len(chunk.choices) == 0: - continue - delta_chunk = chunk.choices[0].delta - text_chunk = "" - if isinstance(delta_chunk, str): - text_chunk = delta_chunk - elif delta_chunk.content: - text_chunk = delta_chunk.content - if text_chunk: - aggregated_response += text_chunk - g.send(text_chunk) + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") - # Calculate cost of chat - input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 - output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0 - cost = ( - chunk.usage.model_extra.get("estimated_cost", 0) if hasattr(chunk, "usage") and chunk.usage else 0 - ) # Estimated costs returned by DeepInfra API - tracer["usage"] = get_chat_usage_metrics( - model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost - ) + # Calculate cost of chat after stream finishes + input_tokens, output_tokens, cost = 0, 0, 0 + if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage: + input_tokens = final_chunk.usage.prompt_tokens + output_tokens = final_chunk.usage.completion_tokens + # Estimated costs returned by DeepInfra API + if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra: + cost = final_chunk.usage.model_extra.get("estimated_cost", 0) # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost + ) if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: - logger.error(f"Error in llm_thread: {e}", exc_info=True) - finally: - g.close() + logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True) def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 01c25cf4..b8eea907 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -77,42 +77,6 @@ model_to_prompt_size = { model_to_tokenizer: Dict[str, str] = {} -class ThreadedGenerator: - def __init__(self, compiled_references, online_results, completion_func=None): - self.queue = queue.Queue() - self.compiled_references = compiled_references - self.online_results = online_results - self.completion_func = completion_func - self.response = "" - self.start_time = perf_counter() - - def __iter__(self): - return self - - def __next__(self): - item = self.queue.get() - if item is StopIteration: - time_to_response = perf_counter() - self.start_time - logger.info(f"Chat streaming took: {time_to_response:.3f} seconds") - if self.completion_func: - # The completion func effectively acts as a callback. - # It adds the aggregated response to the conversation history. - self.completion_func(chat_response=self.response) - raise StopIteration - return item - - def send(self, data): - if self.response == "": - time_to_first_response = perf_counter() - self.start_time - logger.info(f"First response took: {time_to_first_response:.3f} seconds") - - self.response += data - self.queue.put(data) - - def close(self): - self.queue.put(StopIteration) - - class InformationCollectionIteration: def __init__( self, @@ -254,7 +218,7 @@ def message_to_log( return conversation_log -def save_to_conversation_log( +async def save_to_conversation_log( q: str, chat_response: str, user: KhojUser, @@ -306,7 +270,7 @@ def save_to_conversation_log( khoj_message_metadata=khoj_message_metadata, conversation_log=meta_log.get("chat", []), ) - ConversationAdapters.save_conversation( + await ConversationAdapters.save_conversation( user, {"chat": updated_conversation}, client_application=client_application, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 03f367dd..dd951238 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -67,7 +67,6 @@ from khoj.routers.research import ( from khoj.routers.storage import upload_user_image_to_bucket from khoj.utils import state from khoj.utils.helpers import ( - AsyncIteratorWrapper, ConversationCommand, command_descriptions, convert_image_to_webp, @@ -999,7 +998,7 @@ async def chat( return llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( + await save_to_conversation_log( q, llm_response, user, @@ -1308,26 +1307,31 @@ async def chat( yield result continue_stream = True - iterator = AsyncIteratorWrapper(llm_response) - async for item in iterator: + async for item in llm_response: + # Should not happen with async generator, end is signaled by loop exit. Skip. if item is None: - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - # Send Usage Metadata once llm interactions are complete - async for event in send_event(ChatEvent.USAGE, tracer.get("usage")): - yield event - async for result in send_event(ChatEvent.END_RESPONSE, ""): - yield result - logger.debug("Finished streaming response") - return + continue if not connection_alive or not continue_stream: + # Drain the generator if disconnected but keep processing internally continue try: async for result in send_event(ChatEvent.MESSAGE, f"{item}"): yield result except Exception as e: continue_stream = False - logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + logger.info(f"User {user} disconnected or error during streaming. Stopping send: {e}") + + # Signal end of LLM response after the loop finishes + if connection_alive: + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + # Send Usage Metadata once llm interactions are complete + if tracer.get("usage"): + async for event in send_event(ChatEvent.USAGE, tracer.get("usage")): + yield event + async for result in send_event(ChatEvent.END_RESPONSE, ""): + yield result + logger.debug("Finished streaming response") ## Stream Text Response if stream: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 742a8708..a0baffb9 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,4 +1,3 @@ -import asyncio import base64 import hashlib import json @@ -6,9 +5,7 @@ import logging import math import os import re -from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone -from enum import Enum from functools import partial from random import random from typing import ( @@ -17,7 +14,6 @@ from typing import ( AsyncGenerator, Callable, Dict, - Iterator, List, Optional, Set, @@ -97,7 +93,6 @@ from khoj.processor.conversation.openai.gpt import ( ) from khoj.processor.conversation.utils import ( ChatEvent, - ThreadedGenerator, clean_json, clean_mermaidjs, construct_chat_history, @@ -126,8 +121,6 @@ from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, Loca logger = logging.getLogger(__name__) -executor = ThreadPoolExecutor(max_workers=1) - NOTION_OAUTH_CLIENT_ID = os.getenv("NOTION_OAUTH_CLIENT_ID") NOTION_OAUTH_CLIENT_SECRET = os.getenv("NOTION_OAUTH_CLIENT_SECRET") @@ -262,11 +255,6 @@ def get_conversation_command(query: str) -> ConversationCommand: return ConversationCommand.Default -async def agenerate_chat_response(*args): - loop = asyncio.get_event_loop() - return await loop.run_in_executor(executor, generate_chat_response, *args) - - def gather_raw_query_files( query_files: Dict[str, str], ): @@ -1418,7 +1406,7 @@ def send_message_to_model_wrapper_sync( raise HTTPException(status_code=500, detail="Invalid conversation config") -def generate_chat_response( +async def agenerate_chat_response( q: str, meta_log: dict, conversation: Conversation, @@ -1444,13 +1432,14 @@ def generate_chat_response( generated_asset_results: Dict[str, Dict] = {}, is_subscribed: bool = False, tracer: dict = {}, -) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: +) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]: # Initialize Variables - chat_response = None + chat_response_generator = None logger.debug(f"Conversation Types: {conversation_commands}") metadata = {} - agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None + agent = await AgentAdapters.aget_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None + try: partial_completion = partial( save_to_conversation_log, @@ -1481,17 +1470,17 @@ def generate_chat_response( code_results = {} deepthought = True - chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed) + chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed) vision_available = chat_model.vision_enabled if not vision_available and query_images: - vision_enabled_config = ConversationAdapters.get_vision_enabled_config() + vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() if vision_enabled_config: chat_model = vision_enabled_config vision_available = True if chat_model.model_type == "offline": loaded_model = state.offline_chat_processor_config.loaded_model - chat_response = converse_offline( + chat_response_generator = converse_offline( user_query=query_to_run, references=compiled_references, online_results=online_results, @@ -1515,7 +1504,7 @@ def generate_chat_response( openai_chat_config = chat_model.ai_model_api api_key = openai_chat_config.api_key chat_model_name = chat_model.name - chat_response = converse_openai( + chat_response_generator = converse_openai( compiled_references, query_to_run, query_images=query_images, @@ -1544,7 +1533,7 @@ def generate_chat_response( elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - chat_response = converse_anthropic( + chat_response_generator = converse_anthropic( compiled_references, query_to_run, query_images=query_images, @@ -1572,7 +1561,7 @@ def generate_chat_response( elif chat_model.model_type == ChatModel.ModelType.GOOGLE: api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - chat_response = converse_gemini( + chat_response_generator = converse_gemini( compiled_references, query_to_run, online_results, @@ -1604,7 +1593,8 @@ def generate_chat_response( logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) - return chat_response, metadata + # Return the generator directly + return chat_response_generator, metadata class DeleteMessageRequestBody(BaseModel): diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index f0aa0cd6..4a756dcb 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -23,6 +23,7 @@ from time import perf_counter from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union from urllib.parse import ParseResult, urlparse +import anthropic import openai import psutil import pyjson5 @@ -30,6 +31,7 @@ import requests import torch from asgiref.sync import sync_to_async from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email +from google import genai from google.auth.credentials import Credentials from google.oauth2 import service_account from magika import Magika @@ -729,6 +731,60 @@ def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, o return client +def get_openai_async_client(api_key: str, api_base_url: str) -> Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]: + """Get OpenAI or AzureOpenAI client based on the API Base URL""" + parsed_url = urlparse(api_base_url) + if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"): + client = openai.AsyncAzureOpenAI( + api_key=api_key, + azure_endpoint=api_base_url, + api_version="2024-10-21", + ) + else: + client = openai.AsyncOpenAI( + api_key=api_key, + base_url=api_base_url, + ) + return client + + +def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex: + api_info = get_ai_api_info(api_key, api_base_url) + if api_info.api_key: + client = anthropic.Anthropic(api_key=api_info.api_key) + else: + client = anthropic.AnthropicVertex( + region=api_info.region, + project_id=api_info.project, + credentials=api_info.credentials, + ) + return client + + +def get_anthropic_async_client(api_key, api_base_url=None) -> anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex: + api_info = get_ai_api_info(api_key, api_base_url) + if api_info.api_key: + client = anthropic.AsyncAnthropic(api_key=api_info.api_key) + else: + client = anthropic.AsyncAnthropicVertex( + region=api_info.region, + project_id=api_info.project, + credentials=api_info.credentials, + ) + return client + + +def get_gemini_client(api_key, api_base_url=None) -> genai.Client: + api_info = get_ai_api_info(api_key, api_base_url) + return genai.Client( + location=api_info.region, + project=api_info.project, + credentials=api_info.credentials, + api_key=api_info.api_key, + vertexai=api_info.api_key is None, + ) + + def normalize_email(email: str, check_deliverability=False) -> tuple[str, bool]: """Normalize, validate and check deliverability of email address""" lower_email = email.lower()