mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Refactor Offline chat response to stream async, with separate thread
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
from time import perf_counter
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
@@ -13,7 +14,6 @@ from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
clean_json,
|
||||
commit_conversation_trace,
|
||||
generate_chatml_messages_with_context,
|
||||
@@ -147,7 +147,7 @@ def filter_questions(questions: List[str]):
|
||||
return list(filtered_questions)
|
||||
|
||||
|
||||
def converse_offline(
|
||||
async def converse_offline(
|
||||
user_query,
|
||||
references=[],
|
||||
online_results={},
|
||||
@@ -167,9 +167,9 @@ def converse_offline(
|
||||
additional_context: List[str] = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
tracer: dict = {},
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Converse with user using Llama
|
||||
Converse with user using Llama (Async Version)
|
||||
"""
|
||||
# Initialize Variables
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
@@ -200,10 +200,17 @@ def converse_offline(
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
response = prompts.no_notes_found.format()
|
||||
if completion_func:
|
||||
await completion_func(chat_response=response)
|
||||
yield response
|
||||
return
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
response = prompts.no_online_results_found.format()
|
||||
if completion_func:
|
||||
await completion_func(chat_response=response)
|
||||
yield response
|
||||
return
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
@@ -240,33 +247,77 @@ def converse_offline(
|
||||
|
||||
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
|
||||
|
||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
|
||||
# Use asyncio.Queue and a thread to bridge sync iterator
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||
aggregated_response = ""
|
||||
aggregated_response_container = {"response": ""}
|
||||
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
response_iterator = send_message_to_model_offline(
|
||||
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
||||
)
|
||||
for response in response_iterator:
|
||||
response_delta = response["choices"][0]["delta"].get("content", "")
|
||||
aggregated_response += response_delta
|
||||
g.send(response_delta)
|
||||
def _sync_llm_thread():
|
||||
"""Synchronous function to run in a separate thread."""
|
||||
aggregated_response = ""
|
||||
start_time = perf_counter()
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
response_iterator = send_message_to_model_offline(
|
||||
messages,
|
||||
loaded_model=offline_chat_model,
|
||||
stop=stop_phrases,
|
||||
max_prompt_size=max_prompt_size,
|
||||
streaming=True,
|
||||
tracer=tracer,
|
||||
)
|
||||
for response in response_iterator:
|
||||
response_delta = response["choices"][0]["delta"].get("content", "")
|
||||
# Log the time taken to start response
|
||||
if aggregated_response == "" and response_delta != "":
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Handle response chunk
|
||||
aggregated_response += response_delta
|
||||
# Put chunk into the asyncio queue (non-blocking)
|
||||
try:
|
||||
queue.put_nowait(response_delta)
|
||||
except asyncio.QueueFull:
|
||||
# Should not happen with default queue size unless consumer is very slow
|
||||
logger.warning("Asyncio queue full during offline LLM streaming.")
|
||||
# Potentially block here or handle differently if needed
|
||||
asyncio.run(queue.put(response_delta))
|
||||
|
||||
# Save conversation trace
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
g.close()
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
# Signal end of stream
|
||||
queue.put_nowait(None)
|
||||
aggregated_response_container["response"] = aggregated_response
|
||||
|
||||
# Start the synchronous thread
|
||||
thread = Thread(target=_sync_llm_thread)
|
||||
thread.start()
|
||||
|
||||
# Asynchronously consume from the queue
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None: # End of stream signal
|
||||
queue.task_done()
|
||||
break
|
||||
yield chunk
|
||||
queue.task_done()
|
||||
|
||||
# Wait for the thread to finish (optional, ensures cleanup)
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, thread.join)
|
||||
|
||||
# Call the completion function after streaming is done
|
||||
if completion_func:
|
||||
await completion_func(chat_response=aggregated_response_container["response"])
|
||||
|
||||
|
||||
def send_message_to_model_offline(
|
||||
|
||||
@@ -93,7 +93,6 @@ from khoj.processor.conversation.openai.gpt import (
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
ThreadedGenerator,
|
||||
clean_json,
|
||||
clean_mermaidjs,
|
||||
construct_chat_history,
|
||||
@@ -1480,16 +1479,14 @@ async def agenerate_chat_response(
|
||||
vision_available = True
|
||||
|
||||
if chat_model.model_type == "offline":
|
||||
# Assuming converse_offline remains sync or is refactored separately
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
# If converse_offline returns an iterator, wrap it if needed, or refactor it to async generator
|
||||
chat_response_generator = converse_offline( # Needs adaptation if it becomes async
|
||||
chat_response_generator = converse_offline(
|
||||
user_query=query_to_run,
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
completion_func=partial_completion, # Pass the async wrapper
|
||||
completion_func=partial_completion,
|
||||
conversation_commands=conversation_commands,
|
||||
model_name=chat_model.name,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
|
||||
Reference in New Issue
Block a user