Refactor Offline chat response to stream async, with separate thread

This commit is contained in:
Debanjum
2025-04-20 03:42:04 +05:30
parent 932a9615ef
commit 763fa2fa79
2 changed files with 85 additions and 37 deletions

View File

@@ -1,9 +1,10 @@
import json
import asyncio
import logging
import os
from datetime import datetime, timedelta
from threading import Thread
from typing import Any, Dict, Iterator, List, Optional, Union
from time import perf_counter
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import pyjson5
from langchain.schema import ChatMessage
@@ -13,7 +14,6 @@ from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
ThreadedGenerator,
clean_json,
commit_conversation_trace,
generate_chatml_messages_with_context,
@@ -147,7 +147,7 @@ def filter_questions(questions: List[str]):
return list(filtered_questions)
def converse_offline(
async def converse_offline(
user_query,
references=[],
online_results={},
@@ -167,9 +167,9 @@ def converse_offline(
additional_context: List[str] = None,
generated_asset_results: Dict[str, Dict] = {},
tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]:
) -> AsyncGenerator[str, None]:
"""
Converse with user using Llama
Converse with user using Llama (Async Version)
"""
# Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
@@ -200,10 +200,17 @@ def converse_offline(
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
return iter([prompts.no_notes_found.format()])
response = prompts.no_notes_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
response = prompts.no_online_results_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
context_message = ""
if not is_none_or_empty(references):
@@ -240,33 +247,77 @@ def converse_offline(
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
t.start()
return g
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
# Use asyncio.Queue and a thread to bridge sync iterator
queue: asyncio.Queue = asyncio.Queue()
stop_phrases = ["<s>", "INST]", "Notes:"]
aggregated_response = ""
aggregated_response_container = {"response": ""}
state.chat_lock.acquire()
try:
response_iterator = send_message_to_model_offline(
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
)
for response in response_iterator:
response_delta = response["choices"][0]["delta"].get("content", "")
aggregated_response += response_delta
g.send(response_delta)
def _sync_llm_thread():
"""Synchronous function to run in a separate thread."""
aggregated_response = ""
start_time = perf_counter()
state.chat_lock.acquire()
try:
response_iterator = send_message_to_model_offline(
messages,
loaded_model=offline_chat_model,
stop=stop_phrases,
max_prompt_size=max_prompt_size,
streaming=True,
tracer=tracer,
)
for response in response_iterator:
response_delta = response["choices"][0]["delta"].get("content", "")
# Log the time taken to start response
if aggregated_response == "" and response_delta != "":
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Handle response chunk
aggregated_response += response_delta
# Put chunk into the asyncio queue (non-blocking)
try:
queue.put_nowait(response_delta)
except asyncio.QueueFull:
# Should not happen with default queue size unless consumer is very slow
logger.warning("Asyncio queue full during offline LLM streaming.")
# Potentially block here or handle differently if needed
asyncio.run(queue.put(response_delta))
# Save conversation trace
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
finally:
state.chat_lock.release()
g.close()
# Save conversation trace
tracer["chat_model"] = model_name
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
finally:
state.chat_lock.release()
# Signal end of stream
queue.put_nowait(None)
aggregated_response_container["response"] = aggregated_response
# Start the synchronous thread
thread = Thread(target=_sync_llm_thread)
thread.start()
# Asynchronously consume from the queue
while True:
chunk = await queue.get()
if chunk is None: # End of stream signal
queue.task_done()
break
yield chunk
queue.task_done()
# Wait for the thread to finish (optional, ensures cleanup)
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, thread.join)
# Call the completion function after streaming is done
if completion_func:
await completion_func(chat_response=aggregated_response_container["response"])
def send_message_to_model_offline(

View File

@@ -93,7 +93,6 @@ from khoj.processor.conversation.openai.gpt import (
)
from khoj.processor.conversation.utils import (
ChatEvent,
ThreadedGenerator,
clean_json,
clean_mermaidjs,
construct_chat_history,
@@ -1480,16 +1479,14 @@ async def agenerate_chat_response(
vision_available = True
if chat_model.model_type == "offline":
# Assuming converse_offline remains sync or is refactored separately
loaded_model = state.offline_chat_processor_config.loaded_model
# If converse_offline returns an iterator, wrap it if needed, or refactor it to async generator
chat_response_generator = converse_offline( # Needs adaptation if it becomes async
chat_response_generator = converse_offline(
user_query=query_to_run,
references=compiled_references,
online_results=online_results,
loaded_model=loaded_model,
conversation_log=meta_log,
completion_func=partial_completion, # Pass the async wrapper
completion_func=partial_completion,
conversation_commands=conversation_commands,
model_name=chat_model.name,
max_prompt_size=chat_model.max_prompt_size,