diff --git a/src/interface/web/app/common/chatFunctions.ts b/src/interface/web/app/common/chatFunctions.ts index 9ccd1316..4823659a 100644 --- a/src/interface/web/app/common/chatFunctions.ts +++ b/src/interface/web/app/common/chatFunctions.ts @@ -97,6 +97,17 @@ export function processMessageChunk( console.log(`status: ${chunk.data}`); const statusMessage = chunk.data as string; currentMessage.trainOfThought.push(statusMessage); + } else if (chunk.type === "thought") { + const thoughtChunk = chunk.data as string; + const lastThoughtIndex = currentMessage.trainOfThought.length - 1; + const previousThought = + lastThoughtIndex >= 0 ? currentMessage.trainOfThought[lastThoughtIndex] : ""; + // If the last train of thought started with "Thinking: " append the new thought chunk to it + if (previousThought.startsWith("**Thinking:** ")) { + currentMessage.trainOfThought[lastThoughtIndex] += thoughtChunk; + } else { + currentMessage.trainOfThought.push(`**Thinking:** ${thoughtChunk}`); + } } else if (chunk.type === "references") { const references = chunk.data as RawReferenceData; diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index f76126ab..aba69dfb 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -14,6 +14,7 @@ from khoj.processor.conversation.anthropic.utils import ( format_messages_for_anthropic, ) from khoj.processor.conversation.utils import ( + ResponseWithThought, clean_json, construct_structured_message, generate_chatml_messages_with_context, @@ -162,7 +163,7 @@ async def converse_anthropic( generated_asset_results: Dict[str, Dict] = {}, deepthought: Optional[bool] = False, tracer: dict = {}, -) -> AsyncGenerator[str, None]: +) -> AsyncGenerator[ResponseWithThought, None]: """ Converse with user using Anthropic's Claude """ @@ -247,7 +248,8 @@ async def converse_anthropic( deepthought=deepthought, tracer=tracer, ): - full_response += chunk + if chunk.response: + full_response += chunk.response yield chunk # Call completion_func once finish streaming and we have the full response diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 442e1cb3..baf8fade 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -13,6 +13,7 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ( + ResponseWithThought, commit_conversation_trace, get_image_from_base64, get_image_from_url, @@ -154,13 +155,23 @@ async def anthropic_chat_completion_with_backoff( max_tokens=max_tokens, **model_kwargs, ) as stream: - async for text in stream.text_stream: + async for chunk in stream: # Log the time taken to start response if aggregated_response == "": logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Skip empty chunks + if chunk.type != "content_block_delta": + continue # Handle streamed response chunk - aggregated_response += text - yield text + response_chunk: ResponseWithThought = None + if chunk.delta.type == "text_delta": + response_chunk = ResponseWithThought(response=chunk.delta.text) + aggregated_response += chunk.delta.text + if chunk.delta.type == "thinking_delta": + response_chunk = ResponseWithThought(thought=chunk.delta.thinking) + # Handle streamed response chunk + if response_chunk: + yield response_chunk final_message = await stream.get_final_message() # Log the time taken to stream the entire response diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 4808bc60..65b2d83f 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -17,6 +17,7 @@ from khoj.processor.conversation.openai.utils import ( ) from khoj.processor.conversation.utils import ( JsonSupport, + ResponseWithThought, clean_json, construct_structured_message, generate_chatml_messages_with_context, @@ -188,7 +189,7 @@ async def converse_openai( program_execution_context: List[str] = None, deepthought: Optional[bool] = False, tracer: dict = {}, -) -> AsyncGenerator[str, None]: +) -> AsyncGenerator[ResponseWithThought, None]: """ Converse with user using OpenAI's ChatGPT """ @@ -273,7 +274,8 @@ async def converse_openai( model_kwargs={"stop": ["Notes:\n["]}, tracer=tracer, ): - full_response += chunk + if chunk.response: + full_response += chunk.response yield chunk # Call completion_func once finish streaming and we have the full response diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 7fab44aa..3a1b8947 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -1,12 +1,21 @@ import logging import os +from functools import partial from time import perf_counter -from typing import AsyncGenerator, Dict, List +from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union from urllib.parse import urlparse import openai -from openai.types.chat.chat_completion import ChatCompletion -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.lib.streaming.chat import ( + ChatCompletionStream, + ChatCompletionStreamEvent, + ContentDeltaEvent, +) +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, +) from tenacity import ( before_sleep_log, retry, @@ -16,7 +25,11 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace +from khoj.processor.conversation.utils import ( + JsonSupport, + ResponseWithThought, + commit_conversation_trace, +) from khoj.utils.helpers import ( get_chat_usage_metrics, get_openai_async_client, @@ -59,6 +72,7 @@ def completion_with_backoff( client = get_openai_client(openai_api_key, api_base_url) openai_clients[client_key] = client + stream_processor = default_stream_processor formatted_messages = [{"role": message.role, "content": message.content} for message in messages] # Tune reasoning models arguments @@ -69,6 +83,24 @@ def completion_with_backoff( elif is_twitter_reasoning_model(model_name, api_base_url): reasoning_effort = "high" if deepthought else "low" model_kwargs["reasoning_effort"] = reasoning_effort + elif model_name.startswith("deepseek-reasoner"): + # Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role. + # The first message should always be a user message (except system message). + updated_messages: List[dict] = [] + for i, message in enumerate(formatted_messages): + if i > 0 and message["role"] == formatted_messages[i - 1]["role"]: + updated_messages[-1]["content"] += " " + message["content"] + elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant": + updated_messages[-1]["content"] += " " + message["content"] + else: + updated_messages.append(message) + formatted_messages = updated_messages + elif is_qwen_reasoning_model(model_name, api_base_url): + stream_processor = partial(in_stream_thought_processor, thought_tag="think") + # Reasoning is enabled by default. Disable when deepthought is False. + # See https://qwenlm.github.io/blog/qwen3/#advanced-usages + if not deepthought and len(formatted_messages) > 0: + formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think" model_kwargs["stream_options"] = {"include_usage": True} if os.getenv("KHOJ_LLM_SEED"): @@ -82,12 +114,11 @@ def completion_with_backoff( timeout=20, **model_kwargs, ) as chat: - for chunk in chat: - if chunk.type == "error": - logger.error(f"Openai api response error: {chunk.error}", exc_info=True) - continue - elif chunk.type == "content.delta": + for chunk in stream_processor(chat): + if chunk.type == "content.delta": aggregated_response += chunk.delta + elif chunk.type == "thought.delta": + pass # Calculate cost of chat input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 @@ -124,14 +155,14 @@ def completion_with_backoff( ) async def chat_completion_with_backoff( messages, - model_name, + model_name: str, temperature, openai_api_key=None, api_base_url=None, deepthought=False, model_kwargs: dict = {}, tracer: dict = {}, -) -> AsyncGenerator[str, None]: +) -> AsyncGenerator[ResponseWithThought, None]: try: client_key = f"{openai_api_key}--{api_base_url}" client = openai_async_clients.get(client_key) @@ -139,6 +170,7 @@ async def chat_completion_with_backoff( client = get_openai_async_client(openai_api_key, api_base_url) openai_async_clients[client_key] = client + stream_processor = adefault_stream_processor formatted_messages = [{"role": message.role, "content": message.content} for message in messages] # Configure thinking for openai reasoning models @@ -161,9 +193,11 @@ async def chat_completion_with_backoff( "content" ] = f"{first_system_message_content}\nFormatting re-enabled" elif is_twitter_reasoning_model(model_name, api_base_url): + stream_processor = adeepseek_stream_processor reasoning_effort = "high" if deepthought else "low" model_kwargs["reasoning_effort"] = reasoning_effort elif model_name.startswith("deepseek-reasoner"): + stream_processor = adeepseek_stream_processor # Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role. # The first message should always be a user message (except system message). updated_messages: List[dict] = [] @@ -174,8 +208,13 @@ async def chat_completion_with_backoff( updated_messages[-1]["content"] += " " + message["content"] else: updated_messages.append(message) - formatted_messages = updated_messages + elif is_qwen_reasoning_model(model_name, api_base_url): + stream_processor = partial(ain_stream_thought_processor, thought_tag="think") + # Reasoning is enabled by default. Disable when deepthought is False. + # See https://qwenlm.github.io/blog/qwen3/#advanced-usages + if not deepthought and len(formatted_messages) > 0: + formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think" stream = True model_kwargs["stream_options"] = {"include_usage": True} @@ -193,24 +232,25 @@ async def chat_completion_with_backoff( timeout=20, **model_kwargs, ) - async for chunk in chat_stream: + async for chunk in stream_processor(chat_stream): # Log the time taken to start response if final_chunk is None: logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") # Keep track of the last chunk for usage data final_chunk = chunk - # Handle streamed response chunk + # Skip empty chunks if len(chunk.choices) == 0: continue - delta_chunk = chunk.choices[0].delta - text_chunk = "" - if isinstance(delta_chunk, str): - text_chunk = delta_chunk - elif delta_chunk and delta_chunk.content: - text_chunk = delta_chunk.content - if text_chunk: - aggregated_response += text_chunk - yield text_chunk + # Handle streamed response chunk + response_chunk: ResponseWithThought = None + response_delta = chunk.choices[0].delta + if response_delta.content: + response_chunk = ResponseWithThought(response=response_delta.content) + aggregated_response += response_chunk.response + elif response_delta.thought: + response_chunk = ResponseWithThought(thought=response_delta.thought) + if response_chunk: + yield response_chunk # Log the time taken to stream the entire response logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") @@ -264,3 +304,274 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo and api_base_url is not None and api_base_url.startswith("https://api.x.ai/v1") ) + + +def is_qwen_reasoning_model(model_name: str, api_base_url: str = None) -> bool: + """ + Check if the model is a Qwen reasoning model + """ + return "qwen3" in model_name.lower() and api_base_url is not None + + +class ThoughtDeltaEvent(ContentDeltaEvent): + """ + Chat completion chunk with thoughts, reasoning support. + """ + + type: Literal["thought.delta"] + """The thought or reasoning generated by the model.""" + + +ChatCompletionStreamWithThoughtEvent = Union[ChatCompletionStreamEvent, ThoughtDeltaEvent] + + +class ChoiceDeltaWithThoughts(ChoiceDelta): + """ + Chat completion chunk with thoughts, reasoning support. + """ + + thought: Optional[str] = None + """The thought or reasoning generated by the model.""" + + +class ChoiceWithThoughts(Choice): + delta: ChoiceDeltaWithThoughts + + +class ChatCompletionWithThoughtsChunk(ChatCompletionChunk): + choices: List[ChoiceWithThoughts] # Override the choices type + + +def default_stream_processor( + chat_stream: ChatCompletionStream, +) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]: + """ + Async generator to cast and return chunks from the standard openai chat completions stream. + """ + for chunk in chat_stream: + yield chunk + + +async def adefault_stream_processor( + chat_stream: openai.AsyncStream[ChatCompletionChunk], +) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]: + """ + Async generator to cast and return chunks from the standard openai chat completions stream. + """ + async for chunk in chat_stream: + yield ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump()) + + +async def adeepseek_stream_processor( + chat_stream: openai.AsyncStream[ChatCompletionChunk], +) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]: + """ + Async generator to cast and return chunks from the deepseek chat completions stream. + """ + async for chunk in chat_stream: + tchunk = ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump()) + if ( + len(tchunk.choices) > 0 + and hasattr(tchunk.choices[0].delta, "reasoning_content") + and tchunk.choices[0].delta.reasoning_content + ): + tchunk.choices[0].delta.thought = chunk.choices[0].delta.reasoning_content + yield tchunk + + +def in_stream_thought_processor( + chat_stream: openai.Stream[ChatCompletionChunk], thought_tag="think" +) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]: + """ + Generator for chat completion with thought chunks. + Assumes ... can only appear once at the start. + Handles partial tags across streamed chunks. + """ + start_tag = f"<{thought_tag}>" + end_tag = f"" + buf: str = "" + # Modes and transitions: detect_start > thought (optional) > message + mode = "detect_start" + + for chunk in default_stream_processor(chat_stream): + if mode == "message" or chunk.type != "content.delta": + # Message mode is terminal, so just yield chunks, no processing + yield chunk + continue + + buf += chunk.delta + + if mode == "detect_start": + # Try to determine if we start with thought tag + if buf.startswith(start_tag): + # Found start tag, switch mode + buf = buf[len(start_tag) :] # Remove start tag + mode = "thought" + # Fall through to process the rest of the buffer in 'thought' mode *within this iteration* + elif len(buf) >= len(start_tag): + # Buffer is long enough, definitely doesn't start with tag + chunk.delta = buf + yield chunk + mode = "message" + buf = "" + continue + elif start_tag.startswith(buf): + # Buffer is a prefix of the start tag, need more data + continue + else: + # Buffer doesn't match start tag prefix and is shorter than tag + chunk.delta = buf + yield chunk + mode = "message" + buf = "" + continue + + if mode == "thought": + # Look for the end tag + idx = buf.find(end_tag) + if idx != -1: + # Found end tag. Yield thought content before it. + if idx > 0 and buf[:idx].strip(): + chunk.type = "thought.delta" + chunk.delta = buf[:idx] + yield chunk + # Process content *after* the tag as message + buf = buf[idx + len(end_tag) :] + if buf: + chunk.delta = buf + yield chunk + mode = "message" + buf = "" + continue + else: + # End tag not found yet. Yield thought content, holding back potential partial end tag. + send_upto = len(buf) + # Check if buffer ends with a prefix of end_tag + for i in range(len(end_tag) - 1, 0, -1): + if buf.endswith(end_tag[:i]): + send_upto = len(buf) - i # Don't send the partial tag yet + break + if send_upto > 0 and buf[:send_upto].strip(): + chunk.type = "thought.delta" + chunk.delta = buf[:send_upto] + yield chunk + buf = buf[send_upto:] # Keep only the partial tag (or empty) + # Need more data to find the complete end tag + continue + + # End of stream handling + if buf: + if mode == "thought": # Stream ended before was found + chunk.type = "thought.delta" + chunk.delta = buf + yield chunk + elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied + # If it wasn't a partial start tag, treat as message + if not start_tag.startswith(buf): + chunk.delta = buf + yield chunk + # else: discard partial + # If mode == "message", buffer should be empty due to logic above, but yield just in case + elif mode == "message": + chunk.delta = buf + yield chunk + + +async def ain_stream_thought_processor( + chat_stream: openai.AsyncStream[ChatCompletionChunk], thought_tag="think" +) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]: + """ + Async generator for chat completion with thought chunks. + Assumes ... can only appear once at the start. + Handles partial tags across streamed chunks. + """ + start_tag = f"<{thought_tag}>" + end_tag = f"" + buf: str = "" + # Modes and transitions: detect_start > thought (optional) > message + mode = "detect_start" + + async for chunk in adefault_stream_processor(chat_stream): + if len(chunk.choices) == 0: + continue + if mode == "message": + # Message mode is terminal, so just yield chunks, no processing + yield chunk + continue + + buf += chunk.choices[0].delta.content + + if mode == "detect_start": + # Try to determine if we start with thought tag + if buf.startswith(start_tag): + # Found start tag, switch mode + buf = buf[len(start_tag) :] # Remove start tag + mode = "thought" + # Fall through to process the rest of the buffer in 'thought' mode *within this iteration* + elif len(buf) >= len(start_tag): + # Buffer is long enough, definitely doesn't start with tag + chunk.choices[0].delta.content = buf + yield chunk + mode = "message" + buf = "" + continue + elif start_tag.startswith(buf): + # Buffer is a prefix of the start tag, need more data + continue + else: + # Buffer doesn't match start tag prefix and is shorter than tag + chunk.choices[0].delta.content = buf + yield chunk + mode = "message" + buf = "" + continue + + if mode == "thought": + # Look for the end tag + idx = buf.find(end_tag) + if idx != -1: + # Found end tag. Yield thought content before it. + if idx > 0 and buf[:idx].strip(): + chunk.choices[0].delta.thought = buf[:idx] + chunk.choices[0].delta.content = "" + yield chunk + # Process content *after* the tag as message + buf = buf[idx + len(end_tag) :] + if buf: + chunk.choices[0].delta.content = buf + yield chunk + mode = "message" + buf = "" + continue + else: + # End tag not found yet. Yield thought content, holding back potential partial end tag. + send_upto = len(buf) + # Check if buffer ends with a prefix of end_tag + for i in range(len(end_tag) - 1, 0, -1): + if buf.endswith(end_tag[:i]): + send_upto = len(buf) - i # Don't send the partial tag yet + break + if send_upto > 0 and buf[:send_upto].strip(): + chunk.choices[0].delta.thought = buf[:send_upto] + chunk.choices[0].delta.content = "" + yield chunk + buf = buf[send_upto:] # Keep only the partial tag (or empty) + # Need more data to find the complete end tag + continue + + # End of stream handling + if buf: + if mode == "thought": # Stream ended before was found + chunk.choices[0].delta.thought = buf + chunk.choices[0].delta.content = "" + yield chunk + elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied + # If it wasn't a partial start tag, treat as message + if not start_tag.startswith(buf): + chunk.choices[0].delta.content = buf + yield chunk + # else: discard partial + # If mode == "message", buffer should be empty due to logic above, but yield just in case + elif mode == "message": + chunk.choices[0].delta.content = buf + yield chunk diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 9601a5cd..e86834f9 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -191,6 +191,7 @@ class ChatEvent(Enum): REFERENCES = "references" GENERATED_ASSETS = "generated_assets" STATUS = "status" + THOUGHT = "thought" METADATA = "metadata" USAGE = "usage" END_RESPONSE = "end_response" @@ -873,3 +874,9 @@ class JsonSupport(int, Enum): NONE = 0 OBJECT = 1 SCHEMA = 2 + + +class ResponseWithThought: + def __init__(self, response: str = None, thought: str = None): + self.response = response + self.thought = thought diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index a15f788f..6201a483 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -25,7 +25,11 @@ from khoj.database.adapters import ( from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.prompts import help_message, no_entries_found -from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log +from khoj.processor.conversation.utils import ( + ResponseWithThought, + defilter_query, + save_to_conversation_log, +) from khoj.processor.image.generate import text_to_image from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.tools.online_search import ( @@ -726,6 +730,16 @@ async def chat( ttft = time.perf_counter() - start_time elif event_type == ChatEvent.STATUS: train_of_thought.append({"type": event_type.value, "data": data}) + elif event_type == ChatEvent.THOUGHT: + # Append the data to the last thought as thoughts are streamed + if ( + len(train_of_thought) > 0 + and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value + and type(train_of_thought[-1]["data"]) == type(data) == str + ): + train_of_thought[-1]["data"] += data + else: + train_of_thought.append({"type": event_type.value, "data": data}) if event_type == ChatEvent.MESSAGE: yield data @@ -1306,10 +1320,6 @@ async def chat( tracer, ) - # Send Response - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result - continue_stream = True async for item in llm_response: # Should not happen with async generator, end is signaled by loop exit. Skip. @@ -1318,8 +1328,18 @@ async def chat( if not connection_alive or not continue_stream: # Drain the generator if disconnected but keep processing internally continue + message = item.response if isinstance(item, ResponseWithThought) else item + if isinstance(item, ResponseWithThought) and item.thought: + async for result in send_event(ChatEvent.THOUGHT, item.thought): + yield result + continue + + # Start sending response + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result + try: - async for result in send_event(ChatEvent.MESSAGE, f"{item}"): + async for result in send_event(ChatEvent.MESSAGE, message): yield result except Exception as e: continue_stream = False diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a0baffb9..5d881ce5 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -93,6 +93,7 @@ from khoj.processor.conversation.openai.gpt import ( ) from khoj.processor.conversation.utils import ( ChatEvent, + ResponseWithThought, clean_json, clean_mermaidjs, construct_chat_history, @@ -1432,9 +1433,9 @@ async def agenerate_chat_response( generated_asset_results: Dict[str, Dict] = {}, is_subscribed: bool = False, tracer: dict = {}, -) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]: +) -> Tuple[AsyncGenerator[str | ResponseWithThought, None], Dict[str, str]]: # Initialize Variables - chat_response_generator = None + chat_response_generator: AsyncGenerator[str | ResponseWithThought, None] = None logger.debug(f"Conversation Types: {conversation_commands}") metadata = {}