Refactor Anthropic chat response to stream async, no separate thread

This commit is contained in:
Debanjum
2025-04-20 03:18:32 +05:30
parent a557031447
commit 932a9615ef
3 changed files with 43 additions and 64 deletions

View File

@@ -1,6 +1,6 @@
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from typing import AsyncGenerator, Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
@@ -137,7 +137,7 @@ def anthropic_send_message_to_model(
)
def converse_anthropic(
async def converse_anthropic(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
@@ -161,7 +161,7 @@ def converse_anthropic(
generated_asset_results: Dict[str, Dict] = {},
deepthought: Optional[bool] = False,
tracer: dict = {},
):
) -> AsyncGenerator[str, None]:
"""
Converse with user using Anthropic's Claude
"""
@@ -191,11 +191,17 @@ def converse_anthropic(
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
response = prompts.no_notes_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
response = prompts.no_online_results_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
context_message = ""
if not is_none_or_empty(references):
@@ -228,17 +234,21 @@ def converse_anthropic(
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
# Get Response from Claude
return anthropic_chat_completion_with_backoff(
full_response = ""
async for chunk in anthropic_chat_completion_with_backoff(
messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model,
temperature=0.2,
api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size,
deepthought=deepthought,
tracer=tracer,
)
):
full_response += chunk
yield chunk
# Call completion_func once finish streaming and we have the full response
if completion_func:
await completion_func(chat_response=full_response)

View File

@@ -1,5 +1,5 @@
import logging
from threading import Thread
from time import perf_counter
from typing import Dict, List
import anthropic
@@ -13,12 +13,12 @@ from tenacity import (
)
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
get_image_from_base64,
get_image_from_url,
)
from khoj.utils.helpers import (
get_anthropic_async_client,
get_anthropic_client,
get_chat_usage_metrics,
is_none_or_empty,
@@ -28,6 +28,7 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__)
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
anthropic_async_clients: Dict[str, anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex] = {}
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
MAX_REASONING_TOKENS_ANTHROPIC = 12000
@@ -113,60 +114,23 @@ def anthropic_completion_with_backoff(
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def anthropic_chat_completion_with_backoff(
async def anthropic_chat_completion_with_backoff(
messages: list[ChatMessage],
compiled_references,
online_results,
model_name,
temperature,
api_key,
api_base_url,
system_prompt: str,
max_prompt_size=None,
completion_func=None,
deepthought=False,
model_kwargs=None,
tracer={},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=anthropic_llm_thread,
args=(
g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url,
max_prompt_size,
deepthought,
model_kwargs,
tracer,
),
)
t.start()
return g
def anthropic_llm_thread(
g,
messages: list[ChatMessage],
system_prompt: str,
model_name: str,
temperature,
api_key,
api_base_url=None,
max_prompt_size=None,
deepthought=False,
model_kwargs=None,
tracer={},
):
try:
client = anthropic_clients.get(api_key)
client = anthropic_async_clients.get(api_key)
if not client:
client = get_anthropic_client(api_key, api_base_url)
anthropic_clients[api_key] = client
client = get_anthropic_async_client(api_key, api_base_url)
anthropic_async_clients[api_key] = client
model_kwargs = model_kwargs or dict()
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
@@ -180,7 +144,8 @@ def anthropic_llm_thread(
aggregated_response = ""
final_message = None
with client.messages.stream(
start_time = perf_counter()
async with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature,
@@ -189,10 +154,17 @@ def anthropic_llm_thread(
max_tokens=max_tokens,
**model_kwargs,
) as stream:
for text in stream.text_stream:
async for text in stream.text_stream:
# Log the time taken to start response
if aggregated_response == "":
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Handle streamed response chunk
aggregated_response += text
g.send(text)
final_message = stream.get_final_message()
yield text
final_message = await stream.get_final_message()
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat
input_tokens = final_message.usage.input_tokens
@@ -209,9 +181,7 @@ def anthropic_llm_thread(
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
finally:
g.close()
logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True)
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):

View File

@@ -1534,10 +1534,9 @@ async def agenerate_chat_response(
)
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
# Assuming converse_anthropic remains sync or is refactored separately
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response_generator = converse_anthropic( # Needs adaptation if it becomes async
chat_response_generator = converse_anthropic(
compiled_references,
query_to_run,
query_images=query_images,