mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 05:39:12 +00:00
Refactor Anthropic chat response to stream async, no separate thread
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, List, Optional
|
from typing import AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
import pyjson5
|
import pyjson5
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
@@ -137,7 +137,7 @@ def anthropic_send_message_to_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def converse_anthropic(
|
async def converse_anthropic(
|
||||||
references,
|
references,
|
||||||
user_query,
|
user_query,
|
||||||
online_results: Optional[Dict[str, Dict]] = None,
|
online_results: Optional[Dict[str, Dict]] = None,
|
||||||
@@ -161,7 +161,7 @@ def converse_anthropic(
|
|||||||
generated_asset_results: Dict[str, Dict] = {},
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
deepthought: Optional[bool] = False,
|
deepthought: Optional[bool] = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Anthropic's Claude
|
Converse with user using Anthropic's Claude
|
||||||
"""
|
"""
|
||||||
@@ -191,11 +191,17 @@ def converse_anthropic(
|
|||||||
|
|
||||||
# Get Conversation Primer appropriate to Conversation Type
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||||
completion_func(chat_response=prompts.no_notes_found.format())
|
response = prompts.no_notes_found.format()
|
||||||
return iter([prompts.no_notes_found.format()])
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
response = prompts.no_online_results_found.format()
|
||||||
return iter([prompts.no_online_results_found.format()])
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
|
|
||||||
context_message = ""
|
context_message = ""
|
||||||
if not is_none_or_empty(references):
|
if not is_none_or_empty(references):
|
||||||
@@ -228,17 +234,21 @@ def converse_anthropic(
|
|||||||
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from Claude
|
# Get Response from Claude
|
||||||
return anthropic_chat_completion_with_backoff(
|
full_response = ""
|
||||||
|
async for chunk in anthropic_chat_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
compiled_references=references,
|
|
||||||
online_results=online_results,
|
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
completion_func=completion_func,
|
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
deepthought=deepthought,
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
):
|
||||||
|
full_response += chunk
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Call completion_func once finish streaming and we have the full response
|
||||||
|
if completion_func:
|
||||||
|
await completion_func(chat_response=full_response)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from threading import Thread
|
from time import perf_counter
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
@@ -13,12 +13,12 @@ from tenacity import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ThreadedGenerator,
|
|
||||||
commit_conversation_trace,
|
commit_conversation_trace,
|
||||||
get_image_from_base64,
|
get_image_from_base64,
|
||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
|
get_anthropic_async_client,
|
||||||
get_anthropic_client,
|
get_anthropic_client,
|
||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
@@ -28,6 +28,7 @@ from khoj.utils.helpers import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
|
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
|
||||||
|
anthropic_async_clients: Dict[str, anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex] = {}
|
||||||
|
|
||||||
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
|
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
|
||||||
MAX_REASONING_TOKENS_ANTHROPIC = 12000
|
MAX_REASONING_TOKENS_ANTHROPIC = 12000
|
||||||
@@ -113,60 +114,23 @@ def anthropic_completion_with_backoff(
|
|||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def anthropic_chat_completion_with_backoff(
|
async def anthropic_chat_completion_with_backoff(
|
||||||
messages: list[ChatMessage],
|
messages: list[ChatMessage],
|
||||||
compiled_references,
|
|
||||||
online_results,
|
|
||||||
model_name,
|
model_name,
|
||||||
temperature,
|
temperature,
|
||||||
api_key,
|
api_key,
|
||||||
api_base_url,
|
api_base_url,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
completion_func=None,
|
|
||||||
deepthought=False,
|
|
||||||
model_kwargs=None,
|
|
||||||
tracer={},
|
|
||||||
):
|
|
||||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
|
||||||
t = Thread(
|
|
||||||
target=anthropic_llm_thread,
|
|
||||||
args=(
|
|
||||||
g,
|
|
||||||
messages,
|
|
||||||
system_prompt,
|
|
||||||
model_name,
|
|
||||||
temperature,
|
|
||||||
api_key,
|
|
||||||
api_base_url,
|
|
||||||
max_prompt_size,
|
|
||||||
deepthought,
|
|
||||||
model_kwargs,
|
|
||||||
tracer,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
t.start()
|
|
||||||
return g
|
|
||||||
|
|
||||||
|
|
||||||
def anthropic_llm_thread(
|
|
||||||
g,
|
|
||||||
messages: list[ChatMessage],
|
|
||||||
system_prompt: str,
|
|
||||||
model_name: str,
|
|
||||||
temperature,
|
|
||||||
api_key,
|
|
||||||
api_base_url=None,
|
|
||||||
max_prompt_size=None,
|
|
||||||
deepthought=False,
|
deepthought=False,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
tracer={},
|
tracer={},
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
client = anthropic_clients.get(api_key)
|
client = anthropic_async_clients.get(api_key)
|
||||||
if not client:
|
if not client:
|
||||||
client = get_anthropic_client(api_key, api_base_url)
|
client = get_anthropic_async_client(api_key, api_base_url)
|
||||||
anthropic_clients[api_key] = client
|
anthropic_async_clients[api_key] = client
|
||||||
|
|
||||||
model_kwargs = model_kwargs or dict()
|
model_kwargs = model_kwargs or dict()
|
||||||
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
|
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||||
@@ -180,7 +144,8 @@ def anthropic_llm_thread(
|
|||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
final_message = None
|
final_message = None
|
||||||
with client.messages.stream(
|
start_time = perf_counter()
|
||||||
|
async with client.messages.stream(
|
||||||
messages=formatted_messages,
|
messages=formatted_messages,
|
||||||
model=model_name, # type: ignore
|
model=model_name, # type: ignore
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -189,10 +154,17 @@ def anthropic_llm_thread(
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) as stream:
|
) as stream:
|
||||||
for text in stream.text_stream:
|
async for text in stream.text_stream:
|
||||||
|
# Log the time taken to start response
|
||||||
|
if aggregated_response == "":
|
||||||
|
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
# Handle streamed response chunk
|
||||||
aggregated_response += text
|
aggregated_response += text
|
||||||
g.send(text)
|
yield text
|
||||||
final_message = stream.get_final_message()
|
final_message = await stream.get_final_message()
|
||||||
|
|
||||||
|
# Log the time taken to stream the entire response
|
||||||
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
|
||||||
# Calculate cost of chat
|
# Calculate cost of chat
|
||||||
input_tokens = final_message.usage.input_tokens
|
input_tokens = final_message.usage.input_tokens
|
||||||
@@ -209,9 +181,7 @@ def anthropic_llm_thread(
|
|||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
|
logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True)
|
||||||
finally:
|
|
||||||
g.close()
|
|
||||||
|
|
||||||
|
|
||||||
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):
|
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):
|
||||||
|
|||||||
@@ -1534,10 +1534,9 @@ async def agenerate_chat_response(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
# Assuming converse_anthropic remains sync or is refactored separately
|
|
||||||
api_key = chat_model.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
api_base_url = chat_model.ai_model_api.api_base_url
|
api_base_url = chat_model.ai_model_api.api_base_url
|
||||||
chat_response_generator = converse_anthropic( # Needs adaptation if it becomes async
|
chat_response_generator = converse_anthropic(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
|
|||||||
Reference in New Issue
Block a user