mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 05:40:17 +00:00
Add support per user for configuring the preferred search model from the config page
- Honor this setting across the relevant places where embeddings are used - Convert the VectorField object to have None for dimensions in order to make the search model easily configurable
This commit is contained in:
@@ -332,6 +332,31 @@ async def update_chat_model(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api.post("/config/data/search/model", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def update_chat_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
new_config = await adapters.aset_user_search_model(user, int(id))
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_search_model",
|
||||
client=client,
|
||||
metadata={"search_model": new_config.setting.name},
|
||||
)
|
||||
|
||||
if new_config is None:
|
||||
return {"status": "error", "message": "Model not found"}
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# Create Routes
|
||||
@api.get("/config/data/default")
|
||||
def get_default_config_data():
|
||||
@@ -410,14 +435,10 @@ async def search(
|
||||
defiltered_query = filter.defilter(defiltered_query)
|
||||
|
||||
encoded_asymmetric_query = None
|
||||
if t == SearchType.All or t != SearchType.Image:
|
||||
text_search_models: List[TextSearchModel] = [
|
||||
model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
|
||||
]
|
||||
if text_search_models:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||
if t != SearchType.Image:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
if t in [
|
||||
@@ -473,9 +494,9 @@ async def search(
|
||||
results += text_search.collate_results(hits, dedupe=dedupe)
|
||||
|
||||
# Sort results across all content types and take top results
|
||||
results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
|
||||
:results_count
|
||||
]
|
||||
results = text_search.rerank_and_sort_results(
|
||||
results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
|
||||
)[:results_count]
|
||||
|
||||
# Cache results
|
||||
if user:
|
||||
|
||||
Reference in New Issue
Block a user