Refactor Anthropic chat response to stream async, no separate thread

This commit is contained in:
Debanjum
2025-04-20 03:18:32 +05:30
parent a557031447
commit 932a9615ef
3 changed files with 43 additions and 64 deletions

View File

@@ -1,6 +1,6 @@
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, List, Optional from typing import AsyncGenerator, Dict, List, Optional
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
@@ -137,7 +137,7 @@ def anthropic_send_message_to_model(
) )
def converse_anthropic( async def converse_anthropic(
references, references,
user_query, user_query,
online_results: Optional[Dict[str, Dict]] = None, online_results: Optional[Dict[str, Dict]] = None,
@@ -161,7 +161,7 @@ def converse_anthropic(
generated_asset_results: Dict[str, Dict] = {}, generated_asset_results: Dict[str, Dict] = {},
deepthought: Optional[bool] = False, deepthought: Optional[bool] = False,
tracer: dict = {}, tracer: dict = {},
): ) -> AsyncGenerator[str, None]:
""" """
Converse with user using Anthropic's Claude Converse with user using Anthropic's Claude
""" """
@@ -191,11 +191,17 @@ def converse_anthropic(
# Get Conversation Primer appropriate to Conversation Type # Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format()) response = prompts.no_notes_found.format()
return iter([prompts.no_notes_found.format()]) if completion_func:
await completion_func(chat_response=response)
yield response
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format()) response = prompts.no_online_results_found.format()
return iter([prompts.no_online_results_found.format()]) if completion_func:
await completion_func(chat_response=response)
yield response
return
context_message = "" context_message = ""
if not is_none_or_empty(references): if not is_none_or_empty(references):
@@ -228,17 +234,21 @@ def converse_anthropic(
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}") logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
# Get Response from Claude # Get Response from Claude
return anthropic_chat_completion_with_backoff( full_response = ""
async for chunk in anthropic_chat_completion_with_backoff(
messages=messages, messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model, model_name=model,
temperature=0.2, temperature=0.2,
api_key=api_key, api_key=api_key,
api_base_url=api_base_url, api_base_url=api_base_url,
system_prompt=system_prompt, system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
deepthought=deepthought, deepthought=deepthought,
tracer=tracer, tracer=tracer,
) ):
full_response += chunk
yield chunk
# Call completion_func once finish streaming and we have the full response
if completion_func:
await completion_func(chat_response=full_response)

View File

@@ -1,5 +1,5 @@
import logging import logging
from threading import Thread from time import perf_counter
from typing import Dict, List from typing import Dict, List
import anthropic import anthropic
@@ -13,12 +13,12 @@ from tenacity import (
) )
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace, commit_conversation_trace,
get_image_from_base64, get_image_from_base64,
get_image_from_url, get_image_from_url,
) )
from khoj.utils.helpers import ( from khoj.utils.helpers import (
get_anthropic_async_client,
get_anthropic_client, get_anthropic_client,
get_chat_usage_metrics, get_chat_usage_metrics,
is_none_or_empty, is_none_or_empty,
@@ -28,6 +28,7 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {} anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
anthropic_async_clients: Dict[str, anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex] = {}
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000 DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
MAX_REASONING_TOKENS_ANTHROPIC = 12000 MAX_REASONING_TOKENS_ANTHROPIC = 12000
@@ -113,60 +114,23 @@ def anthropic_completion_with_backoff(
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
def anthropic_chat_completion_with_backoff( async def anthropic_chat_completion_with_backoff(
messages: list[ChatMessage], messages: list[ChatMessage],
compiled_references,
online_results,
model_name, model_name,
temperature, temperature,
api_key, api_key,
api_base_url, api_base_url,
system_prompt: str, system_prompt: str,
max_prompt_size=None, max_prompt_size=None,
completion_func=None,
deepthought=False,
model_kwargs=None,
tracer={},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=anthropic_llm_thread,
args=(
g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url,
max_prompt_size,
deepthought,
model_kwargs,
tracer,
),
)
t.start()
return g
def anthropic_llm_thread(
g,
messages: list[ChatMessage],
system_prompt: str,
model_name: str,
temperature,
api_key,
api_base_url=None,
max_prompt_size=None,
deepthought=False, deepthought=False,
model_kwargs=None, model_kwargs=None,
tracer={}, tracer={},
): ):
try: try:
client = anthropic_clients.get(api_key) client = anthropic_async_clients.get(api_key)
if not client: if not client:
client = get_anthropic_client(api_key, api_base_url) client = get_anthropic_async_client(api_key, api_base_url)
anthropic_clients[api_key] = client anthropic_async_clients[api_key] = client
model_kwargs = model_kwargs or dict() model_kwargs = model_kwargs or dict()
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
@@ -180,7 +144,8 @@ def anthropic_llm_thread(
aggregated_response = "" aggregated_response = ""
final_message = None final_message = None
with client.messages.stream( start_time = perf_counter()
async with client.messages.stream(
messages=formatted_messages, messages=formatted_messages,
model=model_name, # type: ignore model=model_name, # type: ignore
temperature=temperature, temperature=temperature,
@@ -189,10 +154,17 @@ def anthropic_llm_thread(
max_tokens=max_tokens, max_tokens=max_tokens,
**model_kwargs, **model_kwargs,
) as stream: ) as stream:
for text in stream.text_stream: async for text in stream.text_stream:
# Log the time taken to start response
if aggregated_response == "":
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Handle streamed response chunk
aggregated_response += text aggregated_response += text
g.send(text) yield text
final_message = stream.get_final_message() final_message = await stream.get_final_message()
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat # Calculate cost of chat
input_tokens = final_message.usage.input_tokens input_tokens = final_message.usage.input_tokens
@@ -209,9 +181,7 @@ def anthropic_llm_thread(
if is_promptrace_enabled(): if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e: except Exception as e:
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True)
finally:
g.close()
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None): def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):

View File

@@ -1534,10 +1534,9 @@ async def agenerate_chat_response(
) )
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
# Assuming converse_anthropic remains sync or is refactored separately
api_key = chat_model.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url api_base_url = chat_model.ai_model_api.api_base_url
chat_response_generator = converse_anthropic( # Needs adaptation if it becomes async chat_response_generator = converse_anthropic(
compiled_references, compiled_references,
query_to_run, query_to_run,
query_images=query_images, query_images=query_images,