Modularize, Improve API. Formalize Intermediate Text Content Format. Add Type Checking

- **Improve API Endpoints**
  - ee65a4f Merge /reload, /regenerate into single /update API endpoint
  - 9975497 Type the /search API response to better document the response schema
  - 0521ea1 Put image score breakdown under `additional` field in search response
- **Formalize Intermediary Format to Index Text Content**
  - 7e9298f Use new Text `Entry` class to track text entries in Intermediate Format
  - 02d9440 Use Base `TextToJsonl` class to standardize `<text>_to_jsonl` processors
- **Modularize API router code**
  - e42a38e Split router code into `web_client`, `api`, `api_beta` routers. Version Khoj API
  - d292bdc Remove API versioning. Premature given current state of the codebase
- **Miscellaneous**
  - c467df8 Setup `mypy` for static type checking
  - 2c54813 Remove unused imports, `embeddings` variable from text search tests
This commit is contained in:
Debanjum
2022-10-19 11:23:04 +00:00
committed by GitHub
35 changed files with 788 additions and 704 deletions

13
.mypy.ini Normal file
View File

@@ -0,0 +1,13 @@
[mypy]
strict_optional = False
ignore_missing_imports = True
install_types = True
non_interactive = True
show_error_codes = True
exclude = (?x)(
src/interface/desktop/main_window.py
| src/interface/desktop/file_browser.py
| src/interface/desktop/system_tray.py
| build/*
| tests/*
)

View File

@@ -85,7 +85,7 @@ khoj
### 3. Configure ### 3. Configure
1. Enable content types and point to files to search in the First Run Screen that pops up on app start 1. Enable content types and point to files to search in the First Run Screen that pops up on app start
2. Click configure and wait. The app will load ML model, generates embeddings and expose the search API 2. Click `Configure` and wait. The app will download ML models and index the content for search
## Use ## Use
@@ -113,7 +113,7 @@ pip install --upgrade khoj-assistant
## Miscellaneous ## Miscellaneous
- The beta [chat](http://localhost:8000/beta/chat) and [search](http://localhost:8000/beta/search) API endpoints use [OpenAI API](https://openai.com/api/) - The beta [chat](http://localhost:8000/api/beta/chat) and [search](http://localhost:8000/api/beta/search) API endpoints use [OpenAI API](https://openai.com/api/)
- It is disabled by default - It is disabled by default
- To use it add your `openai-api-key` via the app configure screen - To use it add your `openai-api-key` via the app configure screen
- Warning: *If you use the above beta APIs, your query and top result(s) will be sent to OpenAI for processing* - Warning: *If you use the above beta APIs, your query and top result(s) will be sent to OpenAI for processing*

View File

@@ -6,9 +6,9 @@ import logging
import json import json
# Internal Packages # Internal Packages
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl from src.processor.ledger.beancount_to_jsonl import BeancountToJsonl
from src.processor.markdown.markdown_to_jsonl import markdown_to_jsonl from src.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.processor.org_mode.org_to_jsonl import OrgToJsonl
from src.search_type import image_search, text_search from src.search_type import image_search, text_search
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from src.utils import state from src.utils import state
@@ -44,7 +44,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
if (t == SearchType.Org or t == None) and config.content_type.org: if (t == SearchType.Org or t == None) and config.content_type.org:
# Extract Entries, Generate Notes Embeddings # Extract Entries, Generate Notes Embeddings
model.orgmode_search = text_search.setup( model.orgmode_search = text_search.setup(
org_to_jsonl, OrgToJsonl,
config.content_type.org, config.content_type.org,
search_config=config.search_type.asymmetric, search_config=config.search_type.asymmetric,
regenerate=regenerate, regenerate=regenerate,
@@ -54,7 +54,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
if (t == SearchType.Music or t == None) and config.content_type.music: if (t == SearchType.Music or t == None) and config.content_type.music:
# Extract Entries, Generate Music Embeddings # Extract Entries, Generate Music Embeddings
model.music_search = text_search.setup( model.music_search = text_search.setup(
org_to_jsonl, OrgToJsonl,
config.content_type.music, config.content_type.music,
search_config=config.search_type.asymmetric, search_config=config.search_type.asymmetric,
regenerate=regenerate, regenerate=regenerate,
@@ -64,7 +64,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
if (t == SearchType.Markdown or t == None) and config.content_type.markdown: if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
# Extract Entries, Generate Markdown Embeddings # Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup( model.markdown_search = text_search.setup(
markdown_to_jsonl, MarkdownToJsonl,
config.content_type.markdown, config.content_type.markdown,
search_config=config.search_type.asymmetric, search_config=config.search_type.asymmetric,
regenerate=regenerate, regenerate=regenerate,
@@ -74,7 +74,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
if (t == SearchType.Ledger or t == None) and config.content_type.ledger: if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
# Extract Entries, Generate Ledger Embeddings # Extract Entries, Generate Ledger Embeddings
model.ledger_search = text_search.setup( model.ledger_search = text_search.setup(
beancount_to_jsonl, BeancountToJsonl,
config.content_type.ledger, config.content_type.ledger,
search_config=config.search_type.symmetric, search_config=config.search_type.symmetric,
regenerate=regenerate, regenerate=regenerate,

View File

@@ -187,8 +187,8 @@ Use `which-key` if available, else display simple message in echo area"
(lambda (args) (format (lambda (args) (format
"\n\n<h2>Score: %s Meta: %s Image: %s</h2>\n\n<a href=\"%s%s\">\n<img src=\"%s%s?%s\" width=%s height=%s>\n</a>" "\n\n<h2>Score: %s Meta: %s Image: %s</h2>\n\n<a href=\"%s%s\">\n<img src=\"%s%s?%s\" width=%s height=%s>\n</a>"
(cdr (assoc 'score args)) (cdr (assoc 'score args))
(cdr (assoc 'metadata_score args)) (cdr (assoc 'metadata_score (assoc 'additional args)))
(cdr (assoc 'image_score args)) (cdr (assoc 'image_score (assoc 'additional args)))
khoj-server-url khoj-server-url
(cdr (assoc 'entry args)) (cdr (assoc 'entry args))
khoj-server-url khoj-server-url
@@ -226,7 +226,7 @@ Use `which-key` if available, else display simple message in echo area"
(defun khoj--get-enabled-content-types () (defun khoj--get-enabled-content-types ()
"Get content types enabled for search from API." "Get content types enabled for search from API."
(let ((config-url (format "%s/config/data" khoj-server-url)) (let ((config-url (format "%s/api/config/data" khoj-server-url))
(url-request-method "GET")) (url-request-method "GET"))
(with-temp-buffer (with-temp-buffer
(erase-buffer) (erase-buffer)
@@ -244,7 +244,7 @@ Use `which-key` if available, else display simple message in echo area"
"Construct API Query from QUERY, SEARCH-TYPE and (optional) RERANK params." "Construct API Query from QUERY, SEARCH-TYPE and (optional) RERANK params."
(let ((rerank (or rerank "false")) (let ((rerank (or rerank "false"))
(encoded-query (url-hexify-string query))) (encoded-query (url-hexify-string query)))
(format "%s/search?q=%s&t=%s&r=%s&n=%s" khoj-server-url encoded-query search-type rerank khoj-results-count))) (format "%s/api/search?q=%s&t=%s&r=%s&n=%s" khoj-server-url encoded-query search-type rerank khoj-results-count)))
(defun khoj--query-api-and-render-results (query search-type query-url buffer-name) (defun khoj--query-api-and-render-results (query search-type query-url buffer-name)
"Query Khoj API using QUERY, SEARCH-TYPE, QUERY-URL. "Query Khoj API using QUERY, SEARCH-TYPE, QUERY-URL.

View File

@@ -10,7 +10,7 @@ var emptyValueDefault = "🖊️";
/** /**
* Fetch the existing config file. * Fetch the existing config file.
*/ */
fetch("/config/data") fetch("/api/config/data")
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
rawConfig = data; rawConfig = data;
@@ -26,7 +26,7 @@ fetch("/config/data")
configForm.addEventListener("submit", (event) => { configForm.addEventListener("submit", (event) => {
event.preventDefault(); event.preventDefault();
console.log(rawConfig); console.log(rawConfig);
fetch("/config/data", { fetch("/api/config/data", {
method: "POST", method: "POST",
credentials: "same-origin", credentials: "same-origin",
headers: { headers: {
@@ -46,7 +46,7 @@ regenerateButton.addEventListener("click", (event) => {
event.preventDefault(); event.preventDefault();
regenerateButton.style.cursor = "progress"; regenerateButton.style.cursor = "progress";
regenerateButton.disabled = true; regenerateButton.disabled = true;
fetch("/regenerate") fetch("/api/update?force=true")
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
regenerateButton.style.cursor = "pointer"; regenerateButton.style.cursor = "pointer";

View File

@@ -16,7 +16,7 @@
return ` return `
<a href="${item.entry}" class="image-link"> <a href="${item.entry}" class="image-link">
<img id=${item.score} src="${item.entry}?${Math.random()}" <img id=${item.score} src="${item.entry}?${Math.random()}"
title="Effective Score: ${item.score}, Meta: ${item.metadata_score}, Image: ${item.image_score}" title="Effective Score: ${item.score}, Meta: ${item.additional.metadata_score}, Image: ${item.additional.image_score}"
class="image"> class="image">
</a>` </a>`
} }
@@ -77,8 +77,8 @@
// Generate Backend API URL to execute Search // Generate Backend API URL to execute Search
url = type === "image" url = type === "image"
? `/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}` ? `/api/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}`
: `/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}`; : `/api/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}`;
// Execute Search and Render Results // Execute Search and Render Results
fetch(url) fetch(url)
@@ -94,7 +94,7 @@
function updateIndex() { function updateIndex() {
type = document.getElementById("type").value; type = document.getElementById("type").value;
fetch(`/reload?t=${type}`) fetch(`/api/update?t=${type}`)
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
console.log(data); console.log(data);
@@ -118,7 +118,7 @@
function populate_type_dropdown() { function populate_type_dropdown() {
// Populate type dropdown field with enabled search types only // Populate type dropdown field with enabled search types only
var possible_search_types = ["org", "markdown", "ledger", "music", "image"]; var possible_search_types = ["org", "markdown", "ledger", "music", "image"];
fetch("/config/data") fetch("/api/config/data")
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
document.getElementById("type").innerHTML = document.getElementById("type").innerHTML =

View File

@@ -19,7 +19,9 @@ from PyQt6.QtCore import QThread, QTimer
# Internal Packages # Internal Packages
from src.configure import configure_server from src.configure import configure_server
from src.router import router from src.routers.api import api
from src.routers.api_beta import api_beta
from src.routers.web_client import web_client
from src.utils import constants, state from src.utils import constants, state
from src.utils.cli import cli from src.utils.cli import cli
from src.interface.desktop.main_window import MainWindow from src.interface.desktop.main_window import MainWindow
@@ -29,7 +31,9 @@ from src.interface.desktop.system_tray import create_system_tray
# Initialize the Application Server # Initialize the Application Server
app = FastAPI() app = FastAPI()
app.mount("/static", StaticFiles(directory=constants.web_directory), name="static") app.mount("/static", StaticFiles(directory=constants.web_directory), name="static")
app.include_router(router) app.include_router(api, prefix="/api")
app.include_router(api_beta, prefix="/api/beta")
app.include_router(web_client)
logger = logging.getLogger('src') logger = logging.getLogger('src')

View File

@@ -1,26 +1,25 @@
#!/usr/bin/env python3
# Standard Packages # Standard Packages
import json
import glob import glob
import re import re
import logging import logging
import time import time
# Internal Packages # Internal Packages
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.constants import empty_escape_sequences from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import TextContentConfig from src.utils.rawconfig import Entry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define Functions class BeancountToJsonl(TextToJsonl):
def beancount_to_jsonl(config: TextContentConfig, previous_entries=None): # Define Functions
def process(self, previous_entries=None):
# Extract required fields from config # Extract required fields from config
beancount_files, beancount_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl beancount_files, beancount_file_filter, output_file = self.config.input_files, self.config.input_filter,self.config.compressed_jsonl
# Input Validation # Input Validation
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter): if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
@@ -28,11 +27,11 @@ def beancount_to_jsonl(config: TextContentConfig, previous_entries=None):
exit(1) exit(1)
# Get Beancount Files to Process # Get Beancount Files to Process
beancount_files = get_beancount_files(beancount_files, beancount_file_filter) beancount_files = BeancountToJsonl.get_beancount_files(beancount_files, beancount_file_filter)
# Extract Entries from specified Beancount files # Extract Entries from specified Beancount files
start = time.time() start = time.time()
current_entries = convert_transactions_to_maps(*extract_beancount_transactions(beancount_files)) current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files))
end = time.time() end = time.time()
logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds") logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds")
@@ -41,14 +40,14 @@ def beancount_to_jsonl(config: TextContentConfig, previous_entries=None):
if not previous_entries: if not previous_entries:
entries_with_ids = list(enumerate(current_entries)) entries_with_ids = list(enumerate(current_entries))
else: else:
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time() end = time.time()
logger.debug(f"Identify new or updated transaction: {end - start} seconds") logger.debug(f"Identify new or updated transaction: {end - start} seconds")
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
start = time.time() start = time.time()
entries = list(map(lambda entry: entry[1], entries_with_ids)) entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = convert_transaction_maps_to_jsonl(entries) jsonl_data = BeancountToJsonl.convert_transaction_maps_to_jsonl(entries)
# Compress JSONL formatted Data # Compress JSONL formatted Data
if output_file.suffix == ".gz": if output_file.suffix == ".gz":
@@ -60,8 +59,8 @@ def beancount_to_jsonl(config: TextContentConfig, previous_entries=None):
return entries_with_ids return entries_with_ids
@staticmethod
def get_beancount_files(beancount_files=None, beancount_file_filters=None): def get_beancount_files(beancount_files=None, beancount_file_filters=None):
"Get Beancount files to process" "Get Beancount files to process"
absolute_beancount_files, filtered_beancount_files = set(), set() absolute_beancount_files, filtered_beancount_files = set(), set()
if beancount_files: if beancount_files:
@@ -90,8 +89,8 @@ def get_beancount_files(beancount_files=None, beancount_file_filters=None):
return all_beancount_files return all_beancount_files
@staticmethod
def extract_beancount_transactions(beancount_files): def extract_beancount_transactions(beancount_files):
"Extract entries from specified Beancount files" "Extract entries from specified Beancount files"
# Initialize Regex for extracting Beancount Entries # Initialize Regex for extracting Beancount Entries
@@ -111,18 +110,18 @@ def extract_beancount_transactions(beancount_files):
entries.extend(transactions_per_file) entries.extend(transactions_per_file)
return entries, dict(transaction_to_file_map) return entries, dict(transaction_to_file_map)
@staticmethod
def convert_transactions_to_maps(parsed_entries: list[str], transaction_to_file_map) -> list[Entry]:
"Convert each parsed Beancount transaction into a Entry"
entries = []
for parsed_entry in parsed_entries:
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{transaction_to_file_map[parsed_entry]}'))
def convert_transactions_to_maps(entries: list[str], transaction_to_file_map) -> list[dict]: logger.info(f"Converted {len(parsed_entries)} transactions to dictionaries")
"Convert each Beancount transaction into a dictionary"
entry_maps = []
for entry in entries:
entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{transaction_to_file_map[entry]}'})
logger.info(f"Converted {len(entries)} transactions to dictionaries") return entries
return entry_maps @staticmethod
def convert_transaction_maps_to_jsonl(entries: list[Entry]) -> str:
"Convert each Beancount transaction entry to JSON and collate as JSONL"
def convert_transaction_maps_to_jsonl(entries: list[dict]) -> str: return ''.join([f'{entry.to_json()}\n' for entry in entries])
"Convert each Beancount transaction dictionary to JSON and collate as JSONL"
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])

View File

@@ -1,26 +1,25 @@
#!/usr/bin/env python3
# Standard Packages # Standard Packages
import json
import glob import glob
import re import re
import logging import logging
import time import time
# Internal Packages # Internal Packages
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.constants import empty_escape_sequences from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import TextContentConfig from src.utils.rawconfig import Entry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define Functions class MarkdownToJsonl(TextToJsonl):
def markdown_to_jsonl(config: TextContentConfig, previous_entries=None): # Define Functions
def process(self, previous_entries=None):
# Extract required fields from config # Extract required fields from config
markdown_files, markdown_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl markdown_files, markdown_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
# Input Validation # Input Validation
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter): if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
@@ -28,11 +27,11 @@ def markdown_to_jsonl(config: TextContentConfig, previous_entries=None):
exit(1) exit(1)
# Get Markdown Files to Process # Get Markdown Files to Process
markdown_files = get_markdown_files(markdown_files, markdown_file_filter) markdown_files = MarkdownToJsonl.get_markdown_files(markdown_files, markdown_file_filter)
# Extract Entries from specified Markdown files # Extract Entries from specified Markdown files
start = time.time() start = time.time()
current_entries = convert_markdown_entries_to_maps(*extract_markdown_entries(markdown_files)) current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files))
end = time.time() end = time.time()
logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds") logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds")
@@ -41,14 +40,14 @@ def markdown_to_jsonl(config: TextContentConfig, previous_entries=None):
if not previous_entries: if not previous_entries:
entries_with_ids = list(enumerate(current_entries)) entries_with_ids = list(enumerate(current_entries))
else: else:
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time() end = time.time()
logger.debug(f"Identify new or updated entries: {end - start} seconds") logger.debug(f"Identify new or updated entries: {end - start} seconds")
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
start = time.time() start = time.time()
entries = list(map(lambda entry: entry[1], entries_with_ids)) entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = convert_markdown_maps_to_jsonl(entries) jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
# Compress JSONL formatted Data # Compress JSONL formatted Data
if output_file.suffix == ".gz": if output_file.suffix == ".gz":
@@ -60,8 +59,8 @@ def markdown_to_jsonl(config: TextContentConfig, previous_entries=None):
return entries_with_ids return entries_with_ids
@staticmethod
def get_markdown_files(markdown_files=None, markdown_file_filters=None): def get_markdown_files(markdown_files=None, markdown_file_filters=None):
"Get Markdown files to process" "Get Markdown files to process"
absolute_markdown_files, filtered_markdown_files = set(), set() absolute_markdown_files, filtered_markdown_files = set(), set()
if markdown_files: if markdown_files:
@@ -89,8 +88,8 @@ def get_markdown_files(markdown_files=None, markdown_file_filters=None):
return all_markdown_files return all_markdown_files
@staticmethod
def extract_markdown_entries(markdown_files): def extract_markdown_entries(markdown_files):
"Extract entries by heading from specified Markdown files" "Extract entries by heading from specified Markdown files"
# Regex to extract Markdown Entries by Heading # Regex to extract Markdown Entries by Heading
@@ -110,18 +109,18 @@ def extract_markdown_entries(markdown_files):
return entries, dict(entry_to_file_map) return entries, dict(entry_to_file_map)
@staticmethod
def convert_markdown_entries_to_maps(entries: list[str], entry_to_file_map) -> list[dict]: def convert_markdown_entries_to_maps(parsed_entries: list[str], entry_to_file_map) -> list[Entry]:
"Convert each Markdown entries into a dictionary" "Convert each Markdown entries into a dictionary"
entry_maps = [] entries = []
for entry in entries: for parsed_entry in parsed_entries:
entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{entry_to_file_map[entry]}'}) entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{entry_to_file_map[parsed_entry]}'))
logger.info(f"Converted {len(entries)} markdown entries to dictionaries") logger.info(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
return entry_maps return entries
@staticmethod
def convert_markdown_maps_to_jsonl(entries): def convert_markdown_maps_to_jsonl(entries: list[Entry]):
"Convert each Markdown entries to JSON and collate as JSONL" "Convert each Markdown entry to JSON and collate as JSONL"
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) return ''.join([f'{entry.to_json()}\n' for entry in entries])

View File

@@ -1,7 +1,4 @@
#!/usr/bin/env python3
# Standard Packages # Standard Packages
import json
import glob import glob
import logging import logging
import time import time
@@ -9,20 +6,22 @@ from typing import Iterable
# Internal Packages # Internal Packages
from src.processor.org_mode import orgnode from src.processor.org_mode import orgnode
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import Entry
from src.utils import state from src.utils import state
from src.utils.rawconfig import TextContentConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define Functions class OrgToJsonl(TextToJsonl):
def org_to_jsonl(config: TextContentConfig, previous_entries=None): # Define Functions
def process(self, previous_entries: list[Entry]=None):
# Extract required fields from config # Extract required fields from config
org_files, org_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
index_heading_entries = config.index_heading_entries index_heading_entries = self.config.index_heading_entries
# Input Validation # Input Validation
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter): if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter):
@@ -31,16 +30,16 @@ def org_to_jsonl(config: TextContentConfig, previous_entries=None):
# Get Org Files to Process # Get Org Files to Process
start = time.time() start = time.time()
org_files = get_org_files(org_files, org_file_filter) org_files = OrgToJsonl.get_org_files(org_files, org_file_filter)
# Extract Entries from specified Org files # Extract Entries from specified Org files
start = time.time() start = time.time()
entry_nodes, file_to_entries = extract_org_entries(org_files) entry_nodes, file_to_entries = self.extract_org_entries(org_files)
end = time.time() end = time.time()
logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds") logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds")
start = time.time() start = time.time()
current_entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries) current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
end = time.time() end = time.time()
logger.debug(f"Convert OrgNodes into entry dictionaries: {end - start} seconds") logger.debug(f"Convert OrgNodes into entry dictionaries: {end - start} seconds")
@@ -48,12 +47,12 @@ def org_to_jsonl(config: TextContentConfig, previous_entries=None):
if not previous_entries: if not previous_entries:
entries_with_ids = list(enumerate(current_entries)) entries_with_ids = list(enumerate(current_entries))
else: else:
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
start = time.time() start = time.time()
entries = map(lambda entry: entry[1], entries_with_ids) entries = map(lambda entry: entry[1], entries_with_ids)
jsonl_data = convert_org_entries_to_jsonl(entries) jsonl_data = self.convert_org_entries_to_jsonl(entries)
# Compress JSONL formatted Data # Compress JSONL formatted Data
if output_file.suffix == ".gz": if output_file.suffix == ".gz":
@@ -65,8 +64,8 @@ def org_to_jsonl(config: TextContentConfig, previous_entries=None):
return entries_with_ids return entries_with_ids
@staticmethod
def get_org_files(org_files=None, org_file_filters=None): def get_org_files(org_files=None, org_file_filters=None):
"Get Org files to process" "Get Org files to process"
absolute_org_files, filtered_org_files = set(), set() absolute_org_files, filtered_org_files = set(), set()
if org_files: if org_files:
@@ -92,8 +91,8 @@ def get_org_files(org_files=None, org_file_filters=None):
return all_org_files return all_org_files
@staticmethod
def extract_org_entries(org_files): def extract_org_entries(org_files):
"Extract entries from specified Org files" "Extract entries from specified Org files"
entries = [] entries = []
entry_to_file_map = [] entry_to_file_map = []
@@ -104,52 +103,49 @@ def extract_org_entries(org_files):
return entries, dict(entry_to_file_map) return entries, dict(entry_to_file_map)
@staticmethod
def convert_org_nodes_to_entries(entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[dict]: def convert_org_nodes_to_entries(parsed_entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[Entry]:
"Convert Org-Mode entries into list of dictionary" "Convert Org-Mode nodes into list of Entry objects"
entry_maps = [] entries: list[Entry] = []
for entry in entries: for parsed_entry in parsed_entries:
entry_dict = dict() if not parsed_entry.hasBody and not index_heading_entries:
if not entry.hasBody and not index_heading_entries:
# Ignore title notes i.e notes with just headings and empty body # Ignore title notes i.e notes with just headings and empty body
continue continue
entry_dict["compiled"] = f'{entry.heading}.' compiled = f'{parsed_entry.heading}.'
if state.verbose > 2: if state.verbose > 2:
logger.debug(f"Title: {entry.heading}") logger.debug(f"Title: {parsed_entry.heading}")
if entry.tags: if parsed_entry.tags:
tags_str = " ".join(entry.tags) tags_str = " ".join(parsed_entry.tags)
entry_dict["compiled"] += f'\t {tags_str}.' compiled += f'\t {tags_str}.'
if state.verbose > 2: if state.verbose > 2:
logger.debug(f"Tags: {tags_str}") logger.debug(f"Tags: {tags_str}")
if entry.closed: if parsed_entry.closed:
entry_dict["compiled"] += f'\n Closed on {entry.closed.strftime("%Y-%m-%d")}.' compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
if state.verbose > 2: if state.verbose > 2:
logger.debug(f'Closed: {entry.closed.strftime("%Y-%m-%d")}') logger.debug(f'Closed: {parsed_entry.closed.strftime("%Y-%m-%d")}')
if entry.scheduled: if parsed_entry.scheduled:
entry_dict["compiled"] += f'\n Scheduled for {entry.scheduled.strftime("%Y-%m-%d")}.' compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
if state.verbose > 2: if state.verbose > 2:
logger.debug(f'Scheduled: {entry.scheduled.strftime("%Y-%m-%d")}') logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}')
if entry.hasBody: if parsed_entry.hasBody:
entry_dict["compiled"] += f'\n {entry.body}' compiled += f'\n {parsed_entry.body}'
if state.verbose > 2: if state.verbose > 2:
logger.debug(f"Body: {entry.body}") logger.debug(f"Body: {parsed_entry.body}")
if entry_dict: if compiled:
entry_dict["raw"] = f'{entry}' entries += [Entry(
entry_dict["file"] = f'{entry_to_file_map[entry]}' compiled=compiled,
raw=f'{parsed_entry}',
file=f'{entry_to_file_map[parsed_entry]}')]
# Convert Dictionary to JSON and Append to JSONL string return entries
entry_maps.append(entry_dict)
return entry_maps @staticmethod
def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str:
def convert_org_entries_to_jsonl(entries: Iterable[dict]) -> str:
"Convert each Org-Mode entry to JSON and collate as JSONL" "Convert each Org-Mode entry to JSON and collate as JSONL"
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) return ''.join([f'{entry_dict.to_json()}\n' for entry_dict in entries])

