mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 21:29:11 +00:00
Wrap asymmetric search model into SearchModels. Test notes search end-to-end
- Wrap asymmetric search model parameters into AsymmetricSearchModel class - Create wrapper for all search type models. Put notes search model into it - Test notes search end-to-end from client API layer to results. Use model build on test data
This commit is contained in:
@@ -17,6 +17,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
# Internal Packages
|
||||
from utils.helpers import get_absolute_path, resolve_absolute_path
|
||||
from processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||
from utils.config import AsymmetricSearchModel
|
||||
|
||||
|
||||
def initialize_model():
|
||||
@@ -64,7 +65,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v
|
||||
return corpus_embeddings
|
||||
|
||||
|
||||
def query_notes(raw_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k=100):
|
||||
def query_notes(raw_query: str, model: AsymmetricSearchModel):
|
||||
"Search all notes for entries that answer the query"
|
||||
# Separate natural query from explicit required, blocked words filters
|
||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||
@@ -72,20 +73,22 @@ def query_notes(raw_query, corpus_embeddings, entries, bi_encoder, cross_encoder
|
||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
||||
question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True)
|
||||
|
||||
# Find relevant entries for the query
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
|
||||
hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k)
|
||||
hits = hits[0] # Get the hits for the first query
|
||||
|
||||
# Filter results using explicit filters
|
||||
hits = explicit_filter(hits, [entry[0] for entry in entries], required_words, blocked_words)
|
||||
hits = explicit_filter(hits,
|
||||
[entry[0] for entry in model.entries],
|
||||
required_words,blocked_words)
|
||||
if hits is None or len(hits) == 0:
|
||||
return hits
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
cross_inp = [[query, entries[hit['corpus_id']][0]] for hit in hits]
|
||||
cross_scores = cross_encoder.predict(cross_inp)
|
||||
cross_inp = [[query, model.entries[hit['corpus_id']][0]] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
@@ -161,7 +164,7 @@ def setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=Fa
|
||||
# Compute or Load Embeddings
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, embeddings, regenerate=regenerate, verbose=verbose)
|
||||
|
||||
return entries, corpus_embeddings, bi_encoder, cross_encoder, top_k
|
||||
return AsymmetricSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user