mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 05:39:11 +00:00
Extract Entries in a standardized format across text search types
Issue:
- Had different schema of extracted entries for symmetric_ledger vs asymmetric
- Entry extraction for asymmetric was dirty, relying on cryptic
indices to store raw entry vs cleaned entry meant to be passed to embeddings
- This was pushing the load of figuring out what property to extract
from each entry to downstream processes like the filters
- This limited the filters to only work for asymmetric search, not for
symmetric_ledger
- Fix
- Use consistent format for extracted entries
{
'embed': entry_string_meant_to_be_passed_to_model_and_get_embeddings,
'raw' : raw_entry_string_meant_to_be_passed_to_use
}
- Result
- Now filters can be applied across search types, and the specific
field they should be applied on can be configured by each search
type
This commit is contained in:
@@ -17,7 +17,7 @@ import dateparser as dtparse
|
||||
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
||||
|
||||
|
||||
def date_filter(query, entries, embeddings):
|
||||
def date_filter(query, entries, embeddings, entry_key='raw'):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
query_daterange = extract_date_range(query)
|
||||
@@ -34,7 +34,7 @@ def date_filter(query, entries, embeddings):
|
||||
entries_to_include = set()
|
||||
for id, entry in enumerate(entries):
|
||||
# Extract dates from entry
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[1]):
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
|
||||
# Convert date string in entry to unix timestamp
|
||||
try:
|
||||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
||||
|
||||
@@ -5,7 +5,7 @@ import re
|
||||
import torch
|
||||
|
||||
|
||||
def explicit_filter(raw_query, entries, embeddings):
|
||||
def explicit_filter(raw_query, entries, embeddings, entry_key='raw'):
|
||||
# Separate natural query from explicit required, blocked words filters
|
||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||
@@ -19,7 +19,7 @@ def explicit_filter(raw_query, entries, embeddings):
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
|
||||
entries_by_word_set = [set(word.lower()
|
||||
for word
|
||||
in re.split(entry_splitter, entry[1])
|
||||
in re.split(entry_splitter, entry[entry_key])
|
||||
if word != "")
|
||||
for entry in entries]
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ def extract_entries(notesfile, verbose=0):
|
||||
note_string = f'{note["Title"]}' \
|
||||
f'\t{note["Tags"] if "Tags" in note else ""}' \
|
||||
f'\n{note["Body"] if "Body" in note else ""}'
|
||||
entries.append([note_string, note["Raw"]])
|
||||
entries.append({'embed': note_string, 'raw': note["Raw"]})
|
||||
|
||||
# Close File
|
||||
jsonl_file.close()
|
||||
@@ -83,7 +83,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
|
||||
print(f"Loaded embeddings from {embeddings_file}")
|
||||
|
||||
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
||||
corpus_embeddings = bi_encoder.encode([entry[0] for entry in entries], convert_to_tensor=True, show_progress_bar=True)
|
||||
corpus_embeddings = bi_encoder.encode([entry['embed'] for entry in entries], convert_to_tensor=True, show_progress_bar=True)
|
||||
corpus_embeddings.to(device)
|
||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||
torch.save(corpus_embeddings, get_absolute_path(embeddings_file))
|
||||
@@ -116,7 +116,7 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu'), fi
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
cross_inp = [[query, entries[hit['corpus_id']][0]] for hit in hits]
|
||||
cross_inp = [[query, entries[hit['corpus_id']]['embed']] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
@@ -138,20 +138,20 @@ def render_results(hits, entries, count=5, display_biencoder_results=False):
|
||||
print(f"Top-{count} Bi-Encoder Retrieval hits")
|
||||
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
||||
for hit in hits[0:count]:
|
||||
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']][0]}")
|
||||
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['embed']}")
|
||||
|
||||
# Output of top hits from re-ranker
|
||||
print("\n-------------------------\n")
|
||||
print(f"Top-{count} Cross-Encoder Re-ranker hits")
|
||||
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
||||
for hit in hits[0:count]:
|
||||
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']][0]}")
|
||||
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['embed']}")
|
||||
|
||||
|
||||
def collate_results(hits, entries, count=5):
|
||||
return [
|
||||
{
|
||||
"Entry": entries[hit['corpus_id']][1],
|
||||
"Entry": entries[hit['corpus_id']]['raw'],
|
||||
"Score": f"{hit['cross-score']:.3f}"
|
||||
}
|
||||
for hit
|
||||
|
||||
@@ -38,7 +38,7 @@ def initialize_model(search_config: SymmetricSearchConfig):
|
||||
|
||||
def extract_entries(notesfile, verbose=0):
|
||||
"Load entries from compressed jsonl"
|
||||
return [f'{entry["Title"]}'
|
||||
return [{'raw': f'{entry["Title"]}', 'embed': f'{entry["Title"]}'}
|
||||
for entry
|
||||
in load_jsonl(notesfile, verbose=verbose)]
|
||||
|
||||
@@ -80,7 +80,7 @@ def query(raw_query, model: TextSearchModel, filters=[]):
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k)[0]
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
cross_inp = [[query, entries[hit['corpus_id']]] for hit in hits]
|
||||
cross_inp = [[query, entries[hit['corpus_id']]['embed']] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
@@ -102,20 +102,20 @@ def render_results(hits, entries, count=5, display_biencoder_results=False):
|
||||
print(f"Top-{count} Bi-Encoder Retrieval hits")
|
||||
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
||||
for hit in hits[0:count]:
|
||||
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]}")
|
||||
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['embed']}")
|
||||
|
||||
# Output of top hits from re-ranker
|
||||
print("\n-------------------------\n")
|
||||
print(f"Top-{count} Cross-Encoder Re-ranker hits")
|
||||
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
||||
for hit in hits[0:count]:
|
||||
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]}")
|
||||
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['embed']}")
|
||||
|
||||
|
||||
def collate_results(hits, entries, count=5):
|
||||
return [
|
||||
{
|
||||
"Entry": entries[hit['corpus_id']],
|
||||
"Entry": entries[hit['corpus_id']]['raw'],
|
||||
"Score": f"{hit['cross-score']:.3f}"
|
||||
}
|
||||
for hit
|
||||
|
||||
@@ -13,9 +13,9 @@ from src.search_filter import date_filter
|
||||
def test_date_filter():
|
||||
embeddings = torch.randn(3, 10)
|
||||
entries = [
|
||||
['', 'Entry with no date'],
|
||||
['', 'April Fools entry: 1984-04-01'],
|
||||
['', 'Entry with date:1984-04-02']]
|
||||
{'embed': '', 'raw': 'Entry with no date'},
|
||||
{'embed': '', 'raw': 'April Fools entry: 1984-04-01'},
|
||||
{'embed': '', 'raw': 'Entry with date:1984-04-02'}]
|
||||
|
||||
q_with_no_date_filter = 'head tail'
|
||||
ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_no_date_filter, entries.copy(), embeddings)
|
||||
|
||||
Reference in New Issue
Block a user