Support Incremental Search in Khoj

# Details
## Improve Search API Latency
- Improve Search API Latency by ~50-100x to <100ms
- Trade-off speed for accuracy in default, fast path of /search API by not re-ranking results using cross-encoder
- Make re-ranking of results via cross-encoder configurable via new `?&r=<false|true>` query param to /search API
- Only deep-copy entries, embeddings to apply filters if query has any filter keywords

## Support Incremental Update via Khoj Emacs Frontend
- Use default, fast path to query /search API while user is typing
- Upgrade to cross-encoder re-ranked results once user goes idle (or ends search normally)

Closes #37
This commit is contained in:
Debanjum
2022-07-28 09:10:50 -07:00
committed by GitHub
9 changed files with 421 additions and 252 deletions

View File

@@ -1,9 +1,9 @@
;;; khoj.el --- Natural Search via Emacs ;;; khoj.el --- Natural, Incremental Search via Emacs
;; Copyright (C) 2021-2022 Debanjum Singh Solanky ;; Copyright (C) 2021-2022 Debanjum Singh Solanky
;; Author: Debanjum Singh Solanky <debanjum@gmail.com> ;; Author: Debanjum Singh Solanky <debanjum@gmail.com>
;; Version: 1.0 ;; Version: 2.0
;; Keywords: search, org-mode, outlines, markdown, image ;; Keywords: search, org-mode, outlines, markdown, image
;; URL: http://github.com/debanjum/khoj/interface/emacs ;; URL: http://github.com/debanjum/khoj/interface/emacs
@@ -26,26 +26,50 @@
;;; Commentary: ;;; Commentary:
;; This package provides natural language search on org-mode notes, ;; This package provides a natural, incremental search interface to your
;; markdown files, beancount transactions and images. ;; org-mode notes, markdown files, beancount transactions and images.
;; It is a wrapper that interfaces with transformer based ML models. ;; It is a wrapper that interfaces with transformer based ML models.
;; The models search capabilities are exposed via the Khoj HTTP API ;; The models search capabilities are exposed via the Khoj HTTP API.
;;; Code: ;;; Code:
(require 'url) (require 'url)
(require 'json) (require 'json)
(defcustom khoj--server-url "http://localhost:8000" (defcustom khoj--server-url "http://localhost:8000"
"Location of Khoj API server." "Location of Khoj API server."
:group 'khoj :group 'khoj
:type 'string) :type 'string)
(defcustom khoj--image-width 156 (defcustom khoj--image-width 156
"Width of rendered images returned by Khoj" "Width of rendered images returned by Khoj."
:group 'khoj :group 'khoj
:type 'integer) :type 'integer)
(defcustom khoj--rerank-after-idle-time 1.0
"Idle time (in seconds) to trigger cross-encoder to rerank incremental search results."
:group 'khoj
:type 'float)
(defcustom khoj--results-count 5
"Number of results to get from Khoj API for each query."
:group 'khoj
:type 'integer)
(defvar khoj--rerank-timer nil
"Idle timer to make cross-encoder re-rank incremental search results if user idle.")
(defvar khoj--minibuffer-window nil
"Minibuffer window being used by user to enter query.")
(defconst khoj--query-prompt "Khoj: "
"Query prompt shown to user in the minibuffer.")
(defvar khoj--search-type "org"
"The type of content to perform search on.")
(defun khoj--extract-entries-as-markdown (json-response query) (defun khoj--extract-entries-as-markdown (json-response query)
"Convert json response from API to markdown entries" "Convert json response from API to markdown entries"
;; remove leading (, ) or SPC from extracted entries string ;; remove leading (, ) or SPC from extracted entries string
@@ -118,25 +142,20 @@
((or (equal file-extension "markdown") (equal file-extension "md")) "markdown") ((or (equal file-extension "markdown") (equal file-extension "md")) "markdown")
(t "org")))) (t "org"))))
(defun khoj--construct-api-query (query search-type) (defun khoj--construct-api-query (query search-type &optional rerank)
(let ((encoded-query (url-hexify-string query))) (let ((rerank (or rerank "false"))
(format "%s/search?q=%s&t=%s" khoj--server-url encoded-query search-type))) (results-count (or khoj--results-count 5))
(encoded-query (url-hexify-string query)))
(format "%s/search?q=%s&t=%s&r=%s&n=%s" khoj--server-url encoded-query search-type rerank results-count)))
;;;###autoload (defun khoj--query-api-and-render-results (query search-type query-url buffer-name)
(defun khoj (query)
"Search your content naturally using the Khoj API"
(interactive "sQuery: ")
(let* ((default-type (khoj--buffer-name-to-search-type (buffer-name)))
(search-type (completing-read "Type: " '("org" "markdown" "ledger" "music" "image") nil t default-type))
(url (khoj--construct-api-query query search-type))
(buff (get-buffer-create (format "*Khoj (q:%s t:%s)*" query search-type))))
;; get json response from api ;; get json response from api
(with-current-buffer buff (with-current-buffer buffer-name
(let ((inhibit-read-only t)) (let ((inhibit-read-only t))
(erase-buffer) (erase-buffer)
(url-insert-file-contents url))) (url-insert-file-contents query-url)))
;; render json response into formatted entries ;; render json response into formatted entries
(with-current-buffer buff (with-current-buffer buffer-name
(let ((inhibit-read-only t) (let ((inhibit-read-only t)
(json-response (json-parse-buffer :object-type 'alist))) (json-response (json-parse-buffer :object-type 'alist)))
(erase-buffer) (erase-buffer)
@@ -154,8 +173,82 @@
((equal search-type "image") (progn (shr-render-region (point-min) (point-max)) ((equal search-type "image") (progn (shr-render-region (point-min) (point-max))
(goto-char (point-min)))) (goto-char (point-min))))
(t (fundamental-mode)))) (t (fundamental-mode))))
(read-only-mode t)) (read-only-mode t)))
(switch-to-buffer buff)))
;; Incremental Search on Khoj
(defun khoj--incremental-search (&optional rerank)
(let* ((rerank-str (cond (rerank "true") (t "false")))
(search-type khoj--search-type)
(buffer-name (get-buffer-create (format "*Khoj (t:%s)*" search-type)))
(query (minibuffer-contents-no-properties))
(query-url (khoj--construct-api-query query search-type rerank-str)))
;; Query khoj API only when user in khoj minibuffer.
;; Prevents querying during recursive edits or with contents of other buffers user may jump to
(when (and (active-minibuffer-window) (equal (current-buffer) khoj--minibuffer-window))
(progn
(when rerank
(message "[Khoj]: Rerank Results"))
(khoj--query-api-and-render-results
query
search-type
query-url
buffer-name)))))
(defun khoj--teardown-incremental-search ()
;; remove advice to rerank results on normal exit from minibuffer
(advice-remove 'exit-minibuffer #'khoj--minibuffer-exit-advice)
;; unset khoj minibuffer window
(setq khoj--minibuffer-window nil)
;; cancel rerank timer
(when (timerp khoj--rerank-timer)
(cancel-timer khoj--rerank-timer))
;; remove hooks for khoj incremental query and self
(remove-hook 'post-command-hook #'khoj--incremental-search)
(remove-hook 'minibuffer-exit-hook #'khoj--teardown-incremental-search))
(defun khoj--minibuffer-exit-advice (&rest _args)
(khoj--incremental-search t))
;;;###autoload
(defun khoj ()
"Natural, Incremental Search for your personal notes, transactions and music using Khoj"
(interactive)
(let* ((default-type (khoj--buffer-name-to-search-type (buffer-name)))
(search-type (completing-read "Type: " '("org" "markdown" "ledger" "music") nil t default-type))
(buffer-name (get-buffer-create (format "*Khoj (t:%s)*" search-type))))
(setq khoj--search-type search-type)
;; setup rerank to improve results once user idle for KHOJ--RERANK-AFTER-IDLE-TIME seconds
(setq khoj--rerank-timer (run-with-idle-timer khoj--rerank-after-idle-time t 'khoj--incremental-search t))
;; switch to khoj results buffer
(switch-to-buffer buffer-name)
;; open and setup minibuffer for incremental search
(minibuffer-with-setup-hook
(lambda ()
;; set current (mini-)buffer entered as khoj minibuffer
;; used to query khoj API only when user in khoj minibuffer
(setq khoj--minibuffer-window (current-buffer))
;; rerank results on normal exit from minibuffer
(advice-add 'exit-minibuffer :before #'khoj--minibuffer-exit-advice)
(add-hook 'post-command-hook #'khoj--incremental-search) ; do khoj incremental search after every user action
(add-hook 'minibuffer-exit-hook #'khoj--teardown-incremental-search)) ; teardown khoj incremental search on minibuffer exit
(read-string khoj--query-prompt))))
;;;###autoload
(defun khoj-simple (query)
"Natural Search for QUERY in your personal notes, transactions, music and images using Khoj"
(interactive "sQuery: ")
(let* ((rerank "true")
(default-type (khoj--buffer-name-to-search-type (buffer-name)))
(search-type (completing-read "Type: " '("org" "markdown" "ledger" "music" "image") nil t default-type))
(query-url (khoj--construct-api-query query search-type rerank))
(buffer-name (get-buffer-create (format "*Khoj (q:%s t:%s)*" query search-type))))
(khoj--query-api-and-render-results
query
search-type
query-url
buffer-name)
(switch-to-buffer buffer-name)))
(provide 'khoj) (provide 'khoj)

View File

@@ -1,5 +1,6 @@
# Standard Packages # Standard Packages
import sys, json, yaml, os import sys, json, yaml, os
import time
from typing import Optional from typing import Optional
# External Packages # External Packages
@@ -20,8 +21,8 @@ from src.utils.cli import cli
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from src.utils.rawconfig import FullConfig from src.utils.rawconfig import FullConfig
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from src.search_filter.explicit_filter import explicit_filter from src.search_filter.explicit_filter import ExplicitFilter
from src.search_filter.date_filter import date_filter from src.search_filter.date_filter import DateFilter
# Application Global State # Application Global State
config = FullConfig() config = FullConfig()
@@ -58,7 +59,7 @@ async def config_data(updated_config: FullConfig):
return config return config
@app.get('/search') @app.get('/search')
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
if q is None or q == '': if q is None or q == '':
print(f'No query param (q) passed in API call to initiate search') print(f'No query param (q) passed in API call to initiate search')
return {} return {}
@@ -66,50 +67,74 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
user_query = q user_query = q
results_count = n results_count = n
results = {}
if (t == SearchType.Org or t == None) and model.orgmode_search: if (t == SearchType.Org or t == None) and model.orgmode_search:
# query org-mode notes # query org-mode notes
hits, entries = text_search.query(user_query, model.orgmode_search, device=device, filters=[explicit_filter, date_filter]) query_start = time.time()
hits, entries = text_search.query(user_query, model.orgmode_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
query_end = time.time()
# collate and return results # collate and return results
return text_search.collate_results(hits, entries, results_count) collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Music or t == None) and model.music_search: if (t == SearchType.Music or t == None) and model.music_search:
# query music library # query music library
hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[explicit_filter, date_filter]) query_start = time.time()
hits, entries = text_search.query(user_query, model.music_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
query_end = time.time()
# collate and return results # collate and return results
return text_search.collate_results(hits, entries, results_count) collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Markdown or t == None) and model.orgmode_search: if (t == SearchType.Markdown or t == None) and model.orgmode_search:
# query markdown files # query markdown files
hits, entries = text_search.query(user_query, model.markdown_search, device=device, filters=[explicit_filter, date_filter]) query_start = time.time()
hits, entries = text_search.query(user_query, model.markdown_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
query_end = time.time()
# collate and return results # collate and return results
return text_search.collate_results(hits, entries, results_count) collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Ledger or t == None) and model.ledger_search: if (t == SearchType.Ledger or t == None) and model.ledger_search:
# query transactions # query transactions
hits, entries = text_search.query(user_query, model.ledger_search, filters=[explicit_filter, date_filter]) query_start = time.time()
hits, entries = text_search.query(user_query, model.ledger_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
query_end = time.time()
# collate and return results # collate and return results
return text_search.collate_results(hits, entries, results_count) collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Image or t == None) and model.image_search: if (t == SearchType.Image or t == None) and model.image_search:
# query images # query images
query_start = time.time()
hits = image_search.query(user_query, results_count, model.image_search) hits = image_search.query(user_query, results_count, model.image_search)
output_directory = f'{os.getcwd()}/{web_directory}' output_directory = f'{os.getcwd()}/{web_directory}'
query_end = time.time()
# collate and return results # collate and return results
return image_search.collate_results( collate_start = time.time()
results = image_search.collate_results(
hits, hits,
image_names=model.image_search.image_names, image_names=model.image_search.image_names,
output_directory=output_directory, output_directory=output_directory,
static_files_url='/static', static_files_url='/static',
count=results_count) count=results_count)
collate_end = time.time()
else: if verbose > 1:
return {} print(f"Query took {query_end - query_start:.3f} seconds")
print(f"Collating results took {collate_end - collate_start:.3f} seconds")
return results
@app.get('/reload') @app.get('/reload')

View File

@@ -82,7 +82,7 @@ def convert_org_entries_to_jsonl(entries, verbose=0):
continue continue
entry_dict["compiled"] = f'{entry.Heading()}.' entry_dict["compiled"] = f'{entry.Heading()}.'
if verbose > 1: if verbose > 2:
print(f"Title: {entry.Heading()}") print(f"Title: {entry.Heading()}")
if entry.Tags(): if entry.Tags():

View File

@@ -9,6 +9,7 @@ import torch
import dateparser as dtparse import dateparser as dtparse
class DateFilter:
# Date Range Filter Regexes # Date Range Filter Regexes
# Example filter queries: # Example filter queries:
# - dt>="yesterday" dt<"tomorrow" # - dt>="yesterday" dt<"tomorrow"
@@ -16,18 +17,22 @@ import dateparser as dtparse
# - dt:"2 years ago" # - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})\"(.*?)\"" date_regex = r"dt([:><=]{1,2})\"(.*?)\""
def can_filter(self, raw_query):
"Check if query contains date filters"
return self.extract_date_range(raw_query) is not None
def date_filter(query, entries, embeddings, entry_key='raw'):
def filter(self, query, entries, embeddings, entry_key='raw'):
"Find entries containing any dates that fall within date range specified in query" "Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query # extract date range specified in date filter of query
query_daterange = extract_date_range(query) query_daterange = self.extract_date_range(query)
# if no date in query, return all entries # if no date in query, return all entries
if query_daterange is None: if query_daterange is None:
return query, entries, embeddings return query, entries, embeddings
# remove date range filter from query # remove date range filter from query
query = re.sub(f'\s+{date_regex}', ' ', query) query = re.sub(f'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
# find entries containing any dates that fall with date range specified in query # find entries containing any dates that fall with date range specified in query
@@ -54,9 +59,9 @@ def date_filter(query, entries, embeddings, entry_key='raw'):
return query, entries, embeddings return query, entries, embeddings
def extract_date_range(query): def extract_date_range(self, query):
# find date range filter in query # find date range filter in query
date_range_matches = re.findall(date_regex, query) date_range_matches = re.findall(self.date_regex, query)
if len(date_range_matches) == 0: if len(date_range_matches) == 0:
return None return None
@@ -65,8 +70,8 @@ def extract_date_range(query):
# e.g today maps to (start_of_day, start_of_tomorrow) # e.g today maps to (start_of_day, start_of_tomorrow)
date_ranges_from_filter = [] date_ranges_from_filter = []
for (cmp, date_str) in date_range_matches: for (cmp, date_str) in date_range_matches:
if parse(date_str): if self.parse(date_str):
dt_start, dt_end = parse(date_str) dt_start, dt_end = self.parse(date_str)
date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]] date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]]
# Combine dates with their comparators to form date range intervals # Combine dates with their comparators to form date range intervals
@@ -103,7 +108,7 @@ def extract_date_range(query):
return effective_date_range return effective_date_range
def parse(date_str, relative_base=None): def parse(self, date_str, relative_base=None):
"Parse date string passed in date filter of query to datetime object" "Parse date string passed in date filter of query to datetime object"
# clean date string to handle future date parsing by date parser # clean date string to handle future date parsing by date parser
future_strings = ['later', 'from now', 'from today'] future_strings = ['later', 'from now', 'from today']
@@ -122,10 +127,10 @@ def parse(date_str, relative_base=None):
if parsed_date is None: if parsed_date is None:
return None return None
return date_to_daterange(parsed_date, date_str) return self.date_to_daterange(parsed_date, date_str)
def date_to_daterange(parsed_date, date_str): def date_to_daterange(self, parsed_date, date_str):
"Convert parsed date to date ranges at natural granularity (day, week, month or year)" "Convert parsed date to date ranges at natural granularity (day, week, month or year)"
start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0) start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)

