diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index 81febe1e..97ed2a8a 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -17,7 +17,7 @@ import dateparser as dtparse date_regex = r"dt([:><=]{1,2})\"(.*?)\"" -def date_filter(query, entries, embeddings): +def date_filter(query, entries, embeddings, entry_key='raw'): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query query_daterange = extract_date_range(query) @@ -34,7 +34,7 @@ def date_filter(query, entries, embeddings): entries_to_include = set() for id, entry in enumerate(entries): # Extract dates from entry - for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[1]): + for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]): # Convert date string in entry to unix timestamp try: date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index f913a820..61576bdf 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -5,7 +5,7 @@ import re import torch -def explicit_filter(raw_query, entries, embeddings): +def explicit_filter(raw_query, entries, embeddings, entry_key='raw'): # Separate natural query from explicit required, blocked words filters query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) @@ -19,7 +19,7 @@ def explicit_filter(raw_query, entries, embeddings): entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' entries_by_word_set = [set(word.lower() for word - in re.split(entry_splitter, entry[1]) + in re.split(entry_splitter, entry[entry_key]) if word != "") for entry in entries] diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index f8a750f2..e087b580 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -63,7 +63,7 @@ def extract_entries(notesfile, verbose=0): note_string = f'{note["Title"]}' \ f'\t{note["Tags"] if "Tags" in note else ""}' \ f'\n{note["Body"] if "Body" in note else ""}' - entries.append([note_string, note["Raw"]]) + entries.append({'embed': note_string, 'raw': note["Raw"]}) # Close File jsonl_file.close() @@ -83,7 +83,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d print(f"Loaded embeddings from {embeddings_file}") else: # Else compute the corpus_embeddings from scratch, which can take a while - corpus_embeddings = bi_encoder.encode([entry[0] for entry in entries], convert_to_tensor=True, show_progress_bar=True) + corpus_embeddings = bi_encoder.encode([entry['embed'] for entry in entries], convert_to_tensor=True, show_progress_bar=True) corpus_embeddings.to(device) corpus_embeddings = util.normalize_embeddings(corpus_embeddings) torch.save(corpus_embeddings, get_absolute_path(embeddings_file)) @@ -116,7 +116,7 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu'), fi hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] # Score all retrieved entries using the cross-encoder - cross_inp = [[query, entries[hit['corpus_id']][0]] for hit in hits] + cross_inp = [[query, entries[hit['corpus_id']]['embed']] for hit in hits] cross_scores = model.cross_encoder.predict(cross_inp) # Store cross-encoder scores in results dictionary for ranking @@ -138,20 +138,20 @@ def render_results(hits, entries, count=5, display_biencoder_results=False): print(f"Top-{count} Bi-Encoder Retrieval hits") hits = sorted(hits, key=lambda x: x['score'], reverse=True) for hit in hits[0:count]: - print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']][0]}") + print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['embed']}") # Output of top hits from re-ranker print("\n-------------------------\n") print(f"Top-{count} Cross-Encoder Re-ranker hits") hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) for hit in hits[0:count]: - print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']][0]}") + print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['embed']}") def collate_results(hits, entries, count=5): return [ { - "Entry": entries[hit['corpus_id']][1], + "Entry": entries[hit['corpus_id']]['raw'], "Score": f"{hit['cross-score']:.3f}" } for hit diff --git a/src/search_type/symmetric_ledger.py b/src/search_type/symmetric_ledger.py index 814813a1..b42369c5 100644 --- a/src/search_type/symmetric_ledger.py +++ b/src/search_type/symmetric_ledger.py @@ -38,7 +38,7 @@ def initialize_model(search_config: SymmetricSearchConfig): def extract_entries(notesfile, verbose=0): "Load entries from compressed jsonl" - return [f'{entry["Title"]}' + return [{'raw': f'{entry["Title"]}', 'embed': f'{entry["Title"]}'} for entry in load_jsonl(notesfile, verbose=verbose)] @@ -80,7 +80,7 @@ def query(raw_query, model: TextSearchModel, filters=[]): hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k)[0] # Score all retrieved entries using the cross-encoder - cross_inp = [[query, entries[hit['corpus_id']]] for hit in hits] + cross_inp = [[query, entries[hit['corpus_id']]['embed']] for hit in hits] cross_scores = model.cross_encoder.predict(cross_inp) # Store cross-encoder scores in results dictionary for ranking @@ -102,20 +102,20 @@ def render_results(hits, entries, count=5, display_biencoder_results=False): print(f"Top-{count} Bi-Encoder Retrieval hits") hits = sorted(hits, key=lambda x: x['score'], reverse=True) for hit in hits[0:count]: - print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]}") + print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['embed']}") # Output of top hits from re-ranker print("\n-------------------------\n") print(f"Top-{count} Cross-Encoder Re-ranker hits") hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) for hit in hits[0:count]: - print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]}") + print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['embed']}") def collate_results(hits, entries, count=5): return [ { - "Entry": entries[hit['corpus_id']], + "Entry": entries[hit['corpus_id']]['raw'], "Score": f"{hit['cross-score']:.3f}" } for hit diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 525f011e..44d052f0 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -13,9 +13,9 @@ from src.search_filter import date_filter def test_date_filter(): embeddings = torch.randn(3, 10) entries = [ - ['', 'Entry with no date'], - ['', 'April Fools entry: 1984-04-01'], - ['', 'Entry with date:1984-04-02']] + {'embed': '', 'raw': 'Entry with no date'}, + {'embed': '', 'raw': 'April Fools entry: 1984-04-01'}, + {'embed': '', 'raw': 'Entry with date:1984-04-02'}] q_with_no_date_filter = 'head tail' ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_no_date_filter, entries.copy(), embeddings)