diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py index 585df6a6..d5201780 100644 --- a/src/khoj/processor/conversation/gpt4all/utils.py +++ b/src/khoj/processor/conversation/gpt4all/utils.py @@ -11,4 +11,12 @@ def download_model(model_name: str): logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") raise e - return GPT4All(model_name=model_name) + # Use GPU for Chat Model, if available + try: + model = GPT4All(model_name=model_name, device="gpu") + logger.debug("Loaded chat model to GPU.") + except ValueError: + model = GPT4All(model_name=model_name) + logger.debug("Loaded chat model to CPU.") + + return model