View File

@@ -5,7 +5,18 @@ import re
import torch import torch
def explicit_filter(raw_query, entries, embeddings, entry_key='raw'): class ExplicitFilter:
def can_filter(self, raw_query):
"Check if query contains explicit filters"
# Extract explicit query portion with required, blocked words to filter from natural query
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
return len(required_words) != 0 or len(blocked_words) != 0
def filter(self, raw_query, entries, embeddings, entry_key='raw'):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from explicit required, blocked words filters # Separate natural query from explicit required, blocked words filters
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])

View File

@@ -2,6 +2,7 @@
import argparse import argparse
import pathlib import pathlib
from copy import deepcopy from copy import deepcopy
import time
# External Packages # External Packages
import torch import torch
@@ -19,7 +20,7 @@ def initialize_model(search_config: TextSearchConfig):
torch.set_num_threads(4) torch.set_num_threads(4)
# Number of entries we want to retrieve with the bi-encoder # Number of entries we want to retrieve with the bi-encoder
top_k = 30 top_k = 15
# The bi-encoder encodes all entries to use for semantic search # The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model( bi_encoder = load_model(
@@ -62,38 +63,71 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
return corpus_embeddings return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = []): def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0):
"Search for entries that answer the query" "Search for entries that answer the query"
# Copy original embeddings, entries to filter them for query
query = raw_query query = raw_query
# Use deep copy of original embeddings, entries to filter if query contains filters
start = time.time()
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
if filters_in_query:
corpus_embeddings = deepcopy(model.corpus_embeddings) corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries) entries = deepcopy(model.entries)
else:
corpus_embeddings = model.corpus_embeddings
entries = model.entries
end = time.time()
if verbose > 1:
print(f"Copy Time: {end - start:.3f} seconds")
# Filter query, entries and embeddings before semantic search # Filter query, entries and embeddings before semantic search
for filter in filters: start = time.time()
query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings) for filter in filters_in_query:
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
end = time.time()
if verbose > 1:
print(f"Filter Time: {end - start:.3f} seconds")
if entries is None or len(entries) == 0: if entries is None or len(entries) == 0:
return [], [] return [], []
# Encode the query using the bi-encoder # Encode the query using the bi-encoder
start = time.time()
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True) question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
question_embedding.to(device) question_embedding.to(device)
question_embedding = util.normalize_embeddings(question_embedding) question_embedding = util.normalize_embeddings(question_embedding)
end = time.time()
if verbose > 1:
print(f"Query Encode Time: {end - start:.3f} seconds")
# Find relevant entries for the query # Find relevant entries for the query
start = time.time()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
end = time.time()
if verbose > 1:
print(f"Search Time: {end - start:.3f} seconds")
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
if rank_results:
start = time.time()
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits] cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp) cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time()
if verbose > 1:
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
# Store cross-encoder scores in results dictionary for ranking # Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)): for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx] hits[idx]['cross-score'] = cross_scores[idx]
# Order results by cross-encoder score followed by bi-encoder score # Order results by cross-encoder score followed by bi-encoder score
start = time.time()
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time()
if verbose > 1:
print(f"Rank Time: {end - start:.3f} seconds")
return hits, entries return hits, entries
@@ -120,7 +154,7 @@ def collate_results(hits, entries, count=5):
return [ return [
{ {
"entry": entries[hit['corpus_id']]['raw'], "entry": entries[hit['corpus_id']]['raw'],
"score": f"{hit['cross-score']:.3f}" "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
} }
for hit for hit
in hits[0:count]] in hits[0:count]]

