From 6eaf54eb7aa06e383237919ac887ce6755366898 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Fri, 2 May 2025 10:33:42 -0600 Subject: [PATCH] Parse Qwen3 reasoning model thoughts served via OpenAI compatible API The Qwen3 reasoning models return thoughts within tags before response. This change parses the thoughts out from final response from the response stream and returns as structured response with thoughts. These thoughts aren't passed to client yet --- .../processor/conversation/openai/utils.py | 209 ++++++++++++++++++ 1 file changed, 209 insertions(+) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 019c785a..be7ea165 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -79,6 +79,8 @@ def completion_with_backoff( elif is_twitter_reasoning_model(model_name, api_base_url): reasoning_effort = "high" if deepthought else "low" model_kwargs["reasoning_effort"] = reasoning_effort + elif is_qwen_reasoning_model(model_name, api_base_url): + stream_processor = partial(in_stream_thought_processor, thought_tag="think") model_kwargs["stream_options"] = {"include_usage": True} if os.getenv("KHOJ_LLM_SEED"): @@ -189,6 +191,8 @@ async def chat_completion_with_backoff( updated_messages.append(message) formatted_messages = updated_messages + elif is_qwen_reasoning_model(model_name, api_base_url): + stream_processor = partial(ain_stream_thought_processor, thought_tag="think") stream = True model_kwargs["stream_options"] = {"include_usage": True} @@ -279,6 +283,13 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo ) +def is_qwen_reasoning_model(model_name: str, api_base_url: str = None) -> bool: + """ + Check if the model is a Qwen reasoning model + """ + return "qwen3" in model_name.lower() and api_base_url is not None + + class ThoughtDeltaEvent(ContentDeltaEvent): """ Chat completion chunk with thoughts, reasoning support. @@ -326,3 +337,201 @@ async def adefault_stream_processor( """ async for chunk in chat_stream: yield ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump()) + + +def in_stream_thought_processor( + chat_stream: openai.Stream[ChatCompletionChunk], thought_tag="think" +) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]: + """ + Generator for chat completion with thought chunks. + Assumes ... can only appear once at the start. + Handles partial tags across streamed chunks. + """ + start_tag = f"<{thought_tag}>" + end_tag = f"" + buf: str = "" + # Modes and transitions: detect_start > thought (optional) > message + mode = "detect_start" + + for chunk in default_stream_processor(chat_stream): + if mode == "message" or chunk.type != "content.delta": + # Message mode is terminal, so just yield chunks, no processing + yield chunk + continue + + buf += chunk.delta + + if mode == "detect_start": + # Try to determine if we start with thought tag + if buf.startswith(start_tag): + # Found start tag, switch mode + buf = buf[len(start_tag) :] # Remove start tag + mode = "thought" + # Fall through to process the rest of the buffer in 'thought' mode *within this iteration* + elif len(buf) >= len(start_tag): + # Buffer is long enough, definitely doesn't start with tag + chunk.delta = buf + yield chunk + mode = "message" + buf = "" + continue + elif start_tag.startswith(buf): + # Buffer is a prefix of the start tag, need more data + continue + else: + # Buffer doesn't match start tag prefix and is shorter than tag + chunk.delta = buf + yield chunk + mode = "message" + buf = "" + continue + + if mode == "thought": + # Look for the end tag + idx = buf.find(end_tag) + if idx != -1: + # Found end tag. Yield thought content before it. + if idx > 0 and buf[:idx].strip(): + chunk.type = "thought.delta" + chunk.delta = buf[:idx] + yield chunk + # Process content *after* the tag as message + buf = buf[idx + len(end_tag) :] + if buf: + chunk.delta = buf + yield chunk + mode = "message" + buf = "" + continue + else: + # End tag not found yet. Yield thought content, holding back potential partial end tag. + send_upto = len(buf) + # Check if buffer ends with a prefix of end_tag + for i in range(len(end_tag) - 1, 0, -1): + if buf.endswith(end_tag[:i]): + send_upto = len(buf) - i # Don't send the partial tag yet + break + if send_upto > 0 and buf[:send_upto].strip(): + chunk.type = "thought.delta" + chunk.delta = buf[:send_upto] + yield chunk + buf = buf[send_upto:] # Keep only the partial tag (or empty) + # Need more data to find the complete end tag + continue + + # End of stream handling + if buf: + if mode == "thought": # Stream ended before was found + chunk.type = "thought.delta" + chunk.delta = buf + yield chunk + elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied + # If it wasn't a partial start tag, treat as message + if not start_tag.startswith(buf): + chunk.delta = buf + yield chunk + # else: discard partial + # If mode == "message", buffer should be empty due to logic above, but yield just in case + elif mode == "message": + chunk.delta = buf + yield chunk + + +async def ain_stream_thought_processor( + chat_stream: openai.AsyncStream[ChatCompletionChunk], thought_tag="think" +) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]: + """ + Async generator for chat completion with thought chunks. + Assumes ... can only appear once at the start. + Handles partial tags across streamed chunks. + """ + start_tag = f"<{thought_tag}>" + end_tag = f"" + buf: str = "" + # Modes and transitions: detect_start > thought (optional) > message + mode = "detect_start" + + async for chunk in adefault_stream_processor(chat_stream): + if len(chunk.choices) == 0: + continue + if mode == "message": + # Message mode is terminal, so just yield chunks, no processing + yield chunk + continue + + buf += chunk.choices[0].delta.content + + if mode == "detect_start": + # Try to determine if we start with thought tag + if buf.startswith(start_tag): + # Found start tag, switch mode + buf = buf[len(start_tag) :] # Remove start tag + mode = "thought" + # Fall through to process the rest of the buffer in 'thought' mode *within this iteration* + elif len(buf) >= len(start_tag): + # Buffer is long enough, definitely doesn't start with tag + chunk.choices[0].delta.content = buf + yield chunk + mode = "message" + buf = "" + continue + elif start_tag.startswith(buf): + # Buffer is a prefix of the start tag, need more data + continue + else: + # Buffer doesn't match start tag prefix and is shorter than tag + chunk.choices[0].delta.content = buf + yield chunk + mode = "message" + buf = "" + continue + + if mode == "thought": + # Look for the end tag + idx = buf.find(end_tag) + if idx != -1: + # Found end tag. Yield thought content before it. + if idx > 0 and buf[:idx].strip(): + chunk.choices[0].delta.thought = buf[:idx] + chunk.choices[0].delta.content = "" + yield chunk + # Process content *after* the tag as message + buf = buf[idx + len(end_tag) :] + if buf: + chunk.choices[0].delta.content = buf + yield chunk + mode = "message" + buf = "" + continue + else: + # End tag not found yet. Yield thought content, holding back potential partial end tag. + send_upto = len(buf) + # Check if buffer ends with a prefix of end_tag + for i in range(len(end_tag) - 1, 0, -1): + if buf.endswith(end_tag[:i]): + send_upto = len(buf) - i # Don't send the partial tag yet + break + if send_upto > 0 and buf[:send_upto].strip(): + chunk.choices[0].delta.thought = buf[:send_upto] + chunk.choices[0].delta.content = "" + yield chunk + buf = buf[send_upto:] # Keep only the partial tag (or empty) + # Need more data to find the complete end tag + continue + + # End of stream handling + if buf: + if mode == "thought": # Stream ended before was found + chunk.choices[0].delta.thought = buf + chunk.choices[0].delta.content = "" + yield chunk + elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied + # If it wasn't a partial start tag, treat as message + if not start_tag.startswith(buf): + chunk.choices[0].delta.content = buf + yield chunk + # else: discard partial + # If mode == "message", buffer should be empty due to logic above, but yield just in case + elif mode == "message": + chunk.choices[0].delta.content = buf + yield chunk