Only allow supported search types to /search, /regenerate APIs

- Use a SearchType to limit types that can be passed by user
- FastAPI automatically validates type passed in query param
- Available type options show up in Swagger UI, FastAPI docs
- controller code looks neater instead of doing string comparisons for type
- Test invalid, valid search types via pytest
This commit is contained in:
Debanjum Singh Solanky
2021-09-29 19:02:55 -07:00
parent 150593c776
commit 81ce0cacc3
3 changed files with 58 additions and 10 deletions

View File

@@ -11,13 +11,14 @@ from fastapi import FastAPI
from search_type import asymmetric, symmetric_ledger, image_search
from utils.helpers import get_from_dict
from utils.cli import cli
from utils.config import SearchType
app = FastAPI()
@app.get('/search')
def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if q is None or q == '':
print(f'No query param (q) passed in API call to initiate search')
return {}
@@ -25,7 +26,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
user_query = q
results_count = n
if (t == 'notes' or t == None) and notes_search_enabled:
if (t == SearchType.Notes or t == None) and notes_search_enabled:
# query notes
hits = asymmetric.query_notes(
user_query,
@@ -38,7 +39,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
# collate and return results
return asymmetric.collate_results(hits, entries, results_count)
if (t == 'music' or t == None) and music_search_enabled:
if (t == SearchType.Music or t == None) and music_search_enabled:
# query music library
hits = asymmetric.query_notes(
user_query,
@@ -51,7 +52,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
# collate and return results
return asymmetric.collate_results(hits, songs, results_count)
if (t == 'ledger' or t == None) and ledger_search_enabled:
if (t == SearchType.Ledger or t == None) and ledger_search_enabled:
# query transactions
hits = symmetric_ledger.query_transactions(
user_query,
@@ -63,7 +64,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
# collate and return results
return symmetric_ledger.collate_results(hits, transactions, results_count)
if (t == 'image' or t == None) and image_search_enabled:
if (t == SearchType.Image or t == None) and image_search_enabled:
# query transactions
hits = image_search.query_images(
user_query,
@@ -85,8 +86,8 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None):
@app.get('/regenerate')
def regenerate(t: Optional[str] = None):
if (t == 'notes' or t == None) and notes_search_enabled:
def regenerate(t: Optional[SearchType] = None):
if (t == SearchType.Notes or t == None) and notes_search_enabled:
# Extract Entries, Generate Embeddings
global corpus_embeddings
global entries
@@ -98,7 +99,7 @@ def regenerate(t: Optional[str] = None):
regenerate=True,
verbose=args.verbose)
if (t == 'music' or t == None) and music_search_enabled:
if (t == SearchType.Music or t == None) and music_search_enabled:
# Extract Entries, Generate Song Embeddings
global song_embeddings
global songs
@@ -110,7 +111,7 @@ def regenerate(t: Optional[str] = None):
regenerate=True,
verbose=args.verbose)
if (t == 'ledger' or t == None) and ledger_search_enabled:
if (t == SearchType.Ledger or t == None) and ledger_search_enabled:
# Extract Entries, Generate Embeddings
global transaction_embeddings
global transactions
@@ -122,7 +123,7 @@ def regenerate(t: Optional[str] = None):
regenerate=True,
verbose=args.verbose)
if (t == 'image' or t == None) and image_search_enabled:
if (t == SearchType.Image or t == None) and image_search_enabled:
# Extract Images, Generate Embeddings
global image_embeddings
global image_metadata_embeddings