Parse Anthropic reasoning model thoughts returned by API

This commit is contained in:
Debanjum
2025-05-02 19:08:57 -06:00
parent ae4e352b42
commit 8cadb0dbc0
2 changed files with 18 additions and 5 deletions

View File

@@ -14,6 +14,7 @@ from khoj.processor.conversation.anthropic.utils import (
format_messages_for_anthropic, format_messages_for_anthropic,
) )
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ResponseWithThought,
clean_json, clean_json,
construct_structured_message, construct_structured_message,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
@@ -162,7 +163,7 @@ async def converse_anthropic(
generated_asset_results: Dict[str, Dict] = {}, generated_asset_results: Dict[str, Dict] = {},
deepthought: Optional[bool] = False, deepthought: Optional[bool] = False,
tracer: dict = {}, tracer: dict = {},
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[ResponseWithThought, None]:
""" """
Converse with user using Anthropic's Claude Converse with user using Anthropic's Claude
""" """
@@ -247,7 +248,8 @@ async def converse_anthropic(
deepthought=deepthought, deepthought=deepthought,
tracer=tracer, tracer=tracer,
): ):
full_response += chunk if chunk.response:
full_response += chunk.response
yield chunk yield chunk
# Call completion_func once finish streaming and we have the full response # Call completion_func once finish streaming and we have the full response

View File

@@ -13,6 +13,7 @@ from tenacity import (
) )
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ResponseWithThought,
commit_conversation_trace, commit_conversation_trace,
get_image_from_base64, get_image_from_base64,
get_image_from_url, get_image_from_url,
@@ -154,13 +155,23 @@ async def anthropic_chat_completion_with_backoff(
max_tokens=max_tokens, max_tokens=max_tokens,
**model_kwargs, **model_kwargs,
) as stream: ) as stream:
async for text in stream.text_stream: async for chunk in stream:
# Log the time taken to start response # Log the time taken to start response
if aggregated_response == "": if aggregated_response == "":
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Skip empty chunks
if chunk.type != "content_block_delta":
continue
# Handle streamed response chunk # Handle streamed response chunk
aggregated_response += text response_chunk: ResponseWithThought = None
yield text if chunk.delta.type == "text_delta":
response_chunk = ResponseWithThought(response=chunk.delta.text)
aggregated_response += chunk.delta.text
if chunk.delta.type == "thinking_delta":
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
# Handle streamed response chunk
if response_chunk:
yield response_chunk
final_message = await stream.get_final_message() final_message = await stream.get_final_message()
# Log the time taken to stream the entire response # Log the time taken to stream the entire response