mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-03 13:19:16 +00:00
Create and use a context manager to time code
Use the timer context manager in all places where code was being timed - Benefits - Deduplicate timing code scattered across codebase. - Provides single place to manage perf timing code - Use consistent timing log patterns
This commit is contained in:
@@ -6,7 +6,7 @@ import time
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.text_to_jsonl import TextToJsonl
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||
from src.utils.constants import empty_escape_sequences
|
||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from src.utils.rawconfig import Entry
|
||||
@@ -30,38 +30,30 @@ class BeancountToJsonl(TextToJsonl):
|
||||
beancount_files = BeancountToJsonl.get_beancount_files(beancount_files, beancount_file_filter)
|
||||
|
||||
# Extract Entries from specified Beancount files
|
||||
start = time.time()
|
||||
current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files))
|
||||
end = time.time()
|
||||
logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds")
|
||||
with timer("Parse transactions from Beancount files into dictionaries", logger):
|
||||
current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files))
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
start = time.time()
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
end = time.time()
|
||||
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
start = time.time()
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
end = time.time()
|
||||
logger.debug(f"Identify new or updated transaction: {end - start} seconds")
|
||||
with timer("Identify new or updated transaction", logger):
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
start = time.time()
|
||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||
jsonl_data = BeancountToJsonl.convert_transaction_maps_to_jsonl(entries)
|
||||
with timer("Write transactions to JSONL file", logger):
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||
jsonl_data = BeancountToJsonl.convert_transaction_maps_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
end = time.time()
|
||||
logger.debug(f"Write transactions to JSONL file: {end - start} seconds")
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
|
||||
return entries_with_ids
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import time
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.text_to_jsonl import TextToJsonl
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||
from src.utils.constants import empty_escape_sequences
|
||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from src.utils.rawconfig import Entry
|
||||
@@ -30,38 +30,30 @@ class MarkdownToJsonl(TextToJsonl):
|
||||
markdown_files = MarkdownToJsonl.get_markdown_files(markdown_files, markdown_file_filter)
|
||||
|
||||
# Extract Entries from specified Markdown files
|
||||
start = time.time()
|
||||
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files))
|
||||
end = time.time()
|
||||
logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds")
|
||||
with timer("Parse entries from Markdown files into dictionaries", logger):
|
||||
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files))
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
start = time.time()
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
end = time.time()
|
||||
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
start = time.time()
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
end = time.time()
|
||||
logger.debug(f"Identify new or updated entries: {end - start} seconds")
|
||||
with timer("Identify new or updated entries", logger):
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
start = time.time()
|
||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
||||
with timer("Write markdown entries to JSONL file", logger):
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
end = time.time()
|
||||
logger.debug(f"Write markdown entries to JSONL file: {end - start} seconds")
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
|
||||
return entries_with_ids
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Iterable
|
||||
# Internal Packages
|
||||
from src.processor.org_mode import orgnode
|
||||
from src.processor.text_to_jsonl import TextToJsonl
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from src.utils.rawconfig import Entry
|
||||
from src.utils import state
|
||||
@@ -29,24 +29,18 @@ class OrgToJsonl(TextToJsonl):
|
||||
exit(1)
|
||||
|
||||
# Get Org Files to Process
|
||||
start = time.time()
|
||||
org_files = OrgToJsonl.get_org_files(org_files, org_file_filter)
|
||||
with timer("Get org files to process", logger):
|
||||
org_files = OrgToJsonl.get_org_files(org_files, org_file_filter)
|
||||
|
||||
# Extract Entries from specified Org files
|
||||
start = time.time()
|
||||
entry_nodes, file_to_entries = self.extract_org_entries(org_files)
|
||||
end = time.time()
|
||||
logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds")
|
||||
with timer("Parse entries from org files into OrgNode objects", logger):
|
||||
entry_nodes, file_to_entries = self.extract_org_entries(org_files)
|
||||
|
||||
start = time.time()
|
||||
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
|
||||
end = time.time()
|
||||
logger.debug(f"Convert OrgNodes into list of entries: {end - start} seconds")
|
||||
with timer("Convert OrgNodes into list of entries", logger):
|
||||
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
|
||||
|
||||
start = time.time()
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
end = time.time()
|
||||
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
if not previous_entries:
|
||||
@@ -55,17 +49,15 @@ class OrgToJsonl(TextToJsonl):
|
||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
start = time.time()
|
||||
entries = map(lambda entry: entry[1], entries_with_ids)
|
||||
jsonl_data = self.convert_org_entries_to_jsonl(entries)
|
||||
with timer("Write org entries to JSONL file", logger):
|
||||
entries = map(lambda entry: entry[1], entries_with_ids)
|
||||
jsonl_data = self.convert_org_entries_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
end = time.time()
|
||||
logger.debug(f"Write org entries to JSONL file: {end - start} seconds")
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
|
||||
return entries_with_ids
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import hashlib
|
||||
import time
|
||||
import logging
|
||||
from typing import Callable
|
||||
from src.utils.helpers import timer
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.rawconfig import Entry, TextContentConfig
|
||||
@@ -40,35 +41,31 @@ class TextToJsonl(ABC):
|
||||
|
||||
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
|
||||
# Hash all current and previous entries to identify new entries
|
||||
start = time.time()
|
||||
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
|
||||
previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries))
|
||||
end = time.time()
|
||||
logger.debug(f"Hash previous, current entries: {end - start} seconds")
|
||||
with timer("Hash previous, current entries", logger):
|
||||
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
|
||||
previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries))
|
||||
|
||||
start = time.time()
|
||||
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
|
||||
with timer("Identify, Mark, Combine new, existing entries", logger):
|
||||
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
|
||||
|
||||
# All entries that did not exist in the previous set are to be added
|
||||
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
|
||||
# All entries that exist in both current and previous sets are kept
|
||||
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
|
||||
# All entries that did not exist in the previous set are to be added
|
||||
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
|
||||
# All entries that exist in both current and previous sets are kept
|
||||
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
|
||||
|
||||
# Mark new entries with -1 id to flag for later embeddings generation
|
||||
new_entries = [
|
||||
(-1, hash_to_current_entries[entry_hash])
|
||||
for entry_hash in new_entry_hashes
|
||||
]
|
||||
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
|
||||
existing_entries = [
|
||||
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
|
||||
for entry_hash in existing_entry_hashes
|
||||
]
|
||||
# Mark new entries with -1 id to flag for later embeddings generation
|
||||
new_entries = [
|
||||
(-1, hash_to_current_entries[entry_hash])
|
||||
for entry_hash in new_entry_hashes
|
||||
]
|
||||
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
|
||||
existing_entries = [
|
||||
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
|
||||
for entry_hash in existing_entry_hashes
|
||||
]
|
||||
|
||||
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
|
||||
entries_with_ids = existing_entries_sorted + new_entries
|
||||
end = time.time()
|
||||
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
|
||||
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
|
||||
entries_with_ids = existing_entries_sorted + new_entries
|
||||
|
||||
return entries_with_ids
|
||||
@@ -10,6 +10,7 @@ from fastapi import APIRouter
|
||||
# Internal Packages
|
||||
from src.configure import configure_processor, configure_search
|
||||
from src.search_type import image_search, text_search
|
||||
from src.utils.helpers import timer
|
||||
from src.utils.rawconfig import FullConfig, SearchResponse
|
||||
from src.utils.config import SearchType
|
||||
from src.utils import state, constants
|
||||
@@ -47,7 +48,6 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||
# initialize variables
|
||||
user_query = q.strip()
|
||||
results_count = n
|
||||
query_start, query_end, collate_start, collate_end = None, None, None, None
|
||||
|
||||
# return cached results, if available
|
||||
query_cache_key = f'{user_query}-{n}-{t}-{r}'
|
||||
@@ -57,73 +57,58 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||
|
||||
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
||||
# query org-mode notes
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
|
||||
|
||||
# collate and return results
|
||||
collate_start = time.time()
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
collate_end = time.time()
|
||||
with timer("Collating results took", logger):
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
|
||||
if (t == SearchType.Music or t == None) and state.model.music_search:
|
||||
# query music library
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
|
||||
|
||||
# collate and return results
|
||||
collate_start = time.time()
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
collate_end = time.time()
|
||||
with timer("Collating results took", logger):
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
|
||||
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||
# query markdown files
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
|
||||
|
||||
# collate and return results
|
||||
collate_start = time.time()
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
collate_end = time.time()
|
||||
with timer("Collating results took", logger):
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
|
||||
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
||||
# query transactions
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
|
||||
|
||||
# collate and return results
|
||||
collate_start = time.time()
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
collate_end = time.time()
|
||||
with timer("Collating results took", logger):
|
||||
results = text_search.collate_results(hits, entries, results_count)
|
||||
|
||||
if (t == SearchType.Image or t == None) and state.model.image_search:
|
||||
# query images
|
||||
query_start = time.time()
|
||||
hits = image_search.query(user_query, results_count, state.model.image_search)
|
||||
output_directory = constants.web_directory / 'images'
|
||||
query_end = time.time()
|
||||
with timer("Query took", logger):
|
||||
hits = image_search.query(user_query, results_count, state.model.image_search)
|
||||
output_directory = constants.web_directory / 'images'
|
||||
|
||||
# collate and return results
|
||||
collate_start = time.time()
|
||||
results = image_search.collate_results(
|
||||
hits,
|
||||
image_names=state.model.image_search.image_names,
|
||||
output_directory=output_directory,
|
||||
image_files_url='/static/images',
|
||||
count=results_count)
|
||||
collate_end = time.time()
|
||||
with timer("Collating results took", logger):
|
||||
results = image_search.collate_results(
|
||||
hits,
|
||||
image_names=state.model.image_search.image_names,
|
||||
output_directory=output_directory,
|
||||
image_files_url='/static/images',
|
||||
count=results_count)
|
||||
|
||||
# Cache results
|
||||
state.query_cache[query_cache_key] = results
|
||||
|
||||
if query_start and query_end:
|
||||
logger.debug(f"Query took {query_end - query_start:.3f} seconds")
|
||||
if collate_start and collate_end:
|
||||
logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import dateparser as dtparse
|
||||
|
||||
# Internal Packages
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
from src.utils.helpers import LRU
|
||||
from src.utils.helpers import LRU, timer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -34,19 +34,16 @@ class DateFilter(BaseFilter):
|
||||
|
||||
|
||||
def load(self, entries, *args, **kwargs):
|
||||
start = time.time()
|
||||
for id, entry in enumerate(entries):
|
||||
# Extract dates from entry
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
|
||||
# Convert date string in entry to unix timestamp
|
||||
try:
|
||||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
||||
except ValueError:
|
||||
continue
|
||||
self.date_to_entry_ids[date_in_entry].add(id)
|
||||
end = time.time()
|
||||
logger.debug(f"Created date filter index: {end - start} seconds")
|
||||
|
||||
with timer("Created date filter index", logger):
|
||||
for id, entry in enumerate(entries):
|
||||
# Extract dates from entry
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
|
||||
# Convert date string in entry to unix timestamp
|
||||
try:
|
||||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
||||
except ValueError:
|
||||
continue
|
||||
self.date_to_entry_ids[date_in_entry].add(id)
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains date filters"
|
||||
@@ -56,10 +53,8 @@ class DateFilter(BaseFilter):
|
||||
def apply(self, query, entries):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
start = time.time()
|
||||
query_daterange = self.extract_date_range(query)
|
||||
end = time.time()
|
||||
logger.debug(f"Extract date range to filter from query: {end - start} seconds")
|
||||
with timer("Extract date range to filter from query", logger):
|
||||
query_daterange = self.extract_date_range(query)
|
||||
|
||||
# if no date in query, return all entries
|
||||
if query_daterange is None:
|
||||
@@ -80,14 +75,12 @@ class DateFilter(BaseFilter):
|
||||
self.load(entries)
|
||||
|
||||
# find entries containing any dates that fall with date range specified in query
|
||||
start = time.time()
|
||||
entries_to_include = set()
|
||||
for date_in_entry in self.date_to_entry_ids.keys():
|
||||
# Check if date in entry is within date range specified in query
|
||||
if query_daterange[0] <= date_in_entry < query_daterange[1]:
|
||||
entries_to_include |= self.date_to_entry_ids[date_in_entry]
|
||||
end = time.time()
|
||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
||||
with timer("Mark entries satisfying filter", logger):
|
||||
entries_to_include = set()
|
||||
for date_in_entry in self.date_to_entry_ids.keys():
|
||||
# Check if date in entry is within date range specified in query
|
||||
if query_daterange[0] <= date_in_entry < query_daterange[1]:
|
||||
entries_to_include |= self.date_to_entry_ids[date_in_entry]
|
||||
|
||||
# cache results
|
||||
self.cache[cache_key] = entries_to_include
|
||||
|
||||
@@ -7,7 +7,7 @@ from collections import defaultdict
|
||||
|
||||
# Internal Packages
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
from src.utils.helpers import LRU
|
||||
from src.utils.helpers import LRU, timer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,32 +22,28 @@ class FileFilter(BaseFilter):
|
||||
self.cache = LRU()
|
||||
|
||||
def load(self, entries, *args, **kwargs):
|
||||
start = time.time()
|
||||
for id, entry in enumerate(entries):
|
||||
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
|
||||
end = time.time()
|
||||
logger.debug(f"Created file filter index: {end - start} seconds")
|
||||
with timer("Created file filter index", logger):
|
||||
for id, entry in enumerate(entries):
|
||||
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
return re.search(self.file_filter_regex, raw_query) is not None
|
||||
|
||||
def apply(self, query, entries):
|
||||
# Extract file filters from raw query
|
||||
start = time.time()
|
||||
raw_files_to_search = re.findall(self.file_filter_regex, query)
|
||||
if not raw_files_to_search:
|
||||
return query, set(range(len(entries)))
|
||||
with timer("Extract files_to_search from query", logger):
|
||||
raw_files_to_search = re.findall(self.file_filter_regex, query)
|
||||
if not raw_files_to_search:
|
||||
return query, set(range(len(entries)))
|
||||
|
||||
# Convert simple file filters with no path separator into regex
|
||||
# e.g. "file:notes.org" -> "file:.*notes.org"
|
||||
files_to_search = []
|
||||
for file in sorted(raw_files_to_search):
|
||||
if '/' not in file and '\\' not in file and '*' not in file:
|
||||
files_to_search += [f'*{file}']
|
||||
else:
|
||||
files_to_search += [file]
|
||||
end = time.time()
|
||||
logger.debug(f"Extract files_to_search from query: {end - start} seconds")
|
||||
# Convert simple file filters with no path separator into regex
|
||||
# e.g. "file:notes.org" -> "file:.*notes.org"
|
||||
files_to_search = []
|
||||
for file in sorted(raw_files_to_search):
|
||||
if '/' not in file and '\\' not in file and '*' not in file:
|
||||
files_to_search += [f'*{file}']
|
||||
else:
|
||||
files_to_search += [file]
|
||||
|
||||
# Return item from cache if exists
|
||||
query = re.sub(self.file_filter_regex, '', query).strip()
|
||||
@@ -61,17 +57,13 @@ class FileFilter(BaseFilter):
|
||||
self.load(entries, regenerate=False)
|
||||
|
||||
# Mark entries that contain any blocked_words for exclusion
|
||||
start = time.time()
|
||||
|
||||
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
|
||||
for entry_file in self.file_to_entry_map.keys()
|
||||
for search_file in files_to_search
|
||||
if fnmatch.fnmatch(entry_file, search_file)], set())
|
||||
if not included_entry_indices:
|
||||
return query, {}
|
||||
|
||||
end = time.time()
|
||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
||||
with timer("Mark entries satisfying filter", logger):
|
||||
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
|
||||
for entry_file in self.file_to_entry_map.keys()
|
||||
for search_file in files_to_search
|
||||
if fnmatch.fnmatch(entry_file, search_file)], set())
|
||||
if not included_entry_indices:
|
||||
return query, {}
|
||||
|
||||
# Cache results
|
||||
self.cache[cache_key] = included_entry_indices
|
||||
|
||||
@@ -6,7 +6,7 @@ from collections import defaultdict
|
||||
|
||||
# Internal Packages
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
from src.utils.helpers import LRU
|
||||
from src.utils.helpers import LRU, timer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -24,17 +24,15 @@ class WordFilter(BaseFilter):
|
||||
|
||||
|
||||
def load(self, entries, *args, **kwargs):
|
||||
start = time.time()
|
||||
self.cache = {} # Clear cache on filter (re-)load
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
|
||||
# Create map of words to entries they exist in
|
||||
for entry_index, entry in enumerate(entries):
|
||||
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
|
||||
if word == '':
|
||||
continue
|
||||
self.word_to_entry_index[word].add(entry_index)
|
||||
end = time.time()
|
||||
logger.debug(f"Created word filter index: {end - start} seconds")
|
||||
with timer("Created word filter index", logger):
|
||||
self.cache = {} # Clear cache on filter (re-)load
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
|
||||
# Create map of words to entries they exist in
|
||||
for entry_index, entry in enumerate(entries):
|
||||
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
|
||||
if word == '':
|
||||
continue
|
||||
self.word_to_entry_index[word].add(entry_index)
|
||||
|
||||
return self.word_to_entry_index
|
||||
|
||||
@@ -50,14 +48,10 @@ class WordFilter(BaseFilter):
|
||||
def apply(self, query, entries):
|
||||
"Find entries containing required and not blocked words specified in query"
|
||||
# Separate natural query from required, blocked words filters
|
||||
start = time.time()
|
||||
|
||||
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
||||
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
|
||||
|
||||
end = time.time()
|
||||
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
|
||||
with timer("Extract required, blocked filters from query", logger):
|
||||
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
||||
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
|
||||
|
||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||
return query, set(range(len(entries)))
|
||||
@@ -72,20 +66,16 @@ class WordFilter(BaseFilter):
|
||||
if not self.word_to_entry_index:
|
||||
self.load(entries, regenerate=False)
|
||||
|
||||
start = time.time()
|
||||
|
||||
# mark entries that contain all required_words for inclusion
|
||||
entries_with_all_required_words = set(range(len(entries)))
|
||||
if len(required_words) > 0:
|
||||
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
|
||||
with timer("Mark entries satisfying filter", logger):
|
||||
entries_with_all_required_words = set(range(len(entries)))
|
||||
if len(required_words) > 0:
|
||||
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
|
||||
|
||||
# mark entries that contain any blocked_words for exclusion
|
||||
entries_with_any_blocked_words = set()
|
||||
if len(blocked_words) > 0:
|
||||
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
|
||||
|
||||
end = time.time()
|
||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
||||
# mark entries that contain any blocked_words for exclusion
|
||||
entries_with_any_blocked_words = set()
|
||||
if len(blocked_words) > 0:
|
||||
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
|
||||
|
||||
# get entries satisfying inclusion and exclusion filters
|
||||
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
|
||||
|
||||
@@ -13,7 +13,7 @@ from tqdm import trange
|
||||
import torch
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model
|
||||
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
|
||||
from src.utils.config import ImageSearchModel
|
||||
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
|
||||
|
||||
@@ -147,27 +147,21 @@ def query(raw_query, count, model: ImageSearchModel):
|
||||
logger.info(f"Find Images by Text: {query}")
|
||||
|
||||
# Now we encode the query (which can either be an image or a text string)
|
||||
start = time.time()
|
||||
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||
end = time.time()
|
||||
logger.debug(f"Query Encode Time: {end - start:.3f} seconds")
|
||||
with timer("Query Encode Time", logger):
|
||||
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||
start = time.time()
|
||||
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
|
||||
for result
|
||||
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
|
||||
end = time.time()
|
||||
logger.debug(f"Search Time: {end - start:.3f} seconds")
|
||||
with timer("Search Time", logger):
|
||||
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
|
||||
for result
|
||||
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
||||
if model.image_metadata_embeddings:
|
||||
start = time.time()
|
||||
metadata_hits = {result['corpus_id']: result['score']
|
||||
for result
|
||||
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
|
||||
end = time.time()
|
||||
logger.debug(f"Metadata Search Time: {end - start:.3f} seconds")
|
||||
with timer("Metadata Search Time", logger):
|
||||
metadata_hits = {result['corpus_id']: result['score']
|
||||
for result
|
||||
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
|
||||
|
||||
# Sum metadata, image scores of the highest ranked images
|
||||
for corpus_id, score in metadata_hits.items():
|
||||
|
||||
@@ -12,7 +12,7 @@ from src.search_filter.base_filter import BaseFilter
|
||||
|
||||
# Internal Packages
|
||||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.models import BaseEncoder
|
||||
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
|
||||
@@ -96,6 +96,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
|
||||
|
||||
# If no entries left after filtering, return empty results
|
||||
if entries is None or len(entries) == 0:
|
||||
return [], []
|
||||
@@ -105,17 +106,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
|
||||
return hits, entries
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
start = time.time()
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
end = time.time()
|
||||
logger.debug(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
with timer("Query Encode Time", logger, state.device):
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
|
||||
# Find relevant entries for the query
|
||||
start = time.time()
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
||||
end = time.time()
|
||||
logger.debug(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
with timer("Search Time", logger, state.device):
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
@@ -170,36 +167,29 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
|
||||
|
||||
def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]:
|
||||
'''Filter query, entries and embeddings before semantic search'''
|
||||
start_filter = time.time()
|
||||
included_entry_indices = set(range(len(entries)))
|
||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||
for filter in filters_in_query:
|
||||
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
||||
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
||||
|
||||
# Get entries (and associated embeddings) satisfying all filters
|
||||
if not included_entry_indices:
|
||||
return '', [], torch.tensor([], device=state.device)
|
||||
else:
|
||||
start = time.time()
|
||||
entries = [entries[id] for id in included_entry_indices]
|
||||
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
|
||||
end = time.time()
|
||||
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
|
||||
with timer("Total Filter Time", logger, state.device):
|
||||
included_entry_indices = set(range(len(entries)))
|
||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||
for filter in filters_in_query:
|
||||
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
||||
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
||||
|
||||
end_filter = time.time()
|
||||
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}")
|
||||
# Get entries (and associated embeddings) satisfying all filters
|
||||
if not included_entry_indices:
|
||||
return '', [], torch.tensor([], device=state.device)
|
||||
else:
|
||||
entries = [entries[id] for id in included_entry_indices]
|
||||
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
|
||||
|
||||
return query, entries, corpus_embeddings
|
||||
|
||||
|
||||
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]:
|
||||
'''Score all retrieved entries using the cross-encoder'''
|
||||
start = time.time()
|
||||
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
||||
cross_scores = cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
||||
cross_scores = cross_encoder.predict(cross_inp)
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
@@ -210,12 +200,10 @@ def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[E
|
||||
|
||||
def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
|
||||
'''Order results by cross-encoder score followed by bi-encoder score'''
|
||||
start = time.time()
|
||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||
if rank_results:
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
end = time.time()
|
||||
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
with timer("Rank Time", logger, state.device):
|
||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||
if rank_results:
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
return hits
|
||||
|
||||
|
||||
@@ -223,11 +211,12 @@ def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]:
|
||||
'''Deduplicate entries by raw entry text before showing to users
|
||||
Compiled entries are split by max tokens supported by ML models.
|
||||
This can result in duplicate hits, entries shown to user.'''
|
||||
start = time.time()
|
||||
seen, original_hits_count = set(), len(hits)
|
||||
hits = [hit for hit in hits
|
||||
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
|
||||
duplicate_hits = original_hits_count - len(hits)
|
||||
end = time.time()
|
||||
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
|
||||
|
||||
with timer("Deduplication Time", logger, state.device):
|
||||
seen, original_hits_count = set(), len(hits)
|
||||
hits = [hit for hit in hits
|
||||
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
|
||||
duplicate_hits = original_hits_count - len(hits)
|
||||
|
||||
logger.debug(f"Removed {duplicate_hits} duplicates")
|
||||
return hits
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
from __future__ import annotations # to avoid quoting type hints
|
||||
import logging
|
||||
import sys
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from importlib import import_module
|
||||
from os.path import join
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from typing import Optional, Union, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -81,6 +83,25 @@ def get_class_by_name(name: str) -> object:
|
||||
return getattr(import_module(module_name), class_name)
|
||||
|
||||
|
||||
class timer:
|
||||
'''Context manager to log time taken for a block of code to run'''
|
||||
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
|
||||
self.message = message
|
||||
self.logger = logger
|
||||
self.device = device
|
||||
|
||||
def __enter__(self):
|
||||
self.start = perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
elapsed = perf_counter() - self.start
|
||||
if self.device is None:
|
||||
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds")
|
||||
else:
|
||||
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}")
|
||||
|
||||
|
||||
class LRU(OrderedDict):
|
||||
def __init__(self, *args, capacity=128, **kwargs):
|
||||
self.capacity = capacity
|
||||
|
||||
Reference in New Issue
Block a user