mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
add a lock around chat operations to prevent the offline model from getting bombarded and stealing a bunch of compute resources
- This also solves #367
This commit is contained in:
@@ -10,6 +10,7 @@ from gpt4all import GPT4All
|
|||||||
from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context
|
from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.utils.constants import empty_escape_sequences
|
from khoj.utils.constants import empty_escape_sequences
|
||||||
|
from khoj.utils import state
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -58,7 +59,11 @@ def extract_questions_offline(
|
|||||||
next_christmas_date=next_christmas_date,
|
next_christmas_date=next_christmas_date,
|
||||||
)
|
)
|
||||||
message = system_prompt + example_questions
|
message = system_prompt + example_questions
|
||||||
|
state.chat_lock.acquire()
|
||||||
|
try:
|
||||||
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=128)
|
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=128)
|
||||||
|
finally:
|
||||||
|
state.chat_lock.release()
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
try:
|
try:
|
||||||
@@ -162,6 +167,10 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
|
|||||||
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
|
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
|
||||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=1000, n_batch=256)
|
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=1000, n_batch=256)
|
||||||
|
state.chat_lock.acquire()
|
||||||
|
try:
|
||||||
for response in response_iterator:
|
for response in response_iterator:
|
||||||
g.send(response)
|
g.send(response)
|
||||||
|
finally:
|
||||||
|
state.chat_lock.release()
|
||||||
g.close()
|
g.close()
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ port: int = None
|
|||||||
cli_args: List[str] = None
|
cli_args: List[str] = None
|
||||||
query_cache = LRU()
|
query_cache = LRU()
|
||||||
config_lock = threading.Lock()
|
config_lock = threading.Lock()
|
||||||
|
chat_lock = threading.Lock()
|
||||||
SearchType = utils_config.SearchType
|
SearchType = utils_config.SearchType
|
||||||
telemetry: List[Dict[str, str]] = []
|
telemetry: List[Dict[str, str]] = []
|
||||||
previous_query: str = None
|
previous_query: str = None
|
||||||
|
|||||||
Reference in New Issue
Block a user