mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Add typing to text_search. Reformat code to set existing_embedding
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
@@ -57,7 +58,7 @@ def extract_entries(jsonl_file) -> list[Entry]:
|
|||||||
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
|
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file, regenerate=False):
|
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False):
|
||||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||||
new_entries = []
|
new_entries = []
|
||||||
# Load pre-computed embeddings from file if exists and update them if required
|
# Load pre-computed embeddings from file if exists and update them if required
|
||||||
@@ -70,7 +71,10 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: Ba
|
|||||||
if new_entries:
|
if new_entries:
|
||||||
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||||
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
||||||
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)) if existing_entry_ids else torch.tensor([], device=state.device)
|
if existing_entry_ids:
|
||||||
|
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device))
|
||||||
|
else:
|
||||||
|
existing_embeddings = torch.tensor([], device=state.device)
|
||||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||||
# Else compute the corpus embeddings from scratch
|
# Else compute the corpus embeddings from scratch
|
||||||
else:
|
else:
|
||||||
@@ -86,7 +90,7 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: Ba
|
|||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False):
|
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> tuple[list[dict], list[Entry]]:
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user