View File

@@ -0,0 +1,54 @@
# Standard Packages
from abc import ABC, abstractmethod
import hashlib
import time
import logging
# Internal Packages
from src.utils.rawconfig import Entry, TextContentConfig
logger = logging.getLogger(__name__)
class TextToJsonl(ABC):
def __init__(self, config: TextContentConfig):
self.config = config
@abstractmethod
def process(self, previous_entries: list[Entry]=None) -> list[tuple[int, Entry]]: ...
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
# Hash all current and previous entries to identify new entries
start = time.time()
current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(getattr(e, key), encoding='utf-8')).hexdigest(), current_entries))
previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(getattr(e, key), encoding='utf-8')).hexdigest(), previous_entries))
end = time.time()
logger.debug(f"Hash previous, current entries: {end - start} seconds")
start = time.time()
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
# All entries that did not exist in the previous set are to be added
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
# All entries that exist in both current and previous sets are kept
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# Mark new entries with -1 id to flag for later embeddings generation
new_entries = [
(-1, hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes
]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
for entry_hash in existing_entry_hashes
]
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
entries_with_ids = existing_entries_sorted + new_entries
end = time.time()
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
return entries_with_ids

View File

@@ -1,220 +0,0 @@
# Standard Packages
import yaml
import json
import time
import logging
from typing import Optional
# External Packages
from fastapi import APIRouter
from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.templating import Jinja2Templates
# Internal Packages
from src.configure import configure_search
from src.search_type import image_search, text_search
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from src.utils.rawconfig import FullConfig
from src.utils.config import SearchType
from src.utils.helpers import LRU, get_absolute_path, get_from_dict
from src.utils import state, constants
router = APIRouter()
templates = Jinja2Templates(directory=constants.web_directory)
logger = logging.getLogger(__name__)
query_cache = LRU()
@router.get("/", response_class=FileResponse)
def index():
return FileResponse(constants.web_directory / "index.html")
@router.get('/config', response_class=HTMLResponse)
def config_page(request: Request):
return templates.TemplateResponse("config.html", context={'request': request})
@router.get('/config/data', response_model=FullConfig)
def config_data():
return state.config
@router.post('/config/data')
async def config_data(updated_config: FullConfig):
state.config = updated_config
with open(state.config_file, 'w') as outfile:
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
outfile.close()
return state.config
@router.get('/search')
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
if q is None or q == '':
logger.info(f'No query param (q) passed in API call to initiate search')
return {}
# initialize variables
user_query = q.strip()
results_count = n
results = {}
query_start, query_end, collate_start, collate_end = None, None, None, None
# return cached results, if available
query_cache_key = f'{user_query}-{n}-{t}-{r}'
if query_cache_key in state.query_cache:
logger.info(f'Return response from query cache')
return state.query_cache[query_cache_key]
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
# query org-mode notes
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Music or t == None) and state.model.music_search:
# query music library
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
# query markdown files
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
# query transactions
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Image or t == None) and state.model.image_search:
# query images
query_start = time.time()
hits = image_search.query(user_query, results_count, state.model.image_search)
output_directory = constants.web_directory / 'images'
query_end = time.time()
# collate and return results
collate_start = time.time()
results = image_search.collate_results(
hits,
image_names=state.model.image_search.image_names,
output_directory=output_directory,
image_files_url='/static/images',
count=results_count)
collate_end = time.time()
# Cache results
state.query_cache[query_cache_key] = results
if query_start and query_end:
logger.debug(f"Query took {query_end - query_start:.3f} seconds")
if collate_start and collate_end:
logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds")
return results
@router.get('/reload')
def reload(t: Optional[SearchType] = None):
state.model = configure_search(state.model, state.config, regenerate=False, t=t)
return {'status': 'ok', 'message': 'reload completed'}
@router.get('/regenerate')
def regenerate(t: Optional[SearchType] = None):
state.model = configure_search(state.model, state.config, regenerate=True, t=t)
return {'status': 'ok', 'message': 'regeneration completed'}
@router.get('/beta/search')
def search_beta(q: str, n: Optional[int] = 1):
# Extract Search Type using GPT
metadata = extract_search_type(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose)
search_type = get_from_dict(metadata, "search-type")
# Search
search_results = search(q, n=n, t=SearchType(search_type))
# Return response
return {'status': 'ok', 'result': search_results, 'type': search_type}
@router.get('/beta/chat')
def chat(q: str):
# Load Conversation History
chat_session = state.processor_config.conversation.chat_session
meta_log = state.processor_config.conversation.meta_log
# Converse with OpenAI GPT
metadata = understand(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose)
if state.verbose > 1:
print(f'Understood: {get_from_dict(metadata, "intent")}')
if get_from_dict(metadata, "intent", "memory-type") == "notes":
query = get_from_dict(metadata, "intent", "query")
result_list = search(query, n=1, t=SearchType.Org)
collated_result = "\n".join([item["entry"] for item in result_list])
if state.verbose > 1:
print(f'Semantically Similar Notes:\n{collated_result}')
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, api_key=state.processor_config.conversation.openai_api_key)
else:
gpt_response = converse(q, chat_session, api_key=state.processor_config.conversation.openai_api_key)
# Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log['chat'] = message_to_log(q, metadata, gpt_response, meta_log.get('chat', []))
return {'status': 'ok', 'response': gpt_response}
@router.on_event('shutdown')
def shutdown_event():
# No need to create empty log file
if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log):
return
elif state.processor_config.conversation.verbose:
print('INFO:\tSaving conversation logs to disk...')
# Summarize Conversation Logs for this Session
chat_session = state.processor_config.conversation.chat_session
openai_api_key = state.processor_config.conversation.openai_api_key
conversation_log = state.processor_config.conversation.meta_log
session = {
"summary": summarize(chat_session, summary_type="chat", api_key=openai_api_key),
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"])
}
if 'session' in conversation_log:
conversation_log['session'].append(session)
else:
conversation_log['session'] = [session]
# Save Conversation Metadata Logs to Disk
conversation_logfile = get_absolute_path(state.processor_config.conversation.conversation_logfile)
with open(conversation_logfile, "w+", encoding='utf-8') as logfile:
json.dump(conversation_log, logfile)
print('INFO:\tConversation logs saved to disk.')

