diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index e5682bde..e5db2c45 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -100,6 +100,7 @@ def completion_with_backoff( reasoning_effort = "high" if deepthought else "low" model_kwargs["reasoning_effort"] = reasoning_effort elif model_name.startswith("deepseek-reasoner"): + stream_processor = in_stream_thought_processor # Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role. # The first message should always be a user message (except system message). updated_messages: List[dict] = [] @@ -112,7 +113,7 @@ def completion_with_backoff( updated_messages.append(message) formatted_messages = updated_messages elif is_qwen_style_reasoning_model(model_name, api_base_url): - stream_processor = partial(in_stream_thought_processor, thought_tag="think") + stream_processor = in_stream_thought_processor # Reasoning is enabled by default. Disable when deepthought is False. # See https://qwenlm.github.io/blog/qwen3/#advanced-usages if not deepthought: @@ -144,6 +145,14 @@ def completion_with_backoff( elif chunk.type == "tool_calls.function.arguments.done": tool_calls += [ToolCall(name=chunk.name, args=json.loads(chunk.arguments), id=None)] if tool_calls: + # If there are tool calls, aggregate thoughts and responses into thoughts + if thoughts and aggregated_response: + # wrap each line of thought in italics + thoughts = "\n".join([f"*{line.strip()}*" for line in thoughts.splitlines() if line.strip()]) + thoughts = f"{thoughts}\n\n{aggregated_response}" + else: + thoughts = thoughts or aggregated_response + # Json dump tool calls into aggregated response tool_calls = [ ToolCall(name=chunk.name, args=chunk.args, id=tool_id) for chunk, tool_id in zip(tool_calls, tool_ids) ] @@ -158,12 +167,24 @@ def completion_with_backoff( **model_kwargs, ) aggregated_response = chunk.choices[0].message.content + if hasattr(chunk.choices[0].message, "reasoning_content"): + thoughts = chunk.choices[0].message.reasoning_content + else: + thoughts = chunk.choices[0].message.model_extra.get("reasoning_content", "") raw_tool_calls = chunk.choices[0].message.tool_calls if raw_tool_calls: tool_calls = [ ToolCall(name=tool.function.name, args=tool.function.parsed_arguments, id=tool.id) for tool in raw_tool_calls ] + # If there are tool calls, aggregate thoughts and responses into thoughts + if thoughts and aggregated_response: + # wrap each line of thought in italics + thoughts = "\n".join([f"*{line.strip()}*" for line in thoughts.splitlines() if line.strip()]) + thoughts = f"{thoughts}\n\n{aggregated_response}" + else: + thoughts = thoughts or aggregated_response + # Json dump tool calls into aggregated response aggregated_response = json.dumps([tool_call.__dict__ for tool_call in tool_calls]) # Calculate cost of chat @@ -223,7 +244,7 @@ async def chat_completion_with_backoff( openai_async_clients[client_key] = client stream = not is_non_streaming_model(model_name, api_base_url) - stream_processor = adefault_stream_processor + stream_processor = astream_thought_processor if stream: model_kwargs["stream_options"] = {"include_usage": True} else: @@ -251,13 +272,13 @@ async def chat_completion_with_backoff( "content" ] = f"{first_system_message_content}\nFormatting re-enabled" elif is_twitter_reasoning_model(model_name, api_base_url): - stream_processor = adeepseek_stream_processor reasoning_effort = "high" if deepthought else "low" model_kwargs["reasoning_effort"] = reasoning_effort elif model_name.startswith("deepseek-reasoner") or "deepseek-r1" in model_name: - # Official Deepseek reasoner model returns structured thinking output. - # Deepseek r1 served via other AI model API providers return it in response stream - stream_processor = ain_stream_thought_processor if "deepseek-r1" in model_name else adeepseek_stream_processor # type: ignore[assignment] + # Official Deepseek reasoner model and some inference APIs like vLLM return structured thinking output. + # Others like DeepInfra return it in response stream. + # Using the instream thought processor handles both cases, structured thoughts and in response thoughts. + stream_processor = ain_stream_thought_processor # Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role. # The first message should always be a user message (except system message). updated_messages: List[dict] = [] @@ -274,7 +295,7 @@ async def chat_completion_with_backoff( updated_messages.append(message) formatted_messages = updated_messages elif is_qwen_style_reasoning_model(model_name, api_base_url): - stream_processor = partial(ain_stream_thought_processor, thought_tag="think") + stream_processor = ain_stream_thought_processor # Reasoning is enabled by default. Disable when deepthought is False. # See https://qwenlm.github.io/blog/qwen3/#advanced-usages if not deepthought: @@ -551,39 +572,17 @@ def default_stream_processor( chat_stream: ChatCompletionStream, ) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]: """ - Async generator to cast and return chunks from the standard openai chat completions stream. + Generator of chunks from the standard openai chat completions stream. """ for chunk in chat_stream: yield chunk -async def adefault_stream_processor( +async def astream_thought_processor( chat_stream: openai.AsyncStream[ChatCompletionChunk], ) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]: """ - Async generator to cast and return chunks from the standard openai chat completions stream. - """ - async for chunk in chat_stream: - try: - # Validate the chunk has the required fields before processing - chunk_data = chunk.model_dump() - - # Skip chunks that don't have the required object field or have invalid values - if not chunk_data.get("object") or chunk_data.get("object") != "chat.completion.chunk": - logger.warning(f"Skipping invalid chunk with object field: {chunk_data.get('object', 'missing')}") - continue - - yield ChatCompletionWithThoughtsChunk.model_validate(chunk_data) - except Exception as e: - logger.warning(f"Error processing chunk: {e}. Skipping malformed chunk.") - continue - - -async def adeepseek_stream_processor( - chat_stream: openai.AsyncStream[ChatCompletionChunk], -) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]: - """ - Async generator to cast and return chunks from the deepseek chat completions stream. + Async generator of chunks from standard openai chat completions stream with thoughts/reasoning. """ async for chunk in chat_stream: try: @@ -596,12 +595,19 @@ async def adeepseek_stream_processor( continue tchunk = ChatCompletionWithThoughtsChunk.model_validate(chunk_data) + + # Handlle deepseek style response with thoughts. Used by AI APIs like vLLM, sgLang, DeepSeek, LiteLLM. if ( len(tchunk.choices) > 0 and hasattr(tchunk.choices[0].delta, "reasoning_content") and tchunk.choices[0].delta.reasoning_content ): tchunk.choices[0].delta.thought = chunk.choices[0].delta.reasoning_content + + # Handlle llama.cpp server style response with thoughts. + elif len(tchunk.choices) > 0 and tchunk.choices[0].delta.model_extra.get("reasoning_content"): + tchunk.choices[0].delta.thought = tchunk.choices[0].delta.model_extra.get("reasoning_content") + yield tchunk except Exception as e: logger.warning(f"Error processing chunk: {e}. Skipping malformed chunk.") @@ -710,7 +716,7 @@ async def ain_stream_thought_processor( chat_stream: openai.AsyncStream[ChatCompletionChunk], thought_tag="think" ) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]: """ - Async generator for chat completion with thought chunks. + Async generator for chat completion with structured and inline thought chunks. Assumes ... can only appear once at the start. Handles partial tags across streamed chunks. """ @@ -720,7 +726,7 @@ async def ain_stream_thought_processor( # Modes and transitions: detect_start > thought (optional) > message mode = "detect_start" - async for chunk in adefault_stream_processor(chat_stream): + async for chunk in astream_thought_processor(chat_stream): if len(chunk.choices) == 0: continue if mode == "message":