Merge Symmetric, Asymmetric Search Types into a single Text Search Type

- The code for both the text search types were mostly the same
  It was earlier done this way for expedience while experimenting
- The minor differences were reconciled and merged into a single
  text_search type
- This simplifies the app and making it easier to process other
  text types
This commit is contained in:
Debanjum Singh Solanky
2022-07-21 18:05:43 +04:00
parent 0917f1574d
commit 0602d018c0
9 changed files with 52 additions and 324 deletions

View File

@@ -11,7 +11,9 @@ from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
# Internal Packages
from src.search_type import asymmetric, symmetric_ledger, image_search
from src.search_type import image_search, text_search
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
from src.utils.helpers import get_absolute_path, get_from_dict
from src.utils.cli import cli
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
@@ -66,24 +68,24 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Notes or t == None) and model.notes_search:
# query notes
hits, entries = asymmetric.query(user_query, model.notes_search, device=device, filters=[explicit_filter, date_filter])
hits, entries = text_search.query(user_query, model.notes_search, device=device, filters=[explicit_filter, date_filter])
# collate and return results
return asymmetric.collate_results(hits, entries, results_count)
return text_search.collate_results(hits, entries, results_count)
if (t == SearchType.Music or t == None) and model.music_search:
# query music library
hits, entries = asymmetric.query(user_query, model.music_search, device=device, filters=[explicit_filter, date_filter])
hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[explicit_filter, date_filter])
# collate and return results
return asymmetric.collate_results(hits, entries, results_count)
return text_search.collate_results(hits, entries, results_count)
if (t == SearchType.Ledger or t == None) and model.ledger_search:
# query transactions
hits, entries = symmetric_ledger.query(user_query, model.ledger_search, filters=[explicit_filter, date_filter])
hits, entries = text_search.query(user_query, model.ledger_search, filters=[explicit_filter, date_filter])
# collate and return results
return symmetric_ledger.collate_results(hits, entries, results_count)
return text_search.collate_results(hits, entries, results_count)
if (t == SearchType.Image or t == None) and model.image_search:
# query transactions
@@ -163,17 +165,17 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None
# Initialize Org Notes Search
if (t == SearchType.Notes or t == None) and config.content_type.org:
# Extract Entries, Generate Notes Embeddings
model.notes_search = asymmetric.setup(config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
model.notes_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
# Initialize Org Music Search
if (t == SearchType.Music or t == None) and config.content_type.music:
# Extract Entries, Generate Music Embeddings
model.music_search = asymmetric.setup(config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, device=device, verbose=verbose)
# Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
# Extract Entries, Generate Ledger Embeddings
model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose)
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose)
# Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image:

View File

