mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Refactor Offline chat response to stream async, with separate thread
This commit is contained in:
@@ -1,9 +1,10 @@
|
|||||||
import json
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
from time import perf_counter
|
||||||
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||||
|
|
||||||
import pyjson5
|
import pyjson5
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
@@ -13,7 +14,6 @@ from khoj.database.models import Agent, ChatModel, KhojUser
|
|||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.offline.utils import download_model
|
from khoj.processor.conversation.offline.utils import download_model
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ThreadedGenerator,
|
|
||||||
clean_json,
|
clean_json,
|
||||||
commit_conversation_trace,
|
commit_conversation_trace,
|
||||||
generate_chatml_messages_with_context,
|
generate_chatml_messages_with_context,
|
||||||
@@ -147,7 +147,7 @@ def filter_questions(questions: List[str]):
|
|||||||
return list(filtered_questions)
|
return list(filtered_questions)
|
||||||
|
|
||||||
|
|
||||||
def converse_offline(
|
async def converse_offline(
|
||||||
user_query,
|
user_query,
|
||||||
references=[],
|
references=[],
|
||||||
online_results={},
|
online_results={},
|
||||||
@@ -167,9 +167,9 @@ def converse_offline(
|
|||||||
additional_context: List[str] = None,
|
additional_context: List[str] = None,
|
||||||
generated_asset_results: Dict[str, Dict] = {},
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Llama
|
Converse with user using Llama (Async Version)
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
@@ -200,10 +200,17 @@ def converse_offline(
|
|||||||
|
|
||||||
# Get Conversation Primer appropriate to Conversation Type
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||||
return iter([prompts.no_notes_found.format()])
|
response = prompts.no_notes_found.format()
|
||||||
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
response = prompts.no_online_results_found.format()
|
||||||
return iter([prompts.no_online_results_found.format()])
|
if completion_func:
|
||||||
|
await completion_func(chat_response=response)
|
||||||
|
yield response
|
||||||
|
return
|
||||||
|
|
||||||
context_message = ""
|
context_message = ""
|
||||||
if not is_none_or_empty(references):
|
if not is_none_or_empty(references):
|
||||||
@@ -240,33 +247,77 @@ def converse_offline(
|
|||||||
|
|
||||||
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
|
||||||
|
|
||||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
# Use asyncio.Queue and a thread to bridge sync iterator
|
||||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
t.start()
|
|
||||||
return g
|
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
|
|
||||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||||
aggregated_response = ""
|
aggregated_response_container = {"response": ""}
|
||||||
|
|
||||||
|
def _sync_llm_thread():
|
||||||
|
"""Synchronous function to run in a separate thread."""
|
||||||
|
aggregated_response = ""
|
||||||
|
start_time = perf_counter()
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
try:
|
try:
|
||||||
response_iterator = send_message_to_model_offline(
|
response_iterator = send_message_to_model_offline(
|
||||||
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
messages,
|
||||||
|
loaded_model=offline_chat_model,
|
||||||
|
stop=stop_phrases,
|
||||||
|
max_prompt_size=max_prompt_size,
|
||||||
|
streaming=True,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
for response in response_iterator:
|
for response in response_iterator:
|
||||||
response_delta = response["choices"][0]["delta"].get("content", "")
|
response_delta = response["choices"][0]["delta"].get("content", "")
|
||||||
|
# Log the time taken to start response
|
||||||
|
if aggregated_response == "" and response_delta != "":
|
||||||
|
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
# Handle response chunk
|
||||||
aggregated_response += response_delta
|
aggregated_response += response_delta
|
||||||
g.send(response_delta)
|
# Put chunk into the asyncio queue (non-blocking)
|
||||||
|
try:
|
||||||
|
queue.put_nowait(response_delta)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
# Should not happen with default queue size unless consumer is very slow
|
||||||
|
logger.warning("Asyncio queue full during offline LLM streaming.")
|
||||||
|
# Potentially block here or handle differently if needed
|
||||||
|
asyncio.run(queue.put(response_delta))
|
||||||
|
|
||||||
|
# Log the time taken to stream the entire response
|
||||||
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
|
tracer["chat_model"] = model_name
|
||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
state.chat_lock.release()
|
state.chat_lock.release()
|
||||||
g.close()
|
# Signal end of stream
|
||||||
|
queue.put_nowait(None)
|
||||||
|
aggregated_response_container["response"] = aggregated_response
|
||||||
|
|
||||||
|
# Start the synchronous thread
|
||||||
|
thread = Thread(target=_sync_llm_thread)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# Asynchronously consume from the queue
|
||||||
|
while True:
|
||||||
|
chunk = await queue.get()
|
||||||
|
if chunk is None: # End of stream signal
|
||||||
|
queue.task_done()
|
||||||
|
break
|
||||||
|
yield chunk
|
||||||
|
queue.task_done()
|
||||||
|
|
||||||
|
# Wait for the thread to finish (optional, ensures cleanup)
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
await loop.run_in_executor(None, thread.join)
|
||||||
|
|
||||||
|
# Call the completion function after streaming is done
|
||||||
|
if completion_func:
|
||||||
|
await completion_func(chat_response=aggregated_response_container["response"])
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_model_offline(
|
def send_message_to_model_offline(
|
||||||
|
|||||||
@@ -93,7 +93,6 @@ from khoj.processor.conversation.openai.gpt import (
|
|||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
ThreadedGenerator,
|
|
||||||
clean_json,
|
clean_json,
|
||||||
clean_mermaidjs,
|
clean_mermaidjs,
|
||||||
construct_chat_history,
|
construct_chat_history,
|
||||||
@@ -1480,16 +1479,14 @@ async def agenerate_chat_response(
|
|||||||
vision_available = True
|
vision_available = True
|
||||||
|
|
||||||
if chat_model.model_type == "offline":
|
if chat_model.model_type == "offline":
|
||||||
# Assuming converse_offline remains sync or is refactored separately
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
# If converse_offline returns an iterator, wrap it if needed, or refactor it to async generator
|
chat_response_generator = converse_offline(
|
||||||
chat_response_generator = converse_offline( # Needs adaptation if it becomes async
|
|
||||||
user_query=query_to_run,
|
user_query=query_to_run,
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
completion_func=partial_completion, # Pass the async wrapper
|
completion_func=partial_completion,
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
model_name=chat_model.name,
|
model_name=chat_model.name,
|
||||||
max_prompt_size=chat_model.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user