diff --git a/asymmetric.py b/asymmetric.py index 957f7aa8..449b137f 100644 --- a/asymmetric.py +++ b/asymmetric.py @@ -6,6 +6,7 @@ import time import gzip import os import sys +import re import torch import argparse import pathlib @@ -56,8 +57,13 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, verbose=False): return corpus_embeddings -def query_notes(query, corpus_embeddings, entries, bi_encoder, cross_encoder, topk=100): +def query_notes(raw_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k=100): "Search all notes for entries that answer the query" + # Separate natural query from explicit required, blocked words filters + query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) + required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) + blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + # Encode the query using the bi-encoder question_embedding = bi_encoder.encode(query, convert_to_tensor=True) @@ -65,6 +71,11 @@ def query_notes(query, corpus_embeddings, entries, bi_encoder, cross_encoder, to hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k) hits = hits[0] # Get the hits for the first query + # Filter results using explicit filters + hits = explicit_filter(hits, entries, required_words, blocked_words) + if hits is None or len(hits) == 0: + return hits + # Score all retrieved entries using the cross-encoder cross_inp = [[query, entries[hit['corpus_id']]] for hit in hits] cross_scores = cross_encoder.predict(cross_inp) @@ -76,6 +87,28 @@ def query_notes(query, corpus_embeddings, entries, bi_encoder, cross_encoder, to # Order results by cross encoder score followed by biencoder score hits.sort(key=lambda x: x['score'], reverse=True) # sort by biencoder score hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross encoder score + + return hits + + +def explicit_filter(hits, entries, required_words, blocked_words): + hits_by_word_set = [(set(word.lower() + for word + in re.split( + ',|\.| |\]|\[\(|\)|\{|\}', + entries[hit['corpus_id']]) + if word != ""), + hit) + for hit in hits] + + if len(required_words) == 0 and len(blocked_words) == 0: + return hits + if len(required_words) > 0: + return [hit for (words_in_entry, hit) in hits_by_word_set + if required_words.intersection(words_in_entry) and not blocked_words.intersection(words_in_entry)] + if len(blocked_words) > 0: + return [hit for (words_in_entry, hit) in hits_by_word_set + if not blocked_words.intersection(words_in_entry)] return hits