@@ -1,97 +0,0 @@
import pandas as pd
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import argparse
import os
def create_index(
model,
dataset_path,
index_path,
column_name,
recreate):
# Load Dataset
dataset = pd.read_csv(dataset_path)
# Clean Dataset
dataset = dataset.dropna()
dataset[column_name] = dataset[column_name].str.strip()
# Create Index or Load it if it already exists
if os.path.exists(index_path) and not recreate:
index = faiss.read_index(index_path)
else:
# Create Embedding Vectors of Documents
embeddings = model.encode(dataset[column_name].to_list(), show_progress_bar=True)
embeddings = np.array([embedding for embedding in embeddings]).astype("float32")
index = faiss.IndexIDMap(
faiss.IndexFlatL2(
embeddings.shape[1]))
index.add_with_ids(embeddings, dataset.index.values)
faiss.write_index(index, index_path)
return index, dataset
def resolve_column(dataset, Id, column):
return [list(dataset[dataset.index == idx][column]) for idx in Id[0]]
def vector_search(query, index, dataset, column_name, num_results=10):
query_vector = np.array(query).astype("float32")
D, Id = index.search(query_vector, k=num_results)
return zip(D[0], Id[0], resolve_column(dataset, Id, column_name))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Find most suitable match based on users exclude, include preferences")
parser.add_argument('positives', type=str, help="Terms to find closest match to")
parser.add_argument('--negatives', '-n', type=str, help="Terms to find farthest match from")
parser.add_argument('--recreate', action='store_true', default=False, help="Recreate index at index_path from dataset at dataset path")
parser.add_argument('--index', type=str, default="./.faiss_index", help="Path to index for storing vector embeddings")
parser.add_argument('--dataset', type=str, default="./.dataset", help="Path to dataset to generate index from")
parser.add_argument('--column', type=str, default="DATA", help="Name of dataset column to index")
parser.add_argument('--num_results', type=int, default=10, help="Number of most suitable matches to show")
parser.add_argument('--model_name', type=str, default='all-MiniLM-L6-v2', help="Specify name of the SentenceTransformer model to use for encoding")
args = parser.parse_args()
model = SentenceTransformer(args.model_name)
if args.positives and not args.negatives:
# Get index, create it from dataset if doesn't exist
index, dataset = create_index(model, args.dataset, args.index, args.column, args.recreate)
# Create vector to represent user's stated positive preference
preference_vector = model.encode([args.positives])
# Find and display most suitable matches for users preferences in the dataset
results = vector_search(preference_vector, index, dataset, args.column, args.num_results)
print("Most Suitable Matches:")
for similarity, id_, data in results:
print(f"Id: {id_}\nSimilarity: {similarity}\n{args.column}: {data[0]}")
elif args.positives and args.negatives:
# Get index, create it from dataset if doesn't exist
index, dataset = create_index(model, args.dataset, args.index, args.column, args.recreate)
# Create vector to represent user's stated preference
positives_vector = np.array(model.encode([args.positives])).astype("float32")
negatives_vector = np.array(model.encode([args.negatives])).astype("float32")
# preference_vector = np.mean([positives_vector, -1 * negatives_vector], axis=0)
preference_vector = np.add(positives_vector, -1 * negatives_vector)
# Find and display most suitable matches for users preferences in the dataset
results = vector_search(preference_vector, index, dataset, args.column, args.num_results)
print("Most Suitable Matches:")
for similarity, id_, data in results:
print(f"Id: {id_}\nSimilarity: {similarity}\n{args.column}: {data[0]}")

View File

@@ -1,170 +0,0 @@
# Standard Packages
import argparse
import pathlib
from copy import deepcopy
# External Packages
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
from src.utils.config import TextSearchModel
from src.utils.rawconfig import SymmetricSearchConfig, TextContentConfig
from src.utils.jsonl import load_jsonl
def initialize_model(search_config: SymmetricSearchConfig):
"Initialize model for symmetric semantic search. That is, where query of similar size to results"
torch.set_num_threads(4)
# Number of entries we want to retrieve with the bi-encoder
top_k = 30
# The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = SentenceTransformer)
# The cross-encoder re-ranks the results to improve quality
cross_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.cross_encoder,
model_type = CrossEncoder)
return bi_encoder, cross_encoder, top_k
def extract_entries(notesfile, verbose=0):
"Load entries from compressed jsonl"
return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'}
for entry
in load_jsonl(notesfile, verbose=verbose)]
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
# Load pre-computed embeddings from file if exists
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
corpus_embeddings = torch.load(get_absolute_path(embeddings_file))
if verbose > 0:
print(f"Loaded embeddings from {embeddings_file}")
else: # Else compute the corpus_embeddings from scratch, which can take a while
corpus_embeddings = bi_encoder.encode(entries, convert_to_tensor=True, show_progress_bar=True)
torch.save(corpus_embeddings, get_absolute_path(embeddings_file))
if verbose > 0:
print(f"Computed embeddings and saved them to {embeddings_file}")
return corpus_embeddings
def query(raw_query, model: TextSearchModel, filters=[]):
"Search all notes for entries that answer the query"
# Copy original embeddings, entries to filter them for query
query = raw_query
corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries)
# Filter query, entries and embeddings before semantic search
for filter in filters:
query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings)
if entries is None or len(entries) == 0:
return [], []
# Encode the query using the bi-encoder
question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True)
# Find relevant entries for the query
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k)[0]
# Score all retrieved entries using the cross-encoder
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp)
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
# Order results by cross encoder score followed by biencoder score
hits.sort(key=lambda x: x['score'], reverse=True) # sort by biencoder score
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross encoder score
return hits, entries
def render_results(hits, entries, count=5, display_biencoder_results=False):
"Render the Results returned by Search for the Query"
if display_biencoder_results:
# Output of top hits from bi-encoder
print("\n-------------------------\n")
print(f"Top-{count} Bi-Encoder Retrieval hits")
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
for hit in hits[0:count]:
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['compiled']}")
# Output of top hits from re-ranker
print("\n-------------------------\n")
print(f"Top-{count} Cross-Encoder Re-ranker hits")
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
for hit in hits[0:count]:
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}")
def collate_results(hits, entries, count=5):
return [
{
"entry": entries[hit['corpus_id']]['raw'],
"score": f"{hit['cross-score']:.3f}"
}
for hit
in hits[0:count]]
def setup(config: TextContentConfig, search_config: SymmetricSearchConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in Org-Mode files to (compressed) JSONL formatted file
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
beancount_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose)
# Extract Entries
entries = extract_entries(config.compressed_jsonl, verbose)
top_k = min(len(entries), top_k)
# Compute or Load Embeddings
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
if __name__ == '__main__':
# Setup Argument Parser
parser = argparse.ArgumentParser(description="Map Beancount transactions into (compressed) JSONL format")
parser.add_argument('--input-files', '-i', nargs='*', help="List of Beancount files to process")
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for Beancount files to process")
parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path(".transactions.jsonl.gz"), help="Compressed JSONL formatted transactions file to compute embeddings from")
parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path(".transaction_embeddings.pt"), help="File to save/load model embeddings to/from")
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from Beancount files. Default: false")
parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5")
parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true")
parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0")
args = parser.parse_args()
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate, args.verbose)
# Run User Queries on Entries in Interactive Mode
while args.interactive:
# get query from user
user_query = input("Enter your query: ")
if user_query == "exit":
exit(0)
# query
hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
# render results
render_results(hits, entries, count=args.results_count)

