mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 21:29:12 +00:00
Improve Query Speed. Normalize Embeddings, Moving them to Cuda GPU
- Move embeddings to CUDA GPU for compute, when available - Normalize embeddings and Use Dot Product instead of Cosine
This commit is contained in:
14
src/main.py
14
src/main.py
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
# External Packages
|
||||
import uvicorn
|
||||
import torch
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -24,6 +25,7 @@ processor_config = ProcessorConfigModel()
|
||||
config_file = ""
|
||||
verbose = 0
|
||||
app = FastAPI()
|
||||
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
app.mount("/views", StaticFiles(directory="views"), name="views")
|
||||
templates = Jinja2Templates(directory="views/")
|
||||
@@ -56,14 +58,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||
|
||||
if (t == SearchType.Notes or t == None) and model.notes_search:
|
||||
# query notes
|
||||
hits = asymmetric.query(user_query, model.notes_search)
|
||||
hits = asymmetric.query(user_query, model.notes_search, device=device)
|
||||
|
||||
# collate and return results
|
||||
return asymmetric.collate_results(hits, model.notes_search.entries, results_count)
|
||||
|
||||
if (t == SearchType.Music or t == None) and model.music_search:
|
||||
# query music library
|
||||
hits = asymmetric.query(user_query, model.music_search)
|
||||
hits = asymmetric.query(user_query, model.music_search, device=device)
|
||||
|
||||
# collate and return results
|
||||
return asymmetric.collate_results(hits, model.music_search.entries, results_count)
|
||||
@@ -93,14 +95,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||
@app.get('/reload')
|
||||
def regenerate(t: Optional[SearchType] = None):
|
||||
global model
|
||||
model = initialize_search(config, regenerate=False, t=t)
|
||||
model = initialize_search(config, regenerate=False, t=t, device=device)
|
||||
return {'status': 'ok', 'message': 'reload completed'}
|
||||
|
||||
|
||||
@app.get('/regenerate')
|
||||
def regenerate(t: Optional[SearchType] = None):
|
||||
global model
|
||||
model = initialize_search(config, regenerate=True, t=t)
|
||||
model = initialize_search(config, regenerate=True, t=t, device=device)
|
||||
return {'status': 'ok', 'message': 'regeneration completed'}
|
||||
|
||||
|
||||
@@ -149,12 +151,12 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None
|
||||
# Initialize Org Notes Search
|
||||
if (t == SearchType.Notes or t == None) and config.content_type.org:
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
model.notes_search = asymmetric.setup(config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||
model.notes_search = asymmetric.setup(config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
||||
|
||||
# Initialize Org Music Search
|
||||
if (t == SearchType.Music or t == None) and config.content_type.music:
|
||||
# Extract Entries, Generate Music Embeddings
|
||||
model.music_search = asymmetric.setup(config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||
model.music_search = asymmetric.setup(config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
||||
|
||||
# Initialize Ledger Search
|
||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||
|
||||
Reference in New Issue
Block a user