Refactor Offline chat response to stream async, with separate thread

This commit is contained in:
Debanjum
2025-04-20 03:42:04 +05:30
parent 932a9615ef
commit 763fa2fa79
2 changed files with 85 additions and 37 deletions

View File

@@ -1,9 +1,10 @@
import json import asyncio
import logging import logging
import os import os
from datetime import datetime, timedelta from datetime import datetime, timedelta
from threading import Thread from threading import Thread
from typing import Any, Dict, Iterator, List, Optional, Union from time import perf_counter
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
@@ -13,7 +14,6 @@ from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ThreadedGenerator,
clean_json, clean_json,
commit_conversation_trace, commit_conversation_trace,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
@@ -147,7 +147,7 @@ def filter_questions(questions: List[str]):
return list(filtered_questions) return list(filtered_questions)
def converse_offline( async def converse_offline(
user_query, user_query,
references=[], references=[],
online_results={}, online_results={},
@@ -167,9 +167,9 @@ def converse_offline(
additional_context: List[str] = None, additional_context: List[str] = None,
generated_asset_results: Dict[str, Dict] = {}, generated_asset_results: Dict[str, Dict] = {},
tracer: dict = {}, tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]: ) -> AsyncGenerator[str, None]:
""" """
Converse with user using Llama Converse with user using Llama (Async Version)
""" """
# Initialize Variables # Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
@@ -200,10 +200,17 @@ def converse_offline(
# Get Conversation Primer appropriate to Conversation Type # Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
return iter([prompts.no_notes_found.format()]) response = prompts.no_notes_found.format()
if completion_func:
await completion_func(chat_response=response)
yield response
return
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format()) response = prompts.no_online_results_found.format()
return iter([prompts.no_online_results_found.format()]) if completion_func:
await completion_func(chat_response=response)
yield response
return
context_message = "" context_message = ""
if not is_none_or_empty(references): if not is_none_or_empty(references):
@@ -240,33 +247,77 @@ def converse_offline(
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}") logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
g = ThreadedGenerator(references, online_results, completion_func=completion_func) # Use asyncio.Queue and a thread to bridge sync iterator
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer)) queue: asyncio.Queue = asyncio.Queue()
t.start()
return g
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
stop_phrases = ["<s>", "INST]", "Notes:"] stop_phrases = ["<s>", "INST]", "Notes:"]
aggregated_response = "" aggregated_response_container = {"response": ""}
def _sync_llm_thread():
"""Synchronous function to run in a separate thread."""
aggregated_response = ""
start_time = perf_counter()
state.chat_lock.acquire() state.chat_lock.acquire()
try: try:
response_iterator = send_message_to_model_offline( response_iterator = send_message_to_model_offline(
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True messages,
loaded_model=offline_chat_model,
stop=stop_phrases,
max_prompt_size=max_prompt_size,
streaming=True,
tracer=tracer,
) )
for response in response_iterator: for response in response_iterator:
response_delta = response["choices"][0]["delta"].get("content", "") response_delta = response["choices"][0]["delta"].get("content", "")
# Log the time taken to start response
if aggregated_response == "" and response_delta != "":
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Handle response chunk
aggregated_response += response_delta aggregated_response += response_delta
g.send(response_delta) # Put chunk into the asyncio queue (non-blocking)
try:
queue.put_nowait(response_delta)
except asyncio.QueueFull:
# Should not happen with default queue size unless consumer is very slow
logger.warning("Asyncio queue full during offline LLM streaming.")
# Potentially block here or handle differently if needed
asyncio.run(queue.put(response_delta))
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name
if is_promptrace_enabled(): if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
finally: finally:
state.chat_lock.release() state.chat_lock.release()
g.close() # Signal end of stream
queue.put_nowait(None)
aggregated_response_container["response"] = aggregated_response
# Start the synchronous thread
thread = Thread(target=_sync_llm_thread)
thread.start()
# Asynchronously consume from the queue
while True:
chunk = await queue.get()
if chunk is None: # End of stream signal
queue.task_done()
break
yield chunk
queue.task_done()
# Wait for the thread to finish (optional, ensures cleanup)
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, thread.join)
# Call the completion function after streaming is done
if completion_func:
await completion_func(chat_response=aggregated_response_container["response"])
def send_message_to_model_offline( def send_message_to_model_offline(

View File

@@ -93,7 +93,6 @@ from khoj.processor.conversation.openai.gpt import (
) )
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ChatEvent, ChatEvent,
ThreadedGenerator,
clean_json, clean_json,
clean_mermaidjs, clean_mermaidjs,
construct_chat_history, construct_chat_history,
@@ -1480,16 +1479,14 @@ async def agenerate_chat_response(
vision_available = True vision_available = True
if chat_model.model_type == "offline": if chat_model.model_type == "offline":
# Assuming converse_offline remains sync or is refactored separately
loaded_model = state.offline_chat_processor_config.loaded_model loaded_model = state.offline_chat_processor_config.loaded_model
# If converse_offline returns an iterator, wrap it if needed, or refactor it to async generator chat_response_generator = converse_offline(
chat_response_generator = converse_offline( # Needs adaptation if it becomes async
user_query=query_to_run, user_query=query_to_run,
references=compiled_references, references=compiled_references,
online_results=online_results, online_results=online_results,
loaded_model=loaded_model, loaded_model=loaded_model,
conversation_log=meta_log, conversation_log=meta_log,
completion_func=partial_completion, # Pass the async wrapper completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
model_name=chat_model.name, model_name=chat_model.name,
max_prompt_size=chat_model.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,