View File

@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Standard Packages
import argparse
import pathlib
@@ -11,14 +9,13 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.utils.config import TextSearchModel
from src.utils.rawconfig import AsymmetricSearchConfig, TextContentConfig
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
from src.utils.jsonl import load_jsonl
def initialize_model(search_config: AsymmetricSearchConfig):
"Initialize model for assymetric semantic search. That is, where query smaller than results"
def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text"
torch.set_num_threads(4)
# Number of entries we want to retrieve with the bi-encoder
@@ -39,11 +36,11 @@ def initialize_model(search_config: AsymmetricSearchConfig):
return bi_encoder, cross_encoder, top_k
def extract_entries(notesfile, verbose=0):
def extract_entries(jsonl_file, verbose=0):
"Load entries from compressed jsonl"
return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'}
for entry
in load_jsonl(notesfile, verbose=verbose)]
in load_jsonl(jsonl_file, verbose=verbose)]
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, device='cpu', verbose=0):
@@ -65,9 +62,8 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu'), filters: list = []):
"Search all notes for entries that answer the query"
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = []):
"Search for entries that answer the query"
# Copy original embeddings, entries to filter them for query
query = raw_query
corpus_embeddings = deepcopy(model.corpus_embeddings)
@@ -130,13 +126,13 @@ def collate_results(hits, entries, count=5):
in hits[0:count]]
def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, device=torch.device('cpu'), verbose: bool=False) -> TextSearchModel:
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, device='cpu', verbose: bool=False) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in Org-Mode files to (compressed) JSONL formatted file
# Map notes in text files to (compressed) JSONL formatted file
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
org_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose)
text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose)
# Extract Entries
entries = extract_entries(config.compressed_jsonl, verbose)
@@ -150,12 +146,12 @@ def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, rege
if __name__ == '__main__':
# Setup Argument Parser
parser = argparse.ArgumentParser(description="Map Org-Mode notes into (compressed) JSONL format")
parser.add_argument('--input-files', '-i', nargs='*', help="List of org-mode files to process")
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for org-mode files to process")
parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path(".notes.jsonl.gz"), help="Compressed JSONL formatted notes file to compute embeddings from")
parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path(".notes_embeddings.pt"), help="File to save/load model embeddings to/from")
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from org-mode files. Default: false")
parser = argparse.ArgumentParser(description="Map Text files into (compressed) JSONL format")
parser.add_argument('--input-files', '-i', nargs='*', help="List of Text files to process")
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for Text files to process")
parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path("text.jsonl.gz"), help="Compressed JSONL to compute embeddings from")
parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path("text_embeddings.pt"), help="File to save/load model embeddings to/from")
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from text files. Default: false")
parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5")
parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true")
parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0")

