mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 05:40:17 +00:00
Parse and show reasoning model thoughts (#1172)
### Major All reasoning models return thoughts differently due to lack of standardization. We normalize thoughts by reasoning models and providers to ease handling within Khoj. The model thoughts are parsed during research mode when generating final response. These model thoughts are returned by the chat API and shown in train of thought shown on web app. Thoughts are enabled for Deepseek, Anthropic, Grok and Qwen3 reasoning models served via API. Gemini and Openai reasoning models do not show their thoughts via standard APIs. ### Minor - Fix ability to use Deepseek reasoner for intermediate stages of chat - Enable handling Qwen3 reasoning models
This commit is contained in:
@@ -97,6 +97,17 @@ export function processMessageChunk(
|
||||
console.log(`status: ${chunk.data}`);
|
||||
const statusMessage = chunk.data as string;
|
||||
currentMessage.trainOfThought.push(statusMessage);
|
||||
} else if (chunk.type === "thought") {
|
||||
const thoughtChunk = chunk.data as string;
|
||||
const lastThoughtIndex = currentMessage.trainOfThought.length - 1;
|
||||
const previousThought =
|
||||
lastThoughtIndex >= 0 ? currentMessage.trainOfThought[lastThoughtIndex] : "";
|
||||
// If the last train of thought started with "Thinking: " append the new thought chunk to it
|
||||
if (previousThought.startsWith("**Thinking:** ")) {
|
||||
currentMessage.trainOfThought[lastThoughtIndex] += thoughtChunk;
|
||||
} else {
|
||||
currentMessage.trainOfThought.push(`**Thinking:** ${thoughtChunk}`);
|
||||
}
|
||||
} else if (chunk.type === "references") {
|
||||
const references = chunk.data as RawReferenceData;
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from khoj.processor.conversation.anthropic.utils import (
|
||||
format_messages_for_anthropic,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
ResponseWithThought,
|
||||
clean_json,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
@@ -162,7 +163,7 @@ async def converse_anthropic(
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
deepthought: Optional[bool] = False,
|
||||
tracer: dict = {},
|
||||
) -> AsyncGenerator[str, None]:
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
"""
|
||||
Converse with user using Anthropic's Claude
|
||||
"""
|
||||
@@ -247,7 +248,8 @@ async def converse_anthropic(
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
):
|
||||
full_response += chunk
|
||||
if chunk.response:
|
||||
full_response += chunk.response
|
||||
yield chunk
|
||||
|
||||
# Call completion_func once finish streaming and we have the full response
|
||||
|
||||
@@ -13,6 +13,7 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import (
|
||||
ResponseWithThought,
|
||||
commit_conversation_trace,
|
||||
get_image_from_base64,
|
||||
get_image_from_url,
|
||||
@@ -154,13 +155,23 @@ async def anthropic_chat_completion_with_backoff(
|
||||
max_tokens=max_tokens,
|
||||
**model_kwargs,
|
||||
) as stream:
|
||||
async for text in stream.text_stream:
|
||||
async for chunk in stream:
|
||||
# Log the time taken to start response
|
||||
if aggregated_response == "":
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Skip empty chunks
|
||||
if chunk.type != "content_block_delta":
|
||||
continue
|
||||
# Handle streamed response chunk
|
||||
aggregated_response += text
|
||||
yield text
|
||||
response_chunk: ResponseWithThought = None
|
||||
if chunk.delta.type == "text_delta":
|
||||
response_chunk = ResponseWithThought(response=chunk.delta.text)
|
||||
aggregated_response += chunk.delta.text
|
||||
if chunk.delta.type == "thinking_delta":
|
||||
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
|
||||
# Handle streamed response chunk
|
||||
if response_chunk:
|
||||
yield response_chunk
|
||||
final_message = await stream.get_final_message()
|
||||
|
||||
# Log the time taken to stream the entire response
|
||||
|
||||
@@ -17,6 +17,7 @@ from khoj.processor.conversation.openai.utils import (
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
JsonSupport,
|
||||
ResponseWithThought,
|
||||
clean_json,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
@@ -188,7 +189,7 @@ async def converse_openai(
|
||||
program_execution_context: List[str] = None,
|
||||
deepthought: Optional[bool] = False,
|
||||
tracer: dict = {},
|
||||
) -> AsyncGenerator[str, None]:
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
"""
|
||||
Converse with user using OpenAI's ChatGPT
|
||||
"""
|
||||
@@ -273,7 +274,8 @@ async def converse_openai(
|
||||
model_kwargs={"stop": ["Notes:\n["]},
|
||||
tracer=tracer,
|
||||
):
|
||||
full_response += chunk
|
||||
if chunk.response:
|
||||
full_response += chunk.response
|
||||
yield chunk
|
||||
|
||||
# Call completion_func once finish streaming and we have the full response
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
import logging
|
||||
import os
|
||||
from functools import partial
|
||||
from time import perf_counter
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import openai
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.lib.streaming.chat import (
|
||||
ChatCompletionStream,
|
||||
ChatCompletionStreamEvent,
|
||||
ContentDeltaEvent,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk,
|
||||
Choice,
|
||||
ChoiceDelta,
|
||||
)
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
@@ -16,7 +25,11 @@ from tenacity import (
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
|
||||
from khoj.processor.conversation.utils import (
|
||||
JsonSupport,
|
||||
ResponseWithThought,
|
||||
commit_conversation_trace,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
get_chat_usage_metrics,
|
||||
get_openai_async_client,
|
||||
@@ -59,6 +72,7 @@ def completion_with_backoff(
|
||||
client = get_openai_client(openai_api_key, api_base_url)
|
||||
openai_clients[client_key] = client
|
||||
|
||||
stream_processor = default_stream_processor
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
|
||||
# Tune reasoning models arguments
|
||||
@@ -69,6 +83,24 @@ def completion_with_backoff(
|
||||
elif is_twitter_reasoning_model(model_name, api_base_url):
|
||||
reasoning_effort = "high" if deepthought else "low"
|
||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||
elif model_name.startswith("deepseek-reasoner"):
|
||||
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
|
||||
# The first message should always be a user message (except system message).
|
||||
updated_messages: List[dict] = []
|
||||
for i, message in enumerate(formatted_messages):
|
||||
if i > 0 and message["role"] == formatted_messages[i - 1]["role"]:
|
||||
updated_messages[-1]["content"] += " " + message["content"]
|
||||
elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant":
|
||||
updated_messages[-1]["content"] += " " + message["content"]
|
||||
else:
|
||||
updated_messages.append(message)
|
||||
formatted_messages = updated_messages
|
||||
elif is_qwen_reasoning_model(model_name, api_base_url):
|
||||
stream_processor = partial(in_stream_thought_processor, thought_tag="think")
|
||||
# Reasoning is enabled by default. Disable when deepthought is False.
|
||||
# See https://qwenlm.github.io/blog/qwen3/#advanced-usages
|
||||
if not deepthought and len(formatted_messages) > 0:
|
||||
formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think"
|
||||
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
@@ -82,12 +114,11 @@ def completion_with_backoff(
|
||||
timeout=20,
|
||||
**model_kwargs,
|
||||
) as chat:
|
||||
for chunk in chat:
|
||||
if chunk.type == "error":
|
||||
logger.error(f"Openai api response error: {chunk.error}", exc_info=True)
|
||||
continue
|
||||
elif chunk.type == "content.delta":
|
||||
for chunk in stream_processor(chat):
|
||||
if chunk.type == "content.delta":
|
||||
aggregated_response += chunk.delta
|
||||
elif chunk.type == "thought.delta":
|
||||
pass
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||
@@ -124,14 +155,14 @@ def completion_with_backoff(
|
||||
)
|
||||
async def chat_completion_with_backoff(
|
||||
messages,
|
||||
model_name,
|
||||
model_name: str,
|
||||
temperature,
|
||||
openai_api_key=None,
|
||||
api_base_url=None,
|
||||
deepthought=False,
|
||||
model_kwargs: dict = {},
|
||||
tracer: dict = {},
|
||||
) -> AsyncGenerator[str, None]:
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
try:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client = openai_async_clients.get(client_key)
|
||||
@@ -139,6 +170,7 @@ async def chat_completion_with_backoff(
|
||||
client = get_openai_async_client(openai_api_key, api_base_url)
|
||||
openai_async_clients[client_key] = client
|
||||
|
||||
stream_processor = adefault_stream_processor
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
|
||||
# Configure thinking for openai reasoning models
|
||||
@@ -161,9 +193,11 @@ async def chat_completion_with_backoff(
|
||||
"content"
|
||||
] = f"{first_system_message_content}\nFormatting re-enabled"
|
||||
elif is_twitter_reasoning_model(model_name, api_base_url):
|
||||
stream_processor = adeepseek_stream_processor
|
||||
reasoning_effort = "high" if deepthought else "low"
|
||||
model_kwargs["reasoning_effort"] = reasoning_effort
|
||||
elif model_name.startswith("deepseek-reasoner"):
|
||||
stream_processor = adeepseek_stream_processor
|
||||
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
|
||||
# The first message should always be a user message (except system message).
|
||||
updated_messages: List[dict] = []
|
||||
@@ -174,8 +208,13 @@ async def chat_completion_with_backoff(
|
||||
updated_messages[-1]["content"] += " " + message["content"]
|
||||
else:
|
||||
updated_messages.append(message)
|
||||
|
||||
formatted_messages = updated_messages
|
||||
elif is_qwen_reasoning_model(model_name, api_base_url):
|
||||
stream_processor = partial(ain_stream_thought_processor, thought_tag="think")
|
||||
# Reasoning is enabled by default. Disable when deepthought is False.
|
||||
# See https://qwenlm.github.io/blog/qwen3/#advanced-usages
|
||||
if not deepthought and len(formatted_messages) > 0:
|
||||
formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think"
|
||||
|
||||
stream = True
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
@@ -193,24 +232,25 @@ async def chat_completion_with_backoff(
|
||||
timeout=20,
|
||||
**model_kwargs,
|
||||
)
|
||||
async for chunk in chat_stream:
|
||||
async for chunk in stream_processor(chat_stream):
|
||||
# Log the time taken to start response
|
||||
if final_chunk is None:
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Keep track of the last chunk for usage data
|
||||
final_chunk = chunk
|
||||
# Handle streamed response chunk
|
||||
# Skip empty chunks
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta_chunk = chunk.choices[0].delta
|
||||
text_chunk = ""
|
||||
if isinstance(delta_chunk, str):
|
||||
text_chunk = delta_chunk
|
||||
elif delta_chunk and delta_chunk.content:
|
||||
text_chunk = delta_chunk.content
|
||||
if text_chunk:
|
||||
aggregated_response += text_chunk
|
||||
yield text_chunk
|
||||
# Handle streamed response chunk
|
||||
response_chunk: ResponseWithThought = None
|
||||
response_delta = chunk.choices[0].delta
|
||||
if response_delta.content:
|
||||
response_chunk = ResponseWithThought(response=response_delta.content)
|
||||
aggregated_response += response_chunk.response
|
||||
elif response_delta.thought:
|
||||
response_chunk = ResponseWithThought(thought=response_delta.thought)
|
||||
if response_chunk:
|
||||
yield response_chunk
|
||||
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
@@ -264,3 +304,274 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo
|
||||
and api_base_url is not None
|
||||
and api_base_url.startswith("https://api.x.ai/v1")
|
||||
)
|
||||
|
||||
|
||||
def is_qwen_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
|
||||
"""
|
||||
Check if the model is a Qwen reasoning model
|
||||
"""
|
||||
return "qwen3" in model_name.lower() and api_base_url is not None
|
||||
|
||||
|
||||
class ThoughtDeltaEvent(ContentDeltaEvent):
|
||||
"""
|
||||
Chat completion chunk with thoughts, reasoning support.
|
||||
"""
|
||||
|
||||
type: Literal["thought.delta"]
|
||||
"""The thought or reasoning generated by the model."""
|
||||
|
||||
|
||||
ChatCompletionStreamWithThoughtEvent = Union[ChatCompletionStreamEvent, ThoughtDeltaEvent]
|
||||
|
||||
|
||||
class ChoiceDeltaWithThoughts(ChoiceDelta):
|
||||
"""
|
||||
Chat completion chunk with thoughts, reasoning support.
|
||||
"""
|
||||
|
||||
thought: Optional[str] = None
|
||||
"""The thought or reasoning generated by the model."""
|
||||
|
||||
|
||||
class ChoiceWithThoughts(Choice):
|
||||
delta: ChoiceDeltaWithThoughts
|
||||
|
||||
|
||||
class ChatCompletionWithThoughtsChunk(ChatCompletionChunk):
|
||||
choices: List[ChoiceWithThoughts] # Override the choices type
|
||||
|
||||
|
||||
def default_stream_processor(
|
||||
chat_stream: ChatCompletionStream,
|
||||
) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]:
|
||||
"""
|
||||
Async generator to cast and return chunks from the standard openai chat completions stream.
|
||||
"""
|
||||
for chunk in chat_stream:
|
||||
yield chunk
|
||||
|
||||
|
||||
async def adefault_stream_processor(
|
||||
chat_stream: openai.AsyncStream[ChatCompletionChunk],
|
||||
) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]:
|
||||
"""
|
||||
Async generator to cast and return chunks from the standard openai chat completions stream.
|
||||
"""
|
||||
async for chunk in chat_stream:
|
||||
yield ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump())
|
||||
|
||||
|
||||
async def adeepseek_stream_processor(
|
||||
chat_stream: openai.AsyncStream[ChatCompletionChunk],
|
||||
) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]:
|
||||
"""
|
||||
Async generator to cast and return chunks from the deepseek chat completions stream.
|
||||
"""
|
||||
async for chunk in chat_stream:
|
||||
tchunk = ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump())
|
||||
if (
|
||||
len(tchunk.choices) > 0
|
||||
and hasattr(tchunk.choices[0].delta, "reasoning_content")
|
||||
and tchunk.choices[0].delta.reasoning_content
|
||||
):
|
||||
tchunk.choices[0].delta.thought = chunk.choices[0].delta.reasoning_content
|
||||
yield tchunk
|
||||
|
||||
|
||||
def in_stream_thought_processor(
|
||||
chat_stream: openai.Stream[ChatCompletionChunk], thought_tag="think"
|
||||
) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]:
|
||||
"""
|
||||
Generator for chat completion with thought chunks.
|
||||
Assumes <thought_tag>...</thought_tag> can only appear once at the start.
|
||||
Handles partial tags across streamed chunks.
|
||||
"""
|
||||
start_tag = f"<{thought_tag}>"
|
||||
end_tag = f"</{thought_tag}>"
|
||||
buf: str = ""
|
||||
# Modes and transitions: detect_start > thought (optional) > message
|
||||
mode = "detect_start"
|
||||
|
||||
for chunk in default_stream_processor(chat_stream):
|
||||
if mode == "message" or chunk.type != "content.delta":
|
||||
# Message mode is terminal, so just yield chunks, no processing
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
buf += chunk.delta
|
||||
|
||||
if mode == "detect_start":
|
||||
# Try to determine if we start with thought tag
|
||||
if buf.startswith(start_tag):
|
||||
# Found start tag, switch mode
|
||||
buf = buf[len(start_tag) :] # Remove start tag
|
||||
mode = "thought"
|
||||
# Fall through to process the rest of the buffer in 'thought' mode *within this iteration*
|
||||
elif len(buf) >= len(start_tag):
|
||||
# Buffer is long enough, definitely doesn't start with tag
|
||||
chunk.delta = buf
|
||||
yield chunk
|
||||
mode = "message"
|
||||
buf = ""
|
||||
continue
|
||||
elif start_tag.startswith(buf):
|
||||
# Buffer is a prefix of the start tag, need more data
|
||||
continue
|
||||
else:
|
||||
# Buffer doesn't match start tag prefix and is shorter than tag
|
||||
chunk.delta = buf
|
||||
yield chunk
|
||||
mode = "message"
|
||||
buf = ""
|
||||
continue
|
||||
|
||||
if mode == "thought":
|
||||
# Look for the end tag
|
||||
idx = buf.find(end_tag)
|
||||
if idx != -1:
|
||||
# Found end tag. Yield thought content before it.
|
||||
if idx > 0 and buf[:idx].strip():
|
||||
chunk.type = "thought.delta"
|
||||
chunk.delta = buf[:idx]
|
||||
yield chunk
|
||||
# Process content *after* the tag as message
|
||||
buf = buf[idx + len(end_tag) :]
|
||||
if buf:
|
||||
chunk.delta = buf
|
||||
yield chunk
|
||||
mode = "message"
|
||||
buf = ""
|
||||
continue
|
||||
else:
|
||||
# End tag not found yet. Yield thought content, holding back potential partial end tag.
|
||||
send_upto = len(buf)
|
||||
# Check if buffer ends with a prefix of end_tag
|
||||
for i in range(len(end_tag) - 1, 0, -1):
|
||||
if buf.endswith(end_tag[:i]):
|
||||
send_upto = len(buf) - i # Don't send the partial tag yet
|
||||
break
|
||||
if send_upto > 0 and buf[:send_upto].strip():
|
||||
chunk.type = "thought.delta"
|
||||
chunk.delta = buf[:send_upto]
|
||||
yield chunk
|
||||
buf = buf[send_upto:] # Keep only the partial tag (or empty)
|
||||
# Need more data to find the complete end tag
|
||||
continue
|
||||
|
||||
# End of stream handling
|
||||
if buf:
|
||||
if mode == "thought": # Stream ended before </think> was found
|
||||
chunk.type = "thought.delta"
|
||||
chunk.delta = buf
|
||||
yield chunk
|
||||
elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied
|
||||
# If it wasn't a partial start tag, treat as message
|
||||
if not start_tag.startswith(buf):
|
||||
chunk.delta = buf
|
||||
yield chunk
|
||||
# else: discard partial <think>
|
||||
# If mode == "message", buffer should be empty due to logic above, but yield just in case
|
||||
elif mode == "message":
|
||||
chunk.delta = buf
|
||||
yield chunk
|
||||
|
||||
|
||||
async def ain_stream_thought_processor(
|
||||
chat_stream: openai.AsyncStream[ChatCompletionChunk], thought_tag="think"
|
||||
) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]:
|
||||
"""
|
||||
Async generator for chat completion with thought chunks.
|
||||
Assumes <thought_tag>...</thought_tag> can only appear once at the start.
|
||||
Handles partial tags across streamed chunks.
|
||||
"""
|
||||
start_tag = f"<{thought_tag}>"
|
||||
end_tag = f"</{thought_tag}>"
|
||||
buf: str = ""
|
||||
# Modes and transitions: detect_start > thought (optional) > message
|
||||
mode = "detect_start"
|
||||
|
||||
async for chunk in adefault_stream_processor(chat_stream):
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
if mode == "message":
|
||||
# Message mode is terminal, so just yield chunks, no processing
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
buf += chunk.choices[0].delta.content
|
||||
|
||||
if mode == "detect_start":
|
||||
# Try to determine if we start with thought tag
|
||||
if buf.startswith(start_tag):
|
||||
# Found start tag, switch mode
|
||||
buf = buf[len(start_tag) :] # Remove start tag
|
||||
mode = "thought"
|
||||
# Fall through to process the rest of the buffer in 'thought' mode *within this iteration*
|
||||
elif len(buf) >= len(start_tag):
|
||||
# Buffer is long enough, definitely doesn't start with tag
|
||||
chunk.choices[0].delta.content = buf
|
||||
yield chunk
|
||||
mode = "message"
|
||||
buf = ""
|
||||
continue
|
||||
elif start_tag.startswith(buf):
|
||||
# Buffer is a prefix of the start tag, need more data
|
||||
continue
|
||||
else:
|
||||
# Buffer doesn't match start tag prefix and is shorter than tag
|
||||
chunk.choices[0].delta.content = buf
|
||||
yield chunk
|
||||
mode = "message"
|
||||
buf = ""
|
||||
continue
|
||||
|
||||
if mode == "thought":
|
||||
# Look for the end tag
|
||||
idx = buf.find(end_tag)
|
||||
if idx != -1:
|
||||
# Found end tag. Yield thought content before it.
|
||||
if idx > 0 and buf[:idx].strip():
|
||||
chunk.choices[0].delta.thought = buf[:idx]
|
||||
chunk.choices[0].delta.content = ""
|
||||
yield chunk
|
||||
# Process content *after* the tag as message
|
||||
buf = buf[idx + len(end_tag) :]
|
||||
if buf:
|
||||
chunk.choices[0].delta.content = buf
|
||||
yield chunk
|
||||
mode = "message"
|
||||
buf = ""
|
||||
continue
|
||||
else:
|
||||
# End tag not found yet. Yield thought content, holding back potential partial end tag.
|
||||
send_upto = len(buf)
|
||||
# Check if buffer ends with a prefix of end_tag
|
||||
for i in range(len(end_tag) - 1, 0, -1):
|
||||
if buf.endswith(end_tag[:i]):
|
||||
send_upto = len(buf) - i # Don't send the partial tag yet
|
||||
break
|
||||
if send_upto > 0 and buf[:send_upto].strip():
|
||||
chunk.choices[0].delta.thought = buf[:send_upto]
|
||||
chunk.choices[0].delta.content = ""
|
||||
yield chunk
|
||||
buf = buf[send_upto:] # Keep only the partial tag (or empty)
|
||||
# Need more data to find the complete end tag
|
||||
continue
|
||||
|
||||
# End of stream handling
|
||||
if buf:
|
||||
if mode == "thought": # Stream ended before </think> was found
|
||||
chunk.choices[0].delta.thought = buf
|
||||
chunk.choices[0].delta.content = ""
|
||||
yield chunk
|
||||
elif mode == "detect_start": # Stream ended before start tag could be confirmed/denied
|
||||
# If it wasn't a partial start tag, treat as message
|
||||
if not start_tag.startswith(buf):
|
||||
chunk.choices[0].delta.content = buf
|
||||
yield chunk
|
||||
# else: discard partial <think>
|
||||
# If mode == "message", buffer should be empty due to logic above, but yield just in case
|
||||
elif mode == "message":
|
||||
chunk.choices[0].delta.content = buf
|
||||
yield chunk
|
||||
|
||||
@@ -191,6 +191,7 @@ class ChatEvent(Enum):
|
||||
REFERENCES = "references"
|
||||
GENERATED_ASSETS = "generated_assets"
|
||||
STATUS = "status"
|
||||
THOUGHT = "thought"
|
||||
METADATA = "metadata"
|
||||
USAGE = "usage"
|
||||
END_RESPONSE = "end_response"
|
||||
@@ -873,3 +874,9 @@ class JsonSupport(int, Enum):
|
||||
NONE = 0
|
||||
OBJECT = 1
|
||||
SCHEMA = 2
|
||||
|
||||
|
||||
class ResponseWithThought:
|
||||
def __init__(self, response: str = None, thought: str = None):
|
||||
self.response = response
|
||||
self.thought = thought
|
||||
|
||||
@@ -25,7 +25,11 @@ from khoj.database.adapters import (
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||
from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
|
||||
from khoj.processor.conversation.utils import (
|
||||
ResponseWithThought,
|
||||
defilter_query,
|
||||
save_to_conversation_log,
|
||||
)
|
||||
from khoj.processor.image.generate import text_to_image
|
||||
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
||||
from khoj.processor.tools.online_search import (
|
||||
@@ -726,6 +730,16 @@ async def chat(
|
||||
ttft = time.perf_counter() - start_time
|
||||
elif event_type == ChatEvent.STATUS:
|
||||
train_of_thought.append({"type": event_type.value, "data": data})
|
||||
elif event_type == ChatEvent.THOUGHT:
|
||||
# Append the data to the last thought as thoughts are streamed
|
||||
if (
|
||||
len(train_of_thought) > 0
|
||||
and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
|
||||
and type(train_of_thought[-1]["data"]) == type(data) == str
|
||||
):
|
||||
train_of_thought[-1]["data"] += data
|
||||
else:
|
||||
train_of_thought.append({"type": event_type.value, "data": data})
|
||||
|
||||
if event_type == ChatEvent.MESSAGE:
|
||||
yield data
|
||||
@@ -1306,10 +1320,6 @@ async def chat(
|
||||
tracer,
|
||||
)
|
||||
|
||||
# Send Response
|
||||
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||
yield result
|
||||
|
||||
continue_stream = True
|
||||
async for item in llm_response:
|
||||
# Should not happen with async generator, end is signaled by loop exit. Skip.
|
||||
@@ -1318,8 +1328,18 @@ async def chat(
|
||||
if not connection_alive or not continue_stream:
|
||||
# Drain the generator if disconnected but keep processing internally
|
||||
continue
|
||||
message = item.response if isinstance(item, ResponseWithThought) else item
|
||||
if isinstance(item, ResponseWithThought) and item.thought:
|
||||
async for result in send_event(ChatEvent.THOUGHT, item.thought):
|
||||
yield result
|
||||
continue
|
||||
|
||||
# Start sending response
|
||||
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||
yield result
|
||||
|
||||
try:
|
||||
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
|
||||
async for result in send_event(ChatEvent.MESSAGE, message):
|
||||
yield result
|
||||
except Exception as e:
|
||||
continue_stream = False
|
||||
|
||||
@@ -93,6 +93,7 @@ from khoj.processor.conversation.openai.gpt import (
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
ResponseWithThought,
|
||||
clean_json,
|
||||
clean_mermaidjs,
|
||||
construct_chat_history,
|
||||
@@ -1432,9 +1433,9 @@ async def agenerate_chat_response(
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
is_subscribed: bool = False,
|
||||
tracer: dict = {},
|
||||
) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]:
|
||||
) -> Tuple[AsyncGenerator[str | ResponseWithThought, None], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response_generator = None
|
||||
chat_response_generator: AsyncGenerator[str | ResponseWithThought, None] = None
|
||||
logger.debug(f"Conversation Types: {conversation_commands}")
|
||||
|
||||
metadata = {}
|
||||
|
||||
Reference in New Issue
Block a user