129
src/routers/api.py Normal file
View File

@@ -0,0 +1,129 @@
# Standard Packages
import yaml
import time
import logging
from typing import Optional
# External Packages
from fastapi import APIRouter
# Internal Packages
from src.configure import configure_search
from src.search_type import image_search, text_search
from src.utils.rawconfig import FullConfig, SearchResponse
from src.utils.config import SearchType
from src.utils import state, constants
# Initialize Router
api = APIRouter()
logger = logging.getLogger(__name__)
# Create Routes
@api.get('/config/data', response_model=FullConfig)
def get_config_data():
return state.config
@api.post('/config/data')
async def set_config_data(updated_config: FullConfig):
state.config = updated_config
with open(state.config_file, 'w') as outfile:
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
outfile.close()
return state.config
@api.get('/search', response_model=list[SearchResponse])
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
results: list[SearchResponse] = []
if q is None or q == '':
logger.info(f'No query param (q) passed in API call to initiate search')
return results
# initialize variables
user_query = q.strip()
results_count = n
query_start, query_end, collate_start, collate_end = None, None, None, None
# return cached results, if available
query_cache_key = f'{user_query}-{n}-{t}-{r}'
if query_cache_key in state.query_cache:
logger.info(f'Return response from query cache')
return state.query_cache[query_cache_key]
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
# query org-mode notes
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Music or t == None) and state.model.music_search:
# query music library
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
# query markdown files
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
# query transactions
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Image or t == None) and state.model.image_search:
# query images
query_start = time.time()
hits = image_search.query(user_query, results_count, state.model.image_search)
output_directory = constants.web_directory / 'images'
query_end = time.time()
# collate and return results
collate_start = time.time()
results = image_search.collate_results(
hits,
image_names=state.model.image_search.image_names,
output_directory=output_directory,
image_files_url='/static/images',
count=results_count)
collate_end = time.time()
# Cache results
state.query_cache[query_cache_key] = results
if query_start and query_end:
logger.debug(f"Query took {query_end - query_start:.3f} seconds")
if collate_start and collate_end:
logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds")
return results
@api.get('/update')
def update(t: Optional[SearchType] = None, force: Optional[bool] = False):
state.model = configure_search(state.model, state.config, regenerate=force, t=t)
return {'status': 'ok', 'message': 'index updated'}