View File

@@ -29,7 +29,8 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
# Act # Act
hits, entries = text_search.query( hits, entries = text_search.query(
query, query,
model = model.notes_search) model = model.notes_search,
rank_results=True)
results = text_search.collate_results( results = text_search.collate_results(
hits, hits,

View File

@@ -119,7 +119,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
user_query = "How to git install application?" user_query = "How to git install application?"
# Act # Act
response = client.get(f"/search?q={user_query}&n=1&t=org") response = client.get(f"/search?q={user_query}&n=1&t=org&r=true")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -7,7 +7,7 @@ from math import inf
import torch import torch
# Application Packages # Application Packages
from src.search_filter import date_filter from src.search_filter.date_filter import DateFilter
def test_date_filter(): def test_date_filter():
@@ -18,99 +18,99 @@ def test_date_filter():
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}] {'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
q_with_no_date_filter = 'head tail' q_with_no_date_filter = 'head tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_no_date_filter, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_no_date_filter, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 3 assert len(ret_emb) == 3
assert ret_entries == entries assert ret_entries == entries
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 0 assert len(ret_emb) == 0
assert ret_entries == [] assert ret_entries == []
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[2]] assert ret_entries == [entries[2]]
assert len(ret_emb) == 1 assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[1]] assert ret_entries == [entries[1]]
assert len(ret_emb) == 1 assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[2]] assert ret_entries == [entries[2]]
assert len(ret_emb) == 1 assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[1], entries[2]] assert ret_entries == [entries[1], entries[2]]
assert len(ret_emb) == 2 assert len(ret_emb) == 2
def test_extract_date_range(): def test_extract_date_range():
assert date_filter.extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()]
assert date_filter.extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
assert date_filter.extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf] assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
assert date_filter.extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()]
# Unparseable date filter specified in query # Unparseable date filter specified in query
assert date_filter.extract_date_range('head dt:"Summer of 69" tail') == None assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None
# No date filter specified in query # No date filter specified in query
assert date_filter.extract_date_range('head tail') == None assert DateFilter().extract_date_range('head tail') == None
# Non intersecting date ranges # Non intersecting date ranges
assert date_filter.extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None
def test_parse(): def test_parse():
test_now = datetime(1984, 4, 1, 21, 21, 21) test_now = datetime(1984, 4, 1, 21, 21, 21)
# day variations # day variations
assert date_filter.parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0)) assert DateFilter().parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0))
assert date_filter.parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0)) assert DateFilter().parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0))
assert date_filter.parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0)) assert DateFilter().parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0))
assert date_filter.parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0)) assert DateFilter().parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0))
# week variations # week variations
assert date_filter.parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0)) assert DateFilter().parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0))
assert date_filter.parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0)) assert DateFilter().parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0))
# month variations # month variations
assert date_filter.parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0)) assert DateFilter().parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0))
assert date_filter.parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0)) assert DateFilter().parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0))
# year variations # year variations
assert date_filter.parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0)) assert DateFilter().parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0))
assert date_filter.parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0)) assert DateFilter().parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0))
# specific month/date variation # specific month/date variation
assert date_filter.parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0)) assert DateFilter().parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
assert date_filter.parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0)) assert DateFilter().parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
def test_date_filter_regex(): def test_date_filter_regex():
dtrange_match = re.findall(date_filter.date_regex, 'multi word head dt>"today" dt:"1984-01-01"') dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>"today" dt:"1984-01-01"')
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')] assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
dtrange_match = re.findall(date_filter.date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail') dtrange_match = re.findall(DateFilter().date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail')
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')] assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
dtrange_match = re.findall(date_filter.date_regex, 'multi word head dt>="today" dt="1984-01-01"') dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>="today" dt="1984-01-01"')
assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')] assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')]
dtrange_match = re.findall(date_filter.date_regex, 'dt<"multi word date" multi word tail') dtrange_match = re.findall(DateFilter().date_regex, 'dt<"multi word date" multi word tail')
assert dtrange_match == [('<', 'multi word date')] assert dtrange_match == [('<', 'multi word date')]
dtrange_match = re.findall(date_filter.date_regex, 'head dt<="multi word date"') dtrange_match = re.findall(DateFilter().date_regex, 'head dt<="multi word date"')
assert dtrange_match == [('<=', 'multi word date')] assert dtrange_match == [('<=', 'multi word date')]
dtrange_match = re.findall(date_filter.date_regex, 'head tail') dtrange_match = re.findall(DateFilter().date_regex, 'head tail')
assert dtrange_match == [] assert dtrange_match == []