Handle thinking by reasoning models. Show in train of thought on web client

This commit is contained in:
Debanjum
2025-05-02 06:41:50 -06:00
parent d10dcc83d4
commit 16f3c85dde
6 changed files with 69 additions and 26 deletions

View File

@@ -97,6 +97,17 @@ export function processMessageChunk(
console.log(`status: ${chunk.data}`);
const statusMessage = chunk.data as string;
currentMessage.trainOfThought.push(statusMessage);
} else if (chunk.type === "thought") {
const thoughtChunk = chunk.data as string;
const lastThoughtIndex = currentMessage.trainOfThought.length - 1;
const previousThought =
lastThoughtIndex >= 0 ? currentMessage.trainOfThought[lastThoughtIndex] : "";
// If the last train of thought started with "Thinking: " append the new thought chunk to it
if (previousThought.startsWith("**Thinking:** ")) {
currentMessage.trainOfThought[lastThoughtIndex] += thoughtChunk;
} else {
currentMessage.trainOfThought.push(`**Thinking:** ${thoughtChunk}`);
}
} else if (chunk.type === "references") {
const references = chunk.data as RawReferenceData;

View File

@@ -17,6 +17,7 @@ from khoj.processor.conversation.openai.utils import (
)
from khoj.processor.conversation.utils import (
JsonSupport,
ResponseWithThought,
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
@@ -188,7 +189,7 @@ async def converse_openai(
program_execution_context: List[str] = None,
deepthought: Optional[bool] = False,
tracer: dict = {},
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[ResponseWithThought, None]:
"""
Converse with user using OpenAI's ChatGPT
"""
@@ -273,7 +274,8 @@ async def converse_openai(
model_kwargs={"stop": ["Notes:\n["]},
tracer=tracer,
):
full_response += chunk
if chunk.response:
full_response += chunk.response
yield chunk
# Call completion_func once finish streaming and we have the full response

View File

@@ -25,7 +25,11 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
from khoj.processor.conversation.utils import (
JsonSupport,
ResponseWithThought,
commit_conversation_trace,
)
from khoj.utils.helpers import (
get_chat_usage_metrics,
get_openai_async_client,
@@ -99,10 +103,7 @@ def completion_with_backoff(
**model_kwargs,
) as chat:
for chunk in stream_processor(chat):
if chunk.type == "error":
logger.error(f"Openai api response error: {chunk.error}", exc_info=True)
continue
elif chunk.type == "content.delta":
if chunk.type == "content.delta":
aggregated_response += chunk.delta
elif chunk.type == "thought.delta":
pass
@@ -149,7 +150,7 @@ async def chat_completion_with_backoff(
deepthought=False,
model_kwargs: dict = {},
tracer: dict = {},
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[ResponseWithThought, None]:
try:
client_key = f"{openai_api_key}--{api_base_url}"
client = openai_async_clients.get(client_key)
@@ -224,18 +225,19 @@ async def chat_completion_with_backoff(
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Handle streamed response chunk
# Skip empty chunks
if len(chunk.choices) == 0:
continue
delta_chunk = chunk.choices[0].delta
text_chunk = ""
if isinstance(delta_chunk, str):
text_chunk = delta_chunk
elif delta_chunk and delta_chunk.content:
text_chunk = delta_chunk.content
if text_chunk:
aggregated_response += text_chunk
yield text_chunk
# Handle streamed response chunk
response_chunk: ResponseWithThought = None
response_delta = chunk.choices[0].delta
if response_delta.content:
response_chunk = ResponseWithThought(response=response_delta.content)
aggregated_response += response_chunk.response
elif response_delta.thought:
response_chunk = ResponseWithThought(thought=response_delta.thought)
if response_chunk:
yield response_chunk
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")

View File

@@ -191,6 +191,7 @@ class ChatEvent(Enum):
REFERENCES = "references"
GENERATED_ASSETS = "generated_assets"
STATUS = "status"
THOUGHT = "thought"
METADATA = "metadata"
USAGE = "usage"
END_RESPONSE = "end_response"
@@ -873,3 +874,9 @@ class JsonSupport(int, Enum):
NONE = 0
OBJECT = 1
SCHEMA = 2
class ResponseWithThought:
def __init__(self, response: str = None, thought: str = None):
self.response = response
self.thought = thought

View File

@@ -25,7 +25,11 @@ from khoj.database.adapters import (
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
from khoj.processor.conversation.utils import (
ResponseWithThought,
defilter_query,
save_to_conversation_log,
)
from khoj.processor.image.generate import text_to_image
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import (
@@ -726,6 +730,16 @@ async def chat(
ttft = time.perf_counter() - start_time
elif event_type == ChatEvent.STATUS:
train_of_thought.append({"type": event_type.value, "data": data})
elif event_type == ChatEvent.THOUGHT:
# Append the data to the last thought as thoughts are streamed
if (
len(train_of_thought) > 0
and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
and type(train_of_thought[-1]["data"]) == type(data) == str
):
train_of_thought[-1]["data"] += data
else:
train_of_thought.append({"type": event_type.value, "data": data})
if event_type == ChatEvent.MESSAGE:
yield data
@@ -1306,10 +1320,6 @@ async def chat(
tracer,
)
# Send Response
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result
continue_stream = True
async for item in llm_response:
# Should not happen with async generator, end is signaled by loop exit. Skip.
@@ -1318,8 +1328,18 @@ async def chat(
if not connection_alive or not continue_stream:
# Drain the generator if disconnected but keep processing internally
continue
message = item.response if isinstance(item, ResponseWithThought) else item
if isinstance(item, ResponseWithThought) and item.thought:
async for result in send_event(ChatEvent.THOUGHT, item.thought):
yield result
continue
# Start sending response
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result
try:
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
async for result in send_event(ChatEvent.MESSAGE, message):
yield result
except Exception as e:
continue_stream = False

View File

@@ -93,6 +93,7 @@ from khoj.processor.conversation.openai.gpt import (
)
from khoj.processor.conversation.utils import (
ChatEvent,
ResponseWithThought,
clean_json,
clean_mermaidjs,
construct_chat_history,
@@ -1432,9 +1433,9 @@ async def agenerate_chat_response(
generated_asset_results: Dict[str, Dict] = {},
is_subscribed: bool = False,
tracer: dict = {},
) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]:
) -> Tuple[AsyncGenerator[str | ResponseWithThought, None], Dict[str, str]]:
# Initialize Variables
chat_response_generator = None
chat_response_generator: AsyncGenerator[str | ResponseWithThought, None] = None
logger.debug(f"Conversation Types: {conversation_commands}")
metadata = {}