88
src/routers/api_beta.py Normal file
View File

@@ -0,0 +1,88 @@
# Standard Packages
import json
import logging
from typing import Optional
# External Packages
from fastapi import APIRouter
# Internal Packages
from src.routers.api import search
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from src.utils.config import SearchType
from src.utils.helpers import get_absolute_path, get_from_dict
from src.utils import state
# Initialize Router
api_beta = APIRouter()
logger = logging.getLogger(__name__)
# Create Routes
@api_beta.get('/search')
def search_beta(q: str, n: Optional[int] = 1):
# Extract Search Type using GPT
metadata = extract_search_type(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose)
search_type = get_from_dict(metadata, "search-type")
# Search
search_results = search(q, n=n, t=SearchType(search_type))
# Return response
return {'status': 'ok', 'result': search_results, 'type': search_type}
@api_beta.get('/chat')
def chat(q: str):
# Load Conversation History
chat_session = state.processor_config.conversation.chat_session
meta_log = state.processor_config.conversation.meta_log
# Converse with OpenAI GPT
metadata = understand(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose)
logger.debug(f'Understood: {get_from_dict(metadata, "intent")}')
if get_from_dict(metadata, "intent", "memory-type") == "notes":
query = get_from_dict(metadata, "intent", "query")
result_list = search(query, n=1, t=SearchType.Org)
collated_result = "\n".join([item["entry"] for item in result_list])
logger.debug(f'Semantically Similar Notes:\n{collated_result}')
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, api_key=state.processor_config.conversation.openai_api_key)
else:
gpt_response = converse(q, chat_session, api_key=state.processor_config.conversation.openai_api_key)
# Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log['chat'] = message_to_log(q, metadata, gpt_response, meta_log.get('chat', []))
return {'status': 'ok', 'response': gpt_response}
@api_beta.on_event('shutdown')
def shutdown_event():
# No need to create empty log file
if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log):
return
logger.debug('INFO:\tSaving conversation logs to disk...')
# Summarize Conversation Logs for this Session
chat_session = state.processor_config.conversation.chat_session
openai_api_key = state.processor_config.conversation.openai_api_key
conversation_log = state.processor_config.conversation.meta_log
session = {
"summary": summarize(chat_session, summary_type="chat", api_key=openai_api_key),
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"])
}
if 'session' in conversation_log:
conversation_log['session'].append(session)
else:
conversation_log['session'] = [session]
# Save Conversation Metadata Logs to Disk
conversation_logfile = get_absolute_path(state.processor_config.conversation.conversation_logfile)
with open(conversation_logfile, "w+", encoding='utf-8') as logfile:
json.dump(conversation_log, logfile)
logger.info('INFO:\tConversation logs saved to disk.')

