mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Deep copy entries, embeddings in filters. Defer till actual filtering
- Only the filter knows when entries, embeddings are to be manipulated. So move the responsibility to deep copy before manipulating entries, embeddings to the filters - Create deep copy in filters. Avoids creating deep copy of entries, embeddings when filter results are being loaded from cache etc
This commit is contained in:
@@ -3,6 +3,7 @@ import re
|
|||||||
from datetime import timedelta, datetime
|
from datetime import timedelta, datetime
|
||||||
from dateutil.relativedelta import relativedelta, MO
|
from dateutil.relativedelta import relativedelta, MO
|
||||||
from math import inf
|
from math import inf
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
@@ -31,19 +32,23 @@ class DateFilter:
|
|||||||
return self.extract_date_range(raw_query) is not None
|
return self.extract_date_range(raw_query) is not None
|
||||||
|
|
||||||
|
|
||||||
def apply(self, query, entries, embeddings):
|
def apply(self, query, raw_entries, raw_embeddings):
|
||||||
"Find entries containing any dates that fall within date range specified in query"
|
"Find entries containing any dates that fall within date range specified in query"
|
||||||
# extract date range specified in date filter of query
|
# extract date range specified in date filter of query
|
||||||
query_daterange = self.extract_date_range(query)
|
query_daterange = self.extract_date_range(query)
|
||||||
|
|
||||||
# if no date in query, return all entries
|
# if no date in query, return all entries
|
||||||
if query_daterange is None:
|
if query_daterange is None:
|
||||||
return query, entries, embeddings
|
return query, raw_entries, raw_embeddings
|
||||||
|
|
||||||
# remove date range filter from query
|
# remove date range filter from query
|
||||||
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
|
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
|
||||||
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
|
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
|
||||||
|
|
||||||
|
# deep copy original embeddings, entries before filtering
|
||||||
|
embeddings= deepcopy(raw_embeddings)
|
||||||
|
entries = deepcopy(raw_entries)
|
||||||
|
|
||||||
# find entries containing any dates that fall with date range specified in query
|
# find entries containing any dates that fall with date range specified in query
|
||||||
entries_to_include = set()
|
entries_to_include = set()
|
||||||
for id, entry in enumerate(entries):
|
for id, entry in enumerate(entries):
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import re
|
|||||||
import time
|
import time
|
||||||
import pickle
|
import pickle
|
||||||
import logging
|
import logging
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
@@ -61,7 +62,7 @@ class ExplicitFilter:
|
|||||||
return len(required_words) != 0 or len(blocked_words) != 0
|
return len(required_words) != 0 or len(blocked_words) != 0
|
||||||
|
|
||||||
|
|
||||||
def apply(self, raw_query, entries, embeddings):
|
def apply(self, raw_query, raw_entries, raw_embeddings):
|
||||||
"Find entries containing required and not blocked words specified in query"
|
"Find entries containing required and not blocked words specified in query"
|
||||||
# Separate natural query from explicit required, blocked words filters
|
# Separate natural query from explicit required, blocked words filters
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@@ -83,6 +84,13 @@ class ExplicitFilter:
|
|||||||
entries, embeddings = self.cache[cache_key]
|
entries, embeddings = self.cache[cache_key]
|
||||||
return query, entries, embeddings
|
return query, entries, embeddings
|
||||||
|
|
||||||
|
# deep copy original embeddings, entries before filtering
|
||||||
|
start = time.time()
|
||||||
|
embeddings= deepcopy(raw_embeddings)
|
||||||
|
entries = deepcopy(raw_entries)
|
||||||
|
end = time.time()
|
||||||
|
logger.debug(f"Create copy of embeddings, entries for manipulation: {end - start:.3f} seconds")
|
||||||
|
|
||||||
if not self.entries_by_word_set:
|
if not self.entries_by_word_set:
|
||||||
self.load(entries, regenerate=False)
|
self.load(entries, regenerate=False)
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import argparse
|
|||||||
import pathlib
|
import pathlib
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
@@ -77,22 +76,11 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
|
|||||||
|
|
||||||
def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
query = raw_query
|
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||||
|
|
||||||
# Use deep copy of original embeddings, entries to filter if query contains filters
|
|
||||||
start = time.time()
|
|
||||||
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
|
|
||||||
if filters_in_query:
|
|
||||||
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
|
||||||
entries = deepcopy(model.entries)
|
|
||||||
else:
|
|
||||||
corpus_embeddings = model.corpus_embeddings
|
|
||||||
entries = model.entries
|
|
||||||
end = time.time()
|
|
||||||
logger.debug(f"Copy Time: {end - start:.3f} seconds")
|
|
||||||
|
|
||||||
# Filter query, entries and embeddings before semantic search
|
# Filter query, entries and embeddings before semantic search
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
|
||||||
for filter in filters_in_query:
|
for filter in filters_in_query:
|
||||||
query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings)
|
query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
|||||||
Reference in New Issue
Block a user