Only get text search results above confidence threshold via API

- During the migration, the confidence score stopped being used. It
  was being passed down from API to some point and went unused

- Remove score thresholding for images as image search confidence
  score different from text search model distance score

- Default score threshold of 0.15 is experimentally determined by
  manually looking at search results vs distance for a few queries

- Use distance instead of confidence as metric for search result quality
  Previously we'd moved text search to a distance metric from a
  confidence score.

  Now convert even cross encoder, image search scores to distance metric
  for consistent results sorting
This commit is contained in:
Debanjum Singh Solanky
2023-11-11 03:30:35 -08:00
parent e44e6df221
commit 941c7f23a3
4 changed files with 37 additions and 23 deletions

View File

@@ -1,3 +1,4 @@
import math
from typing import Optional, Type, TypeVar, List from typing import Optional, Type, TypeVar, List
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
import secrets import secrets
@@ -437,12 +438,19 @@ class EntryAdapters:
@staticmethod @staticmethod
def search_with_embeddings( def search_with_embeddings(
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None user: KhojUser,
embeddings: Tensor,
max_results: int = 10,
file_type_filter: str = None,
raw_query: str = None,
max_distance: float = math.inf,
): ):
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter) relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_entries = relevant_entries.filter(user=user).annotate( relevant_entries = relevant_entries.filter(user=user).annotate(
distance=CosineDistance("embeddings", embeddings) distance=CosineDistance("embeddings", embeddings)
) )
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
if file_type_filter: if file_type_filter:
relevant_entries = relevant_entries.filter(file_type=file_type_filter) relevant_entries = relevant_entries.filter(file_type=file_type_filter)
relevant_entries = relevant_entries.order_by("distance") relevant_entries = relevant_entries.order_by("distance")

View File

@@ -356,7 +356,7 @@ async def search(
n: Optional[int] = 5, n: Optional[int] = 5,
t: Optional[SearchType] = SearchType.All, t: Optional[SearchType] = SearchType.All,
r: Optional[bool] = False, r: Optional[bool] = False,
score_threshold: Optional[Union[float, None]] = None, max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True, dedupe: Optional[bool] = True,
client: Optional[str] = None, client: Optional[str] = None,
user_agent: Optional[str] = Header(None), user_agent: Optional[str] = Header(None),
@@ -375,12 +375,12 @@ async def search(
# initialize variables # initialize variables
user_query = q.strip() user_query = q.strip()
results_count = n or 5 results_count = n or 5
score_threshold = score_threshold if score_threshold is not None else -math.inf max_distance = max_distance if max_distance is not None else math.inf
search_futures: List[concurrent.futures.Future] = [] search_futures: List[concurrent.futures.Future] = []
# return cached results, if available # return cached results, if available
if user: if user:
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}" query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
if query_cache_key in state.query_cache[user.uuid]: if query_cache_key in state.query_cache[user.uuid]:
logger.debug(f"Return response from query cache") logger.debug(f"Return response from query cache")
return state.query_cache[user.uuid][query_cache_key] return state.query_cache[user.uuid][query_cache_key]
@@ -418,7 +418,7 @@ async def search(
t, t,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, max_distance=max_distance,
) )
] ]
@@ -431,7 +431,6 @@ async def search(
results_count, results_count,
state.search_models.image_search, state.search_models.image_search,
state.content_index.image, state.content_index.image,
score_threshold=score_threshold,
) )
] ]
@@ -454,11 +453,10 @@ async def search(
# Collate results # Collate results
results += text_search.collate_results(hits, dedupe=dedupe) results += text_search.collate_results(hits, dedupe=dedupe)
if r:
results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
else:
# Sort results across all content types and take top results # Sort results across all content types and take top results
results = sorted(results, key=lambda x: float(x.score))[:results_count] results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
:results_count
]
# Cache results # Cache results
if user: if user:
@@ -583,6 +581,7 @@ async def chat(
request: Request, request: Request,
q: str, q: str,
n: Optional[int] = 5, n: Optional[int] = 5,
d: Optional[float] = 0.15,
client: Optional[str] = None, client: Optional[str] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
user_agent: Optional[str] = Header(None), user_agent: Optional[str] = Header(None),
@@ -599,7 +598,7 @@ async def chat(
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, meta_log, q, (n or 5), conversation_command request, meta_log, q, (n or 5), (d or math.inf), conversation_command
) )
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references): if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
@@ -663,6 +662,7 @@ async def extract_references_and_questions(
meta_log: dict, meta_log: dict,
q: str, q: str,
n: int, n: int,
d: float,
conversation_type: ConversationCommand = ConversationCommand.Default, conversation_type: ConversationCommand = ConversationCommand.Default,
): ):
user = request.user.object if request.user.is_authenticated else None user = request.user.object if request.user.is_authenticated else None
@@ -723,7 +723,7 @@ async def extract_references_and_questions(
request=request, request=request,
n=n_items, n=n_items,
r=True, r=True,
score_threshold=-5.0, max_distance=d,
dedupe=False, dedupe=False,
) )
) )

