mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 21:29:12 +00:00
Create and use a context manager to time code
Use the timer context manager in all places where code was being timed - Benefits - Deduplicate timing code scattered across codebase. - Provides single place to manage perf timing code - Use consistent timing log patterns
This commit is contained in:
@@ -6,7 +6,7 @@ import time
|
|||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.processor.text_to_jsonl import TextToJsonl
|
from src.processor.text_to_jsonl import TextToJsonl
|
||||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
from src.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||||
from src.utils.constants import empty_escape_sequences
|
from src.utils.constants import empty_escape_sequences
|
||||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||||
from src.utils.rawconfig import Entry
|
from src.utils.rawconfig import Entry
|
||||||
@@ -30,38 +30,30 @@ class BeancountToJsonl(TextToJsonl):
|
|||||||
beancount_files = BeancountToJsonl.get_beancount_files(beancount_files, beancount_file_filter)
|
beancount_files = BeancountToJsonl.get_beancount_files(beancount_files, beancount_file_filter)
|
||||||
|
|
||||||
# Extract Entries from specified Beancount files
|
# Extract Entries from specified Beancount files
|
||||||
start = time.time()
|
with timer("Parse transactions from Beancount files into dictionaries", logger):
|
||||||
current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files))
|
current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files))
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds")
|
|
||||||
|
|
||||||
# Split entries by max tokens supported by model
|
# Split entries by max tokens supported by model
|
||||||
start = time.time()
|
with timer("Split entries by max token size supported by model", logger):
|
||||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
|
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
start = time.time()
|
with timer("Identify new or updated transaction", logger):
|
||||||
if not previous_entries:
|
if not previous_entries:
|
||||||
entries_with_ids = list(enumerate(current_entries))
|
entries_with_ids = list(enumerate(current_entries))
|
||||||
else:
|
else:
|
||||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Identify new or updated transaction: {end - start} seconds")
|
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
with timer("Write transactions to JSONL file", logger):
|
||||||
start = time.time()
|
# Process Each Entry from All Notes Files
|
||||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||||
jsonl_data = BeancountToJsonl.convert_transaction_maps_to_jsonl(entries)
|
jsonl_data = BeancountToJsonl.convert_transaction_maps_to_jsonl(entries)
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
# Compress JSONL formatted Data
|
||||||
if output_file.suffix == ".gz":
|
if output_file.suffix == ".gz":
|
||||||
compress_jsonl_data(jsonl_data, output_file)
|
compress_jsonl_data(jsonl_data, output_file)
|
||||||
elif output_file.suffix == ".jsonl":
|
elif output_file.suffix == ".jsonl":
|
||||||
dump_jsonl(jsonl_data, output_file)
|
dump_jsonl(jsonl_data, output_file)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Write transactions to JSONL file: {end - start} seconds")
|
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import time
|
|||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.processor.text_to_jsonl import TextToJsonl
|
from src.processor.text_to_jsonl import TextToJsonl
|
||||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
from src.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||||
from src.utils.constants import empty_escape_sequences
|
from src.utils.constants import empty_escape_sequences
|
||||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||||
from src.utils.rawconfig import Entry
|
from src.utils.rawconfig import Entry
|
||||||
@@ -30,38 +30,30 @@ class MarkdownToJsonl(TextToJsonl):
|
|||||||
markdown_files = MarkdownToJsonl.get_markdown_files(markdown_files, markdown_file_filter)
|
markdown_files = MarkdownToJsonl.get_markdown_files(markdown_files, markdown_file_filter)
|
||||||
|
|
||||||
# Extract Entries from specified Markdown files
|
# Extract Entries from specified Markdown files
|
||||||
start = time.time()
|
with timer("Parse entries from Markdown files into dictionaries", logger):
|
||||||
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files))
|
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files))
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds")
|
|
||||||
|
|
||||||
# Split entries by max tokens supported by model
|
# Split entries by max tokens supported by model
|
||||||
start = time.time()
|
with timer("Split entries by max token size supported by model", logger):
|
||||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
|
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
start = time.time()
|
with timer("Identify new or updated entries", logger):
|
||||||
if not previous_entries:
|
if not previous_entries:
|
||||||
entries_with_ids = list(enumerate(current_entries))
|
entries_with_ids = list(enumerate(current_entries))
|
||||||
else:
|
else:
|
||||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Identify new or updated entries: {end - start} seconds")
|
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
with timer("Write markdown entries to JSONL file", logger):
|
||||||
start = time.time()
|
# Process Each Entry from All Notes Files
|
||||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||||
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
# Compress JSONL formatted Data
|
||||||
if output_file.suffix == ".gz":
|
if output_file.suffix == ".gz":
|
||||||
compress_jsonl_data(jsonl_data, output_file)
|
compress_jsonl_data(jsonl_data, output_file)
|
||||||
elif output_file.suffix == ".jsonl":
|
elif output_file.suffix == ".jsonl":
|
||||||
dump_jsonl(jsonl_data, output_file)
|
dump_jsonl(jsonl_data, output_file)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Write markdown entries to JSONL file: {end - start} seconds")
|
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Iterable
|
|||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.processor.org_mode import orgnode
|
from src.processor.org_mode import orgnode
|
||||||
from src.processor.text_to_jsonl import TextToJsonl
|
from src.processor.text_to_jsonl import TextToJsonl
|
||||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
from src.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||||
from src.utils.rawconfig import Entry
|
from src.utils.rawconfig import Entry
|
||||||
from src.utils import state
|
from src.utils import state
|
||||||
@@ -29,24 +29,18 @@ class OrgToJsonl(TextToJsonl):
|
|||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
# Get Org Files to Process
|
# Get Org Files to Process
|
||||||
start = time.time()
|
with timer("Get org files to process", logger):
|
||||||
org_files = OrgToJsonl.get_org_files(org_files, org_file_filter)
|
org_files = OrgToJsonl.get_org_files(org_files, org_file_filter)
|
||||||
|
|
||||||
# Extract Entries from specified Org files
|
# Extract Entries from specified Org files
|
||||||
start = time.time()
|
with timer("Parse entries from org files into OrgNode objects", logger):
|
||||||
entry_nodes, file_to_entries = self.extract_org_entries(org_files)
|
entry_nodes, file_to_entries = self.extract_org_entries(org_files)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds")
|
|
||||||
|
|
||||||
start = time.time()
|
with timer("Convert OrgNodes into list of entries", logger):
|
||||||
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
|
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Convert OrgNodes into list of entries: {end - start} seconds")
|
|
||||||
|
|
||||||
start = time.time()
|
with timer("Split entries by max token size supported by model", logger):
|
||||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
|
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
if not previous_entries:
|
if not previous_entries:
|
||||||
@@ -55,17 +49,15 @@ class OrgToJsonl(TextToJsonl):
|
|||||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
# Process Each Entry from All Notes Files
|
||||||
start = time.time()
|
with timer("Write org entries to JSONL file", logger):
|
||||||
entries = map(lambda entry: entry[1], entries_with_ids)
|
entries = map(lambda entry: entry[1], entries_with_ids)
|
||||||
jsonl_data = self.convert_org_entries_to_jsonl(entries)
|
jsonl_data = self.convert_org_entries_to_jsonl(entries)
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
# Compress JSONL formatted Data
|
||||||
if output_file.suffix == ".gz":
|
if output_file.suffix == ".gz":
|
||||||
compress_jsonl_data(jsonl_data, output_file)
|
compress_jsonl_data(jsonl_data, output_file)
|
||||||
elif output_file.suffix == ".jsonl":
|
elif output_file.suffix == ".jsonl":
|
||||||
dump_jsonl(jsonl_data, output_file)
|
dump_jsonl(jsonl_data, output_file)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Write org entries to JSONL file: {end - start} seconds")
|
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import hashlib
|
|||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
from src.utils.helpers import timer
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils.rawconfig import Entry, TextContentConfig
|
from src.utils.rawconfig import Entry, TextContentConfig
|
||||||
@@ -40,35 +41,31 @@ class TextToJsonl(ABC):
|
|||||||
|
|
||||||
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
|
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
|
||||||
# Hash all current and previous entries to identify new entries
|
# Hash all current and previous entries to identify new entries
|
||||||
start = time.time()
|
with timer("Hash previous, current entries", logger):
|
||||||
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
|
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
|
||||||
previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries))
|
previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries))
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Hash previous, current entries: {end - start} seconds")
|
|
||||||
|
|
||||||
start = time.time()
|
with timer("Identify, Mark, Combine new, existing entries", logger):
|
||||||
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||||
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
|
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
|
||||||
|
|
||||||
# All entries that did not exist in the previous set are to be added
|
# All entries that did not exist in the previous set are to be added
|
||||||
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
|
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
|
||||||
# All entries that exist in both current and previous sets are kept
|
# All entries that exist in both current and previous sets are kept
|
||||||
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
|
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
|
||||||
|
|
||||||
# Mark new entries with -1 id to flag for later embeddings generation
|
# Mark new entries with -1 id to flag for later embeddings generation
|
||||||
new_entries = [
|
new_entries = [
|
||||||
(-1, hash_to_current_entries[entry_hash])
|
(-1, hash_to_current_entries[entry_hash])
|
||||||
for entry_hash in new_entry_hashes
|
for entry_hash in new_entry_hashes
|
||||||
]
|
]
|
||||||
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
|
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
|
||||||
existing_entries = [
|
existing_entries = [
|
||||||
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
|
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
|
||||||
for entry_hash in existing_entry_hashes
|
for entry_hash in existing_entry_hashes
|
||||||
]
|
]
|
||||||
|
|
||||||
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
|
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
|
||||||
entries_with_ids = existing_entries_sorted + new_entries
|
entries_with_ids = existing_entries_sorted + new_entries
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
|
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
@@ -10,6 +10,7 @@ from fastapi import APIRouter
|
|||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.configure import configure_processor, configure_search
|
from src.configure import configure_processor, configure_search
|
||||||
from src.search_type import image_search, text_search
|
from src.search_type import image_search, text_search
|
||||||
|
from src.utils.helpers import timer
|
||||||
from src.utils.rawconfig import FullConfig, SearchResponse
|
from src.utils.rawconfig import FullConfig, SearchResponse
|
||||||
from src.utils.config import SearchType
|
from src.utils.config import SearchType
|
||||||
from src.utils import state, constants
|
from src.utils import state, constants
|
||||||
@@ -47,7 +48,6 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||||||
# initialize variables
|
# initialize variables
|
||||||
user_query = q.strip()
|
user_query = q.strip()
|
||||||
results_count = n
|
results_count = n
|
||||||
query_start, query_end, collate_start, collate_end = None, None, None, None
|
|
||||||
|
|
||||||
# return cached results, if available
|
# return cached results, if available
|
||||||
query_cache_key = f'{user_query}-{n}-{t}-{r}'
|
query_cache_key = f'{user_query}-{n}-{t}-{r}'
|
||||||
@@ -57,73 +57,58 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||||||
|
|
||||||
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
||||||
# query org-mode notes
|
# query org-mode notes
|
||||||
query_start = time.time()
|
with timer("Query took", logger):
|
||||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
|
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
|
||||||
query_end = time.time()
|
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
collate_start = time.time()
|
with timer("Collating results took", logger):
|
||||||
results = text_search.collate_results(hits, entries, results_count)
|
results = text_search.collate_results(hits, entries, results_count)
|
||||||
collate_end = time.time()
|
|
||||||
|
|
||||||
if (t == SearchType.Music or t == None) and state.model.music_search:
|
if (t == SearchType.Music or t == None) and state.model.music_search:
|
||||||
# query music library
|
# query music library
|
||||||
query_start = time.time()
|
with timer("Query took", logger):
|
||||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
|
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
|
||||||
query_end = time.time()
|
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
collate_start = time.time()
|
with timer("Collating results took", logger):
|
||||||
results = text_search.collate_results(hits, entries, results_count)
|
results = text_search.collate_results(hits, entries, results_count)
|
||||||
collate_end = time.time()
|
|
||||||
|
|
||||||
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||||
# query markdown files
|
# query markdown files
|
||||||
query_start = time.time()
|
with timer("Query took", logger):
|
||||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
|
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
|
||||||
query_end = time.time()
|
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
collate_start = time.time()
|
with timer("Collating results took", logger):
|
||||||
results = text_search.collate_results(hits, entries, results_count)
|
results = text_search.collate_results(hits, entries, results_count)
|
||||||
collate_end = time.time()
|
|
||||||
|
|
||||||
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
||||||
# query transactions
|
# query transactions
|
||||||
query_start = time.time()
|
with timer("Query took", logger):
|
||||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
|
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
|
||||||
query_end = time.time()
|
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
collate_start = time.time()
|
with timer("Collating results took", logger):
|
||||||
results = text_search.collate_results(hits, entries, results_count)
|
results = text_search.collate_results(hits, entries, results_count)
|
||||||
collate_end = time.time()
|
|
||||||
|
|
||||||
if (t == SearchType.Image or t == None) and state.model.image_search:
|
if (t == SearchType.Image or t == None) and state.model.image_search:
|
||||||
# query images
|
# query images
|
||||||
query_start = time.time()
|
with timer("Query took", logger):
|
||||||
hits = image_search.query(user_query, results_count, state.model.image_search)
|
hits = image_search.query(user_query, results_count, state.model.image_search)
|
||||||
output_directory = constants.web_directory / 'images'
|
output_directory = constants.web_directory / 'images'
|
||||||
query_end = time.time()
|
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
collate_start = time.time()
|
with timer("Collating results took", logger):
|
||||||
results = image_search.collate_results(
|
results = image_search.collate_results(
|
||||||
hits,
|
hits,
|
||||||
image_names=state.model.image_search.image_names,
|
image_names=state.model.image_search.image_names,
|
||||||
output_directory=output_directory,
|
output_directory=output_directory,
|
||||||
image_files_url='/static/images',
|
image_files_url='/static/images',
|
||||||
count=results_count)
|
count=results_count)
|
||||||
collate_end = time.time()
|
|
||||||
|
|
||||||
# Cache results
|
# Cache results
|
||||||
state.query_cache[query_cache_key] = results
|
state.query_cache[query_cache_key] = results
|
||||||
|
|
||||||
if query_start and query_end:
|
|
||||||
logger.debug(f"Query took {query_end - query_start:.3f} seconds")
|
|
||||||
if collate_start and collate_end:
|
|
||||||
logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds")
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import dateparser as dtparse
|
|||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.search_filter.base_filter import BaseFilter
|
from src.search_filter.base_filter import BaseFilter
|
||||||
from src.utils.helpers import LRU
|
from src.utils.helpers import LRU, timer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -34,19 +34,16 @@ class DateFilter(BaseFilter):
|
|||||||
|
|
||||||
|
|
||||||
def load(self, entries, *args, **kwargs):
|
def load(self, entries, *args, **kwargs):
|
||||||
start = time.time()
|
with timer("Created date filter index", logger):
|
||||||
for id, entry in enumerate(entries):
|
for id, entry in enumerate(entries):
|
||||||
# Extract dates from entry
|
# Extract dates from entry
|
||||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
|
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
|
||||||
# Convert date string in entry to unix timestamp
|
# Convert date string in entry to unix timestamp
|
||||||
try:
|
try:
|
||||||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
self.date_to_entry_ids[date_in_entry].add(id)
|
self.date_to_entry_ids[date_in_entry].add(id)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Created date filter index: {end - start} seconds")
|
|
||||||
|
|
||||||
|
|
||||||
def can_filter(self, raw_query):
|
def can_filter(self, raw_query):
|
||||||
"Check if query contains date filters"
|
"Check if query contains date filters"
|
||||||
@@ -56,10 +53,8 @@ class DateFilter(BaseFilter):
|
|||||||
def apply(self, query, entries):
|
def apply(self, query, entries):
|
||||||
"Find entries containing any dates that fall within date range specified in query"
|
"Find entries containing any dates that fall within date range specified in query"
|
||||||
# extract date range specified in date filter of query
|
# extract date range specified in date filter of query
|
||||||
start = time.time()
|
with timer("Extract date range to filter from query", logger):
|
||||||
query_daterange = self.extract_date_range(query)
|
query_daterange = self.extract_date_range(query)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Extract date range to filter from query: {end - start} seconds")
|
|
||||||
|
|
||||||
# if no date in query, return all entries
|
# if no date in query, return all entries
|
||||||
if query_daterange is None:
|
if query_daterange is None:
|
||||||
@@ -80,14 +75,12 @@ class DateFilter(BaseFilter):
|
|||||||
self.load(entries)
|
self.load(entries)
|
||||||
|
|
||||||
# find entries containing any dates that fall with date range specified in query
|
# find entries containing any dates that fall with date range specified in query
|
||||||
start = time.time()
|
with timer("Mark entries satisfying filter", logger):
|
||||||
entries_to_include = set()
|
entries_to_include = set()
|
||||||
for date_in_entry in self.date_to_entry_ids.keys():
|
for date_in_entry in self.date_to_entry_ids.keys():
|
||||||
# Check if date in entry is within date range specified in query
|
# Check if date in entry is within date range specified in query
|
||||||
if query_daterange[0] <= date_in_entry < query_daterange[1]:
|
if query_daterange[0] <= date_in_entry < query_daterange[1]:
|
||||||
entries_to_include |= self.date_to_entry_ids[date_in_entry]
|
entries_to_include |= self.date_to_entry_ids[date_in_entry]
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
|
||||||
|
|
||||||
# cache results
|
# cache results
|
||||||
self.cache[cache_key] = entries_to_include
|
self.cache[cache_key] = entries_to_include
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from collections import defaultdict
|
|||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.search_filter.base_filter import BaseFilter
|
from src.search_filter.base_filter import BaseFilter
|
||||||
from src.utils.helpers import LRU
|
from src.utils.helpers import LRU, timer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -22,32 +22,28 @@ class FileFilter(BaseFilter):
|
|||||||
self.cache = LRU()
|
self.cache = LRU()
|
||||||
|
|
||||||
def load(self, entries, *args, **kwargs):
|
def load(self, entries, *args, **kwargs):
|
||||||
start = time.time()
|
with timer("Created file filter index", logger):
|
||||||
for id, entry in enumerate(entries):
|
for id, entry in enumerate(entries):
|
||||||
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
|
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Created file filter index: {end - start} seconds")
|
|
||||||
|
|
||||||
def can_filter(self, raw_query):
|
def can_filter(self, raw_query):
|
||||||
return re.search(self.file_filter_regex, raw_query) is not None
|
return re.search(self.file_filter_regex, raw_query) is not None
|
||||||
|
|
||||||
def apply(self, query, entries):
|
def apply(self, query, entries):
|
||||||
# Extract file filters from raw query
|
# Extract file filters from raw query
|
||||||
start = time.time()
|
with timer("Extract files_to_search from query", logger):
|
||||||
raw_files_to_search = re.findall(self.file_filter_regex, query)
|
raw_files_to_search = re.findall(self.file_filter_regex, query)
|
||||||
if not raw_files_to_search:
|
if not raw_files_to_search:
|
||||||
return query, set(range(len(entries)))
|
return query, set(range(len(entries)))
|
||||||
|
|
||||||
# Convert simple file filters with no path separator into regex
|
# Convert simple file filters with no path separator into regex
|
||||||
# e.g. "file:notes.org" -> "file:.*notes.org"
|
# e.g. "file:notes.org" -> "file:.*notes.org"
|
||||||
files_to_search = []
|
files_to_search = []
|
||||||
for file in sorted(raw_files_to_search):
|
for file in sorted(raw_files_to_search):
|
||||||
if '/' not in file and '\\' not in file and '*' not in file:
|
if '/' not in file and '\\' not in file and '*' not in file:
|
||||||
files_to_search += [f'*{file}']
|
files_to_search += [f'*{file}']
|
||||||
else:
|
else:
|
||||||
files_to_search += [file]
|
files_to_search += [file]
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Extract files_to_search from query: {end - start} seconds")
|
|
||||||
|
|
||||||
# Return item from cache if exists
|
# Return item from cache if exists
|
||||||
query = re.sub(self.file_filter_regex, '', query).strip()
|
query = re.sub(self.file_filter_regex, '', query).strip()
|
||||||
@@ -61,17 +57,13 @@ class FileFilter(BaseFilter):
|
|||||||
self.load(entries, regenerate=False)
|
self.load(entries, regenerate=False)
|
||||||
|
|
||||||
# Mark entries that contain any blocked_words for exclusion
|
# Mark entries that contain any blocked_words for exclusion
|
||||||
start = time.time()
|
with timer("Mark entries satisfying filter", logger):
|
||||||
|
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
|
||||||
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
|
for entry_file in self.file_to_entry_map.keys()
|
||||||
for entry_file in self.file_to_entry_map.keys()
|
for search_file in files_to_search
|
||||||
for search_file in files_to_search
|
if fnmatch.fnmatch(entry_file, search_file)], set())
|
||||||
if fnmatch.fnmatch(entry_file, search_file)], set())
|
if not included_entry_indices:
|
||||||
if not included_entry_indices:
|
return query, {}
|
||||||
return query, {}
|
|
||||||
|
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
|
||||||
|
|
||||||
# Cache results
|
# Cache results
|
||||||
self.cache[cache_key] = included_entry_indices
|
self.cache[cache_key] = included_entry_indices
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from collections import defaultdict
|
|||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.search_filter.base_filter import BaseFilter
|
from src.search_filter.base_filter import BaseFilter
|
||||||
from src.utils.helpers import LRU
|
from src.utils.helpers import LRU, timer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -24,17 +24,15 @@ class WordFilter(BaseFilter):
|
|||||||
|
|
||||||
|
|
||||||
def load(self, entries, *args, **kwargs):
|
def load(self, entries, *args, **kwargs):
|
||||||
start = time.time()
|
with timer("Created word filter index", logger):
|
||||||
self.cache = {} # Clear cache on filter (re-)load
|
self.cache = {} # Clear cache on filter (re-)load
|
||||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
|
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
|
||||||
# Create map of words to entries they exist in
|
# Create map of words to entries they exist in
|
||||||
for entry_index, entry in enumerate(entries):
|
for entry_index, entry in enumerate(entries):
|
||||||
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
|
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
|
||||||
if word == '':
|
if word == '':
|
||||||
continue
|
continue
|
||||||
self.word_to_entry_index[word].add(entry_index)
|
self.word_to_entry_index[word].add(entry_index)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Created word filter index: {end - start} seconds")
|
|
||||||
|
|
||||||
return self.word_to_entry_index
|
return self.word_to_entry_index
|
||||||
|
|
||||||
@@ -50,14 +48,10 @@ class WordFilter(BaseFilter):
|
|||||||
def apply(self, query, entries):
|
def apply(self, query, entries):
|
||||||
"Find entries containing required and not blocked words specified in query"
|
"Find entries containing required and not blocked words specified in query"
|
||||||
# Separate natural query from required, blocked words filters
|
# Separate natural query from required, blocked words filters
|
||||||
start = time.time()
|
with timer("Extract required, blocked filters from query", logger):
|
||||||
|
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
||||||
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
||||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
|
||||||
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
|
|
||||||
|
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
|
|
||||||
|
|
||||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||||
return query, set(range(len(entries)))
|
return query, set(range(len(entries)))
|
||||||
@@ -72,20 +66,16 @@ class WordFilter(BaseFilter):
|
|||||||
if not self.word_to_entry_index:
|
if not self.word_to_entry_index:
|
||||||
self.load(entries, regenerate=False)
|
self.load(entries, regenerate=False)
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
|
|
||||||
# mark entries that contain all required_words for inclusion
|
# mark entries that contain all required_words for inclusion
|
||||||
entries_with_all_required_words = set(range(len(entries)))
|
with timer("Mark entries satisfying filter", logger):
|
||||||
if len(required_words) > 0:
|
entries_with_all_required_words = set(range(len(entries)))
|
||||||
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
|
if len(required_words) > 0:
|
||||||
|
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
|
||||||
|
|
||||||
# mark entries that contain any blocked_words for exclusion
|
# mark entries that contain any blocked_words for exclusion
|
||||||
entries_with_any_blocked_words = set()
|
entries_with_any_blocked_words = set()
|
||||||
if len(blocked_words) > 0:
|
if len(blocked_words) > 0:
|
||||||
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
|
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
|
||||||
|
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
|
||||||
|
|
||||||
# get entries satisfying inclusion and exclusion filters
|
# get entries satisfying inclusion and exclusion filters
|
||||||
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
|
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from tqdm import trange
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model
|
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
|
||||||
from src.utils.config import ImageSearchModel
|
from src.utils.config import ImageSearchModel
|
||||||
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
|
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
|
||||||
|
|
||||||
@@ -147,27 +147,21 @@ def query(raw_query, count, model: ImageSearchModel):
|
|||||||
logger.info(f"Find Images by Text: {query}")
|
logger.info(f"Find Images by Text: {query}")
|
||||||
|
|
||||||
# Now we encode the query (which can either be an image or a text string)
|
# Now we encode the query (which can either be an image or a text string)
|
||||||
start = time.time()
|
with timer("Query Encode Time", logger):
|
||||||
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Query Encode Time: {end - start:.3f} seconds")
|
|
||||||
|
|
||||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||||
start = time.time()
|
with timer("Search Time", logger):
|
||||||
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
|
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
|
||||||
for result
|
for result
|
||||||
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
|
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Search Time: {end - start:.3f} seconds")
|
|
||||||
|
|
||||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
||||||
if model.image_metadata_embeddings:
|
if model.image_metadata_embeddings:
|
||||||
start = time.time()
|
with timer("Metadata Search Time", logger):
|
||||||
metadata_hits = {result['corpus_id']: result['score']
|
metadata_hits = {result['corpus_id']: result['score']
|
||||||
for result
|
for result
|
||||||
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
|
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Metadata Search Time: {end - start:.3f} seconds")
|
|
||||||
|
|
||||||
# Sum metadata, image scores of the highest ranked images
|
# Sum metadata, image scores of the highest ranked images
|
||||||
for corpus_id, score in metadata_hits.items():
|
for corpus_id, score in metadata_hits.items():
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from src.search_filter.base_filter import BaseFilter
|
|||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils import state
|
from src.utils import state
|
||||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
|
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
|
||||||
from src.utils.config import TextSearchModel
|
from src.utils.config import TextSearchModel
|
||||||
from src.utils.models import BaseEncoder
|
from src.utils.models import BaseEncoder
|
||||||
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
|
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
|
||||||
@@ -96,6 +96,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
|
|||||||
|
|
||||||
# Filter query, entries and embeddings before semantic search
|
# Filter query, entries and embeddings before semantic search
|
||||||
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
|
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
|
||||||
|
|
||||||
# If no entries left after filtering, return empty results
|
# If no entries left after filtering, return empty results
|
||||||
if entries is None or len(entries) == 0:
|
if entries is None or len(entries) == 0:
|
||||||
return [], []
|
return [], []
|
||||||
@@ -105,17 +106,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
|
|||||||
return hits, entries
|
return hits, entries
|
||||||
|
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
start = time.time()
|
with timer("Query Encode Time", logger, state.device):
|
||||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||||
question_embedding = util.normalize_embeddings(question_embedding)
|
question_embedding = util.normalize_embeddings(question_embedding)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
|
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
start = time.time()
|
with timer("Search Time", logger, state.device):
|
||||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
|
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
if rank_results:
|
if rank_results:
|
||||||
@@ -170,36 +167,29 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
|
|||||||
|
|
||||||
def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]:
|
def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]:
|
||||||
'''Filter query, entries and embeddings before semantic search'''
|
'''Filter query, entries and embeddings before semantic search'''
|
||||||
start_filter = time.time()
|
|
||||||
included_entry_indices = set(range(len(entries)))
|
|
||||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
|
||||||
for filter in filters_in_query:
|
|
||||||
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
|
||||||
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
|
||||||
|
|
||||||
# Get entries (and associated embeddings) satisfying all filters
|
with timer("Total Filter Time", logger, state.device):
|
||||||
if not included_entry_indices:
|
included_entry_indices = set(range(len(entries)))
|
||||||
return '', [], torch.tensor([], device=state.device)
|
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||||
else:
|
for filter in filters_in_query:
|
||||||
start = time.time()
|
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
||||||
entries = [entries[id] for id in included_entry_indices]
|
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
||||||
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
|
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
|
|
||||||
|
|
||||||
end_filter = time.time()
|
# Get entries (and associated embeddings) satisfying all filters
|
||||||
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}")
|
if not included_entry_indices:
|
||||||
|
return '', [], torch.tensor([], device=state.device)
|
||||||
|
else:
|
||||||
|
entries = [entries[id] for id in included_entry_indices]
|
||||||
|
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
|
||||||
|
|
||||||
return query, entries, corpus_embeddings
|
return query, entries, corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]:
|
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]:
|
||||||
'''Score all retrieved entries using the cross-encoder'''
|
'''Score all retrieved entries using the cross-encoder'''
|
||||||
start = time.time()
|
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||||
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
||||||
cross_scores = cross_encoder.predict(cross_inp)
|
cross_scores = cross_encoder.predict(cross_inp)
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
|
||||||
|
|
||||||
# Store cross-encoder scores in results dictionary for ranking
|
# Store cross-encoder scores in results dictionary for ranking
|
||||||
for idx in range(len(cross_scores)):
|
for idx in range(len(cross_scores)):
|
||||||
@@ -210,12 +200,10 @@ def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[E
|
|||||||
|
|
||||||
def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
|
def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
|
||||||
'''Order results by cross-encoder score followed by bi-encoder score'''
|
'''Order results by cross-encoder score followed by bi-encoder score'''
|
||||||
start = time.time()
|
with timer("Rank Time", logger, state.device):
|
||||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||||
if rank_results:
|
if rank_results:
|
||||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
|
|
||||||
@@ -223,11 +211,12 @@ def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]:
|
|||||||
'''Deduplicate entries by raw entry text before showing to users
|
'''Deduplicate entries by raw entry text before showing to users
|
||||||
Compiled entries are split by max tokens supported by ML models.
|
Compiled entries are split by max tokens supported by ML models.
|
||||||
This can result in duplicate hits, entries shown to user.'''
|
This can result in duplicate hits, entries shown to user.'''
|
||||||
start = time.time()
|
|
||||||
seen, original_hits_count = set(), len(hits)
|
with timer("Deduplication Time", logger, state.device):
|
||||||
hits = [hit for hit in hits
|
seen, original_hits_count = set(), len(hits)
|
||||||
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
|
hits = [hit for hit in hits
|
||||||
duplicate_hits = original_hits_count - len(hits)
|
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
|
||||||
end = time.time()
|
duplicate_hits = original_hits_count - len(hits)
|
||||||
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
|
|
||||||
|
logger.debug(f"Removed {duplicate_hits} duplicates")
|
||||||
return hits
|
return hits
|
||||||
|
|||||||
@@ -2,10 +2,12 @@
|
|||||||
from __future__ import annotations # to avoid quoting type hints
|
from __future__ import annotations # to avoid quoting type hints
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
import torch
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from os.path import join
|
from os.path import join
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from time import perf_counter
|
||||||
from typing import Optional, Union, TYPE_CHECKING
|
from typing import Optional, Union, TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -81,6 +83,25 @@ def get_class_by_name(name: str) -> object:
|
|||||||
return getattr(import_module(module_name), class_name)
|
return getattr(import_module(module_name), class_name)
|
||||||
|
|
||||||
|
|
||||||
|
class timer:
|
||||||
|
'''Context manager to log time taken for a block of code to run'''
|
||||||
|
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
|
||||||
|
self.message = message
|
||||||
|
self.logger = logger
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start = perf_counter()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *_):
|
||||||
|
elapsed = perf_counter() - self.start
|
||||||
|
if self.device is None:
|
||||||
|
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds")
|
||||||
|
else:
|
||||||
|
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}")
|
||||||
|
|
||||||
|
|
||||||
class LRU(OrderedDict):
|
class LRU(OrderedDict):
|
||||||
def __init__(self, *args, capacity=128, **kwargs):
|
def __init__(self, *args, capacity=128, **kwargs):
|
||||||
self.capacity = capacity
|
self.capacity = capacity
|
||||||
|
|||||||
Reference in New Issue
Block a user