Parse Anthropic reasoning model thoughts returned by API

This commit is contained in:
Debanjum
2025-05-02 19:08:57 -06:00
parent ae4e352b42
commit 8cadb0dbc0
2 changed files with 18 additions and 5 deletions

View File

@@ -14,6 +14,7 @@ from khoj.processor.conversation.anthropic.utils import (
format_messages_for_anthropic,
)
from khoj.processor.conversation.utils import (
ResponseWithThought,
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
@@ -162,7 +163,7 @@ async def converse_anthropic(
generated_asset_results: Dict[str, Dict] = {},
deepthought: Optional[bool] = False,
tracer: dict = {},
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[ResponseWithThought, None]:
"""
Converse with user using Anthropic's Claude
"""
@@ -247,7 +248,8 @@ async def converse_anthropic(
deepthought=deepthought,
tracer=tracer,
):
full_response += chunk
if chunk.response:
full_response += chunk.response
yield chunk
# Call completion_func once finish streaming and we have the full response

View File

@@ -13,6 +13,7 @@ from tenacity import (
)
from khoj.processor.conversation.utils import (
ResponseWithThought,
commit_conversation_trace,
get_image_from_base64,
get_image_from_url,
@@ -154,13 +155,23 @@ async def anthropic_chat_completion_with_backoff(
max_tokens=max_tokens,
**model_kwargs,
) as stream:
async for text in stream.text_stream:
async for chunk in stream:
# Log the time taken to start response
if aggregated_response == "":
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Skip empty chunks
if chunk.type != "content_block_delta":
continue
# Handle streamed response chunk
aggregated_response += text
yield text
response_chunk: ResponseWithThought = None
if chunk.delta.type == "text_delta":
response_chunk = ResponseWithThought(response=chunk.delta.text)
aggregated_response += chunk.delta.text
if chunk.delta.type == "thinking_delta":
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
# Handle streamed response chunk
if response_chunk:
yield response_chunk
final_message = await stream.get_final_message()
# Log the time taken to stream the entire response