mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 05:39:11 +00:00
Run Explicit Filter on Entries, Embeddings before Semantic Search for Query
## Issue
- Explicit filtering was being done after search by the bi-encoder
but before re-ranking by the cross-encoder
- This limited the quality of results being returned for queries with explicit filters.
The bi-encoder returned results which were going to be excluded.
So the burden of improving those limited results post filtering was on the
cross-encoder, by re-ranking the remaining results to best match the query
## Fix
- Given that the entry and its embedding are at the same index in their respective lists.
We know which entries map to which embedding tensors.
So we can run the filter for blocked, required words before the bi-encoder search.
And limit entries, embeddings being considered for the current query
## Result
- Semantic search by the bi-encoder returns the most relevant results
for the query, knowing that the results aren't going to be filtered out after.
So the cross-encoder shoulders less of the burden of improving the results
## Corollary
- This pre-filtering technique allows us to apply other explicit filters
on entries relevant for the current query, before calling search
- E.g limit search to entries within date/time specified in query
This commit is contained in:
@@ -58,17 +58,17 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
|||||||
|
|
||||||
if (t == SearchType.Notes or t == None) and model.notes_search:
|
if (t == SearchType.Notes or t == None) and model.notes_search:
|
||||||
# query notes
|
# query notes
|
||||||
hits = asymmetric.query(user_query, model.notes_search, device=device)
|
hits, entries = asymmetric.query(user_query, model.notes_search, device=device)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return asymmetric.collate_results(hits, model.notes_search.entries, results_count)
|
return asymmetric.collate_results(hits, entries, results_count)
|
||||||
|
|
||||||
if (t == SearchType.Music or t == None) and model.music_search:
|
if (t == SearchType.Music or t == None) and model.music_search:
|
||||||
# query music library
|
# query music library
|
||||||
hits = asymmetric.query(user_query, model.music_search, device=device)
|
hits, entries = asymmetric.query(user_query, model.music_search, device=device)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return asymmetric.collate_results(hits, model.music_search.entries, results_count)
|
return asymmetric.collate_results(hits, entries, results_count)
|
||||||
|
|
||||||
if (t == SearchType.Ledger or t == None) and model.ledger_search:
|
if (t == SearchType.Ledger or t == None) and model.ledger_search:
|
||||||
# query transactions
|
# query transactions
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import gzip
|
|||||||
import re
|
import re
|
||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
import pathlib
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
@@ -100,24 +101,25 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')):
|
|||||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||||
|
|
||||||
|
# Copy original embeddings, entries to filter them for query
|
||||||
|
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||||
|
entries = deepcopy(model.entries)
|
||||||
|
|
||||||
|
# Filter to entries that contain all required_words and no blocked_words
|
||||||
|
entries, corpus_embeddings = explicit_filter(entries, corpus_embeddings, required_words, blocked_words)
|
||||||
|
if entries is None or len(entries) == 0:
|
||||||
|
return {}
|
||||||
|
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
|
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
|
||||||
question_embedding.to(device)
|
question_embedding.to(device)
|
||||||
question_embedding = util.normalize_embeddings(question_embedding)
|
question_embedding = util.normalize_embeddings(question_embedding)
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)
|
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
||||||
hits = hits[0] # Get the hits for the first query
|
|
||||||
|
|
||||||
# Filter out entries that contain required words and do not contain blocked words
|
|
||||||
hits = explicit_filter(hits,
|
|
||||||
[entry[0] for entry in model.entries],
|
|
||||||
required_words,blocked_words)
|
|
||||||
if hits is None or len(hits) == 0:
|
|
||||||
return hits
|
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
cross_inp = [[query, model.entries[hit['corpus_id']][0]] for hit in hits]
|
cross_inp = [[query, entries[hit['corpus_id']][0]] for hit in hits]
|
||||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||||
|
|
||||||
# Store cross-encoder scores in results dictionary for ranking
|
# Store cross-encoder scores in results dictionary for ranking
|
||||||
@@ -128,28 +130,43 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')):
|
|||||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||||
|
|
||||||
return hits
|
return hits, entries
|
||||||
|
|
||||||
|
|
||||||
def explicit_filter(hits, entries, required_words, blocked_words):
|
def explicit_filter(entries, embeddings, required_words, blocked_words):
|
||||||
hits_by_word_set = [(set(word.lower()
|
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||||
|
return entries, embeddings
|
||||||
|
|
||||||
|
# convert each entry to a set of words
|
||||||
|
entries_by_word_set = [set(word.lower()
|
||||||
for word
|
for word
|
||||||
in re.split(
|
in re.split(
|
||||||
r',|\.| |\]|\[\(|\)|\{|\}',
|
r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets
|
||||||
entries[hit['corpus_id']])
|
entry[0])
|
||||||
if word != ""),
|
if word != "")
|
||||||
hit)
|
for entry in entries]
|
||||||
for hit in hits]
|
|
||||||
|
|
||||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
# track id of entries to exclude
|
||||||
return hits
|
entries_to_exclude = set()
|
||||||
|
|
||||||
|
# mark entries that do not contain all required_words for exclusion
|
||||||
if len(required_words) > 0:
|
if len(required_words) > 0:
|
||||||
return [hit for (words_in_entry, hit) in hits_by_word_set
|
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||||
if required_words.intersection(words_in_entry) and not blocked_words.intersection(words_in_entry)]
|
if not required_words.issubset(words_in_entry):
|
||||||
|
entries_to_exclude.add(id)
|
||||||
|
|
||||||
|
# mark entries that contain any blocked_words for exclusion
|
||||||
if len(blocked_words) > 0:
|
if len(blocked_words) > 0:
|
||||||
return [hit for (words_in_entry, hit) in hits_by_word_set
|
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||||
if not blocked_words.intersection(words_in_entry)]
|
if words_in_entry.intersection(blocked_words):
|
||||||
return hits
|
entries_to_exclude.add(id)
|
||||||
|
|
||||||
|
# delete entries (and their embeddings) marked for exclusion
|
||||||
|
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||||
|
del entries[id]
|
||||||
|
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||||
|
|
||||||
|
return entries, embeddings
|
||||||
|
|
||||||
|
|
||||||
def render_results(hits, entries, count=5, display_biencoder_results=False):
|
def render_results(hits, entries, count=5, display_biencoder_results=False):
|
||||||
|
|||||||
@@ -26,13 +26,13 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
|
|||||||
query = "How to git install application?"
|
query = "How to git install application?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
hits = asymmetric.query(
|
hits, entries = asymmetric.query(
|
||||||
query,
|
query,
|
||||||
model = model.notes_search)
|
model = model.notes_search)
|
||||||
|
|
||||||
results = asymmetric.collate_results(
|
results = asymmetric.collate_results(
|
||||||
hits,
|
hits,
|
||||||
model.notes_search.entries,
|
entries,
|
||||||
count=1)
|
count=1)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
|||||||
Reference in New Issue
Block a user