diff --git a/src/interface/web/app/common/chatFunctions.ts b/src/interface/web/app/common/chatFunctions.ts index 9ccd1316..4823659a 100644 --- a/src/interface/web/app/common/chatFunctions.ts +++ b/src/interface/web/app/common/chatFunctions.ts @@ -97,6 +97,17 @@ export function processMessageChunk( console.log(`status: ${chunk.data}`); const statusMessage = chunk.data as string; currentMessage.trainOfThought.push(statusMessage); + } else if (chunk.type === "thought") { + const thoughtChunk = chunk.data as string; + const lastThoughtIndex = currentMessage.trainOfThought.length - 1; + const previousThought = + lastThoughtIndex >= 0 ? currentMessage.trainOfThought[lastThoughtIndex] : ""; + // If the last train of thought started with "Thinking: " append the new thought chunk to it + if (previousThought.startsWith("**Thinking:** ")) { + currentMessage.trainOfThought[lastThoughtIndex] += thoughtChunk; + } else { + currentMessage.trainOfThought.push(`**Thinking:** ${thoughtChunk}`); + } } else if (chunk.type === "references") { const references = chunk.data as RawReferenceData; diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 4808bc60..65b2d83f 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -17,6 +17,7 @@ from khoj.processor.conversation.openai.utils import ( ) from khoj.processor.conversation.utils import ( JsonSupport, + ResponseWithThought, clean_json, construct_structured_message, generate_chatml_messages_with_context, @@ -188,7 +189,7 @@ async def converse_openai( program_execution_context: List[str] = None, deepthought: Optional[bool] = False, tracer: dict = {}, -) -> AsyncGenerator[str, None]: +) -> AsyncGenerator[ResponseWithThought, None]: """ Converse with user using OpenAI's ChatGPT """ @@ -273,7 +274,8 @@ async def converse_openai( model_kwargs={"stop": ["Notes:\n["]}, tracer=tracer, ): - full_response += chunk + if chunk.response: + full_response += chunk.response yield chunk # Call completion_func once finish streaming and we have the full response diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index d4ad9105..bec9997b 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -25,7 +25,11 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace +from khoj.processor.conversation.utils import ( + JsonSupport, + ResponseWithThought, + commit_conversation_trace, +) from khoj.utils.helpers import ( get_chat_usage_metrics, get_openai_async_client, @@ -99,10 +103,7 @@ def completion_with_backoff( **model_kwargs, ) as chat: for chunk in stream_processor(chat): - if chunk.type == "error": - logger.error(f"Openai api response error: {chunk.error}", exc_info=True) - continue - elif chunk.type == "content.delta": + if chunk.type == "content.delta": aggregated_response += chunk.delta elif chunk.type == "thought.delta": pass @@ -149,7 +150,7 @@ async def chat_completion_with_backoff( deepthought=False, model_kwargs: dict = {}, tracer: dict = {}, -) -> AsyncGenerator[str, None]: +) -> AsyncGenerator[ResponseWithThought, None]: try: client_key = f"{openai_api_key}--{api_base_url}" client = openai_async_clients.get(client_key) @@ -224,18 +225,19 @@ async def chat_completion_with_backoff( logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") # Keep track of the last chunk for usage data final_chunk = chunk - # Handle streamed response chunk + # Skip empty chunks if len(chunk.choices) == 0: continue - delta_chunk = chunk.choices[0].delta - text_chunk = "" - if isinstance(delta_chunk, str): - text_chunk = delta_chunk - elif delta_chunk and delta_chunk.content: - text_chunk = delta_chunk.content - if text_chunk: - aggregated_response += text_chunk - yield text_chunk + # Handle streamed response chunk + response_chunk: ResponseWithThought = None + response_delta = chunk.choices[0].delta + if response_delta.content: + response_chunk = ResponseWithThought(response=response_delta.content) + aggregated_response += response_chunk.response + elif response_delta.thought: + response_chunk = ResponseWithThought(thought=response_delta.thought) + if response_chunk: + yield response_chunk # Log the time taken to stream the entire response logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 9601a5cd..e86834f9 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -191,6 +191,7 @@ class ChatEvent(Enum): REFERENCES = "references" GENERATED_ASSETS = "generated_assets" STATUS = "status" + THOUGHT = "thought" METADATA = "metadata" USAGE = "usage" END_RESPONSE = "end_response" @@ -873,3 +874,9 @@ class JsonSupport(int, Enum): NONE = 0 OBJECT = 1 SCHEMA = 2 + + +class ResponseWithThought: + def __init__(self, response: str = None, thought: str = None): + self.response = response + self.thought = thought diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index a15f788f..6201a483 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -25,7 +25,11 @@ from khoj.database.adapters import ( from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.prompts import help_message, no_entries_found -from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log +from khoj.processor.conversation.utils import ( + ResponseWithThought, + defilter_query, + save_to_conversation_log, +) from khoj.processor.image.generate import text_to_image from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.tools.online_search import ( @@ -726,6 +730,16 @@ async def chat( ttft = time.perf_counter() - start_time elif event_type == ChatEvent.STATUS: train_of_thought.append({"type": event_type.value, "data": data}) + elif event_type == ChatEvent.THOUGHT: + # Append the data to the last thought as thoughts are streamed + if ( + len(train_of_thought) > 0 + and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value + and type(train_of_thought[-1]["data"]) == type(data) == str + ): + train_of_thought[-1]["data"] += data + else: + train_of_thought.append({"type": event_type.value, "data": data}) if event_type == ChatEvent.MESSAGE: yield data @@ -1306,10 +1320,6 @@ async def chat( tracer, ) - # Send Response - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result - continue_stream = True async for item in llm_response: # Should not happen with async generator, end is signaled by loop exit. Skip. @@ -1318,8 +1328,18 @@ async def chat( if not connection_alive or not continue_stream: # Drain the generator if disconnected but keep processing internally continue + message = item.response if isinstance(item, ResponseWithThought) else item + if isinstance(item, ResponseWithThought) and item.thought: + async for result in send_event(ChatEvent.THOUGHT, item.thought): + yield result + continue + + # Start sending response + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result + try: - async for result in send_event(ChatEvent.MESSAGE, f"{item}"): + async for result in send_event(ChatEvent.MESSAGE, message): yield result except Exception as e: continue_stream = False diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a0baffb9..5d881ce5 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -93,6 +93,7 @@ from khoj.processor.conversation.openai.gpt import ( ) from khoj.processor.conversation.utils import ( ChatEvent, + ResponseWithThought, clean_json, clean_mermaidjs, construct_chat_history, @@ -1432,9 +1433,9 @@ async def agenerate_chat_response( generated_asset_results: Dict[str, Dict] = {}, is_subscribed: bool = False, tracer: dict = {}, -) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]: +) -> Tuple[AsyncGenerator[str | ResponseWithThought, None], Dict[str, str]]: # Initialize Variables - chat_response_generator = None + chat_response_generator: AsyncGenerator[str | ResponseWithThought, None] = None logger.debug(f"Conversation Types: {conversation_commands}") metadata = {}