From ca87b4ede969d52880643147a7761f0c79d23d7e Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 17 Nov 2023 18:22:45 -0800 Subject: [PATCH] Wrap common API query parameters into shared class to deduplicate code - Upgrade FastAPI to >= latest version. Required upgrade of FastAPI. Earlier version didn't support wrapping common query params in class - Use per fixture app instead of a global FastAPI app in conftest - Upgrade minimum required Django version - Fix no notes chat director test with updated no notes message No notes message was updated in commit 118f1143 --- pyproject.toml | 4 +-- src/khoj/routers/api.py | 57 ++++++++---------------------- src/khoj/routers/helpers.py | 21 +++++++++-- tests/conftest.py | 4 +-- tests/test_openai_chat_director.py | 2 +- 5 files changed, 38 insertions(+), 50 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 15a4c8e2..a457aec6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "bs4 >= 0.0.1", "dateparser >= 1.1.1", "defusedxml == 0.7.1", - "fastapi == 0.77.1", + "fastapi >= 0.104.1", "python-multipart >= 0.0.5", "jinja2 == 3.1.2", "openai >= 0.27.0, < 1.0.0", @@ -60,7 +60,7 @@ dependencies = [ "bs4 >= 0.0.1", "anyio == 3.7.1", "pymupdf >= 1.23.5", - "django == 4.2.5", + "django == 4.2.7", "authlib == 1.2.1", "gpt4all >= 2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'", "gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'", diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 85fba38c..be2643bd 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -4,7 +4,7 @@ import math import time import logging import json -from typing import List, Optional, Union, Any +from typing import Annotated, List, Optional, Union, Any # External Packages from fastapi import APIRouter, Depends, HTTPException, Header, Request @@ -31,6 +31,7 @@ from khoj.utils import state, constants from khoj.utils.helpers import AsyncIteratorWrapper, get_device from fastapi.responses import StreamingResponse, Response from khoj.routers.helpers import ( + CommonQueryParams, get_conversation_command, validate_conversation_config, agenerate_chat_response, @@ -354,15 +355,12 @@ def get_config_types( async def search( q: str, request: Request, + common: CommonQueryParams, n: Optional[int] = 5, t: Optional[SearchType] = SearchType.All, r: Optional[bool] = False, max_distance: Optional[Union[float, None]] = None, dedupe: Optional[bool] = True, - client: Optional[str] = None, - user_agent: Optional[str] = Header(None), - referer: Optional[str] = Header(None), - host: Optional[str] = Header(None), ): user = request.user.object start_time = time.time() @@ -466,10 +464,7 @@ async def search( request=request, telemetry_type="api", api="search", - client=client, - user_agent=user_agent, - referer=referer, - host=host, + **common.__dict__, ) end_time = time.time() @@ -482,12 +477,9 @@ async def search( @requires(["authenticated"]) def update( request: Request, + common: CommonQueryParams, t: Optional[SearchType] = None, force: Optional[bool] = False, - client: Optional[str] = None, - user_agent: Optional[str] = Header(None), - referer: Optional[str] = Header(None), - host: Optional[str] = Header(None), ): user = request.user.object if not state.config: @@ -513,10 +505,7 @@ def update( request=request, telemetry_type="api", api="update", - client=client, - user_agent=user_agent, - referer=referer, - host=host, + **common.__dict__, ) return {"status": "ok", "message": "khoj reloaded"} @@ -526,10 +515,7 @@ def update( @requires(["authenticated"]) def chat_history( request: Request, - client: Optional[str] = None, - user_agent: Optional[str] = Header(None), - referer: Optional[str] = Header(None), - host: Optional[str] = Header(None), + common: CommonQueryParams, ): user = request.user.object validate_conversation_config() @@ -541,10 +527,7 @@ def chat_history( request=request, telemetry_type="api", api="chat", - client=client, - user_agent=user_agent, - referer=referer, - host=host, + **common.__dict__, ) return {"status": "ok", "response": meta_log.get("chat", [])} @@ -554,10 +537,7 @@ def chat_history( @requires(["authenticated"]) async def chat_options( request: Request, - client: Optional[str] = None, - user_agent: Optional[str] = Header(None), - referer: Optional[str] = Header(None), - host: Optional[str] = Header(None), + common: CommonQueryParams, ) -> Response: cmd_options = {} for cmd in ConversationCommand: @@ -567,10 +547,7 @@ async def chat_options( request=request, telemetry_type="api", api="chat_options", - client=client, - user_agent=user_agent, - referer=referer, - host=host, + **common.__dict__, ) return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200) @@ -579,14 +556,11 @@ async def chat_options( @requires(["authenticated"]) async def chat( request: Request, + common: CommonQueryParams, q: str, n: Optional[int] = 5, d: Optional[float] = 0.18, - client: Optional[str] = None, stream: Optional[bool] = False, - user_agent: Optional[str] = Header(None), - referer: Optional[str] = Header(None), - host: Optional[str] = Header(None), rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)), rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)), ) -> Response: @@ -600,7 +574,7 @@ async def chat( meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( - request, meta_log, q, (n or 5), (d or math.inf), conversation_command + request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command ) if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references): @@ -634,11 +608,8 @@ async def chat( request=request, telemetry_type="api", api="chat", - client=client, - user_agent=user_agent, - referer=referer, - host=host, metadata=chat_metadata, + **common.__dict__, ) if llm_response is None: @@ -665,6 +636,7 @@ async def chat( async def extract_references_and_questions( request: Request, + common: CommonQueryParams, meta_log: dict, q: str, n: int, @@ -731,6 +703,7 @@ async def extract_references_and_questions( r=True, max_distance=d, dedupe=False, + common=common, ) ) # Dedupe the results again, as duplicates may be returned across queries. diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index b52098e7..272f962d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -6,10 +6,10 @@ from datetime import datetime from functools import partial import logging from time import time -from typing import Iterator, List, Optional, Union, Tuple, Dict +from typing import Annotated, Iterator, List, Optional, Union, Tuple, Dict # External Packages -from fastapi import HTTPException, Request +from fastapi import HTTPException, Header, Request, Depends # Internal Packages from khoj.utils import state @@ -221,3 +221,20 @@ class ApiUserRateLimiter: # Add the current request to the cache user_requests.append(time()) + + +class CommonQueryParamsClass: + def __init__( + self, + client: Optional[str] = None, + user_agent: Optional[str] = Header(None), + referer: Optional[str] = Header(None), + host: Optional[str] = Header(None), + ): + self.client = client + self.user_agent = user_agent + self.referer = referer + self.host = host + + +CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()] diff --git a/tests/conftest.py b/tests/conftest.py index d90bae95..e7a73d6c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,9 +9,6 @@ import os from fastapi import FastAPI -app = FastAPI() - - # Internal Packages from khoj.configure import configure_routes, configure_search_types, configure_middleware from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel @@ -320,6 +317,7 @@ def client( state.anonymous_mode = False + app = FastAPI() configure_routes(app) configure_middleware(app) app.mount("/static", StaticFiles(directory=web_directory), name="static") diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index a8c85787..07c4e0d8 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -227,7 +227,7 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default # Assert assert response.status_code == 200 - assert response_message == prompts.no_notes_found.format() + assert response_message == prompts.no_entries_found.format() # ----------------------------------------------------------------------------------------------------