mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Merge Symmetric, Asymmetric Search Types into a single Text Search Type
- The code for both the text search types were mostly the same It was earlier done this way for expedience while experimenting - The minor differences were reconciled and merged into a single text_search type - This simplifies the app and making it easier to process other text types
This commit is contained in:
22
src/main.py
22
src/main.py
@@ -11,7 +11,9 @@ from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
# Internal Packages
|
||||
from src.search_type import asymmetric, symmetric_ledger, image_search
|
||||
from src.search_type import image_search, text_search
|
||||
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
|
||||
from src.utils.helpers import get_absolute_path, get_from_dict
|
||||
from src.utils.cli import cli
|
||||
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
||||
@@ -66,24 +68,24 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||
|
||||
if (t == SearchType.Notes or t == None) and model.notes_search:
|
||||
# query notes
|
||||
hits, entries = asymmetric.query(user_query, model.notes_search, device=device, filters=[explicit_filter, date_filter])
|
||||
hits, entries = text_search.query(user_query, model.notes_search, device=device, filters=[explicit_filter, date_filter])
|
||||
|
||||
# collate and return results
|
||||
return asymmetric.collate_results(hits, entries, results_count)
|
||||
return text_search.collate_results(hits, entries, results_count)
|
||||
|
||||
if (t == SearchType.Music or t == None) and model.music_search:
|
||||
# query music library
|
||||
hits, entries = asymmetric.query(user_query, model.music_search, device=device, filters=[explicit_filter, date_filter])
|
||||
hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[explicit_filter, date_filter])
|
||||
|
||||
# collate and return results
|
||||
return asymmetric.collate_results(hits, entries, results_count)
|
||||
return text_search.collate_results(hits, entries, results_count)
|
||||
|
||||
if (t == SearchType.Ledger or t == None) and model.ledger_search:
|
||||
# query transactions
|
||||
hits, entries = symmetric_ledger.query(user_query, model.ledger_search, filters=[explicit_filter, date_filter])
|
||||
hits, entries = text_search.query(user_query, model.ledger_search, filters=[explicit_filter, date_filter])
|
||||
|
||||
# collate and return results
|
||||
return symmetric_ledger.collate_results(hits, entries, results_count)
|
||||
return text_search.collate_results(hits, entries, results_count)
|
||||
|
||||
if (t == SearchType.Image or t == None) and model.image_search:
|
||||
# query transactions
|
||||
@@ -163,17 +165,17 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None
|
||||
# Initialize Org Notes Search
|
||||
if (t == SearchType.Notes or t == None) and config.content_type.org:
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
model.notes_search = asymmetric.setup(config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
||||
model.notes_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
||||
|
||||
# Initialize Org Music Search
|
||||
if (t == SearchType.Music or t == None) and config.content_type.music:
|
||||
# Extract Entries, Generate Music Embeddings
|
||||
model.music_search = asymmetric.setup(config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
|
||||
|
||||
# Initialize Ledger Search
|
||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||
# Extract Entries, Generate Ledger Embeddings
|
||||
model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose)
|
||||
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose)
|
||||
|
||||
# Initialize Image Search
|
||||
if (t == SearchType.Image or t == None) and config.content_type.image:
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
import pandas as pd
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
def create_index(
|
||||
model,
|
||||
dataset_path,
|
||||
index_path,
|
||||
column_name,
|
||||
recreate):
|
||||
# Load Dataset
|
||||
dataset = pd.read_csv(dataset_path)
|
||||
|
||||
# Clean Dataset
|
||||
dataset = dataset.dropna()
|
||||
dataset[column_name] = dataset[column_name].str.strip()
|
||||
|
||||
# Create Index or Load it if it already exists
|
||||
if os.path.exists(index_path) and not recreate:
|
||||
index = faiss.read_index(index_path)
|
||||
else:
|
||||
# Create Embedding Vectors of Documents
|
||||
embeddings = model.encode(dataset[column_name].to_list(), show_progress_bar=True)
|
||||
embeddings = np.array([embedding for embedding in embeddings]).astype("float32")
|
||||
|
||||
index = faiss.IndexIDMap(
|
||||
faiss.IndexFlatL2(
|
||||
embeddings.shape[1]))
|
||||
|
||||
index.add_with_ids(embeddings, dataset.index.values)
|
||||
|
||||
faiss.write_index(index, index_path)
|
||||
|
||||
return index, dataset
|
||||
|
||||
|
||||
def resolve_column(dataset, Id, column):
|
||||
return [list(dataset[dataset.index == idx][column]) for idx in Id[0]]
|
||||
|
||||
|
||||
def vector_search(query, index, dataset, column_name, num_results=10):
|
||||
query_vector = np.array(query).astype("float32")
|
||||
D, Id = index.search(query_vector, k=num_results)
|
||||
|
||||
return zip(D[0], Id[0], resolve_column(dataset, Id, column_name))
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Find most suitable match based on users exclude, include preferences")
|
||||
parser.add_argument('positives', type=str, help="Terms to find closest match to")
|
||||
parser.add_argument('--negatives', '-n', type=str, help="Terms to find farthest match from")
|
||||
|
||||
parser.add_argument('--recreate', action='store_true', default=False, help="Recreate index at index_path from dataset at dataset path")
|
||||
parser.add_argument('--index', type=str, default="./.faiss_index", help="Path to index for storing vector embeddings")
|
||||
parser.add_argument('--dataset', type=str, default="./.dataset", help="Path to dataset to generate index from")
|
||||
parser.add_argument('--column', type=str, default="DATA", help="Name of dataset column to index")
|
||||
parser.add_argument('--num_results', type=int, default=10, help="Number of most suitable matches to show")
|
||||
parser.add_argument('--model_name', type=str, default='all-MiniLM-L6-v2', help="Specify name of the SentenceTransformer model to use for encoding")
|
||||
args = parser.parse_args()
|
||||
|
||||
model = SentenceTransformer(args.model_name)
|
||||
|
||||
if args.positives and not args.negatives:
|
||||
# Get index, create it from dataset if doesn't exist
|
||||
index, dataset = create_index(model, args.dataset, args.index, args.column, args.recreate)
|
||||
|
||||
# Create vector to represent user's stated positive preference
|
||||
preference_vector = model.encode([args.positives])
|
||||
|
||||
# Find and display most suitable matches for users preferences in the dataset
|
||||
results = vector_search(preference_vector, index, dataset, args.column, args.num_results)
|
||||
|
||||
print("Most Suitable Matches:")
|
||||
for similarity, id_, data in results:
|
||||
print(f"Id: {id_}\nSimilarity: {similarity}\n{args.column}: {data[0]}")
|
||||
|
||||
elif args.positives and args.negatives:
|
||||
# Get index, create it from dataset if doesn't exist
|
||||
index, dataset = create_index(model, args.dataset, args.index, args.column, args.recreate)
|
||||
|
||||
# Create vector to represent user's stated preference
|
||||
positives_vector = np.array(model.encode([args.positives])).astype("float32")
|
||||
negatives_vector = np.array(model.encode([args.negatives])).astype("float32")
|
||||
|
||||
# preference_vector = np.mean([positives_vector, -1 * negatives_vector], axis=0)
|
||||
preference_vector = np.add(positives_vector, -1 * negatives_vector)
|
||||
|
||||
# Find and display most suitable matches for users preferences in the dataset
|
||||
results = vector_search(preference_vector, index, dataset, args.column, args.num_results)
|
||||
|
||||
print("Most Suitable Matches:")
|
||||
for similarity, id_, data in results:
|
||||
print(f"Id: {id_}\nSimilarity: {similarity}\n{args.column}: {data[0]}")
|
||||
@@ -1,170 +0,0 @@
|
||||
# Standard Packages
|
||||
import argparse
|
||||
import pathlib
|
||||
from copy import deepcopy
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.rawconfig import SymmetricSearchConfig, TextContentConfig
|
||||
from src.utils.jsonl import load_jsonl
|
||||
|
||||
|
||||
def initialize_model(search_config: SymmetricSearchConfig):
|
||||
"Initialize model for symmetric semantic search. That is, where query of similar size to results"
|
||||
torch.set_num_threads(4)
|
||||
|
||||
# Number of entries we want to retrieve with the bi-encoder
|
||||
top_k = 30
|
||||
|
||||
# The bi-encoder encodes all entries to use for semantic search
|
||||
bi_encoder = load_model(
|
||||
model_dir = search_config.model_directory,
|
||||
model_name = search_config.encoder,
|
||||
model_type = SentenceTransformer)
|
||||
|
||||
# The cross-encoder re-ranks the results to improve quality
|
||||
cross_encoder = load_model(
|
||||
model_dir = search_config.model_directory,
|
||||
model_name = search_config.cross_encoder,
|
||||
model_type = CrossEncoder)
|
||||
|
||||
return bi_encoder, cross_encoder, top_k
|
||||
|
||||
|
||||
def extract_entries(notesfile, verbose=0):
|
||||
"Load entries from compressed jsonl"
|
||||
return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'}
|
||||
for entry
|
||||
in load_jsonl(notesfile, verbose=verbose)]
|
||||
|
||||
|
||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
# Load pre-computed embeddings from file if exists
|
||||
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
|
||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file))
|
||||
if verbose > 0:
|
||||
print(f"Loaded embeddings from {embeddings_file}")
|
||||
|
||||
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
||||
corpus_embeddings = bi_encoder.encode(entries, convert_to_tensor=True, show_progress_bar=True)
|
||||
torch.save(corpus_embeddings, get_absolute_path(embeddings_file))
|
||||
if verbose > 0:
|
||||
print(f"Computed embeddings and saved them to {embeddings_file}")
|
||||
|
||||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query, model: TextSearchModel, filters=[]):
|
||||
"Search all notes for entries that answer the query"
|
||||
# Copy original embeddings, entries to filter them for query
|
||||
query = raw_query
|
||||
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||
entries = deepcopy(model.entries)
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
for filter in filters:
|
||||
query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings)
|
||||
if entries is None or len(entries) == 0:
|
||||
return [], []
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True)
|
||||
|
||||
# Find relevant entries for the query
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k)[0]
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
hits[idx]['cross-score'] = cross_scores[idx]
|
||||
|
||||
# Order results by cross encoder score followed by biencoder score
|
||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by biencoder score
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross encoder score
|
||||
|
||||
return hits, entries
|
||||
|
||||
|
||||
def render_results(hits, entries, count=5, display_biencoder_results=False):
|
||||
"Render the Results returned by Search for the Query"
|
||||
if display_biencoder_results:
|
||||
# Output of top hits from bi-encoder
|
||||
print("\n-------------------------\n")
|
||||
print(f"Top-{count} Bi-Encoder Retrieval hits")
|
||||
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
||||
for hit in hits[0:count]:
|
||||
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['compiled']}")
|
||||
|
||||
# Output of top hits from re-ranker
|
||||
print("\n-------------------------\n")
|
||||
print(f"Top-{count} Cross-Encoder Re-ranker hits")
|
||||
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
||||
for hit in hits[0:count]:
|
||||
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}")
|
||||
|
||||
|
||||
def collate_results(hits, entries, count=5):
|
||||
return [
|
||||
{
|
||||
"entry": entries[hit['corpus_id']]['raw'],
|
||||
"score": f"{hit['cross-score']:.3f}"
|
||||
}
|
||||
for hit
|
||||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(config: TextContentConfig, search_config: SymmetricSearchConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
# Map notes in Org-Mode files to (compressed) JSONL formatted file
|
||||
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
|
||||
beancount_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose)
|
||||
|
||||
# Extract Entries
|
||||
entries = extract_entries(config.compressed_jsonl, verbose)
|
||||
top_k = min(len(entries), top_k)
|
||||
|
||||
# Compute or Load Embeddings
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Setup Argument Parser
|
||||
parser = argparse.ArgumentParser(description="Map Beancount transactions into (compressed) JSONL format")
|
||||
parser.add_argument('--input-files', '-i', nargs='*', help="List of Beancount files to process")
|
||||
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for Beancount files to process")
|
||||
parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path(".transactions.jsonl.gz"), help="Compressed JSONL formatted transactions file to compute embeddings from")
|
||||
parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path(".transaction_embeddings.pt"), help="File to save/load model embeddings to/from")
|
||||
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from Beancount files. Default: false")
|
||||
parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5")
|
||||
parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true")
|
||||
parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0")
|
||||
args = parser.parse_args()
|
||||
|
||||
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate, args.verbose)
|
||||
|
||||
# Run User Queries on Entries in Interactive Mode
|
||||
while args.interactive:
|
||||
# get query from user
|
||||
user_query = input("Enter your query: ")
|
||||
if user_query == "exit":
|
||||
exit(0)
|
||||
|
||||
# query
|
||||
hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
|
||||
|
||||
# render results
|
||||
render_results(hits, entries, count=args.results_count)
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Standard Packages
|
||||
import argparse
|
||||
import pathlib
|
||||
@@ -11,14 +9,13 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.rawconfig import AsymmetricSearchConfig, TextContentConfig
|
||||
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
|
||||
from src.utils.jsonl import load_jsonl
|
||||
|
||||
|
||||
def initialize_model(search_config: AsymmetricSearchConfig):
|
||||
"Initialize model for assymetric semantic search. That is, where query smaller than results"
|
||||
def initialize_model(search_config: TextSearchConfig):
|
||||
"Initialize model for semantic search on text"
|
||||
torch.set_num_threads(4)
|
||||
|
||||
# Number of entries we want to retrieve with the bi-encoder
|
||||
@@ -39,11 +36,11 @@ def initialize_model(search_config: AsymmetricSearchConfig):
|
||||
return bi_encoder, cross_encoder, top_k
|
||||
|
||||
|
||||
def extract_entries(notesfile, verbose=0):
|
||||
def extract_entries(jsonl_file, verbose=0):
|
||||
"Load entries from compressed jsonl"
|
||||
return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'}
|
||||
for entry
|
||||
in load_jsonl(notesfile, verbose=verbose)]
|
||||
in load_jsonl(jsonl_file, verbose=verbose)]
|
||||
|
||||
|
||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, device='cpu', verbose=0):
|
||||
@@ -65,9 +62,8 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
|
||||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu'), filters: list = []):
|
||||
"Search all notes for entries that answer the query"
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = []):
|
||||
"Search for entries that answer the query"
|
||||
# Copy original embeddings, entries to filter them for query
|
||||
query = raw_query
|
||||
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||
@@ -130,13 +126,13 @@ def collate_results(hits, entries, count=5):
|
||||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, device=torch.device('cpu'), verbose: bool=False) -> TextSearchModel:
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, device='cpu', verbose: bool=False) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
# Map notes in Org-Mode files to (compressed) JSONL formatted file
|
||||
# Map notes in text files to (compressed) JSONL formatted file
|
||||
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
|
||||
org_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose)
|
||||
text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose)
|
||||
|
||||
# Extract Entries
|
||||
entries = extract_entries(config.compressed_jsonl, verbose)
|
||||
@@ -150,12 +146,12 @@ def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, rege
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Setup Argument Parser
|
||||
parser = argparse.ArgumentParser(description="Map Org-Mode notes into (compressed) JSONL format")
|
||||
parser.add_argument('--input-files', '-i', nargs='*', help="List of org-mode files to process")
|
||||
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for org-mode files to process")
|
||||
parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path(".notes.jsonl.gz"), help="Compressed JSONL formatted notes file to compute embeddings from")
|
||||
parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path(".notes_embeddings.pt"), help="File to save/load model embeddings to/from")
|
||||
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from org-mode files. Default: false")
|
||||
parser = argparse.ArgumentParser(description="Map Text files into (compressed) JSONL format")
|
||||
parser.add_argument('--input-files', '-i', nargs='*', help="List of Text files to process")
|
||||
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for Text files to process")
|
||||
parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path("text.jsonl.gz"), help="Compressed JSONL to compute embeddings from")
|
||||
parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path("text_embeddings.pt"), help="File to save/load model embeddings to/from")
|
||||
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from text files. Default: false")
|
||||
parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5")
|
||||
parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true")
|
||||
parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0")
|
||||
@@ -36,7 +36,7 @@ conda activate khoj
|
||||
cd {get_absolute(args.script_dir)}
|
||||
|
||||
# Act
|
||||
python3 search_types/asymmetric.py -j {get_absolute(args.model_dir)}/notes.jsonl.gz -e {get_absolute(args.model_dir)}/notes_embeddings.pt -n 5 --interactive
|
||||
python3 search_types/text_search.py -j {get_absolute(args.model_dir)}/notes.jsonl.gz -e {get_absolute(args.model_dir)}/notes_embeddings.pt -n 5 --interactive
|
||||
'''
|
||||
|
||||
search_cmd_content = f'''#!/bin/bash
|
||||
|
||||
@@ -32,12 +32,7 @@ class ContentConfig(ConfigBase):
|
||||
image: Optional[ImageContentConfig]
|
||||
music: Optional[TextContentConfig]
|
||||
|
||||
class SymmetricSearchConfig(ConfigBase):
|
||||
encoder: Optional[str]
|
||||
cross_encoder: Optional[str]
|
||||
model_directory: Optional[Path]
|
||||
|
||||
class AsymmetricSearchConfig(ConfigBase):
|
||||
class TextSearchConfig(ConfigBase):
|
||||
encoder: Optional[str]
|
||||
cross_encoder: Optional[str]
|
||||
model_directory: Optional[Path]
|
||||
@@ -47,8 +42,8 @@ class ImageSearchConfig(ConfigBase):
|
||||
model_directory: Optional[Path]
|
||||
|
||||
class SearchConfig(ConfigBase):
|
||||
asymmetric: Optional[AsymmetricSearchConfig]
|
||||
symmetric: Optional[SymmetricSearchConfig]
|
||||
asymmetric: Optional[TextSearchConfig]
|
||||
symmetric: Optional[TextSearchConfig]
|
||||
image: Optional[ImageSearchConfig]
|
||||
|
||||
class ConversationProcessorConfig(ConfigBase):
|
||||
|
||||
@@ -3,8 +3,9 @@ import pytest
|
||||
import torch
|
||||
|
||||
# Internal Packages
|
||||
from src.search_type import asymmetric, image_search
|
||||
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, SymmetricSearchConfig, AsymmetricSearchConfig, ImageSearchConfig
|
||||
from src.search_type import image_search, text_search
|
||||
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
|
||||
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
@@ -13,13 +14,13 @@ def search_config(tmp_path_factory):
|
||||
|
||||
search_config = SearchConfig()
|
||||
|
||||
search_config.asymmetric = SymmetricSearchConfig(
|
||||
search_config.symmetric = TextSearchConfig(
|
||||
encoder = "sentence-transformers/all-MiniLM-L6-v2",
|
||||
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
model_directory = model_dir
|
||||
)
|
||||
|
||||
search_config.asymmetric = AsymmetricSearchConfig(
|
||||
search_config.asymmetric = TextSearchConfig(
|
||||
encoder = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
||||
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
model_directory = model_dir
|
||||
@@ -55,7 +56,7 @@ def model_dir(search_config):
|
||||
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
|
||||
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
|
||||
|
||||
asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False, device=device, verbose=True)
|
||||
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, device=device, verbose=True)
|
||||
|
||||
return model_dir
|
||||
|
||||
|
||||
@@ -3,8 +3,9 @@ from pathlib import Path
|
||||
|
||||
# Internal Packages
|
||||
from src.main import model
|
||||
from src.search_type import asymmetric
|
||||
from src.search_type import text_search
|
||||
from src.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||
|
||||
|
||||
# Test
|
||||
@@ -12,7 +13,7 @@ from src.utils.rawconfig import ContentConfig, SearchConfig
|
||||
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Act
|
||||
# Regenerate notes embeddings during asymmetric setup
|
||||
notes_model = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=True)
|
||||
notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
|
||||
# Assert
|
||||
assert len(notes_model.entries) == 10
|
||||
@@ -22,15 +23,15 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
|
||||
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
query = "How to git install application?"
|
||||
|
||||
# Act
|
||||
hits, entries = asymmetric.query(
|
||||
hits, entries = text_search.query(
|
||||
query,
|
||||
model = model.notes_search)
|
||||
|
||||
results = asymmetric.collate_results(
|
||||
results = text_search.collate_results(
|
||||
hits,
|
||||
entries,
|
||||
count=1)
|
||||
@@ -44,7 +45,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
initial_notes_model= asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
|
||||
initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
|
||||
assert len(initial_notes_model.entries) == 10
|
||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||
@@ -57,11 +58,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
|
||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||
|
||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||
regenerated_notes_model = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=True)
|
||||
regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
|
||||
initial_notes_model = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
|
||||
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
|
||||
# Assert
|
||||
assert len(regenerated_notes_model.entries) == 11
|
||||
|
||||
@@ -8,9 +8,9 @@ import pytest
|
||||
|
||||
# Internal Packages
|
||||
from src.main import app, model, config
|
||||
from src.search_type import asymmetric, image_search
|
||||
from src.utils.helpers import resolve_absolute_path
|
||||
from src.search_type import text_search, image_search
|
||||
from src.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from src.processor.org_mode import org_to_jsonl
|
||||
|
||||
|
||||
# Arrange
|
||||
@@ -115,7 +115,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
|
||||
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
user_query = "How to git install application?"
|
||||
|
||||
# Act
|
||||
@@ -131,7 +131,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
|
||||
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
user_query = "How to git install application? +Emacs"
|
||||
|
||||
# Act
|
||||
@@ -147,7 +147,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
|
||||
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
user_query = "How to git install application? -clone"
|
||||
|
||||
# Act
|
||||
|
||||
Reference in New Issue
Block a user