From 8054bdc89616679811b17fb83728a2c027425a37 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 31 Jul 2023 23:25:08 -0700 Subject: [PATCH] Use n_batch parameter to increase resource consumption on host machine (and implicitly engage GPU) --- src/khoj/processor/conversation/gpt4all/chat_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index 1047ac3b..7b7ff31d 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -58,7 +58,7 @@ def extract_questions_offline( next_christmas_date=next_christmas_date, ) message = system_prompt + example_questions - response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0) + response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=128) # Extract, Clean Message from GPT's Response try: @@ -161,7 +161,7 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All): templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content) templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content) prompted_message = templated_system_message + chat_history + templated_user_message - response_iterator = model.generate(prompted_message, streaming=True, max_tokens=1000) + response_iterator = model.generate(prompted_message, streaming=True, max_tokens=1000, n_batch=256) for response in response_iterator: g.send(response) g.close()