23
src/routers/web_client.py Normal file
View File

@@ -0,0 +1,23 @@
# External Packages
from fastapi import APIRouter
from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.templating import Jinja2Templates
# Internal Packages
from src.utils import constants
# Initialize Router
web_client = APIRouter()
templates = Jinja2Templates(directory=constants.web_directory)
# Create Routes
@web_client.get("/", response_class=FileResponse)
def index():
return FileResponse(constants.web_directory / "index.html")
@web_client.get('/config', response_class=HTMLResponse)
def config_page(request: Request):
return templates.TemplateResponse("config.html", context={'request': request})

View File

@@ -37,7 +37,7 @@ class DateFilter(BaseFilter):
start = time.time() start = time.time()
for id, entry in enumerate(entries): for id, entry in enumerate(entries):
# Extract dates from entry # Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
# Convert date string in entry to unix timestamp # Convert date string in entry to unix timestamp
try: try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()

View File

@@ -24,7 +24,7 @@ class FileFilter(BaseFilter):
def load(self, entries, *args, **kwargs): def load(self, entries, *args, **kwargs):
start = time.time() start = time.time()
for id, entry in enumerate(entries): for id, entry in enumerate(entries):
self.file_to_entry_map[entry[self.entry_key]].add(id) self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
end = time.time() end = time.time()
logger.debug(f"Created file filter index: {end - start} seconds") logger.debug(f"Created file filter index: {end - start} seconds")

View File

@@ -29,7 +29,7 @@ class WordFilter(BaseFilter):
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'' entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
# Create map of words to entries they exist in # Create map of words to entries they exist in
for entry_index, entry in enumerate(entries): for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, entry[self.entry_key].lower()): for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
if word == '': if word == '':
continue continue
self.word_to_entry_index[word].add(entry_index) self.word_to_entry_index[word].add(entry_index)

View File

@@ -15,7 +15,7 @@ import torch
# Internal Packages # Internal Packages
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model
from src.utils.config import ImageSearchModel from src.utils.config import ImageSearchModel
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
# Create Logger # Create Logger
@@ -203,8 +203,8 @@ def render_results(hits, image_names, image_directory, count):
img.show() img.show()
def collate_results(hits, image_names, output_directory, image_files_url, count=5): def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> list[SearchResponse]:
results = [] results: list[SearchResponse] = []
for index, hit in enumerate(hits[:count]): for index, hit in enumerate(hits[:count]):
source_path = image_names[hit['corpus_id']] source_path = image_names[hit['corpus_id']]
@@ -220,12 +220,17 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
shutil.copy(source_path, target_path) shutil.copy(source_path, target_path)
# Add the image metadata to the results # Add the image metadata to the results
results += [{ results += [SearchResponse.parse_obj(
{
"entry": f'{image_files_url}/{target_image_name}', "entry": f'{image_files_url}/{target_image_name}',
"score": f"{hit['score']:.9f}", "score": f"{hit['score']:.9f}",
"additional":
{
"image_score": f"{hit['image_score']:.9f}", "image_score": f"{hit['image_score']:.9f}",
"metadata_score": f"{hit['metadata_score']:.9f}", "metadata_score": f"{hit['metadata_score']:.9f}",
}] }
}
)]
return results return results

View File

@@ -1,17 +1,19 @@
# Standard Packages # Standard Packages
import logging import logging
import time import time
from typing import Type
# External Packages # External Packages
import torch import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util from sentence_transformers import SentenceTransformer, CrossEncoder, util
from src.processor.text_to_jsonl import TextToJsonl
from src.search_filter.base_filter import BaseFilter from src.search_filter.base_filter import BaseFilter
# Internal Packages # Internal Packages
from src.utils import state from src.utils import state
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
from src.utils.config import TextSearchModel from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfig, TextContentConfig from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
from src.utils.jsonl import load_jsonl from src.utils.jsonl import load_jsonl
@@ -48,12 +50,12 @@ def initialize_model(search_config: TextSearchConfig):
return bi_encoder, cross_encoder, top_k return bi_encoder, cross_encoder, top_k
def extract_entries(jsonl_file): def extract_entries(jsonl_file) -> list[Entry]:
"Load entries from compressed jsonl" "Load entries from compressed jsonl"
return load_jsonl(jsonl_file) return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False): def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, embeddings_file, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings" "Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = [] new_entries = []
# Load pre-computed embeddings from file if exists and update them if required # Load pre-computed embeddings from file if exists and update them if required
@@ -62,15 +64,15 @@ def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate
logger.info(f"Loaded embeddings from {embeddings_file}") logger.info(f"Loaded embeddings from {embeddings_file}")
# Encode any new entries in the corpus and update corpus embeddings # Encode any new entries in the corpus and update corpus embeddings
new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None] new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
if new_entries: if new_entries:
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
existing_entry_ids = [id for id, _ in entries_with_ids if id is not None] existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor() existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor()
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
# Else compute the corpus embeddings from scratch # Else compute the corpus embeddings from scratch
else: else:
new_entries = [entry['compiled'] for _, entry in entries_with_ids] new_entries = [entry.compiled for _, entry in entries_with_ids]
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
# Save regenerated or updated embeddings to file # Save regenerated or updated embeddings to file
@@ -131,7 +133,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
if rank_results: if rank_results:
start = time.time() start = time.time()
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits] cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp) cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time() end = time.time()
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}") logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
@@ -151,7 +153,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
return hits, entries return hits, entries
def render_results(hits, entries, count=5, display_biencoder_results=False): def render_results(hits, entries: list[Entry], count=5, display_biencoder_results=False):
"Render the Results returned by Search for the Query" "Render the Results returned by Search for the Query"
if display_biencoder_results: if display_biencoder_results:
# Output of top hits from bi-encoder # Output of top hits from bi-encoder
@@ -159,34 +161,34 @@ def render_results(hits, entries, count=5, display_biencoder_results=False):
print(f"Top-{count} Bi-Encoder Retrieval hits") print(f"Top-{count} Bi-Encoder Retrieval hits")
hits = sorted(hits, key=lambda x: x['score'], reverse=True) hits = sorted(hits, key=lambda x: x['score'], reverse=True)
for hit in hits[0:count]: for hit in hits[0:count]:
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['compiled']}") print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']].compiled}")
# Output of top hits from re-ranker # Output of top hits from re-ranker
print("\n-------------------------\n") print("\n-------------------------\n")
print(f"Top-{count} Cross-Encoder Re-ranker hits") print(f"Top-{count} Cross-Encoder Re-ranker hits")
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
for hit in hits[0:count]: for hit in hits[0:count]:
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}") print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']].compiled}")
def collate_results(hits, entries, count=5): def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse]:
return [ return [SearchResponse.parse_obj(
{ {
"entry": entries[hit['corpus_id']]['raw'], "entry": entries[hit['corpus_id']].raw,
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}" "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
} })
for hit for hit
in hits[0:count]] in hits[0:count]]
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel: def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel:
# Initialize Model # Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config) bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in text files to (compressed) JSONL formatted file # Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
entries_with_indices = text_to_jsonl(config, previous_entries) entries_with_indices = text_to_jsonl(config).process(previous_entries)
# Extract Updated Entries # Extract Updated Entries
entries = extract_entries(config.compressed_jsonl) entries = extract_entries(config.compressed_jsonl)

