diff --git a/pyproject.toml b/pyproject.toml index 693415d4..fbfa7dac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,8 +62,8 @@ dependencies = [ "pymupdf >= 1.23.5", "django == 4.2.7", "authlib == 1.2.1", - "gpt4all >= 2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'", - "gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'", + "gpt4all >= 2.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'", + "gpt4all >= 2.1.0; platform_system == 'Windows' or platform_system == 'Darwin'", "itsdangerous == 2.1.2", "httpx == 0.25.0", "pgvector == 0.2.4", diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 23a77bb2..3e0f5380 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -123,9 +123,9 @@ def filter_questions(questions: List[str]): def converse_offline( - references, - online_results, user_query, + references=[], + online_results=[], conversation_log={}, model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf", loaded_model: Union[Any, None] = None, diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index 3a1862f7..9a2223c6 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -21,9 +21,11 @@ def download_model(model_name: str): # Try load chat model to GPU if: # 1. Loading chat model to GPU isn't disabled via CLI and # 2. Machine has GPU - # 3. GPU has enough free memory to load the chat model + # 3. GPU has enough free memory to load the chat model with max context length of 4096 device = ( - "gpu" if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"]) else "cpu" + "gpu" + if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"], 4096) + else "cpu" ) except ValueError: device = "cpu" @@ -35,7 +37,7 @@ def download_model(model_name: str): raise e # Now load the downloaded chat model onto appropriate device - chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False) + chat_model = gpt4all.GPT4All(model_name=model_name, n_ctx=4096, device=device, allow_download=False) logger.debug(f"Loaded chat model to {device.upper()}.") return chat_model diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 575f094c..4cf670fa 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -6,6 +6,7 @@ import os import time import uuid from typing import Any, Dict, List, Optional, Union +from urllib.parse import unquote from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile @@ -704,6 +705,7 @@ async def chat( rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), ) -> Response: user: KhojUser = request.user.object + q = unquote(q) await is_ready_to_chat(user) conversation_command = get_conversation_command(query=q, any_references=True) diff --git a/tests/test_gpt4all_chat_director.py b/tests/test_gpt4all_chat_director.py index 28bc3a8f..0173ff7b 100644 --- a/tests/test_gpt4all_chat_director.py +++ b/tests/test_gpt4all_chat_director.py @@ -1,3 +1,4 @@ +import os import urllib.parse from urllib.parse import quote @@ -53,6 +54,7 @@ def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_c # ---------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.getenv("SERPER_DEV_API_KEY") is None, reason="requires SERPER_DEV_API_KEY") @pytest.mark.chatquality @pytest.mark.django_db(transaction=True) def test_chat_with_online_content(chat_client):