View File

@@ -146,7 +146,7 @@ def extract_metadata(image_name):
async def query( async def query(
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
): ):
# Set query to image content if query is of form file:/path/to/file.png # Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
@@ -167,7 +167,8 @@ async def query(
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger): with timer("Search Time", logger):
image_hits = { image_hits = {
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]} # Map scores to distance metric by multiplying by -1
result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0] for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
} }
@@ -204,7 +205,7 @@ async def query(
] ]
# Filter results by score threshold # Filter results by score threshold
hits = [hit for hit in hits if hit["image_score"] >= score_threshold] hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
# Sort the images based on their combined metadata, image scores # Sort the images based on their combined metadata, image scores
return sorted(hits, key=lambda hit: hit["score"], reverse=True) return sorted(hits, key=lambda hit: hit["score"], reverse=True)

View File

@@ -105,7 +105,7 @@ async def query(
type: SearchType = SearchType.All, type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None, question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False, rank_results: bool = False,
score_threshold: float = -math.inf, max_distance: float = math.inf,
) -> Tuple[List[dict], List[Entry]]: ) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query" "Search for entries that answer the query"
@@ -127,6 +127,7 @@ async def query(
max_results=top_k, max_results=top_k,
file_type_filter=file_type, file_type_filter=file_type,
raw_query=raw_query, raw_query=raw_query,
max_distance=max_distance,
).all() ).all()
hits = await sync_to_async(list)(hits) # type: ignore[call-arg] hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
@@ -177,12 +178,16 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
) )
def rerank_and_sort_results(hits, query): def rerank_and_sort_results(hits, query, rank_results):
# If we have more than one result and reranking is enabled
rank_results = rank_results and len(list(hits)) > 1
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
hits = cross_encoder_score(query, hits) if rank_results:
hits = cross_encoder_score(query, hits)
# Sort results by cross-encoder score followed by bi-encoder score # Sort results by cross-encoder score followed by bi-encoder score
hits = sort_results(rank_results=True, hits=hits) hits = sort_results(rank_results=rank_results, hits=hits)
return hits return hits
@@ -217,9 +222,9 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
with timer("Cross-Encoder Predict Time", logger, state.device): with timer("Cross-Encoder Predict Time", logger, state.device):
cross_scores = state.cross_encoder_model.predict(query, hits) cross_scores = state.cross_encoder_model.predict(query, hits)
# Store cross-encoder scores in results dictionary for ranking # Convert cross-encoder scores to distances and pass in hits for reranking
for idx in range(len(cross_scores)): for idx in range(len(cross_scores)):
hits[idx]["cross_score"] = cross_scores[idx] hits[idx]["cross_score"] = -1 * cross_scores[idx]
return hits return hits
@@ -227,7 +232,7 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]: def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
"""Order results by cross-encoder score followed by bi-encoder score""" """Order results by cross-encoder score followed by bi-encoder score"""
with timer("Rank Time", logger, state.device): with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score hits.sort(key=lambda x: x["score"]) # sort by bi-encoder score
if rank_results: if rank_results:
hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score hits.sort(key=lambda x: x["cross_score"]) # sort by cross-encoder score
return hits return hits