View File

@@ -1,8 +1,6 @@
# Standard Packages # Standard Packages
from pathlib import Path from pathlib import Path
import sys import sys
import time
import hashlib
from os.path import join from os.path import join
from collections import OrderedDict from collections import OrderedDict
from typing import Optional, Union from typing import Optional, Union
@@ -83,38 +81,3 @@ class LRU(OrderedDict):
oldest = next(iter(self)) oldest = next(iter(self))
del self[oldest] del self[oldest]
def mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=None):
# Hash all current and previous entries to identify new entries
start = time.time()
current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), current_entries))
previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), previous_entries))
end = time.time()
logger.debug(f"Hash previous, current entries: {end - start} seconds")
start = time.time()
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
# All entries that did not exist in the previous set are to be added
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
# All entries that exist in both current and previous sets are kept
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# Mark new entries with no ids for later embeddings generation
new_entries = [
(None, hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes
]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
for entry_hash in existing_entry_hashes
]
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
entries_with_ids = existing_entries_sorted + new_entries
end = time.time()
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
return entries_with_ids

View File

@@ -1,4 +1,5 @@
# System Packages # System Packages
import json
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
@@ -71,3 +72,32 @@ class FullConfig(ConfigBase):
content_type: Optional[ContentConfig] content_type: Optional[ContentConfig]
search_type: Optional[SearchConfig] search_type: Optional[SearchConfig]
processor: Optional[ProcessorConfig] processor: Optional[ProcessorConfig]
class SearchResponse(ConfigBase):
entry: str
score: str
additional: Optional[dict]
class Entry():
raw: str
compiled: str
file: Optional[str]
def __init__(self, raw: str = None, compiled: str = None, file: Optional[str] = None):
self.raw = raw
self.compiled = compiled
self.file = file
def to_json(self) -> str:
return json.dumps(self.__dict__, ensure_ascii=False)
def __repr__(self) -> str:
return self.__dict__.__repr__()
@classmethod
def from_dict(cls, dictionary: dict):
return cls(
raw=dictionary['raw'],
compiled=dictionary['compiled'],
file=dictionary.get('file', None)
)

View File

@@ -9,7 +9,7 @@ from src.utils.rawconfig import FullConfig
# Do not emit tags when dumping to YAML # Do not emit tags when dumping to YAML
yaml.emitter.Emitter.process_tag = lambda self, *args, **kwargs: None yaml.emitter.Emitter.process_tag = lambda self, *args, **kwargs: None # type: ignore[assignment]
def save_config_to_file(yaml_config: dict, yaml_config_file: Path): def save_config_to_file(yaml_config: dict, yaml_config_file: Path):

View File

@@ -6,7 +6,7 @@ from src.search_type import image_search, text_search
from src.utils.config import SearchType from src.utils.config import SearchType
from src.utils.helpers import resolve_absolute_path from src.utils.helpers import resolve_absolute_path
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.processor.org_mode.org_to_jsonl import OrgToJsonl
from src.search_filter.date_filter import DateFilter from src.search_filter.date_filter import DateFilter
from src.search_filter.word_filter import WordFilter from src.search_filter.word_filter import WordFilter
from src.search_filter.file_filter import FileFilter from src.search_filter.file_filter import FileFilter
@@ -60,6 +60,6 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
embeddings_file = content_dir.joinpath('note_embeddings.pt')) embeddings_file = content_dir.joinpath('note_embeddings.pt'))
filters = [DateFilter(), WordFilter(), FileFilter()] filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
return content_config return content_config

View File

