diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 019c785a..be7ea165 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -79,6 +79,8 @@ def completion_with_backoff( elif is_twitter_reasoning_model(model_name, api_base_url): reasoning_effort = "high" if deepthought else "low" model_kwargs["reasoning_effort"] = reasoning_effort + elif is_qwen_reasoning_model(model_name, api_base_url): + stream_processor = partial(in_stream_thought_processor, thought_tag="think") model_kwargs["stream_options"] = {"include_usage": True} if os.getenv("KHOJ_LLM_SEED"): @@ -189,6 +191,8 @@ async def chat_completion_with_backoff( updated_messages.append(message) formatted_messages = updated_messages + elif is_qwen_reasoning_model(model_name, api_base_url): + stream_processor = partial(ain_stream_thought_processor, thought_tag="think") stream = True model_kwargs["stream_options"] = {"include_usage": True} @@ -279,6 +283,13 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo ) +def is_qwen_reasoning_model(model_name: str, api_base_url: str = None) -> bool: + """ + Check if the model is a Qwen reasoning model + """ + return "qwen3" in model_name.lower() and api_base_url is not None + + class ThoughtDeltaEvent(ContentDeltaEvent): """ Chat completion chunk with thoughts, reasoning support. @@ -326,3 +337,201 @@ async def adefault_stream_processor( """ async for chunk in chat_stream: yield ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump()) + + +def in_stream_thought_processor( + chat_stream: openai.Stream[ChatCompletionChunk], thought_tag="think" +) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]: + """ + Generator for chat completion with thought chunks. + Assumes ... can only appear once at the start. + Handles partial tags across streamed chunks. + """ + start_tag = f"<{thought_tag}>" + end_tag = f"" + buf: str = "" + # Modes and transitions: detect_start > thought (optional) > message + mode = "detect_start" + + for chunk in default_stream_processor(chat_stream): + if mode == "message" or chunk.type != "content.delta": + # Message mode is terminal, so just yield chunks, no processing + yield chunk + continue + + buf += chunk.delta + + if mode == "detect_start": + # Try to determine if we start with thought tag + if buf.startswith(start_tag): + # Found start tag, switch mode + buf = buf[len(start_tag) :] # Remove start tag + mode = "thought" + # Fall through to process the rest of the buffer in 'thought' mode *within this iteration* + elif len(buf) >= len(start_tag): + # Buffer is long enough, definitely doesn't start with tag + chunk.delta = buf + yield chunk + mode = "message" + buf = "" + continue + elif start_tag.startswith(buf): + # Buffer is a prefix of the start tag, need more data + continue + else: + # Buffer doesn't match start tag prefix and is shorter than tag + chunk.delta = buf + yield chunk + mode = "message" + buf = "" + continue + + if mode == "thought": + # Look for the end tag + idx = buf.find(end_tag) + if idx != -1: + # Found end tag. Yield thought content before it. + if idx > 0 and buf[:idx].strip(): + chunk.type = "thought.delta" + chunk.delta = buf[:idx] + yield chunk + # Process content *after* the tag as message + buf = buf[idx + len(end_tag) :] + if buf: + chunk.delta = buf + yield chunk + mode = "message" + buf = "" + continue + else: + # End tag not found yet. Yield thought content, holding back potential partial end tag. + send_upto = len(buf) + # Check if buffer ends with a prefix of end_tag + for i in range(len(end_tag) - 1, 0, -1): + if buf.endswith(end_tag[:i]): + send_upto = len(buf) - i # Don't send the partial tag yet + break + if send_upto > 0 and buf[:send_upto].strip(): + chunk.type = "thought.delta" + chunk.delta = buf[:send_upto] + yield chunk + buf = buf[send_upto:] # Keep only the partial tag (or empty) + # Need more data to find the complete end tag + continue + + # End of stream handling + if buf: + if mode == "thought": # Stream ended before was found + chunk.type = "thought.delta" + chunk.delta = buf + yield chunk + elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied + # If it wasn't a partial start tag, treat as message + if not start_tag.startswith(buf): + chunk.delta = buf + yield chunk + # else: discard partial + # If mode == "message", buffer should be empty due to logic above, but yield just in case + elif mode == "message": + chunk.delta = buf + yield chunk + + +async def ain_stream_thought_processor( + chat_stream: openai.AsyncStream[ChatCompletionChunk], thought_tag="think" +) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]: + """ + Async generator for chat completion with thought chunks. + Assumes ... can only appear once at the start. + Handles partial tags across streamed chunks. + """ + start_tag = f"<{thought_tag}>" + end_tag = f"" + buf: str = "" + # Modes and transitions: detect_start > thought (optional) > message + mode = "detect_start" + + async for chunk in adefault_stream_processor(chat_stream): + if len(chunk.choices) == 0: + continue + if mode == "message": + # Message mode is terminal, so just yield chunks, no processing + yield chunk + continue + + buf += chunk.choices[0].delta.content + + if mode == "detect_start": + # Try to determine if we start with thought tag + if buf.startswith(start_tag): + # Found start tag, switch mode + buf = buf[len(start_tag) :] # Remove start tag + mode = "thought" + # Fall through to process the rest of the buffer in 'thought' mode *within this iteration* + elif len(buf) >= len(start_tag): + # Buffer is long enough, definitely doesn't start with tag + chunk.choices[0].delta.content = buf + yield chunk + mode = "message" + buf = "" + continue + elif start_tag.startswith(buf): + # Buffer is a prefix of the start tag, need more data + continue + else: + # Buffer doesn't match start tag prefix and is shorter than tag + chunk.choices[0].delta.content = buf + yield chunk + mode = "message" + buf = "" + continue + + if mode == "thought": + # Look for the end tag + idx = buf.find(end_tag) + if idx != -1: + # Found end tag. Yield thought content before it. + if idx > 0 and buf[:idx].strip(): + chunk.choices[0].delta.thought = buf[:idx] + chunk.choices[0].delta.content = "" + yield chunk + # Process content *after* the tag as message + buf = buf[idx + len(end_tag) :] + if buf: + chunk.choices[0].delta.content = buf + yield chunk + mode = "message" + buf = "" + continue + else: + # End tag not found yet. Yield thought content, holding back potential partial end tag. + send_upto = len(buf) + # Check if buffer ends with a prefix of end_tag + for i in range(len(end_tag) - 1, 0, -1): + if buf.endswith(end_tag[:i]): + send_upto = len(buf) - i # Don't send the partial tag yet + break + if send_upto > 0 and buf[:send_upto].strip(): + chunk.choices[0].delta.thought = buf[:send_upto] + chunk.choices[0].delta.content = "" + yield chunk + buf = buf[send_upto:] # Keep only the partial tag (or empty) + # Need more data to find the complete end tag + continue + + # End of stream handling + if buf: + if mode == "thought": # Stream ended before was found + chunk.choices[0].delta.thought = buf + chunk.choices[0].delta.content = "" + yield chunk + elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied + # If it wasn't a partial start tag, treat as message + if not start_tag.startswith(buf): + chunk.choices[0].delta.content = buf + yield chunk + # else: discard partial + # If mode == "message", buffer should be empty due to logic above, but yield just in case + elif mode == "message": + chunk.choices[0].delta.content = buf + yield chunk