Wrap common API query parameters into shared class to deduplicate code

- Upgrade FastAPI to >= latest version. Required upgrade of FastAPI.
  Earlier version didn't support wrapping common query params in class

- Use per fixture app instead of a global FastAPI app in conftest

- Upgrade minimum required Django version

- Fix no notes chat director test with updated no notes message
  No notes message was updated in commit 118f1143
This commit is contained in:
Debanjum Singh Solanky
2023-11-17 18:22:45 -08:00
parent 68ac1e0193
commit ca87b4ede9
5 changed files with 38 additions and 50 deletions

View File

@@ -4,7 +4,7 @@ import math
import time
import logging
import json
from typing import List, Optional, Union, Any
from typing import Annotated, List, Optional, Union, Any
# External Packages
from fastapi import APIRouter, Depends, HTTPException, Header, Request
@@ -31,6 +31,7 @@ from khoj.utils import state, constants
from khoj.utils.helpers import AsyncIteratorWrapper, get_device
from fastapi.responses import StreamingResponse, Response
from khoj.routers.helpers import (
CommonQueryParams,
get_conversation_command,
validate_conversation_config,
agenerate_chat_response,
@@ -354,15 +355,12 @@ def get_config_types(
async def search(
q: str,
request: Request,
common: CommonQueryParams,
n: Optional[int] = 5,
t: Optional[SearchType] = SearchType.All,
r: Optional[bool] = False,
max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
user = request.user.object
start_time = time.time()
@@ -466,10 +464,7 @@ async def search(
request=request,
telemetry_type="api",
api="search",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
**common.__dict__,
)
end_time = time.time()
@@ -482,12 +477,9 @@ async def search(
@requires(["authenticated"])
def update(
request: Request,
common: CommonQueryParams,
t: Optional[SearchType] = None,
force: Optional[bool] = False,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
user = request.user.object
if not state.config:
@@ -513,10 +505,7 @@ def update(
request=request,
telemetry_type="api",
api="update",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
**common.__dict__,
)
return {"status": "ok", "message": "khoj reloaded"}
@@ -526,10 +515,7 @@ def update(
@requires(["authenticated"])
def chat_history(
request: Request,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
common: CommonQueryParams,
):
user = request.user.object
validate_conversation_config()
@@ -541,10 +527,7 @@ def chat_history(
request=request,
telemetry_type="api",
api="chat",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
**common.__dict__,
)
return {"status": "ok", "response": meta_log.get("chat", [])}
@@ -554,10 +537,7 @@ def chat_history(
@requires(["authenticated"])
async def chat_options(
request: Request,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
common: CommonQueryParams,
) -> Response:
cmd_options = {}
for cmd in ConversationCommand:
@@ -567,10 +547,7 @@ async def chat_options(
request=request,
telemetry_type="api",
api="chat_options",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
**common.__dict__,
)
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
@@ -579,14 +556,11 @@ async def chat_options(
@requires(["authenticated"])
async def chat(
request: Request,
common: CommonQueryParams,
q: str,
n: Optional[int] = 5,
d: Optional[float] = 0.18,
client: Optional[str] = None,
stream: Optional[bool] = False,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
) -> Response:
@@ -600,7 +574,7 @@ async def chat(
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, meta_log, q, (n or 5), (d or math.inf), conversation_command
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
)
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
@@ -634,11 +608,8 @@ async def chat(
request=request,
telemetry_type="api",
api="chat",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
metadata=chat_metadata,
**common.__dict__,
)
if llm_response is None:
@@ -665,6 +636,7 @@ async def chat(
async def extract_references_and_questions(
request: Request,
common: CommonQueryParams,
meta_log: dict,
q: str,
n: int,
@@ -731,6 +703,7 @@ async def extract_references_and_questions(
r=True,
max_distance=d,
dedupe=False,
common=common,
)
)
# Dedupe the results again, as duplicates may be returned across queries.