@@ -43,9 +43,8 @@ just generating embeddings*
- **Khoj via API** - **Khoj via API**
- See [Khoj API Docs](http://localhost:8000/docs) - See [Khoj API Docs](http://localhost:8000/docs)
- [Query](http://localhost:8000/search?q=%22what%20is%20the%20meaning%20of%20life%22) - [Query](http://localhost:8000/api/search?q=%22what%20is%20the%20meaning%20of%20life%22)
- [Regenerate - [Update Index](http://localhost:8000/api/update?t=ledger)
Embeddings](http://localhost:8000/regenerate?t=ledger)
- [Configure Application](https://localhost:8000/ui) - [Configure Application](https://localhost:8000/ui)
- **Khoj via Emacs** - **Khoj via Emacs**
- [Install](https://github.com/debanjum/khoj/tree/master/src/interface/emacs#installation) - [Install](https://github.com/debanjum/khoj/tree/master/src/interface/emacs#installation)

View File

@@ -27,8 +27,8 @@
- Run ~M-x khoj <user-query>~ or Call ~C-c C-s~ - Run ~M-x khoj <user-query>~ or Call ~C-c C-s~
- *Khoj via API* - *Khoj via API*
- Query: ~GET~ [[http://localhost:8000/search?q=%22what%20is%20the%20meaning%20of%20life%22][http://localhost:8000/search?q="What is the meaning of life"]] - Query: ~GET~ [[http://localhost:8000/api/search?q=%22what%20is%20the%20meaning%20of%20life%22][http://localhost:8000/api/search?q="What is the meaning of life"]]
- Regenerate Embeddings: ~GET~ [[http://localhost:8000/regenerate][http://localhost:8000/regenerate]] - Update Index: ~GET~ [[http://localhost:8000/api/update][http://localhost:8000/api/update]]
- [[http://localhost:8000/docs][Khoj API Docs]] - [[http://localhost:8000/docs][Khoj API Docs]]
- *Call Khoj via Python Script Directly* - *Call Khoj via Python Script Directly*

View File

@@ -2,7 +2,7 @@
import json import json
# Internal Packages # Internal Packages
from src.processor.ledger.beancount_to_jsonl import extract_beancount_transactions, convert_transactions_to_maps, convert_transaction_maps_to_jsonl, get_beancount_files from src.processor.ledger.beancount_to_jsonl import BeancountToJsonl
def test_no_transactions_in_file(tmp_path): def test_no_transactions_in_file(tmp_path):
@@ -16,10 +16,11 @@ def test_no_transactions_in_file(tmp_path):
# Act # Act
# Extract Entries from specified Beancount files # Extract Entries from specified Beancount files
entry_nodes, file_to_entries = extract_beancount_transactions(beancount_files=[beancount_file]) entry_nodes, file_to_entries = BeancountToJsonl.extract_beancount_transactions(beancount_files=[beancount_file])
# Process Each Entry from All Beancount Files # Process Each Entry from All Beancount Files
jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entry_nodes, file_to_entries)) jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
BeancountToJsonl.convert_transactions_to_maps(entry_nodes, file_to_entries))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@@ -38,10 +39,11 @@ Assets:Test:Test -1.00 KES
# Act # Act
# Extract Entries from specified Beancount files # Extract Entries from specified Beancount files
entries, entry_to_file_map = extract_beancount_transactions(beancount_files=[beancount_file]) entries, entry_to_file_map = BeancountToJsonl.extract_beancount_transactions(beancount_files=[beancount_file])
# Process Each Entry from All Beancount Files # Process Each Entry from All Beancount Files
jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entries, entry_to_file_map)) jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@@ -65,10 +67,11 @@ Assets:Test:Test -1.00 KES
# Act # Act
# Extract Entries from specified Beancount files # Extract Entries from specified Beancount files
entries, entry_to_file_map = extract_beancount_transactions(beancount_files=[beancount_file]) entries, entry_to_file_map = BeancountToJsonl.extract_beancount_transactions(beancount_files=[beancount_file])
# Process Each Entry from All Beancount Files # Process Each Entry from All Beancount Files
jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entries, entry_to_file_map)) jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@@ -96,7 +99,7 @@ def test_get_beancount_files(tmp_path):
input_filter = [tmp_path / 'group1*.bean', tmp_path / 'group2*.beancount'] input_filter = [tmp_path / 'group1*.bean', tmp_path / 'group2*.beancount']
# Act # Act
extracted_org_files = get_beancount_files(input_files, input_filter) extracted_org_files = BeancountToJsonl.get_beancount_files(input_files, input_filter)
# Assert # Assert
assert len(extracted_org_files) == 5 assert len(extracted_org_files) == 5

View File

@@ -12,7 +12,7 @@ from src.main import app
from src.utils.state import model, config from src.utils.state import model, config
from src.search_type import text_search, image_search from src.search_type import text_search, image_search
from src.utils.rawconfig import ContentConfig, SearchConfig from src.utils.rawconfig import ContentConfig, SearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.processor.org_mode.org_to_jsonl import OrgToJsonl
from src.search_filter.word_filter import WordFilter from src.search_filter.word_filter import WordFilter
from src.search_filter.file_filter import FileFilter from src.search_filter.file_filter import FileFilter
@@ -28,7 +28,7 @@ def test_search_with_invalid_content_type():
user_query = quote("How to call Khoj from Emacs?") user_query = quote("How to call Khoj from Emacs?")
# Act # Act
response = client.get(f"/search?q={user_query}&t=invalid_content_type") response = client.get(f"/api/search?q={user_query}&t=invalid_content_type")
# Assert # Assert
assert response.status_code == 422 assert response.status_code == 422
@@ -43,29 +43,29 @@ def test_search_with_valid_content_type(content_config: ContentConfig, search_co
# config.content_type.image = search_config.image # config.content_type.image = search_config.image
for content_type in ["org", "markdown", "ledger", "music"]: for content_type in ["org", "markdown", "ledger", "music"]:
# Act # Act
response = client.get(f"/search?q=random&t={content_type}") response = client.get(f"/api/search?q=random&t={content_type}")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_reload_with_invalid_content_type(): def test_update_with_invalid_content_type():
# Act # Act
response = client.get(f"/reload?t=invalid_content_type") response = client.get(f"/api/update?t=invalid_content_type")
# Assert # Assert
assert response.status_code == 422 assert response.status_code == 422
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_reload_with_valid_content_type(content_config: ContentConfig, search_config: SearchConfig): def test_update_with_valid_content_type(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
config.content_type = content_config config.content_type = content_config
config.search_type = search_config config.search_type = search_config
for content_type in ["org", "markdown", "ledger", "music"]: for content_type in ["org", "markdown", "ledger", "music"]:
# Act # Act
response = client.get(f"/reload?t={content_type}") response = client.get(f"/api/update?t={content_type}")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@@ -73,7 +73,7 @@ def test_reload_with_valid_content_type(content_config: ContentConfig, search_co
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_regenerate_with_invalid_content_type(): def test_regenerate_with_invalid_content_type():
# Act # Act
response = client.get(f"/regenerate?t=invalid_content_type") response = client.get(f"/api/update?force=true&t=invalid_content_type")
# Assert # Assert
assert response.status_code == 422 assert response.status_code == 422
@@ -87,7 +87,7 @@ def test_regenerate_with_valid_content_type(content_config: ContentConfig, searc
for content_type in ["org", "markdown", "ledger", "music", "image"]: for content_type in ["org", "markdown", "ledger", "music", "image"]:
# Act # Act
response = client.get(f"/regenerate?t={content_type}") response = client.get(f"/api/update?force=true&t={content_type}")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@@ -104,7 +104,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
for query, expected_image_name in query_expected_image_pairs: for query, expected_image_name in query_expected_image_pairs:
# Act # Act
response = client.get(f"/search?q={query}&n=1&t=image") response = client.get(f"/api/search?q={query}&n=1&t=image")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@@ -118,11 +118,11 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_notes_search(content_config: ContentConfig, search_config: SearchConfig): def test_notes_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
user_query = quote("How to git install application?") user_query = quote("How to git install application?")
# Act # Act
response = client.get(f"/search?q={user_query}&n=1&t=org&r=true") response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@@ -135,11 +135,11 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
filters = [WordFilter(), FileFilter()] filters = [WordFilter(), FileFilter()]
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
user_query = quote('+"Emacs" file:"*.org"') user_query = quote('+"Emacs" file:"*.org"')
# Act # Act
response = client.get(f"/search?q={user_query}&n=1&t=org") response = client.get(f"/api/search?q={user_query}&n=1&t=org")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@@ -152,11 +152,11 @@ def test_notes_search_with_only_filters(content_config: ContentConfig, search_co
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
filters = [WordFilter()] filters = [WordFilter()]
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
user_query = quote('How to git install application? +"Emacs"') user_query = quote('How to git install application? +"Emacs"')
# Act # Act
response = client.get(f"/search?q={user_query}&n=1&t=org") response = client.get(f"/api/search?q={user_query}&n=1&t=org")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@@ -169,11 +169,11 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
filters = [WordFilter()] filters = [WordFilter()]
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
user_query = quote('How to git install application? -"clone"') user_query = quote('How to git install application? -"clone"')
# Act # Act
response = client.get(f"/search?q={user_query}&n=1&t=org") response = client.get(f"/api/search?q={user_query}&n=1&t=org")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -3,19 +3,17 @@ import re
from datetime import datetime from datetime import datetime
from math import inf from math import inf
# External Packages
import torch
# Application Packages # Application Packages
from src.search_filter.date_filter import DateFilter from src.search_filter.date_filter import DateFilter
from src.utils.rawconfig import Entry
def test_date_filter(): def test_date_filter():
embeddings = torch.randn(3, 10)
entries = [ entries = [
{'compiled': '', 'raw': 'Entry with no date'}, Entry(compiled='', raw='Entry with no date'),
{'compiled': '', 'raw': 'April Fools entry: 1984-04-01'}, Entry(compiled='', raw='April Fools entry: 1984-04-01'),
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}] Entry(compiled='', raw='Entry with date:1984-04-02')
]
q_with_no_date_filter = 'head tail' q_with_no_date_filter = 'head tail'
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries) ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)

View File

@@ -1,14 +1,12 @@
# External Packages
import torch
# Application Packages # Application Packages
from src.search_filter.file_filter import FileFilter from src.search_filter.file_filter import FileFilter
from src.utils.rawconfig import Entry
def test_no_file_filter(): def test_no_file_filter():
# Arrange # Arrange
file_filter = FileFilter() file_filter = FileFilter()
embeddings, entries = arrange_content() entries = arrange_content()
q_with_no_filter = 'head tail' q_with_no_filter = 'head tail'
# Act # Act
@@ -24,7 +22,7 @@ def test_no_file_filter():
def test_file_filter_with_non_existent_file(): def test_file_filter_with_non_existent_file():
# Arrange # Arrange
file_filter = FileFilter() file_filter = FileFilter()
embeddings, entries = arrange_content() entries = arrange_content()
q_with_no_filter = 'head file:"nonexistent.org" tail' q_with_no_filter = 'head file:"nonexistent.org" tail'
# Act # Act
@@ -40,7 +38,7 @@ def test_file_filter_with_non_existent_file():
def test_single_file_filter(): def test_single_file_filter():
# Arrange # Arrange
file_filter = FileFilter() file_filter = FileFilter()
embeddings, entries = arrange_content() entries = arrange_content()
q_with_no_filter = 'head file:"file 1.org" tail' q_with_no_filter = 'head file:"file 1.org" tail'
# Act # Act
@@ -56,7 +54,7 @@ def test_single_file_filter():
def test_file_filter_with_partial_match(): def test_file_filter_with_partial_match():
# Arrange # Arrange
file_filter = FileFilter() file_filter = FileFilter()
embeddings, entries = arrange_content() entries = arrange_content()
q_with_no_filter = 'head file:"1.org" tail' q_with_no_filter = 'head file:"1.org" tail'
# Act # Act
@@ -72,7 +70,7 @@ def test_file_filter_with_partial_match():
def test_file_filter_with_regex_match(): def test_file_filter_with_regex_match():
# Arrange # Arrange
file_filter = FileFilter() file_filter = FileFilter()
embeddings, entries = arrange_content() entries = arrange_content()
q_with_no_filter = 'head file:"*.org" tail' q_with_no_filter = 'head file:"*.org" tail'
# Act # Act
@@ -88,7 +86,7 @@ def test_file_filter_with_regex_match():
def test_multiple_file_filter(): def test_multiple_file_filter():
# Arrange # Arrange
file_filter = FileFilter() file_filter = FileFilter()
embeddings, entries = arrange_content() entries = arrange_content()
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"' q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
# Act # Act
@@ -102,11 +100,11 @@ def test_multiple_file_filter():
def arrange_content(): def arrange_content():
embeddings = torch.randn(4, 10)
entries = [ entries = [
{'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'}, Entry(compiled='', raw='First Entry', file= 'file 1.org'),
{'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'}, Entry(compiled='', raw='Second Entry', file= 'file2.org'),
{'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'}, Entry(compiled='', raw='Third Entry', file= 'file 1.org'),
{'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}] Entry(compiled='', raw='Fourth Entry', file= 'file2.org')
]
return embeddings, entries return entries

View File

@@ -2,9 +2,6 @@
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
# External Packages
import pytest
# Internal Packages # Internal Packages
from src.utils.state import model from src.utils.state import model
from src.utils.constants import web_directory from src.utils.constants import web_directory
@@ -70,7 +67,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
image_files_url='/static/images', image_files_url='/static/images',
count=1) count=1)
actual_image_path = output_directory.joinpath(Path(results[0]["entry"]).name) actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
actual_image = Image.open(actual_image_path) actual_image = Image.open(actual_image_path)
expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name)) expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name))

View File

@@ -2,7 +2,7 @@
import json import json
# Internal Packages # Internal Packages
from src.processor.markdown.markdown_to_jsonl import extract_markdown_entries, convert_markdown_maps_to_jsonl, convert_markdown_entries_to_maps, get_markdown_files from src.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
def test_markdown_file_with_no_headings_to_jsonl(tmp_path): def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
@@ -16,10 +16,11 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
# Act # Act
# Extract Entries from specified Markdown files # Extract Entries from specified Markdown files
entry_nodes, file_to_entries = extract_markdown_entries(markdown_files=[markdownfile]) entry_nodes, file_to_entries = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entry_nodes, file_to_entries)) jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entry_nodes, file_to_entries))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@@ -37,10 +38,11 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
# Act # Act
# Extract Entries from specified Markdown files # Extract Entries from specified Markdown files
entries, entry_to_file_map = extract_markdown_entries(markdown_files=[markdownfile]) entries, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entries, entry_to_file_map)) jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@@ -62,10 +64,11 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
# Act # Act
# Extract Entries from specified Markdown files # Extract Entries from specified Markdown files
entries, entry_to_file_map = extract_markdown_entries(markdown_files=[markdownfile]) entries, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entries, entry_to_file_map)) jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@@ -93,7 +96,7 @@ def test_get_markdown_files(tmp_path):
input_filter = [tmp_path / 'group1*.md', tmp_path / 'group2*.markdown'] input_filter = [tmp_path / 'group1*.md', tmp_path / 'group2*.markdown']
# Act # Act
extracted_org_files = get_markdown_files(input_files, input_filter) extracted_org_files = MarkdownToJsonl.get_markdown_files(input_files, input_filter)
# Assert # Assert
assert len(extracted_org_files) == 5 assert len(extracted_org_files) == 5

