mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-04 21:29:12 +00:00
Use new Text Entry class to track text entries in Intermediate Format
- Context
- The app maintains all text content in a standard, intermediate format
- The intermediate format was loaded, passed around as a dictionary
for easier, faster updates to the intermediate format schema initially
- The intermediate format is reasonably stable now, given it's usage
by all 3 text content types currently implemented
- Changes
- Concretize text entries into `Entries' class instead of using dictionaries
- Code is updated to load, pass around entries as `Entries' objects
instead of as dictionaries
- `text_search' and `text_to_jsonl' methods are annotated with
type hints for the new `Entries' type
- Code and Tests referencing entries are updated to use class style
access patterns instead of the previous dictionary access patterns
- Move `mark_entries_for_update' method into `TextToJsonl' base class
- This is a more natural location for the method as it is only
(to be) used by `text_to_jsonl' classes
- Avoid circular reference issues on importing `Entries' class
This commit is contained in:
@@ -13,7 +13,7 @@ from src.search_filter.base_filter import BaseFilter
|
||||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig
|
||||
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
|
||||
from src.utils.jsonl import load_jsonl
|
||||
|
||||
|
||||
@@ -50,12 +50,12 @@ def initialize_model(search_config: TextSearchConfig):
|
||||
return bi_encoder, cross_encoder, top_k
|
||||
|
||||
|
||||
def extract_entries(jsonl_file):
|
||||
def extract_entries(jsonl_file) -> list[Entry]:
|
||||
"Load entries from compressed jsonl"
|
||||
return load_jsonl(jsonl_file)
|
||||
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
|
||||
|
||||
|
||||
def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False):
|
||||
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, embeddings_file, regenerate=False):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
new_entries = []
|
||||
# Load pre-computed embeddings from file if exists and update them if required
|
||||
@@ -64,15 +64,15 @@ def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate
|
||||
logger.info(f"Loaded embeddings from {embeddings_file}")
|
||||
|
||||
# Encode any new entries in the corpus and update corpus embeddings
|
||||
new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None]
|
||||
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
|
||||
if new_entries:
|
||||
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
existing_entry_ids = [id for id, _ in entries_with_ids if id is not None]
|
||||
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
||||
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor()
|
||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||
# Else compute the corpus embeddings from scratch
|
||||
else:
|
||||
new_entries = [entry['compiled'] for _, entry in entries_with_ids]
|
||||
new_entries = [entry.compiled for _, entry in entries_with_ids]
|
||||
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
|
||||
# Save regenerated or updated embeddings to file
|
||||
@@ -133,7 +133,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
start = time.time()
|
||||
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
||||
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
@@ -153,7 +153,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||
return hits, entries
|
||||
|
||||
|
||||
def render_results(hits, entries, count=5, display_biencoder_results=False):
|
||||
def render_results(hits, entries: list[Entry], count=5, display_biencoder_results=False):
|
||||
"Render the Results returned by Search for the Query"
|
||||
if display_biencoder_results:
|
||||
# Output of top hits from bi-encoder
|
||||
@@ -161,20 +161,20 @@ def render_results(hits, entries, count=5, display_biencoder_results=False):
|
||||
print(f"Top-{count} Bi-Encoder Retrieval hits")
|
||||
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
||||
for hit in hits[0:count]:
|
||||
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['compiled']}")
|
||||
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']].compiled}")
|
||||
|
||||
# Output of top hits from re-ranker
|
||||
print("\n-------------------------\n")
|
||||
print(f"Top-{count} Cross-Encoder Re-ranker hits")
|
||||
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
||||
for hit in hits[0:count]:
|
||||
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}")
|
||||
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']].compiled}")
|
||||
|
||||
|
||||
def collate_results(hits, entries, count=5) -> list[SearchResponse]:
|
||||
def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse]:
|
||||
return [SearchResponse.parse_obj(
|
||||
{
|
||||
"entry": entries[hit['corpus_id']]['raw'],
|
||||
"entry": entries[hit['corpus_id']].raw,
|
||||
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
|
||||
})
|
||||
for hit
|
||||
|
||||
Reference in New Issue
Block a user