Parse and show reasoning model thoughts (#1172)

### Major
All reasoning models return thoughts differently due to lack of
standardization.
We normalize thoughts by reasoning models and providers to ease handling
within Khoj.

The model thoughts are parsed during research mode when generating final
response.
These model thoughts are returned by the chat API and shown in train of
thought shown on web app.

Thoughts are enabled for Deepseek, Anthropic, Grok and Qwen3 reasoning
models served via API.
Gemini and Openai reasoning models do not show their thoughts via
standard APIs.

### Minor
- Fix ability to use Deepseek reasoner for intermediate stages of chat
- Enable handling Qwen3 reasoning models
This commit is contained in:
Debanjum
2025-05-02 20:29:38 -06:00
committed by GitHub
8 changed files with 403 additions and 38 deletions

View File

@@ -97,6 +97,17 @@ export function processMessageChunk(
console.log(`status: ${chunk.data}`);
const statusMessage = chunk.data as string;
currentMessage.trainOfThought.push(statusMessage);
} else if (chunk.type === "thought") {
const thoughtChunk = chunk.data as string;
const lastThoughtIndex = currentMessage.trainOfThought.length - 1;
const previousThought =
lastThoughtIndex >= 0 ? currentMessage.trainOfThought[lastThoughtIndex] : "";
// If the last train of thought started with "Thinking: " append the new thought chunk to it
if (previousThought.startsWith("**Thinking:** ")) {
currentMessage.trainOfThought[lastThoughtIndex] += thoughtChunk;
} else {
currentMessage.trainOfThought.push(`**Thinking:** ${thoughtChunk}`);
}
} else if (chunk.type === "references") {
const references = chunk.data as RawReferenceData;

View File

@@ -14,6 +14,7 @@ from khoj.processor.conversation.anthropic.utils import (
format_messages_for_anthropic,
)
from khoj.processor.conversation.utils import (
ResponseWithThought,
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
@@ -162,7 +163,7 @@ async def converse_anthropic(
generated_asset_results: Dict[str, Dict] = {},
deepthought: Optional[bool] = False,
tracer: dict = {},
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[ResponseWithThought, None]:
"""
Converse with user using Anthropic's Claude
"""
@@ -247,7 +248,8 @@ async def converse_anthropic(
deepthought=deepthought,
tracer=tracer,
):
full_response += chunk
if chunk.response:
full_response += chunk.response
yield chunk
# Call completion_func once finish streaming and we have the full response

View File

@@ -13,6 +13,7 @@ from tenacity import (
)
from khoj.processor.conversation.utils import (
ResponseWithThought,
commit_conversation_trace,
get_image_from_base64,
get_image_from_url,
@@ -154,13 +155,23 @@ async def anthropic_chat_completion_with_backoff(
max_tokens=max_tokens,
**model_kwargs,
) as stream:
async for text in stream.text_stream:
async for chunk in stream:
# Log the time taken to start response
if aggregated_response == "":
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Skip empty chunks
if chunk.type != "content_block_delta":
continue
# Handle streamed response chunk
aggregated_response += text
yield text
response_chunk: ResponseWithThought = None
if chunk.delta.type == "text_delta":
response_chunk = ResponseWithThought(response=chunk.delta.text)
aggregated_response += chunk.delta.text
if chunk.delta.type == "thinking_delta":
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
# Handle streamed response chunk
if response_chunk:
yield response_chunk
final_message = await stream.get_final_message()
# Log the time taken to stream the entire response

View File

@@ -17,6 +17,7 @@ from khoj.processor.conversation.openai.utils import (
)
from khoj.processor.conversation.utils import (
JsonSupport,
ResponseWithThought,
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
@@ -188,7 +189,7 @@ async def converse_openai(
program_execution_context: List[str] = None,
deepthought: Optional[bool] = False,
tracer: dict = {},
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[ResponseWithThought, None]:
"""
Converse with user using OpenAI's ChatGPT
"""
@@ -273,7 +274,8 @@ async def converse_openai(
model_kwargs={"stop": ["Notes:\n["]},
tracer=tracer,
):
full_response += chunk
if chunk.response:
full_response += chunk.response
yield chunk
# Call completion_func once finish streaming and we have the full response

View File

@@ -1,12 +1,21 @@
import logging
import os
from functools import partial
from time import perf_counter
from typing import AsyncGenerator, Dict, List
from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union
from urllib.parse import urlparse
import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.lib.streaming.chat import (
ChatCompletionStream,
ChatCompletionStreamEvent,
ContentDeltaEvent,
)
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
Choice,
ChoiceDelta,
)
from tenacity import (
before_sleep_log,
retry,
@@ -16,7 +25,11 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
from khoj.processor.conversation.utils import (
JsonSupport,
ResponseWithThought,
commit_conversation_trace,
)
from khoj.utils.helpers import (
get_chat_usage_metrics,
get_openai_async_client,
@@ -59,6 +72,7 @@ def completion_with_backoff(
client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client
stream_processor = default_stream_processor
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
# Tune reasoning models arguments
@@ -69,6 +83,24 @@ def completion_with_backoff(
elif is_twitter_reasoning_model(model_name, api_base_url):
reasoning_effort = "high" if deepthought else "low"
model_kwargs["reasoning_effort"] = reasoning_effort
elif model_name.startswith("deepseek-reasoner"):
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
# The first message should always be a user message (except system message).
updated_messages: List[dict] = []
for i, message in enumerate(formatted_messages):
if i > 0 and message["role"] == formatted_messages[i - 1]["role"]:
updated_messages[-1]["content"] += " " + message["content"]
elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant":
updated_messages[-1]["content"] += " " + message["content"]
else:
updated_messages.append(message)
formatted_messages = updated_messages
elif is_qwen_reasoning_model(model_name, api_base_url):
stream_processor = partial(in_stream_thought_processor, thought_tag="think")
# Reasoning is enabled by default. Disable when deepthought is False.
# See https://qwenlm.github.io/blog/qwen3/#advanced-usages
if not deepthought and len(formatted_messages) > 0:
formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think"
model_kwargs["stream_options"] = {"include_usage": True}
if os.getenv("KHOJ_LLM_SEED"):
@@ -82,12 +114,11 @@ def completion_with_backoff(
timeout=20,
**model_kwargs,
) as chat:
for chunk in chat:
if chunk.type == "error":
logger.error(f"Openai api response error: {chunk.error}", exc_info=True)
continue
elif chunk.type == "content.delta":
for chunk in stream_processor(chat):
if chunk.type == "content.delta":
aggregated_response += chunk.delta
elif chunk.type == "thought.delta":
pass
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
@@ -124,14 +155,14 @@ def completion_with_backoff(
)
async def chat_completion_with_backoff(
messages,
model_name,
model_name: str,
temperature,
openai_api_key=None,
api_base_url=None,
deepthought=False,
model_kwargs: dict = {},
tracer: dict = {},
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[ResponseWithThought, None]:
try:
client_key = f"{openai_api_key}--{api_base_url}"
client = openai_async_clients.get(client_key)
@@ -139,6 +170,7 @@ async def chat_completion_with_backoff(
client = get_openai_async_client(openai_api_key, api_base_url)
openai_async_clients[client_key] = client
stream_processor = adefault_stream_processor
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
# Configure thinking for openai reasoning models
@@ -161,9 +193,11 @@ async def chat_completion_with_backoff(
"content"
] = f"{first_system_message_content}\nFormatting re-enabled"
elif is_twitter_reasoning_model(model_name, api_base_url):
stream_processor = adeepseek_stream_processor
reasoning_effort = "high" if deepthought else "low"
model_kwargs["reasoning_effort"] = reasoning_effort
elif model_name.startswith("deepseek-reasoner"):
stream_processor = adeepseek_stream_processor
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
# The first message should always be a user message (except system message).
updated_messages: List[dict] = []
@@ -174,8 +208,13 @@ async def chat_completion_with_backoff(
updated_messages[-1]["content"] += " " + message["content"]
else:
updated_messages.append(message)
formatted_messages = updated_messages
elif is_qwen_reasoning_model(model_name, api_base_url):
stream_processor = partial(ain_stream_thought_processor, thought_tag="think")
# Reasoning is enabled by default. Disable when deepthought is False.
# See https://qwenlm.github.io/blog/qwen3/#advanced-usages
if not deepthought and len(formatted_messages) > 0:
formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think"
stream = True
model_kwargs["stream_options"] = {"include_usage": True}
@@ -193,24 +232,25 @@ async def chat_completion_with_backoff(
timeout=20,
**model_kwargs,
)
async for chunk in chat_stream:
async for chunk in stream_processor(chat_stream):
# Log the time taken to start response
if final_chunk is None:
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Handle streamed response chunk
# Skip empty chunks
if len(chunk.choices) == 0:
continue
delta_chunk = chunk.choices[0].delta
text_chunk = ""
if isinstance(delta_chunk, str):
text_chunk = delta_chunk
elif delta_chunk and delta_chunk.content:
text_chunk = delta_chunk.content
if text_chunk:
aggregated_response += text_chunk
yield text_chunk
# Handle streamed response chunk
response_chunk: ResponseWithThought = None
response_delta = chunk.choices[0].delta
if response_delta.content:
response_chunk = ResponseWithThought(response=response_delta.content)
aggregated_response += response_chunk.response
elif response_delta.thought:
response_chunk = ResponseWithThought(thought=response_delta.thought)
if response_chunk:
yield response_chunk
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
@@ -264,3 +304,274 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo
and api_base_url is not None
and api_base_url.startswith("https://api.x.ai/v1")
)
def is_qwen_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
"""
Check if the model is a Qwen reasoning model
"""
return "qwen3" in model_name.lower() and api_base_url is not None
class ThoughtDeltaEvent(ContentDeltaEvent):
"""
Chat completion chunk with thoughts, reasoning support.
"""
type: Literal["thought.delta"]
"""The thought or reasoning generated by the model."""
ChatCompletionStreamWithThoughtEvent = Union[ChatCompletionStreamEvent, ThoughtDeltaEvent]
class ChoiceDeltaWithThoughts(ChoiceDelta):
"""
Chat completion chunk with thoughts, reasoning support.
"""
thought: Optional[str] = None
"""The thought or reasoning generated by the model."""
class ChoiceWithThoughts(Choice):
delta: ChoiceDeltaWithThoughts
class ChatCompletionWithThoughtsChunk(ChatCompletionChunk):
choices: List[ChoiceWithThoughts] # Override the choices type
def default_stream_processor(
chat_stream: ChatCompletionStream,
) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]:
"""
Async generator to cast and return chunks from the standard openai chat completions stream.
"""
for chunk in chat_stream:
yield chunk
async def adefault_stream_processor(
chat_stream: openai.AsyncStream[ChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]:
"""
Async generator to cast and return chunks from the standard openai chat completions stream.
"""
async for chunk in chat_stream:
yield ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump())
async def adeepseek_stream_processor(
chat_stream: openai.AsyncStream[ChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]:
"""
Async generator to cast and return chunks from the deepseek chat completions stream.
"""
async for chunk in chat_stream:
tchunk = ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump())
if (
len(tchunk.choices) > 0
and hasattr(tchunk.choices[0].delta, "reasoning_content")
and tchunk.choices[0].delta.reasoning_content
):
tchunk.choices[0].delta.thought = chunk.choices[0].delta.reasoning_content
yield tchunk
def in_stream_thought_processor(
chat_stream: openai.Stream[ChatCompletionChunk], thought_tag="think"
) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]:
"""
Generator for chat completion with thought chunks.
Assumes <thought_tag>...</thought_tag> can only appear once at the start.
Handles partial tags across streamed chunks.
"""
start_tag = f"<{thought_tag}>"
end_tag = f"</{thought_tag}>"
buf: str = ""
# Modes and transitions: detect_start > thought (optional) > message
mode = "detect_start"
for chunk in default_stream_processor(chat_stream):
if mode == "message" or chunk.type != "content.delta":
# Message mode is terminal, so just yield chunks, no processing
yield chunk
continue
buf += chunk.delta
if mode == "detect_start":
# Try to determine if we start with thought tag
if buf.startswith(start_tag):
# Found start tag, switch mode
buf = buf[len(start_tag) :] # Remove start tag
mode = "thought"
# Fall through to process the rest of the buffer in 'thought' mode *within this iteration*
elif len(buf) >= len(start_tag):
# Buffer is long enough, definitely doesn't start with tag
chunk.delta = buf
yield chunk
mode = "message"
buf = ""
continue
elif start_tag.startswith(buf):
# Buffer is a prefix of the start tag, need more data
continue
else:
# Buffer doesn't match start tag prefix and is shorter than tag
chunk.delta = buf
yield chunk
mode = "message"
buf = ""
continue
if mode == "thought":
# Look for the end tag
idx = buf.find(end_tag)
if idx != -1:
# Found end tag. Yield thought content before it.
if idx > 0 and buf[:idx].strip():
chunk.type = "thought.delta"
chunk.delta = buf[:idx]
yield chunk
# Process content *after* the tag as message
buf = buf[idx + len(end_tag) :]
if buf:
chunk.delta = buf
yield chunk
mode = "message"
buf = ""
continue
else:
# End tag not found yet. Yield thought content, holding back potential partial end tag.
send_upto = len(buf)
# Check if buffer ends with a prefix of end_tag
for i in range(len(end_tag) - 1, 0, -1):
if buf.endswith(end_tag[:i]):
send_upto = len(buf) - i # Don't send the partial tag yet
break
if send_upto > 0 and buf[:send_upto].strip():
chunk.type = "thought.delta"
chunk.delta = buf[:send_upto]
yield chunk
buf = buf[send_upto:] # Keep only the partial tag (or empty)
# Need more data to find the complete end tag
continue
# End of stream handling
if buf:
if mode == "thought": # Stream ended before </think> was found
chunk.type = "thought.delta"
chunk.delta = buf
yield chunk
elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied
# If it wasn't a partial start tag, treat as message
if not start_tag.startswith(buf):
chunk.delta = buf
yield chunk
# else: discard partial <think>
# If mode == "message", buffer should be empty due to logic above, but yield just in case
elif mode == "message":
chunk.delta = buf
yield chunk
async def ain_stream_thought_processor(
chat_stream: openai.AsyncStream[ChatCompletionChunk], thought_tag="think"
) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]:
"""
Async generator for chat completion with thought chunks.
Assumes <thought_tag>...</thought_tag> can only appear once at the start.
Handles partial tags across streamed chunks.
"""
start_tag = f"<{thought_tag}>"
end_tag = f"</{thought_tag}>"
buf: str = ""
# Modes and transitions: detect_start > thought (optional) > message
mode = "detect_start"
async for chunk in adefault_stream_processor(chat_stream):
if len(chunk.choices) == 0:
continue
if mode == "message":
# Message mode is terminal, so just yield chunks, no processing
yield chunk
continue
buf += chunk.choices[0].delta.content
if mode == "detect_start":
# Try to determine if we start with thought tag
if buf.startswith(start_tag):
# Found start tag, switch mode
buf = buf[len(start_tag) :] # Remove start tag
mode = "thought"
# Fall through to process the rest of the buffer in 'thought' mode *within this iteration*
elif len(buf) >= len(start_tag):
# Buffer is long enough, definitely doesn't start with tag
chunk.choices[0].delta.content = buf
yield chunk
mode = "message"
buf = ""
continue
elif start_tag.startswith(buf):
# Buffer is a prefix of the start tag, need more data
continue
else:
# Buffer doesn't match start tag prefix and is shorter than tag
chunk.choices[0].delta.content = buf
yield chunk
mode = "message"
buf = ""
continue
if mode == "thought":
# Look for the end tag
idx = buf.find(end_tag)
if idx != -1:
# Found end tag. Yield thought content before it.
if idx > 0 and buf[:idx].strip():
chunk.choices[0].delta.thought = buf[:idx]
chunk.choices[0].delta.content = ""
yield chunk
# Process content *after* the tag as message
buf = buf[idx + len(end_tag) :]
if buf:
chunk.choices[0].delta.content = buf
yield chunk
mode = "message"
buf = ""
continue
else:
# End tag not found yet. Yield thought content, holding back potential partial end tag.
send_upto = len(buf)
# Check if buffer ends with a prefix of end_tag
for i in range(len(end_tag) - 1, 0, -1):
if buf.endswith(end_tag[:i]):
send_upto = len(buf) - i # Don't send the partial tag yet
break
if send_upto > 0 and buf[:send_upto].strip():
chunk.choices[0].delta.thought = buf[:send_upto]
chunk.choices[0].delta.content = ""
yield chunk
buf = buf[send_upto:] # Keep only the partial tag (or empty)
# Need more data to find the complete end tag
continue
# End of stream handling
if buf:
if mode == "thought": # Stream ended before </think> was found
chunk.choices[0].delta.thought = buf
chunk.choices[0].delta.content = ""
yield chunk
elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied
# If it wasn't a partial start tag, treat as message
if not start_tag.startswith(buf):
chunk.choices[0].delta.content = buf
yield chunk
# else: discard partial <think>
# If mode == "message", buffer should be empty due to logic above, but yield just in case
elif mode == "message":
chunk.choices[0].delta.content = buf
yield chunk

View File

@@ -191,6 +191,7 @@ class ChatEvent(Enum):
REFERENCES = "references"
GENERATED_ASSETS = "generated_assets"
STATUS = "status"
THOUGHT = "thought"
METADATA = "metadata"
USAGE = "usage"
END_RESPONSE = "end_response"
@@ -873,3 +874,9 @@ class JsonSupport(int, Enum):
NONE = 0
OBJECT = 1
SCHEMA = 2
class ResponseWithThought:
def __init__(self, response: str = None, thought: str = None):
self.response = response
self.thought = thought

View File

@@ -25,7 +25,11 @@ from khoj.database.adapters import (
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
from khoj.processor.conversation.utils import (
ResponseWithThought,
defilter_query,
save_to_conversation_log,
)
from khoj.processor.image.generate import text_to_image
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import (
@@ -726,6 +730,16 @@ async def chat(
ttft = time.perf_counter() - start_time
elif event_type == ChatEvent.STATUS:
train_of_thought.append({"type": event_type.value, "data": data})
elif event_type == ChatEvent.THOUGHT:
# Append the data to the last thought as thoughts are streamed
if (
len(train_of_thought) > 0
and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
and type(train_of_thought[-1]["data"]) == type(data) == str
):
train_of_thought[-1]["data"] += data
else:
train_of_thought.append({"type": event_type.value, "data": data})
if event_type == ChatEvent.MESSAGE:
yield data
@@ -1306,10 +1320,6 @@ async def chat(
tracer,
)
# Send Response
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result
continue_stream = True
async for item in llm_response:
# Should not happen with async generator, end is signaled by loop exit. Skip.
@@ -1318,8 +1328,18 @@ async def chat(
if not connection_alive or not continue_stream:
# Drain the generator if disconnected but keep processing internally
continue
message = item.response if isinstance(item, ResponseWithThought) else item
if isinstance(item, ResponseWithThought) and item.thought:
async for result in send_event(ChatEvent.THOUGHT, item.thought):
yield result
continue
# Start sending response
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result
try:
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
async for result in send_event(ChatEvent.MESSAGE, message):
yield result
except Exception as e:
continue_stream = False

View File

@@ -93,6 +93,7 @@ from khoj.processor.conversation.openai.gpt import (
)
from khoj.processor.conversation.utils import (
ChatEvent,
ResponseWithThought,
clean_json,
clean_mermaidjs,
construct_chat_history,
@@ -1432,9 +1433,9 @@ async def agenerate_chat_response(
generated_asset_results: Dict[str, Dict] = {},
is_subscribed: bool = False,
tracer: dict = {},
) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]:
) -> Tuple[AsyncGenerator[str | ResponseWithThought, None], Dict[str, str]]:
# Initialize Variables
chat_response_generator = None
chat_response_generator: AsyncGenerator[str | ResponseWithThought, None] = None
logger.debug(f"Conversation Types: {conversation_commands}")
metadata = {}