mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Only allow supported search types to /search, /regenerate APIs
- Use a SearchType to limit types that can be passed by user - FastAPI automatically validates type passed in query param - Available type options show up in Swagger UI, FastAPI docs - controller code looks neater instead of doing string comparisons for type - Test invalid, valid search types via pytest
This commit is contained in:
21
src/main.py
21
src/main.py
@@ -11,13 +11,14 @@ from fastapi import FastAPI
|
|||||||
from search_type import asymmetric, symmetric_ledger, image_search
|
from search_type import asymmetric, symmetric_ledger, image_search
|
||||||
from utils.helpers import get_from_dict
|
from utils.helpers import get_from_dict
|
||||||
from utils.cli import cli
|
from utils.cli import cli
|
||||||
|
from utils.config import SearchType
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
@app.get('/search')
|
@app.get('/search')
|
||||||
def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||||
if q is None or q == '':
|
if q is None or q == '':
|
||||||
print(f'No query param (q) passed in API call to initiate search')
|
print(f'No query param (q) passed in API call to initiate search')
|
||||||
return {}
|
return {}
|
||||||
@@ -25,7 +26,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
|||||||
user_query = q
|
user_query = q
|
||||||
results_count = n
|
results_count = n
|
||||||
|
|
||||||
if (t == 'notes' or t == None) and notes_search_enabled:
|
if (t == SearchType.Notes or t == None) and notes_search_enabled:
|
||||||
# query notes
|
# query notes
|
||||||
hits = asymmetric.query_notes(
|
hits = asymmetric.query_notes(
|
||||||
user_query,
|
user_query,
|
||||||
@@ -38,7 +39,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
|||||||
# collate and return results
|
# collate and return results
|
||||||
return asymmetric.collate_results(hits, entries, results_count)
|
return asymmetric.collate_results(hits, entries, results_count)
|
||||||
|
|
||||||
if (t == 'music' or t == None) and music_search_enabled:
|
if (t == SearchType.Music or t == None) and music_search_enabled:
|
||||||
# query music library
|
# query music library
|
||||||
hits = asymmetric.query_notes(
|
hits = asymmetric.query_notes(
|
||||||
user_query,
|
user_query,
|
||||||
@@ -51,7 +52,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
|||||||
# collate and return results
|
# collate and return results
|
||||||
return asymmetric.collate_results(hits, songs, results_count)
|
return asymmetric.collate_results(hits, songs, results_count)
|
||||||
|
|
||||||
if (t == 'ledger' or t == None) and ledger_search_enabled:
|
if (t == SearchType.Ledger or t == None) and ledger_search_enabled:
|
||||||
# query transactions
|
# query transactions
|
||||||
hits = symmetric_ledger.query_transactions(
|
hits = symmetric_ledger.query_transactions(
|
||||||
user_query,
|
user_query,
|
||||||
@@ -63,7 +64,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
|||||||
# collate and return results
|
# collate and return results
|
||||||
return symmetric_ledger.collate_results(hits, transactions, results_count)
|
return symmetric_ledger.collate_results(hits, transactions, results_count)
|
||||||
|
|
||||||
if (t == 'image' or t == None) and image_search_enabled:
|
if (t == SearchType.Image or t == None) and image_search_enabled:
|
||||||
# query transactions
|
# query transactions
|
||||||
hits = image_search.query_images(
|
hits = image_search.query_images(
|
||||||
user_query,
|
user_query,
|
||||||
@@ -85,8 +86,8 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
|
|||||||
|
|
||||||
|
|
||||||
@app.get('/regenerate')
|
@app.get('/regenerate')
|
||||||
def regenerate(t: Optional[str] = None):
|
def regenerate(t: Optional[SearchType] = None):
|
||||||
if (t == 'notes' or t == None) and notes_search_enabled:
|
if (t == SearchType.Notes or t == None) and notes_search_enabled:
|
||||||
# Extract Entries, Generate Embeddings
|
# Extract Entries, Generate Embeddings
|
||||||
global corpus_embeddings
|
global corpus_embeddings
|
||||||
global entries
|
global entries
|
||||||
@@ -98,7 +99,7 @@ def regenerate(t: Optional[str] = None):
|
|||||||
regenerate=True,
|
regenerate=True,
|
||||||
verbose=args.verbose)
|
verbose=args.verbose)
|
||||||
|
|
||||||
if (t == 'music' or t == None) and music_search_enabled:
|
if (t == SearchType.Music or t == None) and music_search_enabled:
|
||||||
# Extract Entries, Generate Song Embeddings
|
# Extract Entries, Generate Song Embeddings
|
||||||
global song_embeddings
|
global song_embeddings
|
||||||
global songs
|
global songs
|
||||||
@@ -110,7 +111,7 @@ def regenerate(t: Optional[str] = None):
|
|||||||
regenerate=True,
|
regenerate=True,
|
||||||
verbose=args.verbose)
|
verbose=args.verbose)
|
||||||
|
|
||||||
if (t == 'ledger' or t == None) and ledger_search_enabled:
|
if (t == SearchType.Ledger or t == None) and ledger_search_enabled:
|
||||||
# Extract Entries, Generate Embeddings
|
# Extract Entries, Generate Embeddings
|
||||||
global transaction_embeddings
|
global transaction_embeddings
|
||||||
global transactions
|
global transactions
|
||||||
@@ -122,7 +123,7 @@ def regenerate(t: Optional[str] = None):
|
|||||||
regenerate=True,
|
regenerate=True,
|
||||||
verbose=args.verbose)
|
verbose=args.verbose)
|
||||||
|
|
||||||
if (t == 'image' or t == None) and image_search_enabled:
|
if (t == SearchType.Image or t == None) and image_search_enabled:
|
||||||
# Extract Images, Generate Embeddings
|
# Extract Images, Generate Embeddings
|
||||||
global image_embeddings
|
global image_embeddings
|
||||||
global image_metadata_embeddings
|
global image_metadata_embeddings
|
||||||
|
|||||||
@@ -16,6 +16,44 @@ client = TestClient(app)
|
|||||||
|
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def test_search_with_invalid_search_type():
|
||||||
|
# Arrange
|
||||||
|
user_query = "How to call semantic search from Emacs?"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get(f"/search?q={user_query}&t=invalid_search_type")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_with_valid_search_type():
|
||||||
|
# Arrange
|
||||||
|
for search_type in ["notes", "ledger", "music", "image"]:
|
||||||
|
# Act
|
||||||
|
response = client.get(f"/search?q=random&t={search_type}")
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_regenerate_with_invalid_search_type():
|
||||||
|
# Act
|
||||||
|
response = client.get(f"/regenerate?t=invalid_search_type")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_regenerate_with_valid_search_type():
|
||||||
|
# Arrange
|
||||||
|
for search_type in ["notes", "ledger", "music", "image"]:
|
||||||
|
# Act
|
||||||
|
response = client.get(f"/regenerate?t={search_type}")
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_asymmetric_setup():
|
def test_asymmetric_setup():
|
||||||
# Arrange
|
# Arrange
|
||||||
|
|||||||
9
src/utils/config.py
Normal file
9
src/utils/config.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class SearchType(str, Enum):
|
||||||
|
Notes = "notes"
|
||||||
|
Ledger = "ledger"
|
||||||
|
Music = "music"
|
||||||
|
Image = "image"
|
||||||
|
|
||||||
Reference in New Issue
Block a user