mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Wrap common API query parameters into shared class to deduplicate code
- Upgrade FastAPI to >= latest version. Required upgrade of FastAPI.
Earlier version didn't support wrapping common query params in class
- Use per fixture app instead of a global FastAPI app in conftest
- Upgrade minimum required Django version
- Fix no notes chat director test with updated no notes message
No notes message was updated in commit 118f1143
This commit is contained in:
@@ -39,7 +39,7 @@ dependencies = [
|
|||||||
"bs4 >= 0.0.1",
|
"bs4 >= 0.0.1",
|
||||||
"dateparser >= 1.1.1",
|
"dateparser >= 1.1.1",
|
||||||
"defusedxml == 0.7.1",
|
"defusedxml == 0.7.1",
|
||||||
"fastapi == 0.77.1",
|
"fastapi >= 0.104.1",
|
||||||
"python-multipart >= 0.0.5",
|
"python-multipart >= 0.0.5",
|
||||||
"jinja2 == 3.1.2",
|
"jinja2 == 3.1.2",
|
||||||
"openai >= 0.27.0, < 1.0.0",
|
"openai >= 0.27.0, < 1.0.0",
|
||||||
@@ -60,7 +60,7 @@ dependencies = [
|
|||||||
"bs4 >= 0.0.1",
|
"bs4 >= 0.0.1",
|
||||||
"anyio == 3.7.1",
|
"anyio == 3.7.1",
|
||||||
"pymupdf >= 1.23.5",
|
"pymupdf >= 1.23.5",
|
||||||
"django == 4.2.5",
|
"django == 4.2.7",
|
||||||
"authlib == 1.2.1",
|
"authlib == 1.2.1",
|
||||||
"gpt4all >= 2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
"gpt4all >= 2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||||
"gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'",
|
"gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'",
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import math
|
|||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from typing import List, Optional, Union, Any
|
from typing import Annotated, List, Optional, Union, Any
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||||
@@ -31,6 +31,7 @@ from khoj.utils import state, constants
|
|||||||
from khoj.utils.helpers import AsyncIteratorWrapper, get_device
|
from khoj.utils.helpers import AsyncIteratorWrapper, get_device
|
||||||
from fastapi.responses import StreamingResponse, Response
|
from fastapi.responses import StreamingResponse, Response
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
|
CommonQueryParams,
|
||||||
get_conversation_command,
|
get_conversation_command,
|
||||||
validate_conversation_config,
|
validate_conversation_config,
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
@@ -354,15 +355,12 @@ def get_config_types(
|
|||||||
async def search(
|
async def search(
|
||||||
q: str,
|
q: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
common: CommonQueryParams,
|
||||||
n: Optional[int] = 5,
|
n: Optional[int] = 5,
|
||||||
t: Optional[SearchType] = SearchType.All,
|
t: Optional[SearchType] = SearchType.All,
|
||||||
r: Optional[bool] = False,
|
r: Optional[bool] = False,
|
||||||
max_distance: Optional[Union[float, None]] = None,
|
max_distance: Optional[Union[float, None]] = None,
|
||||||
dedupe: Optional[bool] = True,
|
dedupe: Optional[bool] = True,
|
||||||
client: Optional[str] = None,
|
|
||||||
user_agent: Optional[str] = Header(None),
|
|
||||||
referer: Optional[str] = Header(None),
|
|
||||||
host: Optional[str] = Header(None),
|
|
||||||
):
|
):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -466,10 +464,7 @@ async def search(
|
|||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="search",
|
api="search",
|
||||||
client=client,
|
**common.__dict__,
|
||||||
user_agent=user_agent,
|
|
||||||
referer=referer,
|
|
||||||
host=host,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
@@ -482,12 +477,9 @@ async def search(
|
|||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
def update(
|
def update(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
common: CommonQueryParams,
|
||||||
t: Optional[SearchType] = None,
|
t: Optional[SearchType] = None,
|
||||||
force: Optional[bool] = False,
|
force: Optional[bool] = False,
|
||||||
client: Optional[str] = None,
|
|
||||||
user_agent: Optional[str] = Header(None),
|
|
||||||
referer: Optional[str] = Header(None),
|
|
||||||
host: Optional[str] = Header(None),
|
|
||||||
):
|
):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
if not state.config:
|
if not state.config:
|
||||||
@@ -513,10 +505,7 @@ def update(
|
|||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="update",
|
api="update",
|
||||||
client=client,
|
**common.__dict__,
|
||||||
user_agent=user_agent,
|
|
||||||
referer=referer,
|
|
||||||
host=host,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"status": "ok", "message": "khoj reloaded"}
|
return {"status": "ok", "message": "khoj reloaded"}
|
||||||
@@ -526,10 +515,7 @@ def update(
|
|||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
def chat_history(
|
def chat_history(
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Optional[str] = None,
|
common: CommonQueryParams,
|
||||||
user_agent: Optional[str] = Header(None),
|
|
||||||
referer: Optional[str] = Header(None),
|
|
||||||
host: Optional[str] = Header(None),
|
|
||||||
):
|
):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
validate_conversation_config()
|
validate_conversation_config()
|
||||||
@@ -541,10 +527,7 @@ def chat_history(
|
|||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="chat",
|
api="chat",
|
||||||
client=client,
|
**common.__dict__,
|
||||||
user_agent=user_agent,
|
|
||||||
referer=referer,
|
|
||||||
host=host,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"status": "ok", "response": meta_log.get("chat", [])}
|
return {"status": "ok", "response": meta_log.get("chat", [])}
|
||||||
@@ -554,10 +537,7 @@ def chat_history(
|
|||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def chat_options(
|
async def chat_options(
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Optional[str] = None,
|
common: CommonQueryParams,
|
||||||
user_agent: Optional[str] = Header(None),
|
|
||||||
referer: Optional[str] = Header(None),
|
|
||||||
host: Optional[str] = Header(None),
|
|
||||||
) -> Response:
|
) -> Response:
|
||||||
cmd_options = {}
|
cmd_options = {}
|
||||||
for cmd in ConversationCommand:
|
for cmd in ConversationCommand:
|
||||||
@@ -567,10 +547,7 @@ async def chat_options(
|
|||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="chat_options",
|
api="chat_options",
|
||||||
client=client,
|
**common.__dict__,
|
||||||
user_agent=user_agent,
|
|
||||||
referer=referer,
|
|
||||||
host=host,
|
|
||||||
)
|
)
|
||||||
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
|
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
|
||||||
|
|
||||||
@@ -579,14 +556,11 @@ async def chat_options(
|
|||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def chat(
|
async def chat(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
common: CommonQueryParams,
|
||||||
q: str,
|
q: str,
|
||||||
n: Optional[int] = 5,
|
n: Optional[int] = 5,
|
||||||
d: Optional[float] = 0.18,
|
d: Optional[float] = 0.18,
|
||||||
client: Optional[str] = None,
|
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
user_agent: Optional[str] = Header(None),
|
|
||||||
referer: Optional[str] = Header(None),
|
|
||||||
host: Optional[str] = Header(None),
|
|
||||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
|
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
|
||||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
|
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
@@ -600,7 +574,7 @@ async def chat(
|
|||||||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||||
|
|
||||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||||
request, meta_log, q, (n or 5), (d or math.inf), conversation_command
|
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
|
||||||
)
|
)
|
||||||
|
|
||||||
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
|
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
|
||||||
@@ -634,11 +608,8 @@ async def chat(
|
|||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="chat",
|
api="chat",
|
||||||
client=client,
|
|
||||||
user_agent=user_agent,
|
|
||||||
referer=referer,
|
|
||||||
host=host,
|
|
||||||
metadata=chat_metadata,
|
metadata=chat_metadata,
|
||||||
|
**common.__dict__,
|
||||||
)
|
)
|
||||||
|
|
||||||
if llm_response is None:
|
if llm_response is None:
|
||||||
@@ -665,6 +636,7 @@ async def chat(
|
|||||||
|
|
||||||
async def extract_references_and_questions(
|
async def extract_references_and_questions(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
common: CommonQueryParams,
|
||||||
meta_log: dict,
|
meta_log: dict,
|
||||||
q: str,
|
q: str,
|
||||||
n: int,
|
n: int,
|
||||||
@@ -731,6 +703,7 @@ async def extract_references_and_questions(
|
|||||||
r=True,
|
r=True,
|
||||||
max_distance=d,
|
max_distance=d,
|
||||||
dedupe=False,
|
dedupe=False,
|
||||||
|
common=common,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Dedupe the results again, as duplicates may be returned across queries.
|
# Dedupe the results again, as duplicates may be returned across queries.
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ from datetime import datetime
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Iterator, List, Optional, Union, Tuple, Dict
|
from typing import Annotated, Iterator, List, Optional, Union, Tuple, Dict
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Header, Request, Depends
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
@@ -221,3 +221,20 @@ class ApiUserRateLimiter:
|
|||||||
|
|
||||||
# Add the current request to the cache
|
# Add the current request to the cache
|
||||||
user_requests.append(time())
|
user_requests.append(time())
|
||||||
|
|
||||||
|
|
||||||
|
class CommonQueryParamsClass:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
user_agent: Optional[str] = Header(None),
|
||||||
|
referer: Optional[str] = Header(None),
|
||||||
|
host: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
|
self.client = client
|
||||||
|
self.user_agent = user_agent
|
||||||
|
self.referer = referer
|
||||||
|
self.host = host
|
||||||
|
|
||||||
|
|
||||||
|
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
|
||||||
|
|||||||
@@ -9,9 +9,6 @@ import os
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_routes, configure_search_types, configure_middleware
|
from khoj.configure import configure_routes, configure_search_types, configure_middleware
|
||||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
@@ -320,6 +317,7 @@ def client(
|
|||||||
|
|
||||||
state.anonymous_mode = False
|
state.anonymous_mode = False
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
configure_middleware(app)
|
configure_middleware(app)
|
||||||
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
|
|||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response_message == prompts.no_notes_found.format()
|
assert response_message == prompts.no_entries_found.format()
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user