From 1c52a6993f9c0f73b33d4f2bda100790ec40ba0f Mon Sep 17 00:00:00 2001 From: sabaimran Date: Tue, 1 Aug 2023 00:23:17 -0700 Subject: [PATCH] add a lock around chat operations to prevent the offline model from getting bombarded and stealing a bunch of compute resources - This also solves #367 --- .../processor/conversation/gpt4all/chat_model.py | 15 ++++++++++++--- src/khoj/utils/state.py | 1 + 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index 7b7ff31d..199f6e44 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -10,6 +10,7 @@ from gpt4all import GPT4All from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context from khoj.processor.conversation import prompts from khoj.utils.constants import empty_escape_sequences +from khoj.utils import state logger = logging.getLogger(__name__) @@ -58,7 +59,11 @@ def extract_questions_offline( next_christmas_date=next_christmas_date, ) message = system_prompt + example_questions - response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=128) + state.chat_lock.acquire() + try: + response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=128) + finally: + state.chat_lock.release() # Extract, Clean Message from GPT's Response try: @@ -162,6 +167,10 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All): templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content) prompted_message = templated_system_message + chat_history + templated_user_message response_iterator = model.generate(prompted_message, streaming=True, max_tokens=1000, n_batch=256) - for response in response_iterator: - g.send(response) + state.chat_lock.acquire() + try: + for response in response_iterator: + g.send(response) + finally: + state.chat_lock.release() g.close() diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 40b3daae..5e6baeae 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -25,6 +25,7 @@ port: int = None cli_args: List[str] = None query_cache = LRU() config_lock = threading.Lock() +chat_lock = threading.Lock() SearchType = utils_config.SearchType telemetry: List[Dict[str, str]] = [] previous_query: str = None