mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Support Incremental Search in Khoj
# Details ## Improve Search API Latency - Improve Search API Latency by ~50-100x to <100ms - Trade-off speed for accuracy in default, fast path of /search API by not re-ranking results using cross-encoder - Make re-ranking of results via cross-encoder configurable via new `?&r=<false|true>` query param to /search API - Only deep-copy entries, embeddings to apply filters if query has any filter keywords ## Support Incremental Update via Khoj Emacs Frontend - Use default, fast path to query /search API while user is typing - Upgrade to cross-encoder re-ranked results once user goes idle (or ends search normally) Closes #37
This commit is contained in:
@@ -1,9 +1,9 @@
|
|||||||
;;; khoj.el --- Natural Search via Emacs
|
;;; khoj.el --- Natural, Incremental Search via Emacs
|
||||||
|
|
||||||
;; Copyright (C) 2021-2022 Debanjum Singh Solanky
|
;; Copyright (C) 2021-2022 Debanjum Singh Solanky
|
||||||
|
|
||||||
;; Author: Debanjum Singh Solanky <debanjum@gmail.com>
|
;; Author: Debanjum Singh Solanky <debanjum@gmail.com>
|
||||||
;; Version: 1.0
|
;; Version: 2.0
|
||||||
;; Keywords: search, org-mode, outlines, markdown, image
|
;; Keywords: search, org-mode, outlines, markdown, image
|
||||||
;; URL: http://github.com/debanjum/khoj/interface/emacs
|
;; URL: http://github.com/debanjum/khoj/interface/emacs
|
||||||
|
|
||||||
@@ -26,26 +26,50 @@
|
|||||||
|
|
||||||
;;; Commentary:
|
;;; Commentary:
|
||||||
|
|
||||||
;; This package provides natural language search on org-mode notes,
|
;; This package provides a natural, incremental search interface to your
|
||||||
;; markdown files, beancount transactions and images.
|
;; org-mode notes, markdown files, beancount transactions and images.
|
||||||
;; It is a wrapper that interfaces with transformer based ML models.
|
;; It is a wrapper that interfaces with transformer based ML models.
|
||||||
;; The models search capabilities are exposed via the Khoj HTTP API
|
;; The models search capabilities are exposed via the Khoj HTTP API.
|
||||||
|
|
||||||
;;; Code:
|
;;; Code:
|
||||||
|
|
||||||
(require 'url)
|
(require 'url)
|
||||||
(require 'json)
|
(require 'json)
|
||||||
|
|
||||||
|
|
||||||
(defcustom khoj--server-url "http://localhost:8000"
|
(defcustom khoj--server-url "http://localhost:8000"
|
||||||
"Location of Khoj API server."
|
"Location of Khoj API server."
|
||||||
:group 'khoj
|
:group 'khoj
|
||||||
:type 'string)
|
:type 'string)
|
||||||
|
|
||||||
(defcustom khoj--image-width 156
|
(defcustom khoj--image-width 156
|
||||||
"Width of rendered images returned by Khoj"
|
"Width of rendered images returned by Khoj."
|
||||||
:group 'khoj
|
:group 'khoj
|
||||||
:type 'integer)
|
:type 'integer)
|
||||||
|
|
||||||
|
(defcustom khoj--rerank-after-idle-time 1.0
|
||||||
|
"Idle time (in seconds) to trigger cross-encoder to rerank incremental search results."
|
||||||
|
:group 'khoj
|
||||||
|
:type 'float)
|
||||||
|
|
||||||
|
(defcustom khoj--results-count 5
|
||||||
|
"Number of results to get from Khoj API for each query."
|
||||||
|
:group 'khoj
|
||||||
|
:type 'integer)
|
||||||
|
|
||||||
|
(defvar khoj--rerank-timer nil
|
||||||
|
"Idle timer to make cross-encoder re-rank incremental search results if user idle.")
|
||||||
|
|
||||||
|
(defvar khoj--minibuffer-window nil
|
||||||
|
"Minibuffer window being used by user to enter query.")
|
||||||
|
|
||||||
|
(defconst khoj--query-prompt "Khoj: "
|
||||||
|
"Query prompt shown to user in the minibuffer.")
|
||||||
|
|
||||||
|
(defvar khoj--search-type "org"
|
||||||
|
"The type of content to perform search on.")
|
||||||
|
|
||||||
|
|
||||||
(defun khoj--extract-entries-as-markdown (json-response query)
|
(defun khoj--extract-entries-as-markdown (json-response query)
|
||||||
"Convert json response from API to markdown entries"
|
"Convert json response from API to markdown entries"
|
||||||
;; remove leading (, ) or SPC from extracted entries string
|
;; remove leading (, ) or SPC from extracted entries string
|
||||||
@@ -118,34 +142,29 @@
|
|||||||
((or (equal file-extension "markdown") (equal file-extension "md")) "markdown")
|
((or (equal file-extension "markdown") (equal file-extension "md")) "markdown")
|
||||||
(t "org"))))
|
(t "org"))))
|
||||||
|
|
||||||
(defun khoj--construct-api-query (query search-type)
|
(defun khoj--construct-api-query (query search-type &optional rerank)
|
||||||
(let ((encoded-query (url-hexify-string query)))
|
(let ((rerank (or rerank "false"))
|
||||||
(format "%s/search?q=%s&t=%s" khoj--server-url encoded-query search-type)))
|
(results-count (or khoj--results-count 5))
|
||||||
|
(encoded-query (url-hexify-string query)))
|
||||||
|
(format "%s/search?q=%s&t=%s&r=%s&n=%s" khoj--server-url encoded-query search-type rerank results-count)))
|
||||||
|
|
||||||
;;;###autoload
|
(defun khoj--query-api-and-render-results (query search-type query-url buffer-name)
|
||||||
(defun khoj (query)
|
;; get json response from api
|
||||||
"Search your content naturally using the Khoj API"
|
(with-current-buffer buffer-name
|
||||||
(interactive "sQuery: ")
|
(let ((inhibit-read-only t))
|
||||||
(let* ((default-type (khoj--buffer-name-to-search-type (buffer-name)))
|
(erase-buffer)
|
||||||
(search-type (completing-read "Type: " '("org" "markdown" "ledger" "music" "image") nil t default-type))
|
(url-insert-file-contents query-url)))
|
||||||
(url (khoj--construct-api-query query search-type))
|
;; render json response into formatted entries
|
||||||
(buff (get-buffer-create (format "*Khoj (q:%s t:%s)*" query search-type))))
|
(with-current-buffer buffer-name
|
||||||
;; get json response from api
|
(let ((inhibit-read-only t)
|
||||||
(with-current-buffer buff
|
(json-response (json-parse-buffer :object-type 'alist)))
|
||||||
(let ((inhibit-read-only t))
|
(erase-buffer)
|
||||||
(erase-buffer)
|
(insert
|
||||||
(url-insert-file-contents url)))
|
(cond ((or (equal search-type "org") (equal search-type "music")) (khoj--extract-entries-as-org json-response query))
|
||||||
;; render json response into formatted entries
|
((equal search-type "markdown") (khoj--extract-entries-as-markdown json-response query))
|
||||||
(with-current-buffer buff
|
((equal search-type "ledger") (khoj--extract-entries-as-ledger json-response query))
|
||||||
(let ((inhibit-read-only t)
|
((equal search-type "image") (khoj--extract-entries-as-images json-response query))
|
||||||
(json-response (json-parse-buffer :object-type 'alist)))
|
(t (format "%s" json-response))))
|
||||||
(erase-buffer)
|
|
||||||
(insert
|
|
||||||
(cond ((or (equal search-type "org") (equal search-type "music")) (khoj--extract-entries-as-org json-response query))
|
|
||||||
((equal search-type "markdown") (khoj--extract-entries-as-markdown json-response query))
|
|
||||||
((equal search-type "ledger") (khoj--extract-entries-as-ledger json-response query))
|
|
||||||
((equal search-type "image") (khoj--extract-entries-as-images json-response query))
|
|
||||||
(t (format "%s" json-response))))
|
|
||||||
(cond ((equal search-type "org") (org-mode))
|
(cond ((equal search-type "org") (org-mode))
|
||||||
((equal search-type "markdown") (markdown-mode))
|
((equal search-type "markdown") (markdown-mode))
|
||||||
((equal search-type "ledger") (beancount-mode))
|
((equal search-type "ledger") (beancount-mode))
|
||||||
@@ -154,8 +173,82 @@
|
|||||||
((equal search-type "image") (progn (shr-render-region (point-min) (point-max))
|
((equal search-type "image") (progn (shr-render-region (point-min) (point-max))
|
||||||
(goto-char (point-min))))
|
(goto-char (point-min))))
|
||||||
(t (fundamental-mode))))
|
(t (fundamental-mode))))
|
||||||
(read-only-mode t))
|
(read-only-mode t)))
|
||||||
(switch-to-buffer buff)))
|
|
||||||
|
|
||||||
|
;; Incremental Search on Khoj
|
||||||
|
(defun khoj--incremental-search (&optional rerank)
|
||||||
|
(let* ((rerank-str (cond (rerank "true") (t "false")))
|
||||||
|
(search-type khoj--search-type)
|
||||||
|
(buffer-name (get-buffer-create (format "*Khoj (t:%s)*" search-type)))
|
||||||
|
(query (minibuffer-contents-no-properties))
|
||||||
|
(query-url (khoj--construct-api-query query search-type rerank-str)))
|
||||||
|
;; Query khoj API only when user in khoj minibuffer.
|
||||||
|
;; Prevents querying during recursive edits or with contents of other buffers user may jump to
|
||||||
|
(when (and (active-minibuffer-window) (equal (current-buffer) khoj--minibuffer-window))
|
||||||
|
(progn
|
||||||
|
(when rerank
|
||||||
|
(message "[Khoj]: Rerank Results"))
|
||||||
|
(khoj--query-api-and-render-results
|
||||||
|
query
|
||||||
|
search-type
|
||||||
|
query-url
|
||||||
|
buffer-name)))))
|
||||||
|
|
||||||
|
(defun khoj--teardown-incremental-search ()
|
||||||
|
;; remove advice to rerank results on normal exit from minibuffer
|
||||||
|
(advice-remove 'exit-minibuffer #'khoj--minibuffer-exit-advice)
|
||||||
|
;; unset khoj minibuffer window
|
||||||
|
(setq khoj--minibuffer-window nil)
|
||||||
|
;; cancel rerank timer
|
||||||
|
(when (timerp khoj--rerank-timer)
|
||||||
|
(cancel-timer khoj--rerank-timer))
|
||||||
|
;; remove hooks for khoj incremental query and self
|
||||||
|
(remove-hook 'post-command-hook #'khoj--incremental-search)
|
||||||
|
(remove-hook 'minibuffer-exit-hook #'khoj--teardown-incremental-search))
|
||||||
|
|
||||||
|
(defun khoj--minibuffer-exit-advice (&rest _args)
|
||||||
|
(khoj--incremental-search t))
|
||||||
|
|
||||||
|
;;;###autoload
|
||||||
|
(defun khoj ()
|
||||||
|
"Natural, Incremental Search for your personal notes, transactions and music using Khoj"
|
||||||
|
(interactive)
|
||||||
|
(let* ((default-type (khoj--buffer-name-to-search-type (buffer-name)))
|
||||||
|
(search-type (completing-read "Type: " '("org" "markdown" "ledger" "music") nil t default-type))
|
||||||
|
(buffer-name (get-buffer-create (format "*Khoj (t:%s)*" search-type))))
|
||||||
|
(setq khoj--search-type search-type)
|
||||||
|
;; setup rerank to improve results once user idle for KHOJ--RERANK-AFTER-IDLE-TIME seconds
|
||||||
|
(setq khoj--rerank-timer (run-with-idle-timer khoj--rerank-after-idle-time t 'khoj--incremental-search t))
|
||||||
|
;; switch to khoj results buffer
|
||||||
|
(switch-to-buffer buffer-name)
|
||||||
|
;; open and setup minibuffer for incremental search
|
||||||
|
(minibuffer-with-setup-hook
|
||||||
|
(lambda ()
|
||||||
|
;; set current (mini-)buffer entered as khoj minibuffer
|
||||||
|
;; used to query khoj API only when user in khoj minibuffer
|
||||||
|
(setq khoj--minibuffer-window (current-buffer))
|
||||||
|
;; rerank results on normal exit from minibuffer
|
||||||
|
(advice-add 'exit-minibuffer :before #'khoj--minibuffer-exit-advice)
|
||||||
|
(add-hook 'post-command-hook #'khoj--incremental-search) ; do khoj incremental search after every user action
|
||||||
|
(add-hook 'minibuffer-exit-hook #'khoj--teardown-incremental-search)) ; teardown khoj incremental search on minibuffer exit
|
||||||
|
(read-string khoj--query-prompt))))
|
||||||
|
|
||||||
|
;;;###autoload
|
||||||
|
(defun khoj-simple (query)
|
||||||
|
"Natural Search for QUERY in your personal notes, transactions, music and images using Khoj"
|
||||||
|
(interactive "sQuery: ")
|
||||||
|
(let* ((rerank "true")
|
||||||
|
(default-type (khoj--buffer-name-to-search-type (buffer-name)))
|
||||||
|
(search-type (completing-read "Type: " '("org" "markdown" "ledger" "music" "image") nil t default-type))
|
||||||
|
(query-url (khoj--construct-api-query query search-type rerank))
|
||||||
|
(buffer-name (get-buffer-create (format "*Khoj (q:%s t:%s)*" query search-type))))
|
||||||
|
(khoj--query-api-and-render-results
|
||||||
|
query
|
||||||
|
search-type
|
||||||
|
query-url
|
||||||
|
buffer-name)
|
||||||
|
(switch-to-buffer buffer-name)))
|
||||||
|
|
||||||
(provide 'khoj)
|
(provide 'khoj)
|
||||||
|
|
||||||
|
|||||||
53
src/main.py
53
src/main.py
@@ -1,5 +1,6 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
import sys, json, yaml, os
|
import sys, json, yaml, os
|
||||||
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
@@ -20,8 +21,8 @@ from src.utils.cli import cli
|
|||||||
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
||||||
from src.utils.rawconfig import FullConfig
|
from src.utils.rawconfig import FullConfig
|
||||||
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
|
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
|
||||||
from src.search_filter.explicit_filter import explicit_filter
|
from src.search_filter.explicit_filter import ExplicitFilter
|
||||||
from src.search_filter.date_filter import date_filter
|
from src.search_filter.date_filter import DateFilter
|
||||||
|
|
||||||
# Application Global State
|
# Application Global State
|
||||||
config = FullConfig()
|
config = FullConfig()
|
||||||
@@ -58,7 +59,7 @@ async def config_data(updated_config: FullConfig):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
@app.get('/search')
|
@app.get('/search')
|
||||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
||||||
if q is None or q == '':
|
if q is None or q == '':
|
||||||
print(f'No query param (q) passed in API call to initiate search')
|
print(f'No query param (q) passed in API call to initiate search')
|
||||||
return {}
|
return {}
|
||||||
@@ -66,50 +67,74 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
|||||||
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
user_query = q
|
user_query = q
|
||||||
results_count = n
|
results_count = n
|
||||||
|
results = {}
|
||||||
|
|
||||||
if (t == SearchType.Org or t == None) and model.orgmode_search:
|
if (t == SearchType.Org or t == None) and model.orgmode_search:
|
||||||
# query org-mode notes
|
# query org-mode notes
|
||||||
hits, entries = text_search.query(user_query, model.orgmode_search, device=device, filters=[explicit_filter, date_filter])
|
query_start = time.time()
|
||||||
|
hits, entries = text_search.query(user_query, model.orgmode_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
|
||||||
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return text_search.collate_results(hits, entries, results_count)
|
collate_start = time.time()
|
||||||
|
results = text_search.collate_results(hits, entries, results_count)
|
||||||
|
collate_end = time.time()
|
||||||
|
|
||||||
if (t == SearchType.Music or t == None) and model.music_search:
|
if (t == SearchType.Music or t == None) and model.music_search:
|
||||||
# query music library
|
# query music library
|
||||||
hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[explicit_filter, date_filter])
|
query_start = time.time()
|
||||||
|
hits, entries = text_search.query(user_query, model.music_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
|
||||||
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return text_search.collate_results(hits, entries, results_count)
|
collate_start = time.time()
|
||||||
|
results = text_search.collate_results(hits, entries, results_count)
|
||||||
|
collate_end = time.time()
|
||||||
|
|
||||||
if (t == SearchType.Markdown or t == None) and model.orgmode_search:
|
if (t == SearchType.Markdown or t == None) and model.orgmode_search:
|
||||||
# query markdown files
|
# query markdown files
|
||||||
hits, entries = text_search.query(user_query, model.markdown_search, device=device, filters=[explicit_filter, date_filter])
|
query_start = time.time()
|
||||||
|
hits, entries = text_search.query(user_query, model.markdown_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
|
||||||
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return text_search.collate_results(hits, entries, results_count)
|
collate_start = time.time()
|
||||||
|
results = text_search.collate_results(hits, entries, results_count)
|
||||||
|
collate_end = time.time()
|
||||||
|
|
||||||
if (t == SearchType.Ledger or t == None) and model.ledger_search:
|
if (t == SearchType.Ledger or t == None) and model.ledger_search:
|
||||||
# query transactions
|
# query transactions
|
||||||
hits, entries = text_search.query(user_query, model.ledger_search, filters=[explicit_filter, date_filter])
|
query_start = time.time()
|
||||||
|
hits, entries = text_search.query(user_query, model.ledger_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
|
||||||
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return text_search.collate_results(hits, entries, results_count)
|
collate_start = time.time()
|
||||||
|
results = text_search.collate_results(hits, entries, results_count)
|
||||||
|
collate_end = time.time()
|
||||||
|
|
||||||
if (t == SearchType.Image or t == None) and model.image_search:
|
if (t == SearchType.Image or t == None) and model.image_search:
|
||||||
# query images
|
# query images
|
||||||
|
query_start = time.time()
|
||||||
hits = image_search.query(user_query, results_count, model.image_search)
|
hits = image_search.query(user_query, results_count, model.image_search)
|
||||||
output_directory = f'{os.getcwd()}/{web_directory}'
|
output_directory = f'{os.getcwd()}/{web_directory}'
|
||||||
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return image_search.collate_results(
|
collate_start = time.time()
|
||||||
|
results = image_search.collate_results(
|
||||||
hits,
|
hits,
|
||||||
image_names=model.image_search.image_names,
|
image_names=model.image_search.image_names,
|
||||||
output_directory=output_directory,
|
output_directory=output_directory,
|
||||||
static_files_url='/static',
|
static_files_url='/static',
|
||||||
count=results_count)
|
count=results_count)
|
||||||
|
collate_end = time.time()
|
||||||
|
|
||||||
else:
|
if verbose > 1:
|
||||||
return {}
|
print(f"Query took {query_end - query_start:.3f} seconds")
|
||||||
|
print(f"Collating results took {collate_end - collate_start:.3f} seconds")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
@app.get('/reload')
|
@app.get('/reload')
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ def convert_org_entries_to_jsonl(entries, verbose=0):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
entry_dict["compiled"] = f'{entry.Heading()}.'
|
entry_dict["compiled"] = f'{entry.Heading()}.'
|
||||||
if verbose > 1:
|
if verbose > 2:
|
||||||
print(f"Title: {entry.Heading()}")
|
print(f"Title: {entry.Heading()}")
|
||||||
|
|
||||||
if entry.Tags():
|
if entry.Tags():
|
||||||
|
|||||||
@@ -9,138 +9,143 @@ import torch
|
|||||||
import dateparser as dtparse
|
import dateparser as dtparse
|
||||||
|
|
||||||
|
|
||||||
# Date Range Filter Regexes
|
class DateFilter:
|
||||||
# Example filter queries:
|
# Date Range Filter Regexes
|
||||||
# - dt>="yesterday" dt<"tomorrow"
|
# Example filter queries:
|
||||||
# - dt>="last week"
|
# - dt>="yesterday" dt<"tomorrow"
|
||||||
# - dt:"2 years ago"
|
# - dt>="last week"
|
||||||
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
# - dt:"2 years ago"
|
||||||
|
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
||||||
|
|
||||||
|
def can_filter(self, raw_query):
|
||||||
|
"Check if query contains date filters"
|
||||||
|
return self.extract_date_range(raw_query) is not None
|
||||||
|
|
||||||
|
|
||||||
def date_filter(query, entries, embeddings, entry_key='raw'):
|
def filter(self, query, entries, embeddings, entry_key='raw'):
|
||||||
"Find entries containing any dates that fall within date range specified in query"
|
"Find entries containing any dates that fall within date range specified in query"
|
||||||
# extract date range specified in date filter of query
|
# extract date range specified in date filter of query
|
||||||
query_daterange = extract_date_range(query)
|
query_daterange = self.extract_date_range(query)
|
||||||
|
|
||||||
|
# if no date in query, return all entries
|
||||||
|
if query_daterange is None:
|
||||||
|
return query, entries, embeddings
|
||||||
|
|
||||||
|
# remove date range filter from query
|
||||||
|
query = re.sub(f'\s+{self.date_regex}', ' ', query)
|
||||||
|
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
|
||||||
|
|
||||||
|
# find entries containing any dates that fall with date range specified in query
|
||||||
|
entries_to_include = set()
|
||||||
|
for id, entry in enumerate(entries):
|
||||||
|
# Extract dates from entry
|
||||||
|
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
|
||||||
|
# Convert date string in entry to unix timestamp
|
||||||
|
try:
|
||||||
|
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
# Check if date in entry is within date range specified in query
|
||||||
|
if query_daterange[0] <= date_in_entry < query_daterange[1]:
|
||||||
|
entries_to_include.add(id)
|
||||||
|
break
|
||||||
|
|
||||||
|
# delete entries (and their embeddings) marked for exclusion
|
||||||
|
entries_to_exclude = set(range(len(entries))) - entries_to_include
|
||||||
|
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||||
|
del entries[id]
|
||||||
|
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||||
|
|
||||||
# if no date in query, return all entries
|
|
||||||
if query_daterange is None:
|
|
||||||
return query, entries, embeddings
|
return query, entries, embeddings
|
||||||
|
|
||||||
# remove date range filter from query
|
|
||||||
query = re.sub(f'\s+{date_regex}', ' ', query)
|
|
||||||
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
|
|
||||||
|
|
||||||
# find entries containing any dates that fall with date range specified in query
|
def extract_date_range(self, query):
|
||||||
entries_to_include = set()
|
# find date range filter in query
|
||||||
for id, entry in enumerate(entries):
|
date_range_matches = re.findall(self.date_regex, query)
|
||||||
# Extract dates from entry
|
|
||||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
|
|
||||||
# Convert date string in entry to unix timestamp
|
|
||||||
try:
|
|
||||||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
|
||||||
except ValueError:
|
|
||||||
continue
|
|
||||||
# Check if date in entry is within date range specified in query
|
|
||||||
if query_daterange[0] <= date_in_entry < query_daterange[1]:
|
|
||||||
entries_to_include.add(id)
|
|
||||||
break
|
|
||||||
|
|
||||||
# delete entries (and their embeddings) marked for exclusion
|
if len(date_range_matches) == 0:
|
||||||
entries_to_exclude = set(range(len(entries))) - entries_to_include
|
return None
|
||||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
|
||||||
del entries[id]
|
|
||||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
|
||||||
|
|
||||||
return query, entries, embeddings
|
# extract, parse natural dates ranges from date range filter passed in query
|
||||||
|
# e.g today maps to (start_of_day, start_of_tomorrow)
|
||||||
|
date_ranges_from_filter = []
|
||||||
|
for (cmp, date_str) in date_range_matches:
|
||||||
|
if self.parse(date_str):
|
||||||
|
dt_start, dt_end = self.parse(date_str)
|
||||||
|
date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]]
|
||||||
|
|
||||||
|
# Combine dates with their comparators to form date range intervals
|
||||||
|
# For e.g
|
||||||
|
# >=yesterday maps to [start_of_yesterday, inf)
|
||||||
|
# <tomorrow maps to [0, start_of_tomorrow)
|
||||||
|
# ---
|
||||||
|
effective_date_range = [0, inf]
|
||||||
|
date_range_considering_comparator = []
|
||||||
|
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
|
||||||
|
if cmp == '>':
|
||||||
|
date_range_considering_comparator += [[dtrange_end, inf]]
|
||||||
|
elif cmp == '>=':
|
||||||
|
date_range_considering_comparator += [[dtrange_start, inf]]
|
||||||
|
elif cmp == '<':
|
||||||
|
date_range_considering_comparator += [[0, dtrange_start]]
|
||||||
|
elif cmp == '<=':
|
||||||
|
date_range_considering_comparator += [[0, dtrange_end]]
|
||||||
|
elif cmp == '=' or cmp == ':' or cmp == '==':
|
||||||
|
date_range_considering_comparator += [[dtrange_start, dtrange_end]]
|
||||||
|
|
||||||
|
# Combine above intervals (via AND/intersect)
|
||||||
|
# In the above example, this gives us [start_of_yesterday, start_of_tomorrow)
|
||||||
|
# This is the effective date range to filter entries by
|
||||||
|
# ---
|
||||||
|
for date_range in date_range_considering_comparator:
|
||||||
|
effective_date_range = [
|
||||||
|
max(effective_date_range[0], date_range[0]),
|
||||||
|
min(effective_date_range[1], date_range[1])]
|
||||||
|
|
||||||
|
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return effective_date_range
|
||||||
|
|
||||||
|
|
||||||
def extract_date_range(query):
|
def parse(self, date_str, relative_base=None):
|
||||||
# find date range filter in query
|
"Parse date string passed in date filter of query to datetime object"
|
||||||
date_range_matches = re.findall(date_regex, query)
|
# clean date string to handle future date parsing by date parser
|
||||||
|
future_strings = ['later', 'from now', 'from today']
|
||||||
|
prefer_dates_from = {True: 'future', False: 'past'}[any([True for fstr in future_strings if fstr in date_str])]
|
||||||
|
clean_date_str = re.sub('|'.join(future_strings), '', date_str)
|
||||||
|
|
||||||
if len(date_range_matches) == 0:
|
# parse date passed in query date filter
|
||||||
return None
|
parsed_date = dtparse.parse(
|
||||||
|
clean_date_str,
|
||||||
|
settings= {
|
||||||
|
'RELATIVE_BASE': relative_base or datetime.now(),
|
||||||
|
'PREFER_DAY_OF_MONTH': 'first',
|
||||||
|
'PREFER_DATES_FROM': prefer_dates_from
|
||||||
|
})
|
||||||
|
|
||||||
# extract, parse natural dates ranges from date range filter passed in query
|
if parsed_date is None:
|
||||||
# e.g today maps to (start_of_day, start_of_tomorrow)
|
return None
|
||||||
date_ranges_from_filter = []
|
|
||||||
for (cmp, date_str) in date_range_matches:
|
|
||||||
if parse(date_str):
|
|
||||||
dt_start, dt_end = parse(date_str)
|
|
||||||
date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]]
|
|
||||||
|
|
||||||
# Combine dates with their comparators to form date range intervals
|
return self.date_to_daterange(parsed_date, date_str)
|
||||||
# For e.g
|
|
||||||
# >=yesterday maps to [start_of_yesterday, inf)
|
|
||||||
# <tomorrow maps to [0, start_of_tomorrow)
|
|
||||||
# ---
|
|
||||||
effective_date_range = [0, inf]
|
|
||||||
date_range_considering_comparator = []
|
|
||||||
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
|
|
||||||
if cmp == '>':
|
|
||||||
date_range_considering_comparator += [[dtrange_end, inf]]
|
|
||||||
elif cmp == '>=':
|
|
||||||
date_range_considering_comparator += [[dtrange_start, inf]]
|
|
||||||
elif cmp == '<':
|
|
||||||
date_range_considering_comparator += [[0, dtrange_start]]
|
|
||||||
elif cmp == '<=':
|
|
||||||
date_range_considering_comparator += [[0, dtrange_end]]
|
|
||||||
elif cmp == '=' or cmp == ':' or cmp == '==':
|
|
||||||
date_range_considering_comparator += [[dtrange_start, dtrange_end]]
|
|
||||||
|
|
||||||
# Combine above intervals (via AND/intersect)
|
|
||||||
# In the above example, this gives us [start_of_yesterday, start_of_tomorrow)
|
|
||||||
# This is the effective date range to filter entries by
|
|
||||||
# ---
|
|
||||||
for date_range in date_range_considering_comparator:
|
|
||||||
effective_date_range = [
|
|
||||||
max(effective_date_range[0], date_range[0]),
|
|
||||||
min(effective_date_range[1], date_range[1])]
|
|
||||||
|
|
||||||
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return effective_date_range
|
|
||||||
|
|
||||||
|
|
||||||
def parse(date_str, relative_base=None):
|
def date_to_daterange(self, parsed_date, date_str):
|
||||||
"Parse date string passed in date filter of query to datetime object"
|
"Convert parsed date to date ranges at natural granularity (day, week, month or year)"
|
||||||
# clean date string to handle future date parsing by date parser
|
|
||||||
future_strings = ['later', 'from now', 'from today']
|
|
||||||
prefer_dates_from = {True: 'future', False: 'past'}[any([True for fstr in future_strings if fstr in date_str])]
|
|
||||||
clean_date_str = re.sub('|'.join(future_strings), '', date_str)
|
|
||||||
|
|
||||||
# parse date passed in query date filter
|
start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
parsed_date = dtparse.parse(
|
|
||||||
clean_date_str,
|
|
||||||
settings= {
|
|
||||||
'RELATIVE_BASE': relative_base or datetime.now(),
|
|
||||||
'PREFER_DAY_OF_MONTH': 'first',
|
|
||||||
'PREFER_DATES_FROM': prefer_dates_from
|
|
||||||
})
|
|
||||||
|
|
||||||
if parsed_date is None:
|
if 'year' in date_str:
|
||||||
return None
|
return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year+1, 1, 1, 0, 0, 0))
|
||||||
|
if 'month' in date_str:
|
||||||
return date_to_daterange(parsed_date, date_str)
|
start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0)
|
||||||
|
next_month = start_of_month + relativedelta(months=1)
|
||||||
|
return (start_of_month, next_month)
|
||||||
def date_to_daterange(parsed_date, date_str):
|
if 'week' in date_str:
|
||||||
"Convert parsed date to date ranges at natural granularity (day, week, month or year)"
|
# if week in date string, dateparser parses it to next week start
|
||||||
|
# so today = end of this week
|
||||||
start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
start_of_week = start_of_day - timedelta(days=7)
|
||||||
|
return (start_of_week, start_of_day)
|
||||||
if 'year' in date_str:
|
else:
|
||||||
return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year+1, 1, 1, 0, 0, 0))
|
next_day = start_of_day + relativedelta(days=1)
|
||||||
if 'month' in date_str:
|
|
||||||
start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0)
|
|
||||||
next_month = start_of_month + relativedelta(months=1)
|
|
||||||
return (start_of_month, next_month)
|
|
||||||
if 'week' in date_str:
|
|
||||||
# if week in date string, dateparser parses it to next week start
|
|
||||||
# so today = end of this week
|
|
||||||
start_of_week = start_of_day - timedelta(days=7)
|
|
||||||
return (start_of_week, start_of_day)
|
|
||||||
else:
|
|
||||||
next_day = start_of_day + relativedelta(days=1)
|
|
||||||
return (start_of_day, next_day)
|
return (start_of_day, next_day)
|
||||||
|
|||||||
@@ -5,42 +5,53 @@ import re
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def explicit_filter(raw_query, entries, embeddings, entry_key='raw'):
|
class ExplicitFilter:
|
||||||
# Separate natural query from explicit required, blocked words filters
|
def can_filter(self, raw_query):
|
||||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
"Check if query contains explicit filters"
|
||||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
# Extract explicit query portion with required, blocked words to filter from natural query
|
||||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||||
|
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||||
|
|
||||||
|
return len(required_words) != 0 or len(blocked_words) != 0
|
||||||
|
|
||||||
|
|
||||||
|
def filter(self, raw_query, entries, embeddings, entry_key='raw'):
|
||||||
|
"Find entries containing required and not blocked words specified in query"
|
||||||
|
# Separate natural query from explicit required, blocked words filters
|
||||||
|
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||||
|
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||||
|
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||||
|
|
||||||
|
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||||
|
return query, entries, embeddings
|
||||||
|
|
||||||
|
# convert each entry to a set of words
|
||||||
|
# split on fullstop, comma, colon, tab, newline or any brackets
|
||||||
|
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
|
||||||
|
entries_by_word_set = [set(word.lower()
|
||||||
|
for word
|
||||||
|
in re.split(entry_splitter, entry[entry_key])
|
||||||
|
if word != "")
|
||||||
|
for entry in entries]
|
||||||
|
|
||||||
|
# track id of entries to exclude
|
||||||
|
entries_to_exclude = set()
|
||||||
|
|
||||||
|
# mark entries that do not contain all required_words for exclusion
|
||||||
|
if len(required_words) > 0:
|
||||||
|
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||||
|
if not required_words.issubset(words_in_entry):
|
||||||
|
entries_to_exclude.add(id)
|
||||||
|
|
||||||
|
# mark entries that contain any blocked_words for exclusion
|
||||||
|
if len(blocked_words) > 0:
|
||||||
|
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||||
|
if words_in_entry.intersection(blocked_words):
|
||||||
|
entries_to_exclude.add(id)
|
||||||
|
|
||||||
|
# delete entries (and their embeddings) marked for exclusion
|
||||||
|
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||||
|
del entries[id]
|
||||||
|
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||||
|
|
||||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
|
||||||
return query, entries, embeddings
|
return query, entries, embeddings
|
||||||
|
|
||||||
# convert each entry to a set of words
|
|
||||||
# split on fullstop, comma, colon, tab, newline or any brackets
|
|
||||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
|
|
||||||
entries_by_word_set = [set(word.lower()
|
|
||||||
for word
|
|
||||||
in re.split(entry_splitter, entry[entry_key])
|
|
||||||
if word != "")
|
|
||||||
for entry in entries]
|
|
||||||
|
|
||||||
# track id of entries to exclude
|
|
||||||
entries_to_exclude = set()
|
|
||||||
|
|
||||||
# mark entries that do not contain all required_words for exclusion
|
|
||||||
if len(required_words) > 0:
|
|
||||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
|
||||||
if not required_words.issubset(words_in_entry):
|
|
||||||
entries_to_exclude.add(id)
|
|
||||||
|
|
||||||
# mark entries that contain any blocked_words for exclusion
|
|
||||||
if len(blocked_words) > 0:
|
|
||||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
|
||||||
if words_in_entry.intersection(blocked_words):
|
|
||||||
entries_to_exclude.add(id)
|
|
||||||
|
|
||||||
# delete entries (and their embeddings) marked for exclusion
|
|
||||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
|
||||||
del entries[id]
|
|
||||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
|
||||||
|
|
||||||
return query, entries, embeddings
|
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
import pathlib
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
import time
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
@@ -19,7 +20,7 @@ def initialize_model(search_config: TextSearchConfig):
|
|||||||
torch.set_num_threads(4)
|
torch.set_num_threads(4)
|
||||||
|
|
||||||
# Number of entries we want to retrieve with the bi-encoder
|
# Number of entries we want to retrieve with the bi-encoder
|
||||||
top_k = 30
|
top_k = 15
|
||||||
|
|
||||||
# The bi-encoder encodes all entries to use for semantic search
|
# The bi-encoder encodes all entries to use for semantic search
|
||||||
bi_encoder = load_model(
|
bi_encoder = load_model(
|
||||||
@@ -62,38 +63,71 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
|
|||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = []):
|
def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0):
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
# Copy original embeddings, entries to filter them for query
|
|
||||||
query = raw_query
|
query = raw_query
|
||||||
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
|
||||||
entries = deepcopy(model.entries)
|
# Use deep copy of original embeddings, entries to filter if query contains filters
|
||||||
|
start = time.time()
|
||||||
|
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||||
|
if filters_in_query:
|
||||||
|
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||||
|
entries = deepcopy(model.entries)
|
||||||
|
else:
|
||||||
|
corpus_embeddings = model.corpus_embeddings
|
||||||
|
entries = model.entries
|
||||||
|
end = time.time()
|
||||||
|
if verbose > 1:
|
||||||
|
print(f"Copy Time: {end - start:.3f} seconds")
|
||||||
|
|
||||||
# Filter query, entries and embeddings before semantic search
|
# Filter query, entries and embeddings before semantic search
|
||||||
for filter in filters:
|
start = time.time()
|
||||||
query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings)
|
for filter in filters_in_query:
|
||||||
|
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
|
||||||
|
end = time.time()
|
||||||
|
if verbose > 1:
|
||||||
|
print(f"Filter Time: {end - start:.3f} seconds")
|
||||||
|
|
||||||
if entries is None or len(entries) == 0:
|
if entries is None or len(entries) == 0:
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
|
start = time.time()
|
||||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
|
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
|
||||||
question_embedding.to(device)
|
question_embedding.to(device)
|
||||||
question_embedding = util.normalize_embeddings(question_embedding)
|
question_embedding = util.normalize_embeddings(question_embedding)
|
||||||
|
end = time.time()
|
||||||
|
if verbose > 1:
|
||||||
|
print(f"Query Encode Time: {end - start:.3f} seconds")
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
|
start = time.time()
|
||||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
||||||
|
end = time.time()
|
||||||
|
if verbose > 1:
|
||||||
|
print(f"Search Time: {end - start:.3f} seconds")
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
if rank_results:
|
||||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
start = time.time()
|
||||||
|
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
||||||
|
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||||
|
end = time.time()
|
||||||
|
if verbose > 1:
|
||||||
|
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
|
||||||
|
|
||||||
# Store cross-encoder scores in results dictionary for ranking
|
# Store cross-encoder scores in results dictionary for ranking
|
||||||
for idx in range(len(cross_scores)):
|
for idx in range(len(cross_scores)):
|
||||||
hits[idx]['cross-score'] = cross_scores[idx]
|
hits[idx]['cross-score'] = cross_scores[idx]
|
||||||
|
|
||||||
# Order results by cross-encoder score followed by bi-encoder score
|
# Order results by cross-encoder score followed by bi-encoder score
|
||||||
|
start = time.time()
|
||||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
if rank_results:
|
||||||
|
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||||
|
end = time.time()
|
||||||
|
if verbose > 1:
|
||||||
|
print(f"Rank Time: {end - start:.3f} seconds")
|
||||||
|
|
||||||
return hits, entries
|
return hits, entries
|
||||||
|
|
||||||
@@ -120,7 +154,7 @@ def collate_results(hits, entries, count=5):
|
|||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"entry": entries[hit['corpus_id']]['raw'],
|
"entry": entries[hit['corpus_id']]['raw'],
|
||||||
"score": f"{hit['cross-score']:.3f}"
|
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
|
||||||
}
|
}
|
||||||
for hit
|
for hit
|
||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|||||||
@@ -29,7 +29,8 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
|
|||||||
# Act
|
# Act
|
||||||
hits, entries = text_search.query(
|
hits, entries = text_search.query(
|
||||||
query,
|
query,
|
||||||
model = model.notes_search)
|
model = model.notes_search,
|
||||||
|
rank_results=True)
|
||||||
|
|
||||||
results = text_search.collate_results(
|
results = text_search.collate_results(
|
||||||
hits,
|
hits,
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
|
|||||||
user_query = "How to git install application?"
|
user_query = "How to git install application?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/search?q={user_query}&n=1&t=org")
|
response = client.get(f"/search?q={user_query}&n=1&t=org&r=true")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from math import inf
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Application Packages
|
# Application Packages
|
||||||
from src.search_filter import date_filter
|
from src.search_filter.date_filter import DateFilter
|
||||||
|
|
||||||
|
|
||||||
def test_date_filter():
|
def test_date_filter():
|
||||||
@@ -18,99 +18,99 @@ def test_date_filter():
|
|||||||
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
|
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
|
||||||
|
|
||||||
q_with_no_date_filter = 'head tail'
|
q_with_no_date_filter = 'head tail'
|
||||||
ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_no_date_filter, entries.copy(), embeddings)
|
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_no_date_filter, entries.copy(), embeddings)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == 'head tail'
|
||||||
assert len(ret_emb) == 3
|
assert len(ret_emb) == 3
|
||||||
assert ret_entries == entries
|
assert ret_entries == entries
|
||||||
|
|
||||||
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
||||||
ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
|
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == 'head tail'
|
||||||
assert len(ret_emb) == 0
|
assert len(ret_emb) == 0
|
||||||
assert ret_entries == []
|
assert ret_entries == []
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
||||||
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == 'head tail'
|
||||||
assert ret_entries == [entries[2]]
|
assert ret_entries == [entries[2]]
|
||||||
assert len(ret_emb) == 1
|
assert len(ret_emb) == 1
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
||||||
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == 'head tail'
|
||||||
assert ret_entries == [entries[1]]
|
assert ret_entries == [entries[1]]
|
||||||
assert len(ret_emb) == 1
|
assert len(ret_emb) == 1
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
||||||
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == 'head tail'
|
||||||
assert ret_entries == [entries[2]]
|
assert ret_entries == [entries[2]]
|
||||||
assert len(ret_emb) == 1
|
assert len(ret_emb) == 1
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
||||||
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == 'head tail'
|
||||||
assert ret_entries == [entries[1], entries[2]]
|
assert ret_entries == [entries[1], entries[2]]
|
||||||
assert len(ret_emb) == 2
|
assert len(ret_emb) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_extract_date_range():
|
def test_extract_date_range():
|
||||||
assert date_filter.extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()]
|
assert DateFilter().extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()]
|
||||||
assert date_filter.extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||||
assert date_filter.extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
|
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
|
||||||
assert date_filter.extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||||
|
|
||||||
# Unparseable date filter specified in query
|
# Unparseable date filter specified in query
|
||||||
assert date_filter.extract_date_range('head dt:"Summer of 69" tail') == None
|
assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None
|
||||||
|
|
||||||
# No date filter specified in query
|
# No date filter specified in query
|
||||||
assert date_filter.extract_date_range('head tail') == None
|
assert DateFilter().extract_date_range('head tail') == None
|
||||||
|
|
||||||
# Non intersecting date ranges
|
# Non intersecting date ranges
|
||||||
assert date_filter.extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None
|
assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None
|
||||||
|
|
||||||
|
|
||||||
def test_parse():
|
def test_parse():
|
||||||
test_now = datetime(1984, 4, 1, 21, 21, 21)
|
test_now = datetime(1984, 4, 1, 21, 21, 21)
|
||||||
|
|
||||||
# day variations
|
# day variations
|
||||||
assert date_filter.parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0))
|
assert DateFilter().parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0))
|
||||||
assert date_filter.parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0))
|
assert DateFilter().parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0))
|
||||||
assert date_filter.parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0))
|
assert DateFilter().parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0))
|
||||||
assert date_filter.parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0))
|
assert DateFilter().parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0))
|
||||||
|
|
||||||
# week variations
|
# week variations
|
||||||
assert date_filter.parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0))
|
assert DateFilter().parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0))
|
||||||
assert date_filter.parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0))
|
assert DateFilter().parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0))
|
||||||
|
|
||||||
# month variations
|
# month variations
|
||||||
assert date_filter.parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0))
|
assert DateFilter().parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0))
|
||||||
assert date_filter.parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0))
|
assert DateFilter().parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0))
|
||||||
|
|
||||||
# year variations
|
# year variations
|
||||||
assert date_filter.parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0))
|
assert DateFilter().parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0))
|
||||||
assert date_filter.parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0))
|
assert DateFilter().parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0))
|
||||||
|
|
||||||
# specific month/date variation
|
# specific month/date variation
|
||||||
assert date_filter.parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
|
assert DateFilter().parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
|
||||||
assert date_filter.parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
|
assert DateFilter().parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
|
||||||
|
|
||||||
|
|
||||||
def test_date_filter_regex():
|
def test_date_filter_regex():
|
||||||
dtrange_match = re.findall(date_filter.date_regex, 'multi word head dt>"today" dt:"1984-01-01"')
|
dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>"today" dt:"1984-01-01"')
|
||||||
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
|
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
|
||||||
|
|
||||||
dtrange_match = re.findall(date_filter.date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail')
|
dtrange_match = re.findall(DateFilter().date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail')
|
||||||
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
|
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
|
||||||
|
|
||||||
dtrange_match = re.findall(date_filter.date_regex, 'multi word head dt>="today" dt="1984-01-01"')
|
dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>="today" dt="1984-01-01"')
|
||||||
assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')]
|
assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')]
|
||||||
|
|
||||||
dtrange_match = re.findall(date_filter.date_regex, 'dt<"multi word date" multi word tail')
|
dtrange_match = re.findall(DateFilter().date_regex, 'dt<"multi word date" multi word tail')
|
||||||
assert dtrange_match == [('<', 'multi word date')]
|
assert dtrange_match == [('<', 'multi word date')]
|
||||||
|
|
||||||
dtrange_match = re.findall(date_filter.date_regex, 'head dt<="multi word date"')
|
dtrange_match = re.findall(DateFilter().date_regex, 'head dt<="multi word date"')
|
||||||
assert dtrange_match == [('<=', 'multi word date')]
|
assert dtrange_match == [('<=', 'multi word date')]
|
||||||
|
|
||||||
dtrange_match = re.findall(date_filter.date_regex, 'head tail')
|
dtrange_match = re.findall(DateFilter().date_regex, 'head tail')
|
||||||
assert dtrange_match == []
|
assert dtrange_match == []
|
||||||
Reference in New Issue
Block a user