diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 977b25de..5bad38ef 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -1,6 +1,6 @@ import logging from datetime import datetime, timedelta -from typing import Dict, List, Optional +from typing import AsyncGenerator, Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage @@ -137,7 +137,7 @@ def anthropic_send_message_to_model( ) -def converse_anthropic( +async def converse_anthropic( references, user_query, online_results: Optional[Dict[str, Dict]] = None, @@ -161,7 +161,7 @@ def converse_anthropic( generated_asset_results: Dict[str, Dict] = {}, deepthought: Optional[bool] = False, tracer: dict = {}, -): +) -> AsyncGenerator[str, None]: """ Converse with user using Anthropic's Claude """ @@ -191,11 +191,17 @@ def converse_anthropic( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - completion_func(chat_response=prompts.no_notes_found.format()) - return iter([prompts.no_notes_found.format()]) + response = prompts.no_notes_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - completion_func(chat_response=prompts.no_online_results_found.format()) - return iter([prompts.no_online_results_found.format()]) + response = prompts.no_online_results_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return context_message = "" if not is_none_or_empty(references): @@ -228,17 +234,21 @@ def converse_anthropic( logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}") # Get Response from Claude - return anthropic_chat_completion_with_backoff( + full_response = "" + async for chunk in anthropic_chat_completion_with_backoff( messages=messages, - compiled_references=references, - online_results=online_results, model_name=model, temperature=0.2, api_key=api_key, api_base_url=api_base_url, system_prompt=system_prompt, - completion_func=completion_func, max_prompt_size=max_prompt_size, deepthought=deepthought, tracer=tracer, - ) + ): + full_response += chunk + yield chunk + + # Call completion_func once finish streaming and we have the full response + if completion_func: + await completion_func(chat_response=full_response) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 48c6515f..442e1cb3 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -1,5 +1,5 @@ import logging -from threading import Thread +from time import perf_counter from typing import Dict, List import anthropic @@ -13,12 +13,12 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ( - ThreadedGenerator, commit_conversation_trace, get_image_from_base64, get_image_from_url, ) from khoj.utils.helpers import ( + get_anthropic_async_client, get_anthropic_client, get_chat_usage_metrics, is_none_or_empty, @@ -28,6 +28,7 @@ from khoj.utils.helpers import ( logger = logging.getLogger(__name__) anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {} +anthropic_async_clients: Dict[str, anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex] = {} DEFAULT_MAX_TOKENS_ANTHROPIC = 8000 MAX_REASONING_TOKENS_ANTHROPIC = 12000 @@ -113,60 +114,23 @@ def anthropic_completion_with_backoff( before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def anthropic_chat_completion_with_backoff( +async def anthropic_chat_completion_with_backoff( messages: list[ChatMessage], - compiled_references, - online_results, model_name, temperature, api_key, api_base_url, system_prompt: str, max_prompt_size=None, - completion_func=None, - deepthought=False, - model_kwargs=None, - tracer={}, -): - g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) - t = Thread( - target=anthropic_llm_thread, - args=( - g, - messages, - system_prompt, - model_name, - temperature, - api_key, - api_base_url, - max_prompt_size, - deepthought, - model_kwargs, - tracer, - ), - ) - t.start() - return g - - -def anthropic_llm_thread( - g, - messages: list[ChatMessage], - system_prompt: str, - model_name: str, - temperature, - api_key, - api_base_url=None, - max_prompt_size=None, deepthought=False, model_kwargs=None, tracer={}, ): try: - client = anthropic_clients.get(api_key) + client = anthropic_async_clients.get(api_key) if not client: - client = get_anthropic_client(api_key, api_base_url) - anthropic_clients[api_key] = client + client = get_anthropic_async_client(api_key, api_base_url) + anthropic_async_clients[api_key] = client model_kwargs = model_kwargs or dict() max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC @@ -180,7 +144,8 @@ def anthropic_llm_thread( aggregated_response = "" final_message = None - with client.messages.stream( + start_time = perf_counter() + async with client.messages.stream( messages=formatted_messages, model=model_name, # type: ignore temperature=temperature, @@ -189,10 +154,17 @@ def anthropic_llm_thread( max_tokens=max_tokens, **model_kwargs, ) as stream: - for text in stream.text_stream: + async for text in stream.text_stream: + # Log the time taken to start response + if aggregated_response == "": + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Handle streamed response chunk aggregated_response += text - g.send(text) - final_message = stream.get_final_message() + yield text + final_message = await stream.get_final_message() + + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") # Calculate cost of chat input_tokens = final_message.usage.input_tokens @@ -209,9 +181,7 @@ def anthropic_llm_thread( if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: - logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) - finally: - g.close() + logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True) def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a5aa415a..3d154059 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1534,10 +1534,9 @@ async def agenerate_chat_response( ) elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: - # Assuming converse_anthropic remains sync or is refactored separately api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - chat_response_generator = converse_anthropic( # Needs adaptation if it becomes async + chat_response_generator = converse_anthropic( compiled_references, query_to_run, query_images=query_images,