View File

@@ -36,7 +36,7 @@ conda activate khoj
cd {get_absolute(args.script_dir)}
# Act
python3 search_types/asymmetric.py -j {get_absolute(args.model_dir)}/notes.jsonl.gz -e {get_absolute(args.model_dir)}/notes_embeddings.pt -n 5 --interactive
python3 search_types/text_search.py -j {get_absolute(args.model_dir)}/notes.jsonl.gz -e {get_absolute(args.model_dir)}/notes_embeddings.pt -n 5 --interactive
'''
search_cmd_content = f'''#!/bin/bash

View File

@@ -32,12 +32,7 @@ class ContentConfig(ConfigBase):
image: Optional[ImageContentConfig]
music: Optional[TextContentConfig]
class SymmetricSearchConfig(ConfigBase):
encoder: Optional[str]
cross_encoder: Optional[str]
model_directory: Optional[Path]
class AsymmetricSearchConfig(ConfigBase):
class TextSearchConfig(ConfigBase):
encoder: Optional[str]
cross_encoder: Optional[str]
model_directory: Optional[Path]
@@ -47,8 +42,8 @@ class ImageSearchConfig(ConfigBase):
model_directory: Optional[Path]
class SearchConfig(ConfigBase):
asymmetric: Optional[AsymmetricSearchConfig]
symmetric: Optional[SymmetricSearchConfig]
asymmetric: Optional[TextSearchConfig]
symmetric: Optional[TextSearchConfig]
image: Optional[ImageSearchConfig]
class ConversationProcessorConfig(ConfigBase):

View File

@@ -3,8 +3,9 @@ import pytest
import torch
# Internal Packages
from src.search_type import asymmetric, image_search
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, SymmetricSearchConfig, AsymmetricSearchConfig, ImageSearchConfig
from src.search_type import image_search, text_search
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
@pytest.fixture(scope='session')
@@ -13,13 +14,13 @@ def search_config(tmp_path_factory):
search_config = SearchConfig()
search_config.asymmetric = SymmetricSearchConfig(
search_config.symmetric = TextSearchConfig(
encoder = "sentence-transformers/all-MiniLM-L6-v2",
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory = model_dir
)
search_config.asymmetric = AsymmetricSearchConfig(
search_config.asymmetric = TextSearchConfig(
encoder = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory = model_dir
@@ -55,7 +56,7 @@ def model_dir(search_config):
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False, device=device, verbose=True)
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, device=device, verbose=True)
return model_dir

View File

@@ -3,8 +3,9 @@ from pathlib import Path
# Internal Packages
from src.main import model
from src.search_type import asymmetric
from src.search_type import text_search
from src.utils.rawconfig import ContentConfig, SearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
# Test
@@ -12,7 +13,7 @@ from src.utils.rawconfig import ContentConfig, SearchConfig
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
# Act
# Regenerate notes embeddings during asymmetric setup
notes_model = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=True)
notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
# Assert
assert len(notes_model.entries) == 10
@@ -22,15 +23,15 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
query = "How to git install application?"
# Act
hits, entries = asymmetric.query(
hits, entries = text_search.query(
query,
model = model.notes_search)
results = asymmetric.collate_results(
results = text_search.collate_results(
hits,
entries,
count=1)
@@ -44,7 +45,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
initial_notes_model= asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
@@ -57,11 +58,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
# regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=True)
regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
# Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
initial_notes_model = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
# Assert
assert len(regenerated_notes_model.entries) == 11

View File

@@ -8,9 +8,9 @@ import pytest
# Internal Packages
from src.main import app, model, config
from src.search_type import asymmetric, image_search
from src.utils.helpers import resolve_absolute_path
from src.search_type import text_search, image_search
from src.utils.rawconfig import ContentConfig, SearchConfig
from src.processor.org_mode import org_to_jsonl
# Arrange
@@ -115,7 +115,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
# ----------------------------------------------------------------------------------------------------
def test_notes_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
user_query = "How to git install application?"
# Act
@@ -131,7 +131,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
user_query = "How to git install application? +Emacs"
# Act
@@ -147,7 +147,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
user_query = "How to git install application? -clone"
# Act