diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 104426a3..d7a92a20 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -299,6 +299,11 @@ class ApiUserRateLimiter: self.cache: dict[str, list[float]] = defaultdict(list) def __call__(self, request: Request): + # Rate limiting is disabled if user unauthenticated. + # Other systems handle authentication + if not request.user.is_authenticated: + return + user: KhojUser = request.user.object subscribed = has_required_scope(request, ["premium"]) user_requests = self.cache[user.uuid] diff --git a/tests/conftest.py b/tests/conftest.py index f7756d99..a7ff1512 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -254,51 +254,45 @@ def md_content_config(): @pytest.fixture(scope="function") def chat_client(search_config: SearchConfig, default_user2: KhojUser): - # Initialize app state - state.config.search_type = search_config - state.SearchType = configure_search_types() + return chat_client_builder(search_config, default_user2, require_auth=False) - LocalMarkdownConfig.objects.create( - input_files=None, - input_filter=["tests/data/markdown/*.markdown"], - user=default_user2, - ) - # Index Markdown Content for Search - all_files = fs_syncer.collect_files(user=default_user2) - state.content_index, _ = configure_content( - state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2 - ) - - # Initialize Processor from Config - if os.getenv("OPENAI_API_KEY"): - chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai") - OpenAIProcessorConversationConfigFactory() - UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model) - - state.anonymous_mode = True - - app = FastAPI() - - configure_routes(app) - configure_middleware(app) - app.mount("/static", StaticFiles(directory=web_directory), name="static") - return TestClient(app) +@pytest.fixture(scope="function") +def chat_client_with_auth(search_config: SearchConfig, default_user2: KhojUser): + return chat_client_builder(search_config, default_user2, require_auth=True) @pytest.fixture(scope="function") def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser): + return chat_client_builder(search_config, default_user2, index_content=False, require_auth=False) + + +@pytest.mark.django_db +def chat_client_builder(search_config, user, index_content=True, require_auth=False): # Initialize app state state.config.search_type = search_config state.SearchType = configure_search_types() + if index_content: + LocalMarkdownConfig.objects.create( + input_files=None, + input_filter=["tests/data/markdown/*.markdown"], + user=user, + ) + + # Index Markdown Content for Search + all_files = fs_syncer.collect_files(user=user) + state.content_index, _ = configure_content( + state.content_index, state.config.content_type, all_files, state.search_models, user=user + ) + # Initialize Processor from Config if os.getenv("OPENAI_API_KEY"): chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai") OpenAIProcessorConversationConfigFactory() - UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model) + UserConversationProcessorConfigFactory(user=user, setting=chat_model) - state.anonymous_mode = True + state.anonymous_mode = not require_auth app = FastAPI() diff --git a/tests/test_client.py b/tests/test_client.py index 0bc3c02f..3954254a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -461,6 +461,20 @@ def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiU assert response.json() == [] +@pytest.mark.django_db(transaction=True) +def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser): + # Arrange + headers = {"Authorization": f"Bearer {api_user2.token}"} + + # Act + auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true', headers=headers) + no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true') + + # Assert + assert auth_response.status_code == 200 + assert no_auth_response.status_code == 403 + + def get_sample_files_data(): return [ ("files", ("path/to/filename.org", "* practicing piano", "text/org")),