From 763fa2fa794e15c5ba4ac17216bd0b46267ad994 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sun, 20 Apr 2025 03:42:04 +0530 Subject: [PATCH] Refactor Offline chat response to stream async, with separate thread --- .../conversation/offline/chat_model.py | 115 +++++++++++++----- src/khoj/routers/helpers.py | 7 +- 2 files changed, 85 insertions(+), 37 deletions(-) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index f727fd1d..b7f89c8d 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -1,9 +1,10 @@ -import json +import asyncio import logging import os from datetime import datetime, timedelta from threading import Thread -from typing import Any, Dict, Iterator, List, Optional, Union +from time import perf_counter +from typing import Any, AsyncGenerator, Dict, List, Optional, Union import pyjson5 from langchain.schema import ChatMessage @@ -13,7 +14,6 @@ from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( - ThreadedGenerator, clean_json, commit_conversation_trace, generate_chatml_messages_with_context, @@ -147,7 +147,7 @@ def filter_questions(questions: List[str]): return list(filtered_questions) -def converse_offline( +async def converse_offline( user_query, references=[], online_results={}, @@ -167,9 +167,9 @@ def converse_offline( additional_context: List[str] = None, generated_asset_results: Dict[str, Dict] = {}, tracer: dict = {}, -) -> Union[ThreadedGenerator, Iterator[str]]: +) -> AsyncGenerator[str, None]: """ - Converse with user using Llama + Converse with user using Llama (Async Version) """ # Initialize Variables assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" @@ -200,10 +200,17 @@ def converse_offline( # Get Conversation Primer appropriate to Conversation Type if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references): - return iter([prompts.no_notes_found.format()]) + response = prompts.no_notes_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): - completion_func(chat_response=prompts.no_online_results_found.format()) - return iter([prompts.no_online_results_found.format()]) + response = prompts.no_online_results_found.format() + if completion_func: + await completion_func(chat_response=response) + yield response + return context_message = "" if not is_none_or_empty(references): @@ -240,33 +247,77 @@ def converse_offline( logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}") - g = ThreadedGenerator(references, online_results, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer)) - t.start() - return g - - -def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}): + # Use asyncio.Queue and a thread to bridge sync iterator + queue: asyncio.Queue = asyncio.Queue() stop_phrases = ["", "INST]", "Notes:"] - aggregated_response = "" + aggregated_response_container = {"response": ""} - state.chat_lock.acquire() - try: - response_iterator = send_message_to_model_offline( - messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True - ) - for response in response_iterator: - response_delta = response["choices"][0]["delta"].get("content", "") - aggregated_response += response_delta - g.send(response_delta) + def _sync_llm_thread(): + """Synchronous function to run in a separate thread.""" + aggregated_response = "" + start_time = perf_counter() + state.chat_lock.acquire() + try: + response_iterator = send_message_to_model_offline( + messages, + loaded_model=offline_chat_model, + stop=stop_phrases, + max_prompt_size=max_prompt_size, + streaming=True, + tracer=tracer, + ) + for response in response_iterator: + response_delta = response["choices"][0]["delta"].get("content", "") + # Log the time taken to start response + if aggregated_response == "" and response_delta != "": + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Handle response chunk + aggregated_response += response_delta + # Put chunk into the asyncio queue (non-blocking) + try: + queue.put_nowait(response_delta) + except asyncio.QueueFull: + # Should not happen with default queue size unless consumer is very slow + logger.warning("Asyncio queue full during offline LLM streaming.") + # Potentially block here or handle differently if needed + asyncio.run(queue.put(response_delta)) - # Save conversation trace - if is_promptrace_enabled(): - commit_conversation_trace(messages, aggregated_response, tracer) + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") - finally: - state.chat_lock.release() - g.close() + # Save conversation trace + tracer["chat_model"] = model_name + if is_promptrace_enabled(): + commit_conversation_trace(messages, aggregated_response, tracer) + + except Exception as e: + logger.error(f"Error in offline LLM thread: {e}", exc_info=True) + finally: + state.chat_lock.release() + # Signal end of stream + queue.put_nowait(None) + aggregated_response_container["response"] = aggregated_response + + # Start the synchronous thread + thread = Thread(target=_sync_llm_thread) + thread.start() + + # Asynchronously consume from the queue + while True: + chunk = await queue.get() + if chunk is None: # End of stream signal + queue.task_done() + break + yield chunk + queue.task_done() + + # Wait for the thread to finish (optional, ensures cleanup) + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, thread.join) + + # Call the completion function after streaming is done + if completion_func: + await completion_func(chat_response=aggregated_response_container["response"]) def send_message_to_model_offline( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 3d154059..a0baffb9 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -93,7 +93,6 @@ from khoj.processor.conversation.openai.gpt import ( ) from khoj.processor.conversation.utils import ( ChatEvent, - ThreadedGenerator, clean_json, clean_mermaidjs, construct_chat_history, @@ -1480,16 +1479,14 @@ async def agenerate_chat_response( vision_available = True if chat_model.model_type == "offline": - # Assuming converse_offline remains sync or is refactored separately loaded_model = state.offline_chat_processor_config.loaded_model - # If converse_offline returns an iterator, wrap it if needed, or refactor it to async generator - chat_response_generator = converse_offline( # Needs adaptation if it becomes async + chat_response_generator = converse_offline( user_query=query_to_run, references=compiled_references, online_results=online_results, loaded_model=loaded_model, conversation_log=meta_log, - completion_func=partial_completion, # Pass the async wrapper + completion_func=partial_completion, conversation_commands=conversation_commands, model_name=chat_model.name, max_prompt_size=chat_model.max_prompt_size,