mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 05:39:11 +00:00
Add Search Config for Symmetric Model. Save Model to Disk
This commit is contained in:
@@ -10,18 +10,31 @@ import torch
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.rawconfig import TextSearchConfig
|
||||
from src.utils.rawconfig import SymmetricConfig, TextSearchConfig
|
||||
|
||||
|
||||
def initialize_model():
|
||||
def initialize_model(search_config: SymmetricConfig):
|
||||
"Initialize model for symmetric semantic search. That is, where query of similar size to results"
|
||||
torch.set_num_threads(4)
|
||||
bi_encoder = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') # The encoder encodes all entries to use for semantic search
|
||||
top_k = 30 # Number of entries we want to retrieve with the bi-encoder
|
||||
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality
|
||||
|
||||
# Number of entries we want to retrieve with the bi-encoder
|
||||
top_k = 30
|
||||
|
||||
# The bi-encoder encodes all entries to use for semantic search
|
||||
bi_encoder = load_model(
|
||||
model_dir = search_config.model_directory,
|
||||
model_name = search_config.encoder,
|
||||
model_type = SentenceTransformer)
|
||||
|
||||
# The cross-encoder re-ranks the results to improve quality
|
||||
cross_encoder = load_model(
|
||||
model_dir = search_config.model_directory,
|
||||
model_name = search_config.cross_encoder,
|
||||
model_type = CrossEncoder)
|
||||
|
||||
return bi_encoder, cross_encoder, top_k
|
||||
|
||||
|
||||
@@ -141,9 +154,9 @@ def collate_results(hits, entries, count=5):
|
||||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(config: TextSearchConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
|
||||
def setup(config: TextSearchConfig, search_config: SymmetricConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model()
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
# Map notes in Org-Mode files to (compressed) JSONL formatted file
|
||||
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
|
||||
|
||||
Reference in New Issue
Block a user