mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Make normalizing embeddings configurable
This commit is contained in:
@@ -58,7 +58,11 @@ def extract_entries(jsonl_file) -> List[Entry]:
|
|||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False
|
entries_with_ids: List[Tuple[int, Entry]],
|
||||||
|
bi_encoder: BaseEncoder,
|
||||||
|
embeddings_file: Path,
|
||||||
|
regenerate=False,
|
||||||
|
normalize=True,
|
||||||
):
|
):
|
||||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||||
new_entries = []
|
new_entries = []
|
||||||
@@ -87,8 +91,11 @@ def compute_embeddings(
|
|||||||
existing_embeddings = torch.tensor([], device=state.device)
|
existing_embeddings = torch.tensor([], device=state.device)
|
||||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
# Normalize embeddings for faster lookup via dot product when querying
|
||||||
|
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||||
|
|
||||||
# Save regenerated or updated embeddings to file
|
# Save regenerated or updated embeddings to file
|
||||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
|
||||||
torch.save(corpus_embeddings, embeddings_file)
|
torch.save(corpus_embeddings, embeddings_file)
|
||||||
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")
|
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")
|
||||||
|
|
||||||
@@ -169,6 +176,7 @@ def setup(
|
|||||||
bi_encoder: BaseEncoder,
|
bi_encoder: BaseEncoder,
|
||||||
regenerate: bool,
|
regenerate: bool,
|
||||||
filters: List[BaseFilter] = [],
|
filters: List[BaseFilter] = [],
|
||||||
|
normalize: bool = True,
|
||||||
) -> TextContent:
|
) -> TextContent:
|
||||||
# Map notes in text files to (compressed) JSONL formatted file
|
# Map notes in text files to (compressed) JSONL formatted file
|
||||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||||
@@ -186,7 +194,7 @@ def setup(
|
|||||||
# Compute or Load Embeddings
|
# Compute or Load Embeddings
|
||||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||||
corpus_embeddings = compute_embeddings(
|
corpus_embeddings = compute_embeddings(
|
||||||
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate
|
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate, normalize=normalize
|
||||||
)
|
)
|
||||||
|
|
||||||
for filter in filters:
|
for filter in filters:
|
||||||
|
|||||||
Reference in New Issue
Block a user