mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Only get text search results above confidence threshold via API
- During the migration, the confidence score stopped being used. It was being passed down from API to some point and went unused - Remove score thresholding for images as image search confidence score different from text search model distance score - Default score threshold of 0.15 is experimentally determined by manually looking at search results vs distance for a few queries - Use distance instead of confidence as metric for search result quality Previously we'd moved text search to a distance metric from a confidence score. Now convert even cross encoder, image search scores to distance metric for consistent results sorting
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
from typing import Optional, Type, TypeVar, List
|
from typing import Optional, Type, TypeVar, List
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
import secrets
|
import secrets
|
||||||
@@ -437,12 +438,19 @@ class EntryAdapters:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def search_with_embeddings(
|
def search_with_embeddings(
|
||||||
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
|
user: KhojUser,
|
||||||
|
embeddings: Tensor,
|
||||||
|
max_results: int = 10,
|
||||||
|
file_type_filter: str = None,
|
||||||
|
raw_query: str = None,
|
||||||
|
max_distance: float = math.inf,
|
||||||
):
|
):
|
||||||
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
|
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
|
||||||
relevant_entries = relevant_entries.filter(user=user).annotate(
|
relevant_entries = relevant_entries.filter(user=user).annotate(
|
||||||
distance=CosineDistance("embeddings", embeddings)
|
distance=CosineDistance("embeddings", embeddings)
|
||||||
)
|
)
|
||||||
|
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
|
||||||
|
|
||||||
if file_type_filter:
|
if file_type_filter:
|
||||||
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
||||||
relevant_entries = relevant_entries.order_by("distance")
|
relevant_entries = relevant_entries.order_by("distance")
|
||||||
|
|||||||
@@ -356,7 +356,7 @@ async def search(
|
|||||||
n: Optional[int] = 5,
|
n: Optional[int] = 5,
|
||||||
t: Optional[SearchType] = SearchType.All,
|
t: Optional[SearchType] = SearchType.All,
|
||||||
r: Optional[bool] = False,
|
r: Optional[bool] = False,
|
||||||
score_threshold: Optional[Union[float, None]] = None,
|
max_distance: Optional[Union[float, None]] = None,
|
||||||
dedupe: Optional[bool] = True,
|
dedupe: Optional[bool] = True,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
user_agent: Optional[str] = Header(None),
|
user_agent: Optional[str] = Header(None),
|
||||||
@@ -375,12 +375,12 @@ async def search(
|
|||||||
# initialize variables
|
# initialize variables
|
||||||
user_query = q.strip()
|
user_query = q.strip()
|
||||||
results_count = n or 5
|
results_count = n or 5
|
||||||
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
max_distance = max_distance if max_distance is not None else math.inf
|
||||||
search_futures: List[concurrent.futures.Future] = []
|
search_futures: List[concurrent.futures.Future] = []
|
||||||
|
|
||||||
# return cached results, if available
|
# return cached results, if available
|
||||||
if user:
|
if user:
|
||||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
|
||||||
if query_cache_key in state.query_cache[user.uuid]:
|
if query_cache_key in state.query_cache[user.uuid]:
|
||||||
logger.debug(f"Return response from query cache")
|
logger.debug(f"Return response from query cache")
|
||||||
return state.query_cache[user.uuid][query_cache_key]
|
return state.query_cache[user.uuid][query_cache_key]
|
||||||
@@ -418,7 +418,7 @@ async def search(
|
|||||||
t,
|
t,
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r or False,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
max_distance=max_distance,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -431,7 +431,6 @@ async def search(
|
|||||||
results_count,
|
results_count,
|
||||||
state.search_models.image_search,
|
state.search_models.image_search,
|
||||||
state.content_index.image,
|
state.content_index.image,
|
||||||
score_threshold=score_threshold,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -454,11 +453,10 @@ async def search(
|
|||||||
# Collate results
|
# Collate results
|
||||||
results += text_search.collate_results(hits, dedupe=dedupe)
|
results += text_search.collate_results(hits, dedupe=dedupe)
|
||||||
|
|
||||||
if r:
|
|
||||||
results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
|
|
||||||
else:
|
|
||||||
# Sort results across all content types and take top results
|
# Sort results across all content types and take top results
|
||||||
results = sorted(results, key=lambda x: float(x.score))[:results_count]
|
results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
|
||||||
|
:results_count
|
||||||
|
]
|
||||||
|
|
||||||
# Cache results
|
# Cache results
|
||||||
if user:
|
if user:
|
||||||
@@ -583,6 +581,7 @@ async def chat(
|
|||||||
request: Request,
|
request: Request,
|
||||||
q: str,
|
q: str,
|
||||||
n: Optional[int] = 5,
|
n: Optional[int] = 5,
|
||||||
|
d: Optional[float] = 0.15,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
user_agent: Optional[str] = Header(None),
|
user_agent: Optional[str] = Header(None),
|
||||||
@@ -599,7 +598,7 @@ async def chat(
|
|||||||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||||
|
|
||||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||||
request, meta_log, q, (n or 5), conversation_command
|
request, meta_log, q, (n or 5), (d or math.inf), conversation_command
|
||||||
)
|
)
|
||||||
|
|
||||||
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
|
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
|
||||||
@@ -663,6 +662,7 @@ async def extract_references_and_questions(
|
|||||||
meta_log: dict,
|
meta_log: dict,
|
||||||
q: str,
|
q: str,
|
||||||
n: int,
|
n: int,
|
||||||
|
d: float,
|
||||||
conversation_type: ConversationCommand = ConversationCommand.Default,
|
conversation_type: ConversationCommand = ConversationCommand.Default,
|
||||||
):
|
):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
@@ -723,7 +723,7 @@ async def extract_references_and_questions(
|
|||||||
request=request,
|
request=request,
|
||||||
n=n_items,
|
n=n_items,
|
||||||
r=True,
|
r=True,
|
||||||
score_threshold=-5.0,
|
max_distance=d,
|
||||||
dedupe=False,
|
dedupe=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ def extract_metadata(image_name):
|
|||||||
|
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
|
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
|
||||||
):
|
):
|
||||||
# Set query to image content if query is of form file:/path/to/file.png
|
# Set query to image content if query is of form file:/path/to/file.png
|
||||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||||
@@ -167,7 +167,8 @@ async def query(
|
|||||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||||
with timer("Search Time", logger):
|
with timer("Search Time", logger):
|
||||||
image_hits = {
|
image_hits = {
|
||||||
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
|
# Map scores to distance metric by multiplying by -1
|
||||||
|
result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
|
||||||
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
|
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,7 +205,7 @@ async def query(
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Filter results by score threshold
|
# Filter results by score threshold
|
||||||
hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
|
hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
|
||||||
|
|
||||||
# Sort the images based on their combined metadata, image scores
|
# Sort the images based on their combined metadata, image scores
|
||||||
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ async def query(
|
|||||||
type: SearchType = SearchType.All,
|
type: SearchType = SearchType.All,
|
||||||
question_embedding: Union[torch.Tensor, None] = None,
|
question_embedding: Union[torch.Tensor, None] = None,
|
||||||
rank_results: bool = False,
|
rank_results: bool = False,
|
||||||
score_threshold: float = -math.inf,
|
max_distance: float = math.inf,
|
||||||
) -> Tuple[List[dict], List[Entry]]:
|
) -> Tuple[List[dict], List[Entry]]:
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
|
|
||||||
@@ -127,6 +127,7 @@ async def query(
|
|||||||
max_results=top_k,
|
max_results=top_k,
|
||||||
file_type_filter=file_type,
|
file_type_filter=file_type,
|
||||||
raw_query=raw_query,
|
raw_query=raw_query,
|
||||||
|
max_distance=max_distance,
|
||||||
).all()
|
).all()
|
||||||
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
|
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
|
||||||
|
|
||||||
@@ -177,12 +178,16 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def rerank_and_sort_results(hits, query):
|
def rerank_and_sort_results(hits, query, rank_results):
|
||||||
|
# If we have more than one result and reranking is enabled
|
||||||
|
rank_results = rank_results and len(list(hits)) > 1
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
hits = cross_encoder_score(query, hits)
|
if rank_results:
|
||||||
|
hits = cross_encoder_score(query, hits)
|
||||||
|
|
||||||
# Sort results by cross-encoder score followed by bi-encoder score
|
# Sort results by cross-encoder score followed by bi-encoder score
|
||||||
hits = sort_results(rank_results=True, hits=hits)
|
hits = sort_results(rank_results=rank_results, hits=hits)
|
||||||
|
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
@@ -217,9 +222,9 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
|
|||||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||||
cross_scores = state.cross_encoder_model.predict(query, hits)
|
cross_scores = state.cross_encoder_model.predict(query, hits)
|
||||||
|
|
||||||
# Store cross-encoder scores in results dictionary for ranking
|
# Convert cross-encoder scores to distances and pass in hits for reranking
|
||||||
for idx in range(len(cross_scores)):
|
for idx in range(len(cross_scores)):
|
||||||
hits[idx]["cross_score"] = cross_scores[idx]
|
hits[idx]["cross_score"] = -1 * cross_scores[idx]
|
||||||
|
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
@@ -227,7 +232,7 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
|
|||||||
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
|
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
|
||||||
"""Order results by cross-encoder score followed by bi-encoder score"""
|
"""Order results by cross-encoder score followed by bi-encoder score"""
|
||||||
with timer("Rank Time", logger, state.device):
|
with timer("Rank Time", logger, state.device):
|
||||||
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
|
hits.sort(key=lambda x: x["score"]) # sort by bi-encoder score
|
||||||
if rank_results:
|
if rank_results:
|
||||||
hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
|
hits.sort(key=lambda x: x["cross_score"]) # sort by cross-encoder score
|
||||||
return hits
|
return hits
|
||||||
|
|||||||
Reference in New Issue
Block a user