View File

@@ -2,7 +2,7 @@
import json import json
# Internal Packages # Internal Packages
from src.processor.org_mode.org_to_jsonl import convert_org_entries_to_jsonl, convert_org_nodes_to_entries, extract_org_entries, get_org_files from src.processor.org_mode.org_to_jsonl import OrgToJsonl
from src.utils.helpers import is_none_or_empty from src.utils.helpers import is_none_or_empty
@@ -21,8 +21,8 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
for index_heading_entries in [True, False]: for index_heading_entries in [True, False]:
# Act # Act
# Extract entries into jsonl from specified Org files # Extract entries into jsonl from specified Org files
jsonl_string = convert_org_entries_to_jsonl(convert_org_nodes_to_entries( jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries(
*extract_org_entries(org_files=[orgfile]), *OrgToJsonl.extract_org_entries(org_files=[orgfile]),
index_heading_entries=index_heading_entries)) index_heading_entries=index_heading_entries))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@@ -49,10 +49,10 @@ def test_entry_with_body_to_jsonl(tmp_path):
# Act # Act
# Extract Entries from specified Org files # Extract Entries from specified Org files
entries, entry_to_file_map = extract_org_entries(org_files=[orgfile]) entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile])
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
jsonl_string = convert_org_entries_to_jsonl(convert_org_nodes_to_entries(entries, entry_to_file_map)) jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@@ -70,11 +70,11 @@ def test_file_with_no_headings_to_jsonl(tmp_path):
# Act # Act
# Extract Entries from specified Org files # Extract Entries from specified Org files
entry_nodes, file_to_entries = extract_org_entries(org_files=[orgfile]) entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=[orgfile])
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries) entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
jsonl_string = convert_org_entries_to_jsonl(entries) jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert # Assert
@@ -102,7 +102,7 @@ def test_get_org_files(tmp_path):
input_filter = [tmp_path / 'group1*.org', tmp_path / 'group2*.org'] input_filter = [tmp_path / 'group1*.org', tmp_path / 'group2*.org']
# Act # Act
extracted_org_files = get_org_files(input_files, input_filter) extracted_org_files = OrgToJsonl.get_org_files(input_files, input_filter)
# Assert # Assert
assert len(extracted_org_files) == 5 assert len(extracted_org_files) == 5

View File

@@ -9,7 +9,7 @@ import pytest
from src.utils.state import model from src.utils.state import model
from src.search_type import text_search from src.search_type import text_search
from src.utils.rawconfig import ContentConfig, SearchConfig from src.utils.rawconfig import ContentConfig, SearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.processor.org_mode.org_to_jsonl import OrgToJsonl
# Test # Test
@@ -24,7 +24,7 @@ def test_asymmetric_setup_with_missing_file_raises_error(content_config: Content
# Act # Act
# Generate notes embeddings during asymmetric setup # Generate notes embeddings during asymmetric setup
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
text_search.setup(org_to_jsonl, new_org_content_config, search_config.asymmetric, regenerate=True) text_search.setup(OrgToJsonl, new_org_content_config, search_config.asymmetric, regenerate=True)
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@@ -39,7 +39,7 @@ def test_asymmetric_setup_with_empty_file_raises_error(content_config: ContentCo
# Act # Act
# Generate notes embeddings during asymmetric setup # Generate notes embeddings during asymmetric setup
with pytest.raises(ValueError, match=r'^No valid entries found*'): with pytest.raises(ValueError, match=r'^No valid entries found*'):
text_search.setup(org_to_jsonl, new_org_content_config, search_config.asymmetric, regenerate=True) text_search.setup(OrgToJsonl, new_org_content_config, search_config.asymmetric, regenerate=True)
# Cleanup # Cleanup
# delete created test file # delete created test file
@@ -50,7 +50,7 @@ def test_asymmetric_setup_with_empty_file_raises_error(content_config: ContentCo
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
# Act # Act
# Regenerate notes embeddings during asymmetric setup # Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
# Assert # Assert
assert len(notes_model.entries) == 10 assert len(notes_model.entries) == 10
@@ -60,7 +60,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
query = "How to git install application?" query = "How to git install application?"
# Act # Act
@@ -76,14 +76,14 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
# Assert # Assert
# Actual_data should contain "Khoj via Emacs" entry # Actual_data should contain "Khoj via Emacs" entry
search_result = results[0]["entry"] search_result = results[0].entry
assert "git clone" in search_result assert "git clone" in search_result
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig): def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10 assert len(initial_notes_model.corpus_embeddings) == 10
@@ -96,11 +96,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
# regenerate notes jsonl, model embeddings and model to include entry from new file # regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) regenerated_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
# Act # Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
# Assert # Assert
assert len(regenerated_notes_model.entries) == 11 assert len(regenerated_notes_model.entries) == 11
@@ -119,7 +119,7 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig): def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10 assert len(initial_notes_model.corpus_embeddings) == 10
@@ -133,7 +133,7 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# Act # Act
# update embeddings, entries with the newly added note # update embeddings, entries with the newly added note
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
# verify new entry added in updated embeddings, entries # verify new entry added in updated embeddings, entries
assert len(initial_notes_model.entries) == 11 assert len(initial_notes_model.entries) == 11

View File

@@ -1,6 +1,6 @@
# Application Packages # Application Packages
from src.search_filter.word_filter import WordFilter from src.search_filter.word_filter import WordFilter
from src.utils.config import SearchType from src.utils.rawconfig import Entry
def test_no_word_filter(): def test_no_word_filter():
@@ -69,9 +69,10 @@ def test_word_include_and_exclude_filter():
def arrange_content(): def arrange_content():
entries = [ entries = [
{'compiled': '', 'raw': 'Minimal Entry'}, Entry(compiled='', raw='Minimal Entry'),
{'compiled': '', 'raw': 'Entry with exclude_word'}, Entry(compiled='', raw='Entry with exclude_word'),
{'compiled': '', 'raw': 'Entry with include_word'}, Entry(compiled='', raw='Entry with include_word'),
{'compiled': '', 'raw': 'Entry with include_word and exclude_word'}] Entry(compiled='', raw='Entry with include_word and exclude_word')
]
return entries return entries