Support Incremental Search in Khoj

# Details
## Improve Search API Latency
- Improve Search API Latency by ~50-100x to <100ms
- Trade-off speed for accuracy in default, fast path of /search API by not re-ranking results using cross-encoder
- Make re-ranking of results via cross-encoder configurable via new `?&r=<false|true>` query param to /search API
- Only deep-copy entries, embeddings to apply filters if query has any filter keywords

## Support Incremental Update via Khoj Emacs Frontend
- Use default, fast path to query /search API while user is typing
- Upgrade to cross-encoder re-ranked results once user goes idle (or ends search normally)

Closes #37
This commit is contained in:
Debanjum
2022-07-28 09:10:50 -07:00
committed by GitHub
9 changed files with 421 additions and 252 deletions

View File

@@ -1,9 +1,9 @@
;;; khoj.el --- Natural Search via Emacs ;;; khoj.el --- Natural, Incremental Search via Emacs
;; Copyright (C) 2021-2022 Debanjum Singh Solanky ;; Copyright (C) 2021-2022 Debanjum Singh Solanky
;; Author: Debanjum Singh Solanky <debanjum@gmail.com> ;; Author: Debanjum Singh Solanky <debanjum@gmail.com>
;; Version: 1.0 ;; Version: 2.0
;; Keywords: search, org-mode, outlines, markdown, image ;; Keywords: search, org-mode, outlines, markdown, image
;; URL: http://github.com/debanjum/khoj/interface/emacs ;; URL: http://github.com/debanjum/khoj/interface/emacs
@@ -26,26 +26,50 @@
;;; Commentary: ;;; Commentary:
;; This package provides natural language search on org-mode notes, ;; This package provides a natural, incremental search interface to your
;; markdown files, beancount transactions and images. ;; org-mode notes, markdown files, beancount transactions and images.
;; It is a wrapper that interfaces with transformer based ML models. ;; It is a wrapper that interfaces with transformer based ML models.
;; The models search capabilities are exposed via the Khoj HTTP API ;; The models search capabilities are exposed via the Khoj HTTP API.
;;; Code: ;;; Code:
(require 'url) (require 'url)
(require 'json) (require 'json)
(defcustom khoj--server-url "http://localhost:8000" (defcustom khoj--server-url "http://localhost:8000"
"Location of Khoj API server." "Location of Khoj API server."
:group 'khoj :group 'khoj
:type 'string) :type 'string)
(defcustom khoj--image-width 156 (defcustom khoj--image-width 156
"Width of rendered images returned by Khoj" "Width of rendered images returned by Khoj."
:group 'khoj :group 'khoj
:type 'integer) :type 'integer)
(defcustom khoj--rerank-after-idle-time 1.0
"Idle time (in seconds) to trigger cross-encoder to rerank incremental search results."
:group 'khoj
:type 'float)
(defcustom khoj--results-count 5
"Number of results to get from Khoj API for each query."
:group 'khoj
:type 'integer)
(defvar khoj--rerank-timer nil
"Idle timer to make cross-encoder re-rank incremental search results if user idle.")
(defvar khoj--minibuffer-window nil
"Minibuffer window being used by user to enter query.")
(defconst khoj--query-prompt "Khoj: "
"Query prompt shown to user in the minibuffer.")
(defvar khoj--search-type "org"
"The type of content to perform search on.")
(defun khoj--extract-entries-as-markdown (json-response query) (defun khoj--extract-entries-as-markdown (json-response query)
"Convert json response from API to markdown entries" "Convert json response from API to markdown entries"
;; remove leading (, ) or SPC from extracted entries string ;; remove leading (, ) or SPC from extracted entries string
@@ -118,34 +142,29 @@
((or (equal file-extension "markdown") (equal file-extension "md")) "markdown") ((or (equal file-extension "markdown") (equal file-extension "md")) "markdown")
(t "org")))) (t "org"))))
(defun khoj--construct-api-query (query search-type) (defun khoj--construct-api-query (query search-type &optional rerank)
(let ((encoded-query (url-hexify-string query))) (let ((rerank (or rerank "false"))
(format "%s/search?q=%s&t=%s" khoj--server-url encoded-query search-type))) (results-count (or khoj--results-count 5))
(encoded-query (url-hexify-string query)))
(format "%s/search?q=%s&t=%s&r=%s&n=%s" khoj--server-url encoded-query search-type rerank results-count)))
;;;###autoload (defun khoj--query-api-and-render-results (query search-type query-url buffer-name)
(defun khoj (query) ;; get json response from api
"Search your content naturally using the Khoj API" (with-current-buffer buffer-name
(interactive "sQuery: ") (let ((inhibit-read-only t))
(let* ((default-type (khoj--buffer-name-to-search-type (buffer-name))) (erase-buffer)
(search-type (completing-read "Type: " '("org" "markdown" "ledger" "music" "image") nil t default-type)) (url-insert-file-contents query-url)))
(url (khoj--construct-api-query query search-type)) ;; render json response into formatted entries
(buff (get-buffer-create (format "*Khoj (q:%s t:%s)*" query search-type)))) (with-current-buffer buffer-name
;; get json response from api (let ((inhibit-read-only t)
(with-current-buffer buff (json-response (json-parse-buffer :object-type 'alist)))
(let ((inhibit-read-only t)) (erase-buffer)
(erase-buffer) (insert
(url-insert-file-contents url))) (cond ((or (equal search-type "org") (equal search-type "music")) (khoj--extract-entries-as-org json-response query))
;; render json response into formatted entries ((equal search-type "markdown") (khoj--extract-entries-as-markdown json-response query))
(with-current-buffer buff ((equal search-type "ledger") (khoj--extract-entries-as-ledger json-response query))
(let ((inhibit-read-only t) ((equal search-type "image") (khoj--extract-entries-as-images json-response query))
(json-response (json-parse-buffer :object-type 'alist))) (t (format "%s" json-response))))
(erase-buffer)
(insert
(cond ((or (equal search-type "org") (equal search-type "music")) (khoj--extract-entries-as-org json-response query))
((equal search-type "markdown") (khoj--extract-entries-as-markdown json-response query))
((equal search-type "ledger") (khoj--extract-entries-as-ledger json-response query))
((equal search-type "image") (khoj--extract-entries-as-images json-response query))
(t (format "%s" json-response))))
(cond ((equal search-type "org") (org-mode)) (cond ((equal search-type "org") (org-mode))
((equal search-type "markdown") (markdown-mode)) ((equal search-type "markdown") (markdown-mode))
((equal search-type "ledger") (beancount-mode)) ((equal search-type "ledger") (beancount-mode))
@@ -154,8 +173,82 @@
((equal search-type "image") (progn (shr-render-region (point-min) (point-max)) ((equal search-type "image") (progn (shr-render-region (point-min) (point-max))
(goto-char (point-min)))) (goto-char (point-min))))
(t (fundamental-mode)))) (t (fundamental-mode))))
(read-only-mode t)) (read-only-mode t)))
(switch-to-buffer buff)))
;; Incremental Search on Khoj
(defun khoj--incremental-search (&optional rerank)
(let* ((rerank-str (cond (rerank "true") (t "false")))
(search-type khoj--search-type)
(buffer-name (get-buffer-create (format "*Khoj (t:%s)*" search-type)))
(query (minibuffer-contents-no-properties))
(query-url (khoj--construct-api-query query search-type rerank-str)))
;; Query khoj API only when user in khoj minibuffer.
;; Prevents querying during recursive edits or with contents of other buffers user may jump to
(when (and (active-minibuffer-window) (equal (current-buffer) khoj--minibuffer-window))
(progn
(when rerank
(message "[Khoj]: Rerank Results"))
(khoj--query-api-and-render-results
query
search-type
query-url
buffer-name)))))
(defun khoj--teardown-incremental-search ()
;; remove advice to rerank results on normal exit from minibuffer
(advice-remove 'exit-minibuffer #'khoj--minibuffer-exit-advice)
;; unset khoj minibuffer window
(setq khoj--minibuffer-window nil)
;; cancel rerank timer
(when (timerp khoj--rerank-timer)
(cancel-timer khoj--rerank-timer))
;; remove hooks for khoj incremental query and self
(remove-hook 'post-command-hook #'khoj--incremental-search)
(remove-hook 'minibuffer-exit-hook #'khoj--teardown-incremental-search))
(defun khoj--minibuffer-exit-advice (&rest _args)
(khoj--incremental-search t))
;;;###autoload
(defun khoj ()
"Natural, Incremental Search for your personal notes, transactions and music using Khoj"
(interactive)
(let* ((default-type (khoj--buffer-name-to-search-type (buffer-name)))
(search-type (completing-read "Type: " '("org" "markdown" "ledger" "music") nil t default-type))
(buffer-name (get-buffer-create (format "*Khoj (t:%s)*" search-type))))
(setq khoj--search-type search-type)
;; setup rerank to improve results once user idle for KHOJ--RERANK-AFTER-IDLE-TIME seconds
(setq khoj--rerank-timer (run-with-idle-timer khoj--rerank-after-idle-time t 'khoj--incremental-search t))
;; switch to khoj results buffer
(switch-to-buffer buffer-name)
;; open and setup minibuffer for incremental search
(minibuffer-with-setup-hook
(lambda ()
;; set current (mini-)buffer entered as khoj minibuffer
;; used to query khoj API only when user in khoj minibuffer
(setq khoj--minibuffer-window (current-buffer))
;; rerank results on normal exit from minibuffer
(advice-add 'exit-minibuffer :before #'khoj--minibuffer-exit-advice)
(add-hook 'post-command-hook #'khoj--incremental-search) ; do khoj incremental search after every user action
(add-hook 'minibuffer-exit-hook #'khoj--teardown-incremental-search)) ; teardown khoj incremental search on minibuffer exit
(read-string khoj--query-prompt))))
;;;###autoload
(defun khoj-simple (query)
"Natural Search for QUERY in your personal notes, transactions, music and images using Khoj"
(interactive "sQuery: ")
(let* ((rerank "true")
(default-type (khoj--buffer-name-to-search-type (buffer-name)))
(search-type (completing-read "Type: " '("org" "markdown" "ledger" "music" "image") nil t default-type))
(query-url (khoj--construct-api-query query search-type rerank))
(buffer-name (get-buffer-create (format "*Khoj (q:%s t:%s)*" query search-type))))
(khoj--query-api-and-render-results
query
search-type
query-url
buffer-name)
(switch-to-buffer buffer-name)))
(provide 'khoj) (provide 'khoj)

View File

@@ -1,5 +1,6 @@
# Standard Packages # Standard Packages
import sys, json, yaml, os import sys, json, yaml, os
import time
from typing import Optional from typing import Optional
# External Packages # External Packages
@@ -20,8 +21,8 @@ from src.utils.cli import cli
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from src.utils.rawconfig import FullConfig from src.utils.rawconfig import FullConfig
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from src.search_filter.explicit_filter import explicit_filter from src.search_filter.explicit_filter import ExplicitFilter
from src.search_filter.date_filter import date_filter from src.search_filter.date_filter import DateFilter
# Application Global State # Application Global State
config = FullConfig() config = FullConfig()
@@ -58,7 +59,7 @@ async def config_data(updated_config: FullConfig):
return config return config
@app.get('/search') @app.get('/search')
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
if q is None or q == '': if q is None or q == '':
print(f'No query param (q) passed in API call to initiate search') print(f'No query param (q) passed in API call to initiate search')
return {} return {}
@@ -66,50 +67,74 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
user_query = q user_query = q
results_count = n results_count = n
results = {}
if (t == SearchType.Org or t == None) and model.orgmode_search: if (t == SearchType.Org or t == None) and model.orgmode_search:
# query org-mode notes # query org-mode notes
hits, entries = text_search.query(user_query, model.orgmode_search, device=device, filters=[explicit_filter, date_filter]) query_start = time.time()
hits, entries = text_search.query(user_query, model.orgmode_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
query_end = time.time()
# collate and return results # collate and return results
return text_search.collate_results(hits, entries, results_count) collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Music or t == None) and model.music_search: if (t == SearchType.Music or t == None) and model.music_search:
# query music library # query music library
hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[explicit_filter, date_filter]) query_start = time.time()
hits, entries = text_search.query(user_query, model.music_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
query_end = time.time()
# collate and return results # collate and return results
return text_search.collate_results(hits, entries, results_count) collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Markdown or t == None) and model.orgmode_search: if (t == SearchType.Markdown or t == None) and model.orgmode_search:
# query markdown files # query markdown files
hits, entries = text_search.query(user_query, model.markdown_search, device=device, filters=[explicit_filter, date_filter]) query_start = time.time()
hits, entries = text_search.query(user_query, model.markdown_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
query_end = time.time()
# collate and return results # collate and return results
return text_search.collate_results(hits, entries, results_count) collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Ledger or t == None) and model.ledger_search: if (t == SearchType.Ledger or t == None) and model.ledger_search:
# query transactions # query transactions
hits, entries = text_search.query(user_query, model.ledger_search, filters=[explicit_filter, date_filter]) query_start = time.time()
hits, entries = text_search.query(user_query, model.ledger_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
query_end = time.time()
# collate and return results # collate and return results
return text_search.collate_results(hits, entries, results_count) collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Image or t == None) and model.image_search: if (t == SearchType.Image or t == None) and model.image_search:
# query images # query images
query_start = time.time()
hits = image_search.query(user_query, results_count, model.image_search) hits = image_search.query(user_query, results_count, model.image_search)
output_directory = f'{os.getcwd()}/{web_directory}' output_directory = f'{os.getcwd()}/{web_directory}'
query_end = time.time()
# collate and return results # collate and return results
return image_search.collate_results( collate_start = time.time()
results = image_search.collate_results(
hits, hits,
image_names=model.image_search.image_names, image_names=model.image_search.image_names,
output_directory=output_directory, output_directory=output_directory,
static_files_url='/static', static_files_url='/static',
count=results_count) count=results_count)
collate_end = time.time()
else: if verbose > 1:
return {} print(f"Query took {query_end - query_start:.3f} seconds")
print(f"Collating results took {collate_end - collate_start:.3f} seconds")
return results
@app.get('/reload') @app.get('/reload')

View File

@@ -82,7 +82,7 @@ def convert_org_entries_to_jsonl(entries, verbose=0):
continue continue
entry_dict["compiled"] = f'{entry.Heading()}.' entry_dict["compiled"] = f'{entry.Heading()}.'
if verbose > 1: if verbose > 2:
print(f"Title: {entry.Heading()}") print(f"Title: {entry.Heading()}")
if entry.Tags(): if entry.Tags():

View File

@@ -9,138 +9,143 @@ import torch
import dateparser as dtparse import dateparser as dtparse
# Date Range Filter Regexes class DateFilter:
# Example filter queries: # Date Range Filter Regexes
# - dt>="yesterday" dt<"tomorrow" # Example filter queries:
# - dt>="last week" # - dt>="yesterday" dt<"tomorrow"
# - dt:"2 years ago" # - dt>="last week"
date_regex = r"dt([:><=]{1,2})\"(.*?)\"" # - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
def can_filter(self, raw_query):
"Check if query contains date filters"
return self.extract_date_range(raw_query) is not None
def date_filter(query, entries, embeddings, entry_key='raw'): def filter(self, query, entries, embeddings, entry_key='raw'):
"Find entries containing any dates that fall within date range specified in query" "Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query # extract date range specified in date filter of query
query_daterange = extract_date_range(query) query_daterange = self.extract_date_range(query)
# if no date in query, return all entries
if query_daterange is None:
return query, entries, embeddings
# remove date range filter from query
query = re.sub(f'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
# find entries containing any dates that fall with date range specified in query
entries_to_include = set()
for id, entry in enumerate(entries):
# Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
# Convert date string in entry to unix timestamp
try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
except ValueError:
continue
# Check if date in entry is within date range specified in query
if query_daterange[0] <= date_in_entry < query_daterange[1]:
entries_to_include.add(id)
break
# delete entries (and their embeddings) marked for exclusion
entries_to_exclude = set(range(len(entries))) - entries_to_include
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
# if no date in query, return all entries
if query_daterange is None:
return query, entries, embeddings return query, entries, embeddings
# remove date range filter from query
query = re.sub(f'\s+{date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
# find entries containing any dates that fall with date range specified in query def extract_date_range(self, query):
entries_to_include = set() # find date range filter in query
for id, entry in enumerate(entries): date_range_matches = re.findall(self.date_regex, query)
# Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
# Convert date string in entry to unix timestamp
try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
except ValueError:
continue
# Check if date in entry is within date range specified in query
if query_daterange[0] <= date_in_entry < query_daterange[1]:
entries_to_include.add(id)
break
# delete entries (and their embeddings) marked for exclusion if len(date_range_matches) == 0:
entries_to_exclude = set(range(len(entries))) - entries_to_include return None
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
return query, entries, embeddings # extract, parse natural dates ranges from date range filter passed in query
# e.g today maps to (start_of_day, start_of_tomorrow)
date_ranges_from_filter = []
for (cmp, date_str) in date_range_matches:
if self.parse(date_str):
dt_start, dt_end = self.parse(date_str)
date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]]
# Combine dates with their comparators to form date range intervals
# For e.g
# >=yesterday maps to [start_of_yesterday, inf)
# <tomorrow maps to [0, start_of_tomorrow)
# ---
effective_date_range = [0, inf]
date_range_considering_comparator = []
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
if cmp == '>':
date_range_considering_comparator += [[dtrange_end, inf]]
elif cmp == '>=':
date_range_considering_comparator += [[dtrange_start, inf]]
elif cmp == '<':
date_range_considering_comparator += [[0, dtrange_start]]
elif cmp == '<=':
date_range_considering_comparator += [[0, dtrange_end]]
elif cmp == '=' or cmp == ':' or cmp == '==':
date_range_considering_comparator += [[dtrange_start, dtrange_end]]
# Combine above intervals (via AND/intersect)
# In the above example, this gives us [start_of_yesterday, start_of_tomorrow)
# This is the effective date range to filter entries by
# ---
for date_range in date_range_considering_comparator:
effective_date_range = [
max(effective_date_range[0], date_range[0]),
min(effective_date_range[1], date_range[1])]
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
return None
else:
return effective_date_range
def extract_date_range(query): def parse(self, date_str, relative_base=None):
# find date range filter in query "Parse date string passed in date filter of query to datetime object"
date_range_matches = re.findall(date_regex, query) # clean date string to handle future date parsing by date parser
future_strings = ['later', 'from now', 'from today']
prefer_dates_from = {True: 'future', False: 'past'}[any([True for fstr in future_strings if fstr in date_str])]
clean_date_str = re.sub('|'.join(future_strings), '', date_str)
if len(date_range_matches) == 0: # parse date passed in query date filter
return None parsed_date = dtparse.parse(
clean_date_str,
settings= {
'RELATIVE_BASE': relative_base or datetime.now(),
'PREFER_DAY_OF_MONTH': 'first',
'PREFER_DATES_FROM': prefer_dates_from
})
# extract, parse natural dates ranges from date range filter passed in query if parsed_date is None:
# e.g today maps to (start_of_day, start_of_tomorrow) return None
date_ranges_from_filter = []
for (cmp, date_str) in date_range_matches:
if parse(date_str):
dt_start, dt_end = parse(date_str)
date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]]
# Combine dates with their comparators to form date range intervals return self.date_to_daterange(parsed_date, date_str)
# For e.g
# >=yesterday maps to [start_of_yesterday, inf)
# <tomorrow maps to [0, start_of_tomorrow)
# ---
effective_date_range = [0, inf]
date_range_considering_comparator = []
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
if cmp == '>':
date_range_considering_comparator += [[dtrange_end, inf]]
elif cmp == '>=':
date_range_considering_comparator += [[dtrange_start, inf]]
elif cmp == '<':
date_range_considering_comparator += [[0, dtrange_start]]
elif cmp == '<=':
date_range_considering_comparator += [[0, dtrange_end]]
elif cmp == '=' or cmp == ':' or cmp == '==':
date_range_considering_comparator += [[dtrange_start, dtrange_end]]
# Combine above intervals (via AND/intersect)
# In the above example, this gives us [start_of_yesterday, start_of_tomorrow)
# This is the effective date range to filter entries by
# ---
for date_range in date_range_considering_comparator:
effective_date_range = [
max(effective_date_range[0], date_range[0]),
min(effective_date_range[1], date_range[1])]
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
return None
else:
return effective_date_range
def parse(date_str, relative_base=None): def date_to_daterange(self, parsed_date, date_str):
"Parse date string passed in date filter of query to datetime object" "Convert parsed date to date ranges at natural granularity (day, week, month or year)"
# clean date string to handle future date parsing by date parser
future_strings = ['later', 'from now', 'from today']
prefer_dates_from = {True: 'future', False: 'past'}[any([True for fstr in future_strings if fstr in date_str])]
clean_date_str = re.sub('|'.join(future_strings), '', date_str)
# parse date passed in query date filter start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)
parsed_date = dtparse.parse(
clean_date_str,
settings= {
'RELATIVE_BASE': relative_base or datetime.now(),
'PREFER_DAY_OF_MONTH': 'first',
'PREFER_DATES_FROM': prefer_dates_from
})
if parsed_date is None: if 'year' in date_str:
return None return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year+1, 1, 1, 0, 0, 0))
if 'month' in date_str:
return date_to_daterange(parsed_date, date_str) start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0)
next_month = start_of_month + relativedelta(months=1)
return (start_of_month, next_month)
def date_to_daterange(parsed_date, date_str): if 'week' in date_str:
"Convert parsed date to date ranges at natural granularity (day, week, month or year)" # if week in date string, dateparser parses it to next week start
# so today = end of this week
start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0) start_of_week = start_of_day - timedelta(days=7)
return (start_of_week, start_of_day)
if 'year' in date_str: else:
return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year+1, 1, 1, 0, 0, 0)) next_day = start_of_day + relativedelta(days=1)
if 'month' in date_str:
start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0)
next_month = start_of_month + relativedelta(months=1)
return (start_of_month, next_month)
if 'week' in date_str:
# if week in date string, dateparser parses it to next week start
# so today = end of this week
start_of_week = start_of_day - timedelta(days=7)
return (start_of_week, start_of_day)
else:
next_day = start_of_day + relativedelta(days=1)
return (start_of_day, next_day) return (start_of_day, next_day)

View File

@@ -5,42 +5,53 @@ import re
import torch import torch
def explicit_filter(raw_query, entries, embeddings, entry_key='raw'): class ExplicitFilter:
# Separate natural query from explicit required, blocked words filters def can_filter(self, raw_query):
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) "Check if query contains explicit filters"
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) # Extract explicit query portion with required, blocked words to filter from natural query
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
return len(required_words) != 0 or len(blocked_words) != 0
def filter(self, raw_query, entries, embeddings, entry_key='raw'):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from explicit required, blocked words filters
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
if len(required_words) == 0 and len(blocked_words) == 0:
return query, entries, embeddings
# convert each entry to a set of words
# split on fullstop, comma, colon, tab, newline or any brackets
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
entries_by_word_set = [set(word.lower()
for word
in re.split(entry_splitter, entry[entry_key])
if word != "")
for entry in entries]
# track id of entries to exclude
entries_to_exclude = set()
# mark entries that do not contain all required_words for exclusion
if len(required_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if not required_words.issubset(words_in_entry):
entries_to_exclude.add(id)
# mark entries that contain any blocked_words for exclusion
if len(blocked_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if words_in_entry.intersection(blocked_words):
entries_to_exclude.add(id)
# delete entries (and their embeddings) marked for exclusion
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
if len(required_words) == 0 and len(blocked_words) == 0:
return query, entries, embeddings return query, entries, embeddings
# convert each entry to a set of words
# split on fullstop, comma, colon, tab, newline or any brackets
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
entries_by_word_set = [set(word.lower()
for word
in re.split(entry_splitter, entry[entry_key])
if word != "")
for entry in entries]
# track id of entries to exclude
entries_to_exclude = set()
# mark entries that do not contain all required_words for exclusion
if len(required_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if not required_words.issubset(words_in_entry):
entries_to_exclude.add(id)
# mark entries that contain any blocked_words for exclusion
if len(blocked_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if words_in_entry.intersection(blocked_words):
entries_to_exclude.add(id)
# delete entries (and their embeddings) marked for exclusion
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
return query, entries, embeddings

View File

@@ -2,6 +2,7 @@
import argparse import argparse
import pathlib import pathlib
from copy import deepcopy from copy import deepcopy
import time
# External Packages # External Packages
import torch import torch
@@ -19,7 +20,7 @@ def initialize_model(search_config: TextSearchConfig):
torch.set_num_threads(4) torch.set_num_threads(4)
# Number of entries we want to retrieve with the bi-encoder # Number of entries we want to retrieve with the bi-encoder
top_k = 30 top_k = 15
# The bi-encoder encodes all entries to use for semantic search # The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model( bi_encoder = load_model(
@@ -62,38 +63,71 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
return corpus_embeddings return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = []): def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0):
"Search for entries that answer the query" "Search for entries that answer the query"
# Copy original embeddings, entries to filter them for query
query = raw_query query = raw_query
corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries) # Use deep copy of original embeddings, entries to filter if query contains filters
start = time.time()
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
if filters_in_query:
corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries)
else:
corpus_embeddings = model.corpus_embeddings
entries = model.entries
end = time.time()
if verbose > 1:
print(f"Copy Time: {end - start:.3f} seconds")
# Filter query, entries and embeddings before semantic search # Filter query, entries and embeddings before semantic search
for filter in filters: start = time.time()
query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings) for filter in filters_in_query:
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
end = time.time()
if verbose > 1:
print(f"Filter Time: {end - start:.3f} seconds")
if entries is None or len(entries) == 0: if entries is None or len(entries) == 0:
return [], [] return [], []
# Encode the query using the bi-encoder # Encode the query using the bi-encoder
start = time.time()
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True) question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
question_embedding.to(device) question_embedding.to(device)
question_embedding = util.normalize_embeddings(question_embedding) question_embedding = util.normalize_embeddings(question_embedding)
end = time.time()
if verbose > 1:
print(f"Query Encode Time: {end - start:.3f} seconds")
# Find relevant entries for the query # Find relevant entries for the query
start = time.time()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
end = time.time()
if verbose > 1:
print(f"Search Time: {end - start:.3f} seconds")
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits] if rank_results:
cross_scores = model.cross_encoder.predict(cross_inp) start = time.time()
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time()
if verbose > 1:
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
# Store cross-encoder scores in results dictionary for ranking # Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)): for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx] hits[idx]['cross-score'] = cross_scores[idx]
# Order results by cross-encoder score followed by bi-encoder score # Order results by cross-encoder score followed by bi-encoder score
start = time.time()
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time()
if verbose > 1:
print(f"Rank Time: {end - start:.3f} seconds")
return hits, entries return hits, entries
@@ -120,7 +154,7 @@ def collate_results(hits, entries, count=5):
return [ return [
{ {
"entry": entries[hit['corpus_id']]['raw'], "entry": entries[hit['corpus_id']]['raw'],
"score": f"{hit['cross-score']:.3f}" "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
} }
for hit for hit
in hits[0:count]] in hits[0:count]]

View File

@@ -29,7 +29,8 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
# Act # Act
hits, entries = text_search.query( hits, entries = text_search.query(
query, query,
model = model.notes_search) model = model.notes_search,
rank_results=True)
results = text_search.collate_results( results = text_search.collate_results(
hits, hits,

View File

@@ -119,7 +119,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
user_query = "How to git install application?" user_query = "How to git install application?"
# Act # Act
response = client.get(f"/search?q={user_query}&n=1&t=org") response = client.get(f"/search?q={user_query}&n=1&t=org&r=true")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -7,7 +7,7 @@ from math import inf
import torch import torch
# Application Packages # Application Packages
from src.search_filter import date_filter from src.search_filter.date_filter import DateFilter
def test_date_filter(): def test_date_filter():
@@ -18,99 +18,99 @@ def test_date_filter():
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}] {'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
q_with_no_date_filter = 'head tail' q_with_no_date_filter = 'head tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_no_date_filter, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_no_date_filter, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 3 assert len(ret_emb) == 3
assert ret_entries == entries assert ret_entries == entries
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 0 assert len(ret_emb) == 0
assert ret_entries == [] assert ret_entries == []
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[2]] assert ret_entries == [entries[2]]
assert len(ret_emb) == 1 assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[1]] assert ret_entries == [entries[1]]
assert len(ret_emb) == 1 assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[2]] assert ret_entries == [entries[2]]
assert len(ret_emb) == 1 assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[1], entries[2]] assert ret_entries == [entries[1], entries[2]]
assert len(ret_emb) == 2 assert len(ret_emb) == 2
def test_extract_date_range(): def test_extract_date_range():
assert date_filter.extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()]
assert date_filter.extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
assert date_filter.extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf] assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
assert date_filter.extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()]
# Unparseable date filter specified in query # Unparseable date filter specified in query
assert date_filter.extract_date_range('head dt:"Summer of 69" tail') == None assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None
# No date filter specified in query # No date filter specified in query
assert date_filter.extract_date_range('head tail') == None assert DateFilter().extract_date_range('head tail') == None
# Non intersecting date ranges # Non intersecting date ranges
assert date_filter.extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None
def test_parse(): def test_parse():
test_now = datetime(1984, 4, 1, 21, 21, 21) test_now = datetime(1984, 4, 1, 21, 21, 21)
# day variations # day variations
assert date_filter.parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0)) assert DateFilter().parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0))
assert date_filter.parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0)) assert DateFilter().parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0))
assert date_filter.parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0)) assert DateFilter().parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0))
assert date_filter.parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0)) assert DateFilter().parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0))
# week variations # week variations
assert date_filter.parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0)) assert DateFilter().parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0))
assert date_filter.parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0)) assert DateFilter().parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0))
# month variations # month variations
assert date_filter.parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0)) assert DateFilter().parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0))
assert date_filter.parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0)) assert DateFilter().parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0))
# year variations # year variations
assert date_filter.parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0)) assert DateFilter().parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0))
assert date_filter.parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0)) assert DateFilter().parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0))
# specific month/date variation # specific month/date variation
assert date_filter.parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0)) assert DateFilter().parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
assert date_filter.parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0)) assert DateFilter().parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
def test_date_filter_regex(): def test_date_filter_regex():
dtrange_match = re.findall(date_filter.date_regex, 'multi word head dt>"today" dt:"1984-01-01"') dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>"today" dt:"1984-01-01"')
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')] assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
dtrange_match = re.findall(date_filter.date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail') dtrange_match = re.findall(DateFilter().date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail')
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')] assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
dtrange_match = re.findall(date_filter.date_regex, 'multi word head dt>="today" dt="1984-01-01"') dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>="today" dt="1984-01-01"')
assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')] assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')]
dtrange_match = re.findall(date_filter.date_regex, 'dt<"multi word date" multi word tail') dtrange_match = re.findall(DateFilter().date_regex, 'dt<"multi word date" multi word tail')
assert dtrange_match == [('<', 'multi word date')] assert dtrange_match == [('<', 'multi word date')]
dtrange_match = re.findall(date_filter.date_regex, 'head dt<="multi word date"') dtrange_match = re.findall(DateFilter().date_regex, 'head dt<="multi word date"')
assert dtrange_match == [('<=', 'multi word date')] assert dtrange_match == [('<=', 'multi word date')]
dtrange_match = re.findall(date_filter.date_regex, 'head tail') dtrange_match = re.findall(DateFilter().date_regex, 'head tail')
assert dtrange_match == [] assert dtrange_match == []