mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Handle thinking by reasoning models. Show in train of thought on web client
This commit is contained in:
@@ -97,6 +97,17 @@ export function processMessageChunk(
|
|||||||
console.log(`status: ${chunk.data}`);
|
console.log(`status: ${chunk.data}`);
|
||||||
const statusMessage = chunk.data as string;
|
const statusMessage = chunk.data as string;
|
||||||
currentMessage.trainOfThought.push(statusMessage);
|
currentMessage.trainOfThought.push(statusMessage);
|
||||||
|
} else if (chunk.type === "thought") {
|
||||||
|
const thoughtChunk = chunk.data as string;
|
||||||
|
const lastThoughtIndex = currentMessage.trainOfThought.length - 1;
|
||||||
|
const previousThought =
|
||||||
|
lastThoughtIndex >= 0 ? currentMessage.trainOfThought[lastThoughtIndex] : "";
|
||||||
|
// If the last train of thought started with "Thinking: " append the new thought chunk to it
|
||||||
|
if (previousThought.startsWith("**Thinking:** ")) {
|
||||||
|
currentMessage.trainOfThought[lastThoughtIndex] += thoughtChunk;
|
||||||
|
} else {
|
||||||
|
currentMessage.trainOfThought.push(`**Thinking:** ${thoughtChunk}`);
|
||||||
|
}
|
||||||
} else if (chunk.type === "references") {
|
} else if (chunk.type === "references") {
|
||||||
const references = chunk.data as RawReferenceData;
|
const references = chunk.data as RawReferenceData;
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from khoj.processor.conversation.openai.utils import (
|
|||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
JsonSupport,
|
JsonSupport,
|
||||||
|
ResponseWithThought,
|
||||||
clean_json,
|
clean_json,
|
||||||
construct_structured_message,
|
construct_structured_message,
|
||||||
generate_chatml_messages_with_context,
|
generate_chatml_messages_with_context,
|
||||||
@@ -188,7 +189,7 @@ async def converse_openai(
|
|||||||
program_execution_context: List[str] = None,
|
program_execution_context: List[str] = None,
|
||||||
deepthought: Optional[bool] = False,
|
deepthought: Optional[bool] = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using OpenAI's ChatGPT
|
Converse with user using OpenAI's ChatGPT
|
||||||
"""
|
"""
|
||||||
@@ -273,7 +274,8 @@ async def converse_openai(
|
|||||||
model_kwargs={"stop": ["Notes:\n["]},
|
model_kwargs={"stop": ["Notes:\n["]},
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
full_response += chunk
|
if chunk.response:
|
||||||
|
full_response += chunk.response
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
# Call completion_func once finish streaming and we have the full response
|
# Call completion_func once finish streaming and we have the full response
|
||||||
|
|||||||
@@ -25,7 +25,11 @@ from tenacity import (
|
|||||||
wait_random_exponential,
|
wait_random_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import JsonSupport, commit_conversation_trace
|
from khoj.processor.conversation.utils import (
|
||||||
|
JsonSupport,
|
||||||
|
ResponseWithThought,
|
||||||
|
commit_conversation_trace,
|
||||||
|
)
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
get_openai_async_client,
|
get_openai_async_client,
|
||||||
@@ -99,10 +103,7 @@ def completion_with_backoff(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) as chat:
|
) as chat:
|
||||||
for chunk in stream_processor(chat):
|
for chunk in stream_processor(chat):
|
||||||
if chunk.type == "error":
|
if chunk.type == "content.delta":
|
||||||
logger.error(f"Openai api response error: {chunk.error}", exc_info=True)
|
|
||||||
continue
|
|
||||||
elif chunk.type == "content.delta":
|
|
||||||
aggregated_response += chunk.delta
|
aggregated_response += chunk.delta
|
||||||
elif chunk.type == "thought.delta":
|
elif chunk.type == "thought.delta":
|
||||||
pass
|
pass
|
||||||
@@ -149,7 +150,7 @@ async def chat_completion_with_backoff(
|
|||||||
deepthought=False,
|
deepthought=False,
|
||||||
model_kwargs: dict = {},
|
model_kwargs: dict = {},
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
try:
|
try:
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
client = openai_async_clients.get(client_key)
|
client = openai_async_clients.get(client_key)
|
||||||
@@ -224,18 +225,19 @@ async def chat_completion_with_backoff(
|
|||||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||||
# Keep track of the last chunk for usage data
|
# Keep track of the last chunk for usage data
|
||||||
final_chunk = chunk
|
final_chunk = chunk
|
||||||
# Handle streamed response chunk
|
# Skip empty chunks
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
continue
|
continue
|
||||||
delta_chunk = chunk.choices[0].delta
|
# Handle streamed response chunk
|
||||||
text_chunk = ""
|
response_chunk: ResponseWithThought = None
|
||||||
if isinstance(delta_chunk, str):
|
response_delta = chunk.choices[0].delta
|
||||||
text_chunk = delta_chunk
|
if response_delta.content:
|
||||||
elif delta_chunk and delta_chunk.content:
|
response_chunk = ResponseWithThought(response=response_delta.content)
|
||||||
text_chunk = delta_chunk.content
|
aggregated_response += response_chunk.response
|
||||||
if text_chunk:
|
elif response_delta.thought:
|
||||||
aggregated_response += text_chunk
|
response_chunk = ResponseWithThought(thought=response_delta.thought)
|
||||||
yield text_chunk
|
if response_chunk:
|
||||||
|
yield response_chunk
|
||||||
|
|
||||||
# Log the time taken to stream the entire response
|
# Log the time taken to stream the entire response
|
||||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
|||||||
@@ -191,6 +191,7 @@ class ChatEvent(Enum):
|
|||||||
REFERENCES = "references"
|
REFERENCES = "references"
|
||||||
GENERATED_ASSETS = "generated_assets"
|
GENERATED_ASSETS = "generated_assets"
|
||||||
STATUS = "status"
|
STATUS = "status"
|
||||||
|
THOUGHT = "thought"
|
||||||
METADATA = "metadata"
|
METADATA = "metadata"
|
||||||
USAGE = "usage"
|
USAGE = "usage"
|
||||||
END_RESPONSE = "end_response"
|
END_RESPONSE = "end_response"
|
||||||
@@ -873,3 +874,9 @@ class JsonSupport(int, Enum):
|
|||||||
NONE = 0
|
NONE = 0
|
||||||
OBJECT = 1
|
OBJECT = 1
|
||||||
SCHEMA = 2
|
SCHEMA = 2
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseWithThought:
|
||||||
|
def __init__(self, response: str = None, thought: str = None):
|
||||||
|
self.response = response
|
||||||
|
self.thought = thought
|
||||||
|
|||||||
@@ -25,7 +25,11 @@ from khoj.database.adapters import (
|
|||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||||
from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
|
from khoj.processor.conversation.utils import (
|
||||||
|
ResponseWithThought,
|
||||||
|
defilter_query,
|
||||||
|
save_to_conversation_log,
|
||||||
|
)
|
||||||
from khoj.processor.image.generate import text_to_image
|
from khoj.processor.image.generate import text_to_image
|
||||||
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
||||||
from khoj.processor.tools.online_search import (
|
from khoj.processor.tools.online_search import (
|
||||||
@@ -726,6 +730,16 @@ async def chat(
|
|||||||
ttft = time.perf_counter() - start_time
|
ttft = time.perf_counter() - start_time
|
||||||
elif event_type == ChatEvent.STATUS:
|
elif event_type == ChatEvent.STATUS:
|
||||||
train_of_thought.append({"type": event_type.value, "data": data})
|
train_of_thought.append({"type": event_type.value, "data": data})
|
||||||
|
elif event_type == ChatEvent.THOUGHT:
|
||||||
|
# Append the data to the last thought as thoughts are streamed
|
||||||
|
if (
|
||||||
|
len(train_of_thought) > 0
|
||||||
|
and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
|
||||||
|
and type(train_of_thought[-1]["data"]) == type(data) == str
|
||||||
|
):
|
||||||
|
train_of_thought[-1]["data"] += data
|
||||||
|
else:
|
||||||
|
train_of_thought.append({"type": event_type.value, "data": data})
|
||||||
|
|
||||||
if event_type == ChatEvent.MESSAGE:
|
if event_type == ChatEvent.MESSAGE:
|
||||||
yield data
|
yield data
|
||||||
@@ -1306,10 +1320,6 @@ async def chat(
|
|||||||
tracer,
|
tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send Response
|
|
||||||
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
|
||||||
yield result
|
|
||||||
|
|
||||||
continue_stream = True
|
continue_stream = True
|
||||||
async for item in llm_response:
|
async for item in llm_response:
|
||||||
# Should not happen with async generator, end is signaled by loop exit. Skip.
|
# Should not happen with async generator, end is signaled by loop exit. Skip.
|
||||||
@@ -1318,8 +1328,18 @@ async def chat(
|
|||||||
if not connection_alive or not continue_stream:
|
if not connection_alive or not continue_stream:
|
||||||
# Drain the generator if disconnected but keep processing internally
|
# Drain the generator if disconnected but keep processing internally
|
||||||
continue
|
continue
|
||||||
|
message = item.response if isinstance(item, ResponseWithThought) else item
|
||||||
|
if isinstance(item, ResponseWithThought) and item.thought:
|
||||||
|
async for result in send_event(ChatEvent.THOUGHT, item.thought):
|
||||||
|
yield result
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Start sending response
|
||||||
|
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||||
|
yield result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for result in send_event(ChatEvent.MESSAGE, f"{item}"):
|
async for result in send_event(ChatEvent.MESSAGE, message):
|
||||||
yield result
|
yield result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
continue_stream = False
|
continue_stream = False
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ from khoj.processor.conversation.openai.gpt import (
|
|||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
|
ResponseWithThought,
|
||||||
clean_json,
|
clean_json,
|
||||||
clean_mermaidjs,
|
clean_mermaidjs,
|
||||||
construct_chat_history,
|
construct_chat_history,
|
||||||
@@ -1432,9 +1433,9 @@ async def agenerate_chat_response(
|
|||||||
generated_asset_results: Dict[str, Dict] = {},
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
is_subscribed: bool = False,
|
is_subscribed: bool = False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> Tuple[AsyncGenerator[str, None], Dict[str, str]]:
|
) -> Tuple[AsyncGenerator[str | ResponseWithThought, None], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
chat_response_generator = None
|
chat_response_generator: AsyncGenerator[str | ResponseWithThought, None] = None
|
||||||
logger.debug(f"Conversation Types: {conversation_commands}")
|
logger.debug(f"Conversation Types: {conversation_commands}")
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user