Parse Qwen3 reasoning model thoughts served via OpenAI compatible API

The Qwen3 reasoning models return thoughts within <think></think> tags
before response.

This change parses the thoughts out from final response from the
response stream and returns as structured response with thoughts.

These thoughts aren't passed to client yet
This commit is contained in:
Debanjum
2025-05-02 10:33:42 -06:00
parent 7b9f2c21c7
commit 6eaf54eb7a

View File

@@ -79,6 +79,8 @@ def completion_with_backoff(
elif is_twitter_reasoning_model(model_name, api_base_url):
reasoning_effort = "high" if deepthought else "low"
model_kwargs["reasoning_effort"] = reasoning_effort
elif is_qwen_reasoning_model(model_name, api_base_url):
stream_processor = partial(in_stream_thought_processor, thought_tag="think")
model_kwargs["stream_options"] = {"include_usage": True}
if os.getenv("KHOJ_LLM_SEED"):
@@ -189,6 +191,8 @@ async def chat_completion_with_backoff(
updated_messages.append(message)
formatted_messages = updated_messages
elif is_qwen_reasoning_model(model_name, api_base_url):
stream_processor = partial(ain_stream_thought_processor, thought_tag="think")
stream = True
model_kwargs["stream_options"] = {"include_usage": True}
@@ -279,6 +283,13 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo
)
def is_qwen_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
"""
Check if the model is a Qwen reasoning model
"""
return "qwen3" in model_name.lower() and api_base_url is not None
class ThoughtDeltaEvent(ContentDeltaEvent):
"""
Chat completion chunk with thoughts, reasoning support.
@@ -326,3 +337,201 @@ async def adefault_stream_processor(
"""
async for chunk in chat_stream:
yield ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump())
def in_stream_thought_processor(
chat_stream: openai.Stream[ChatCompletionChunk], thought_tag="think"
) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]:
"""
Generator for chat completion with thought chunks.
Assumes <thought_tag>...</thought_tag> can only appear once at the start.
Handles partial tags across streamed chunks.
"""
start_tag = f"<{thought_tag}>"
end_tag = f"</{thought_tag}>"
buf: str = ""
# Modes and transitions: detect_start > thought (optional) > message
mode = "detect_start"
for chunk in default_stream_processor(chat_stream):
if mode == "message" or chunk.type != "content.delta":
# Message mode is terminal, so just yield chunks, no processing
yield chunk
continue
buf += chunk.delta
if mode == "detect_start":
# Try to determine if we start with thought tag
if buf.startswith(start_tag):
# Found start tag, switch mode
buf = buf[len(start_tag) :] # Remove start tag
mode = "thought"
# Fall through to process the rest of the buffer in 'thought' mode *within this iteration*
elif len(buf) >= len(start_tag):
# Buffer is long enough, definitely doesn't start with tag
chunk.delta = buf
yield chunk
mode = "message"
buf = ""
continue
elif start_tag.startswith(buf):
# Buffer is a prefix of the start tag, need more data
continue
else:
# Buffer doesn't match start tag prefix and is shorter than tag
chunk.delta = buf
yield chunk
mode = "message"
buf = ""
continue
if mode == "thought":
# Look for the end tag
idx = buf.find(end_tag)
if idx != -1:
# Found end tag. Yield thought content before it.
if idx > 0 and buf[:idx].strip():
chunk.type = "thought.delta"
chunk.delta = buf[:idx]
yield chunk
# Process content *after* the tag as message
buf = buf[idx + len(end_tag) :]
if buf:
chunk.delta = buf
yield chunk
mode = "message"
buf = ""
continue
else:
# End tag not found yet. Yield thought content, holding back potential partial end tag.
send_upto = len(buf)
# Check if buffer ends with a prefix of end_tag
for i in range(len(end_tag) - 1, 0, -1):
if buf.endswith(end_tag[:i]):
send_upto = len(buf) - i # Don't send the partial tag yet
break
if send_upto > 0 and buf[:send_upto].strip():
chunk.type = "thought.delta"
chunk.delta = buf[:send_upto]
yield chunk
buf = buf[send_upto:] # Keep only the partial tag (or empty)
# Need more data to find the complete end tag
continue
# End of stream handling
if buf:
if mode == "thought": # Stream ended before </think> was found
chunk.type = "thought.delta"
chunk.delta = buf
yield chunk
elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied
# If it wasn't a partial start tag, treat as message
if not start_tag.startswith(buf):
chunk.delta = buf
yield chunk
# else: discard partial <think>
# If mode == "message", buffer should be empty due to logic above, but yield just in case
elif mode == "message":
chunk.delta = buf
yield chunk
async def ain_stream_thought_processor(
chat_stream: openai.AsyncStream[ChatCompletionChunk], thought_tag="think"
) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]:
"""
Async generator for chat completion with thought chunks.
Assumes <thought_tag>...</thought_tag> can only appear once at the start.
Handles partial tags across streamed chunks.
"""
start_tag = f"<{thought_tag}>"
end_tag = f"</{thought_tag}>"
buf: str = ""
# Modes and transitions: detect_start > thought (optional) > message
mode = "detect_start"
async for chunk in adefault_stream_processor(chat_stream):
if len(chunk.choices) == 0:
continue
if mode == "message":
# Message mode is terminal, so just yield chunks, no processing
yield chunk
continue
buf += chunk.choices[0].delta.content
if mode == "detect_start":
# Try to determine if we start with thought tag
if buf.startswith(start_tag):
# Found start tag, switch mode
buf = buf[len(start_tag) :] # Remove start tag
mode = "thought"
# Fall through to process the rest of the buffer in 'thought' mode *within this iteration*
elif len(buf) >= len(start_tag):
# Buffer is long enough, definitely doesn't start with tag
chunk.choices[0].delta.content = buf
yield chunk
mode = "message"
buf = ""
continue
elif start_tag.startswith(buf):
# Buffer is a prefix of the start tag, need more data
continue
else:
# Buffer doesn't match start tag prefix and is shorter than tag
chunk.choices[0].delta.content = buf
yield chunk
mode = "message"
buf = ""
continue
if mode == "thought":
# Look for the end tag
idx = buf.find(end_tag)
if idx != -1:
# Found end tag. Yield thought content before it.
if idx > 0 and buf[:idx].strip():
chunk.choices[0].delta.thought = buf[:idx]
chunk.choices[0].delta.content = ""
yield chunk
# Process content *after* the tag as message
buf = buf[idx + len(end_tag) :]
if buf:
chunk.choices[0].delta.content = buf
yield chunk
mode = "message"
buf = ""
continue
else:
# End tag not found yet. Yield thought content, holding back potential partial end tag.
send_upto = len(buf)
# Check if buffer ends with a prefix of end_tag
for i in range(len(end_tag) - 1, 0, -1):
if buf.endswith(end_tag[:i]):
send_upto = len(buf) - i # Don't send the partial tag yet
break
if send_upto > 0 and buf[:send_upto].strip():
chunk.choices[0].delta.thought = buf[:send_upto]
chunk.choices[0].delta.content = ""
yield chunk
buf = buf[send_upto:] # Keep only the partial tag (or empty)
# Need more data to find the complete end tag
continue
# End of stream handling
if buf:
if mode == "thought": # Stream ended before </think> was found
chunk.choices[0].delta.thought = buf
chunk.choices[0].delta.content = ""
yield chunk
elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied
# If it wasn't a partial start tag, treat as message
if not start_tag.startswith(buf):
chunk.choices[0].delta.content = buf
yield chunk
# else: discard partial <think>
# If mode == "message", buffer should be empty due to logic above, but yield just in case
elif mode == "message":
chunk.choices[0].delta.content = buf
yield chunk