Modularize, Improve API. Formalize Intermediate Text Content Format. Add Type Checking

- **Improve API Endpoints**
  - ee65a4f Merge /reload, /regenerate into single /update API endpoint
  - 9975497 Type the /search API response to better document the response schema
  - 0521ea1 Put image score breakdown under `additional` field in search response
- **Formalize Intermediary Format to Index Text Content**
  - 7e9298f Use new Text `Entry` class to track text entries in Intermediate Format
  - 02d9440 Use Base `TextToJsonl` class to standardize `<text>_to_jsonl` processors
- **Modularize API router code**
  - e42a38e Split router code into `web_client`, `api`, `api_beta` routers. Version Khoj API
  - d292bdc Remove API versioning. Premature given current state of the codebase
- **Miscellaneous**
  - c467df8 Setup `mypy` for static type checking
  - 2c54813 Remove unused imports, `embeddings` variable from text search tests
This commit is contained in:
Debanjum
2022-10-19 11:23:04 +00:00
committed by GitHub
35 changed files with 788 additions and 704 deletions

13
.mypy.ini Normal file
View File

@@ -0,0 +1,13 @@
[mypy]
strict_optional = False
ignore_missing_imports = True
install_types = True
non_interactive = True
show_error_codes = True
exclude = (?x)(
src/interface/desktop/main_window.py
| src/interface/desktop/file_browser.py
| src/interface/desktop/system_tray.py
| build/*
| tests/*
)

View File

@@ -85,7 +85,7 @@ khoj
### 3. Configure
1. Enable content types and point to files to search in the First Run Screen that pops up on app start
2. Click configure and wait. The app will load ML model, generates embeddings and expose the search API
2. Click `Configure` and wait. The app will download ML models and index the content for search
## Use
@@ -113,7 +113,7 @@ pip install --upgrade khoj-assistant
## Miscellaneous
- The beta [chat](http://localhost:8000/beta/chat) and [search](http://localhost:8000/beta/search) API endpoints use [OpenAI API](https://openai.com/api/)
- The beta [chat](http://localhost:8000/api/beta/chat) and [search](http://localhost:8000/api/beta/search) API endpoints use [OpenAI API](https://openai.com/api/)
- It is disabled by default
- To use it add your `openai-api-key` via the app configure screen
- Warning: *If you use the above beta APIs, your query and top result(s) will be sent to OpenAI for processing*

View File

@@ -6,9 +6,9 @@ import logging
import json
# Internal Packages
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
from src.processor.markdown.markdown_to_jsonl import markdown_to_jsonl
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.processor.ledger.beancount_to_jsonl import BeancountToJsonl
from src.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from src.processor.org_mode.org_to_jsonl import OrgToJsonl
from src.search_type import image_search, text_search
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from src.utils import state
@@ -44,7 +44,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
if (t == SearchType.Org or t == None) and config.content_type.org:
# Extract Entries, Generate Notes Embeddings
model.orgmode_search = text_search.setup(
org_to_jsonl,
OrgToJsonl,
config.content_type.org,
search_config=config.search_type.asymmetric,
regenerate=regenerate,
@@ -54,7 +54,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
if (t == SearchType.Music or t == None) and config.content_type.music:
# Extract Entries, Generate Music Embeddings
model.music_search = text_search.setup(
org_to_jsonl,
OrgToJsonl,
config.content_type.music,
search_config=config.search_type.asymmetric,
regenerate=regenerate,
@@ -64,7 +64,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
# Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup(
markdown_to_jsonl,
MarkdownToJsonl,
config.content_type.markdown,
search_config=config.search_type.asymmetric,
regenerate=regenerate,
@@ -74,7 +74,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
# Extract Entries, Generate Ledger Embeddings
model.ledger_search = text_search.setup(
beancount_to_jsonl,
BeancountToJsonl,
config.content_type.ledger,
search_config=config.search_type.symmetric,
regenerate=regenerate,

View File

@@ -187,8 +187,8 @@ Use `which-key` if available, else display simple message in echo area"
(lambda (args) (format
"\n\n<h2>Score: %s Meta: %s Image: %s</h2>\n\n<a href=\"%s%s\">\n<img src=\"%s%s?%s\" width=%s height=%s>\n</a>"
(cdr (assoc 'score args))
(cdr (assoc 'metadata_score args))
(cdr (assoc 'image_score args))
(cdr (assoc 'metadata_score (assoc 'additional args)))
(cdr (assoc 'image_score (assoc 'additional args)))
khoj-server-url
(cdr (assoc 'entry args))
khoj-server-url
@@ -226,7 +226,7 @@ Use `which-key` if available, else display simple message in echo area"
(defun khoj--get-enabled-content-types ()
"Get content types enabled for search from API."
(let ((config-url (format "%s/config/data" khoj-server-url))
(let ((config-url (format "%s/api/config/data" khoj-server-url))
(url-request-method "GET"))
(with-temp-buffer
(erase-buffer)
@@ -244,7 +244,7 @@ Use `which-key` if available, else display simple message in echo area"
"Construct API Query from QUERY, SEARCH-TYPE and (optional) RERANK params."
(let ((rerank (or rerank "false"))
(encoded-query (url-hexify-string query)))
(format "%s/search?q=%s&t=%s&r=%s&n=%s" khoj-server-url encoded-query search-type rerank khoj-results-count)))
(format "%s/api/search?q=%s&t=%s&r=%s&n=%s" khoj-server-url encoded-query search-type rerank khoj-results-count)))
(defun khoj--query-api-and-render-results (query search-type query-url buffer-name)
"Query Khoj API using QUERY, SEARCH-TYPE, QUERY-URL.

View File

@@ -10,7 +10,7 @@ var emptyValueDefault = "🖊️";
/**
* Fetch the existing config file.
*/
fetch("/config/data")
fetch("/api/config/data")
.then(response => response.json())
.then(data => {
rawConfig = data;
@@ -26,7 +26,7 @@ fetch("/config/data")
configForm.addEventListener("submit", (event) => {
event.preventDefault();
console.log(rawConfig);
fetch("/config/data", {
fetch("/api/config/data", {
method: "POST",
credentials: "same-origin",
headers: {
@@ -46,7 +46,7 @@ regenerateButton.addEventListener("click", (event) => {
event.preventDefault();
regenerateButton.style.cursor = "progress";
regenerateButton.disabled = true;
fetch("/regenerate")
fetch("/api/update?force=true")
.then(response => response.json())
.then(data => {
regenerateButton.style.cursor = "pointer";

View File

@@ -16,7 +16,7 @@
return `
<a href="${item.entry}" class="image-link">
<img id=${item.score} src="${item.entry}?${Math.random()}"
title="Effective Score: ${item.score}, Meta: ${item.metadata_score}, Image: ${item.image_score}"
title="Effective Score: ${item.score}, Meta: ${item.additional.metadata_score}, Image: ${item.additional.image_score}"
class="image">
</a>`
}
@@ -77,8 +77,8 @@
// Generate Backend API URL to execute Search
url = type === "image"
? `/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}`
: `/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}`;
? `/api/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}`
: `/api/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}`;
// Execute Search and Render Results
fetch(url)
@@ -94,7 +94,7 @@
function updateIndex() {
type = document.getElementById("type").value;
fetch(`/reload?t=${type}`)
fetch(`/api/update?t=${type}`)
.then(response => response.json())
.then(data => {
console.log(data);
@@ -118,7 +118,7 @@
function populate_type_dropdown() {
// Populate type dropdown field with enabled search types only
var possible_search_types = ["org", "markdown", "ledger", "music", "image"];
fetch("/config/data")
fetch("/api/config/data")
.then(response => response.json())
.then(data => {
document.getElementById("type").innerHTML =

View File

@@ -19,7 +19,9 @@ from PyQt6.QtCore import QThread, QTimer
# Internal Packages
from src.configure import configure_server
from src.router import router
from src.routers.api import api
from src.routers.api_beta import api_beta
from src.routers.web_client import web_client
from src.utils import constants, state
from src.utils.cli import cli
from src.interface.desktop.main_window import MainWindow
@@ -29,7 +31,9 @@ from src.interface.desktop.system_tray import create_system_tray
# Initialize the Application Server
app = FastAPI()
app.mount("/static", StaticFiles(directory=constants.web_directory), name="static")
app.include_router(router)
app.include_router(api, prefix="/api")
app.include_router(api_beta, prefix="/api/beta")
app.include_router(web_client)
logger = logging.getLogger('src')

View File

@@ -1,128 +1,127 @@
#!/usr/bin/env python3
# Standard Packages
import json
import glob
import re
import logging
import time
# Internal Packages
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update
from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import TextContentConfig
from src.utils.rawconfig import Entry
logger = logging.getLogger(__name__)
# Define Functions
def beancount_to_jsonl(config: TextContentConfig, previous_entries=None):
# Extract required fields from config
beancount_files, beancount_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl
class BeancountToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=None):
# Extract required fields from config
beancount_files, beancount_file_filter, output_file = self.config.input_files, self.config.input_filter,self.config.compressed_jsonl
# Input Validation
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
print("At least one of beancount-files or beancount-file-filter is required to be specified")
exit(1)
# Input Validation
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
print("At least one of beancount-files or beancount-file-filter is required to be specified")
exit(1)
# Get Beancount Files to Process
beancount_files = get_beancount_files(beancount_files, beancount_file_filter)
# Get Beancount Files to Process
beancount_files = BeancountToJsonl.get_beancount_files(beancount_files, beancount_file_filter)
# Extract Entries from specified Beancount files
start = time.time()
current_entries = convert_transactions_to_maps(*extract_beancount_transactions(beancount_files))
end = time.time()
logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds")
# Extract Entries from specified Beancount files
start = time.time()
current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files))
end = time.time()
logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries
start = time.time()
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time()
logger.debug(f"Identify new or updated transaction: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries
start = time.time()
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time()
logger.debug(f"Identify new or updated transaction: {end - start} seconds")
# Process Each Entry from All Notes Files
start = time.time()
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = convert_transaction_maps_to_jsonl(entries)
# Process Each Entry from All Notes Files
start = time.time()
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = BeancountToJsonl.convert_transaction_maps_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write transactions to JSONL file: {end - start} seconds")
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write transactions to JSONL file: {end - start} seconds")
return entries_with_ids
return entries_with_ids
@staticmethod
def get_beancount_files(beancount_files=None, beancount_file_filters=None):
"Get Beancount files to process"
absolute_beancount_files, filtered_beancount_files = set(), set()
if beancount_files:
absolute_beancount_files = {get_absolute_path(beancount_file)
for beancount_file
in beancount_files}
if beancount_file_filters:
filtered_beancount_files = {
filtered_file
for beancount_file_filter in beancount_file_filters
for filtered_file in glob.glob(get_absolute_path(beancount_file_filter))
}
def get_beancount_files(beancount_files=None, beancount_file_filters=None):
"Get Beancount files to process"
absolute_beancount_files, filtered_beancount_files = set(), set()
if beancount_files:
absolute_beancount_files = {get_absolute_path(beancount_file)
for beancount_file
in beancount_files}
if beancount_file_filters:
filtered_beancount_files = {
filtered_file
for beancount_file_filter in beancount_file_filters
for filtered_file in glob.glob(get_absolute_path(beancount_file_filter))
all_beancount_files = sorted(absolute_beancount_files | filtered_beancount_files)
files_with_non_beancount_extensions = {
beancount_file
for beancount_file
in all_beancount_files
if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount")
}
if any(files_with_non_beancount_extensions):
print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}")
all_beancount_files = sorted(absolute_beancount_files | filtered_beancount_files)
logger.info(f'Processing files: {all_beancount_files}')
files_with_non_beancount_extensions = {
beancount_file
for beancount_file
in all_beancount_files
if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount")
}
if any(files_with_non_beancount_extensions):
print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}")
return all_beancount_files
logger.info(f'Processing files: {all_beancount_files}')
@staticmethod
def extract_beancount_transactions(beancount_files):
"Extract entries from specified Beancount files"
return all_beancount_files
# Initialize Regex for extracting Beancount Entries
transaction_regex = r'^\n?\d{4}-\d{2}-\d{2} [\*|\!] '
empty_newline = f'^[\n\r\t\ ]*$'
entries = []
transaction_to_file_map = []
for beancount_file in beancount_files:
with open(beancount_file) as f:
ledger_content = f.read()
transactions_per_file = [entry.strip(empty_escape_sequences)
for entry
in re.split(empty_newline, ledger_content, flags=re.MULTILINE)
if re.match(transaction_regex, entry)]
transaction_to_file_map += zip(transactions_per_file, [beancount_file]*len(transactions_per_file))
entries.extend(transactions_per_file)
return entries, dict(transaction_to_file_map)
def extract_beancount_transactions(beancount_files):
"Extract entries from specified Beancount files"
@staticmethod
def convert_transactions_to_maps(parsed_entries: list[str], transaction_to_file_map) -> list[Entry]:
"Convert each parsed Beancount transaction into a Entry"
entries = []
for parsed_entry in parsed_entries:
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{transaction_to_file_map[parsed_entry]}'))
# Initialize Regex for extracting Beancount Entries
transaction_regex = r'^\n?\d{4}-\d{2}-\d{2} [\*|\!] '
empty_newline = f'^[\n\r\t\ ]*$'
logger.info(f"Converted {len(parsed_entries)} transactions to dictionaries")
entries = []
transaction_to_file_map = []
for beancount_file in beancount_files:
with open(beancount_file) as f:
ledger_content = f.read()
transactions_per_file = [entry.strip(empty_escape_sequences)
for entry
in re.split(empty_newline, ledger_content, flags=re.MULTILINE)
if re.match(transaction_regex, entry)]
transaction_to_file_map += zip(transactions_per_file, [beancount_file]*len(transactions_per_file))
entries.extend(transactions_per_file)
return entries, dict(transaction_to_file_map)
return entries
def convert_transactions_to_maps(entries: list[str], transaction_to_file_map) -> list[dict]:
"Convert each Beancount transaction into a dictionary"
entry_maps = []
for entry in entries:
entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{transaction_to_file_map[entry]}'})
logger.info(f"Converted {len(entries)} transactions to dictionaries")
return entry_maps
def convert_transaction_maps_to_jsonl(entries: list[dict]) -> str:
"Convert each Beancount transaction dictionary to JSON and collate as JSONL"
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])
@staticmethod
def convert_transaction_maps_to_jsonl(entries: list[Entry]) -> str:
"Convert each Beancount transaction entry to JSON and collate as JSONL"
return ''.join([f'{entry.to_json()}\n' for entry in entries])

View File

@@ -1,127 +1,126 @@
#!/usr/bin/env python3
# Standard Packages
import json
import glob
import re
import logging
import time
# Internal Packages
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update
from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import TextContentConfig
from src.utils.rawconfig import Entry
logger = logging.getLogger(__name__)
# Define Functions
def markdown_to_jsonl(config: TextContentConfig, previous_entries=None):
# Extract required fields from config
markdown_files, markdown_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl
class MarkdownToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=None):
# Extract required fields from config
markdown_files, markdown_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
# Input Validation
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
print("At least one of markdown-files or markdown-file-filter is required to be specified")
exit(1)
# Input Validation
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
print("At least one of markdown-files or markdown-file-filter is required to be specified")
exit(1)
# Get Markdown Files to Process
markdown_files = get_markdown_files(markdown_files, markdown_file_filter)
# Get Markdown Files to Process
markdown_files = MarkdownToJsonl.get_markdown_files(markdown_files, markdown_file_filter)
# Extract Entries from specified Markdown files
start = time.time()
current_entries = convert_markdown_entries_to_maps(*extract_markdown_entries(markdown_files))
end = time.time()
logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds")
# Extract Entries from specified Markdown files
start = time.time()
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files))
end = time.time()
logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries
start = time.time()
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time()
logger.debug(f"Identify new or updated entries: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries
start = time.time()
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time()
logger.debug(f"Identify new or updated entries: {end - start} seconds")
# Process Each Entry from All Notes Files
start = time.time()
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = convert_markdown_maps_to_jsonl(entries)
# Process Each Entry from All Notes Files
start = time.time()
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write markdown entries to JSONL file: {end - start} seconds")
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write markdown entries to JSONL file: {end - start} seconds")
return entries_with_ids
return entries_with_ids
@staticmethod
def get_markdown_files(markdown_files=None, markdown_file_filters=None):
"Get Markdown files to process"
absolute_markdown_files, filtered_markdown_files = set(), set()
if markdown_files:
absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files}
if markdown_file_filters:
filtered_markdown_files = {
filtered_file
for markdown_file_filter in markdown_file_filters
for filtered_file in glob.glob(get_absolute_path(markdown_file_filter))
}
def get_markdown_files(markdown_files=None, markdown_file_filters=None):
"Get Markdown files to process"
absolute_markdown_files, filtered_markdown_files = set(), set()
if markdown_files:
absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files}
if markdown_file_filters:
filtered_markdown_files = {
filtered_file
for markdown_file_filter in markdown_file_filters
for filtered_file in glob.glob(get_absolute_path(markdown_file_filter))
all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files)
files_with_non_markdown_extensions = {
md_file
for md_file
in all_markdown_files
if not md_file.endswith(".md") and not md_file.endswith('.markdown')
}
all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files)
if any(files_with_non_markdown_extensions):
logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
files_with_non_markdown_extensions = {
md_file
for md_file
in all_markdown_files
if not md_file.endswith(".md") and not md_file.endswith('.markdown')
}
logger.info(f'Processing files: {all_markdown_files}')
if any(files_with_non_markdown_extensions):
logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
return all_markdown_files
logger.info(f'Processing files: {all_markdown_files}')
@staticmethod
def extract_markdown_entries(markdown_files):
"Extract entries by heading from specified Markdown files"
return all_markdown_files
# Regex to extract Markdown Entries by Heading
markdown_heading_regex = r'^#'
entries = []
entry_to_file_map = []
for markdown_file in markdown_files:
with open(markdown_file) as f:
markdown_content = f.read()
markdown_entries_per_file = [f'#{entry.strip(empty_escape_sequences)}'
for entry
in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE)
if entry.strip(empty_escape_sequences) != '']
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file]*len(markdown_entries_per_file))
entries.extend(markdown_entries_per_file)
def extract_markdown_entries(markdown_files):
"Extract entries by heading from specified Markdown files"
return entries, dict(entry_to_file_map)
# Regex to extract Markdown Entries by Heading
markdown_heading_regex = r'^#'
@staticmethod
def convert_markdown_entries_to_maps(parsed_entries: list[str], entry_to_file_map) -> list[Entry]:
"Convert each Markdown entries into a dictionary"
entries = []
for parsed_entry in parsed_entries:
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{entry_to_file_map[parsed_entry]}'))
entries = []
entry_to_file_map = []
for markdown_file in markdown_files:
with open(markdown_file) as f:
markdown_content = f.read()
markdown_entries_per_file = [f'#{entry.strip(empty_escape_sequences)}'
for entry
in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE)
if entry.strip(empty_escape_sequences) != '']
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file]*len(markdown_entries_per_file))
entries.extend(markdown_entries_per_file)
logger.info(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
return entries, dict(entry_to_file_map)
return entries
def convert_markdown_entries_to_maps(entries: list[str], entry_to_file_map) -> list[dict]:
"Convert each Markdown entries into a dictionary"
entry_maps = []
for entry in entries:
entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{entry_to_file_map[entry]}'})
logger.info(f"Converted {len(entries)} markdown entries to dictionaries")
return entry_maps
def convert_markdown_maps_to_jsonl(entries):
"Convert each Markdown entries to JSON and collate as JSONL"
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])
@staticmethod
def convert_markdown_maps_to_jsonl(entries: list[Entry]):
"Convert each Markdown entry to JSON and collate as JSONL"
return ''.join([f'{entry.to_json()}\n' for entry in entries])

View File

@@ -1,7 +1,4 @@
#!/usr/bin/env python3
# Standard Packages
import json
import glob
import logging
import time
@@ -9,147 +6,146 @@ from typing import Iterable
# Internal Packages
from src.processor.org_mode import orgnode
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update
from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import Entry
from src.utils import state
from src.utils.rawconfig import TextContentConfig
logger = logging.getLogger(__name__)
# Define Functions
def org_to_jsonl(config: TextContentConfig, previous_entries=None):
# Extract required fields from config
org_files, org_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl
index_heading_entries = config.index_heading_entries
class OrgToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries: list[Entry]=None):
# Extract required fields from config
org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
index_heading_entries = self.config.index_heading_entries
# Input Validation
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter):
print("At least one of org-files or org-file-filter is required to be specified")
exit(1)
# Input Validation
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter):
print("At least one of org-files or org-file-filter is required to be specified")
exit(1)
# Get Org Files to Process
start = time.time()
org_files = get_org_files(org_files, org_file_filter)
# Get Org Files to Process
start = time.time()
org_files = OrgToJsonl.get_org_files(org_files, org_file_filter)
# Extract Entries from specified Org files
start = time.time()
entry_nodes, file_to_entries = extract_org_entries(org_files)
end = time.time()
logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds")
# Extract Entries from specified Org files
start = time.time()
entry_nodes, file_to_entries = self.extract_org_entries(org_files)
end = time.time()
logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds")
start = time.time()
current_entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
end = time.time()
logger.debug(f"Convert OrgNodes into entry dictionaries: {end - start} seconds")
start = time.time()
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
end = time.time()
logger.debug(f"Convert OrgNodes into entry dictionaries: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
# Identify, mark and merge any new entries with previous entries
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
# Process Each Entry from All Notes Files
start = time.time()
entries = map(lambda entry: entry[1], entries_with_ids)
jsonl_data = convert_org_entries_to_jsonl(entries)
# Process Each Entry from All Notes Files
start = time.time()
entries = map(lambda entry: entry[1], entries_with_ids)
jsonl_data = self.convert_org_entries_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write org entries to JSONL file: {end - start} seconds")
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write org entries to JSONL file: {end - start} seconds")
return entries_with_ids
return entries_with_ids
@staticmethod
def get_org_files(org_files=None, org_file_filters=None):
"Get Org files to process"
absolute_org_files, filtered_org_files = set(), set()
if org_files:
absolute_org_files = {
get_absolute_path(org_file)
for org_file
in org_files
}
if org_file_filters:
filtered_org_files = {
filtered_file
for org_file_filter in org_file_filters
for filtered_file in glob.glob(get_absolute_path(org_file_filter))
}
def get_org_files(org_files=None, org_file_filters=None):
"Get Org files to process"
absolute_org_files, filtered_org_files = set(), set()
if org_files:
absolute_org_files = {
get_absolute_path(org_file)
for org_file
in org_files
}
if org_file_filters:
filtered_org_files = {
filtered_file
for org_file_filter in org_file_filters
for filtered_file in glob.glob(get_absolute_path(org_file_filter))
}
all_org_files = sorted(absolute_org_files | filtered_org_files)
all_org_files = sorted(absolute_org_files | filtered_org_files)
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
if any(files_with_non_org_extensions):
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
if any(files_with_non_org_extensions):
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
logger.info(f'Processing files: {all_org_files}')
logger.info(f'Processing files: {all_org_files}')
return all_org_files
return all_org_files
@staticmethod
def extract_org_entries(org_files):
"Extract entries from specified Org files"
entries = []
entry_to_file_map = []
for org_file in org_files:
org_file_entries = orgnode.makelist(str(org_file))
entry_to_file_map += zip(org_file_entries, [org_file]*len(org_file_entries))
entries.extend(org_file_entries)
return entries, dict(entry_to_file_map)
def extract_org_entries(org_files):
"Extract entries from specified Org files"
entries = []
entry_to_file_map = []
for org_file in org_files:
org_file_entries = orgnode.makelist(str(org_file))
entry_to_file_map += zip(org_file_entries, [org_file]*len(org_file_entries))
entries.extend(org_file_entries)
@staticmethod
def convert_org_nodes_to_entries(parsed_entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[Entry]:
"Convert Org-Mode nodes into list of Entry objects"
entries: list[Entry] = []
for parsed_entry in parsed_entries:
if not parsed_entry.hasBody and not index_heading_entries:
# Ignore title notes i.e notes with just headings and empty body
continue
return entries, dict(entry_to_file_map)
def convert_org_nodes_to_entries(entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[dict]:
"Convert Org-Mode entries into list of dictionary"
entry_maps = []
for entry in entries:
entry_dict = dict()
if not entry.hasBody and not index_heading_entries:
# Ignore title notes i.e notes with just headings and empty body
continue
entry_dict["compiled"] = f'{entry.heading}.'
if state.verbose > 2:
logger.debug(f"Title: {entry.heading}")
if entry.tags:
tags_str = " ".join(entry.tags)
entry_dict["compiled"] += f'\t {tags_str}.'
compiled = f'{parsed_entry.heading}.'
if state.verbose > 2:
logger.debug(f"Tags: {tags_str}")
logger.debug(f"Title: {parsed_entry.heading}")
if entry.closed:
entry_dict["compiled"] += f'\n Closed on {entry.closed.strftime("%Y-%m-%d")}.'
if state.verbose > 2:
logger.debug(f'Closed: {entry.closed.strftime("%Y-%m-%d")}')
if parsed_entry.tags:
tags_str = " ".join(parsed_entry.tags)
compiled += f'\t {tags_str}.'
if state.verbose > 2:
logger.debug(f"Tags: {tags_str}")
if entry.scheduled:
entry_dict["compiled"] += f'\n Scheduled for {entry.scheduled.strftime("%Y-%m-%d")}.'
if state.verbose > 2:
logger.debug(f'Scheduled: {entry.scheduled.strftime("%Y-%m-%d")}')
if parsed_entry.closed:
compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
if state.verbose > 2:
logger.debug(f'Closed: {parsed_entry.closed.strftime("%Y-%m-%d")}')
if entry.hasBody:
entry_dict["compiled"] += f'\n {entry.body}'
if state.verbose > 2:
logger.debug(f"Body: {entry.body}")
if parsed_entry.scheduled:
compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
if state.verbose > 2:
logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}')
if entry_dict:
entry_dict["raw"] = f'{entry}'
entry_dict["file"] = f'{entry_to_file_map[entry]}'
if parsed_entry.hasBody:
compiled += f'\n {parsed_entry.body}'
if state.verbose > 2:
logger.debug(f"Body: {parsed_entry.body}")
# Convert Dictionary to JSON and Append to JSONL string
entry_maps.append(entry_dict)
if compiled:
entries += [Entry(
compiled=compiled,
raw=f'{parsed_entry}',
file=f'{entry_to_file_map[parsed_entry]}')]
return entry_maps
return entries
def convert_org_entries_to_jsonl(entries: Iterable[dict]) -> str:
"Convert each Org-Mode entry to JSON and collate as JSONL"
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])
@staticmethod
def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str:
"Convert each Org-Mode entry to JSON and collate as JSONL"
return ''.join([f'{entry_dict.to_json()}\n' for entry_dict in entries])

View File

@@ -0,0 +1,54 @@
# Standard Packages
from abc import ABC, abstractmethod
import hashlib
import time
import logging
# Internal Packages
from src.utils.rawconfig import Entry, TextContentConfig
logger = logging.getLogger(__name__)
class TextToJsonl(ABC):
def __init__(self, config: TextContentConfig):
self.config = config
@abstractmethod
def process(self, previous_entries: list[Entry]=None) -> list[tuple[int, Entry]]: ...
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
# Hash all current and previous entries to identify new entries
start = time.time()
current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(getattr(e, key), encoding='utf-8')).hexdigest(), current_entries))
previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(getattr(e, key), encoding='utf-8')).hexdigest(), previous_entries))
end = time.time()
logger.debug(f"Hash previous, current entries: {end - start} seconds")
start = time.time()
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
# All entries that did not exist in the previous set are to be added
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
# All entries that exist in both current and previous sets are kept
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# Mark new entries with -1 id to flag for later embeddings generation
new_entries = [
(-1, hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes
]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
for entry_hash in existing_entry_hashes
]
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
entries_with_ids = existing_entries_sorted + new_entries
end = time.time()
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
return entries_with_ids

View File

@@ -1,220 +0,0 @@
# Standard Packages
import yaml
import json
import time
import logging
from typing import Optional
# External Packages
from fastapi import APIRouter
from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.templating import Jinja2Templates
# Internal Packages
from src.configure import configure_search
from src.search_type import image_search, text_search
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from src.utils.rawconfig import FullConfig
from src.utils.config import SearchType
from src.utils.helpers import LRU, get_absolute_path, get_from_dict
from src.utils import state, constants
router = APIRouter()
templates = Jinja2Templates(directory=constants.web_directory)
logger = logging.getLogger(__name__)
query_cache = LRU()
@router.get("/", response_class=FileResponse)
def index():
return FileResponse(constants.web_directory / "index.html")
@router.get('/config', response_class=HTMLResponse)
def config_page(request: Request):
return templates.TemplateResponse("config.html", context={'request': request})
@router.get('/config/data', response_model=FullConfig)
def config_data():
return state.config
@router.post('/config/data')
async def config_data(updated_config: FullConfig):
state.config = updated_config
with open(state.config_file, 'w') as outfile:
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
outfile.close()
return state.config
@router.get('/search')
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
if q is None or q == '':
logger.info(f'No query param (q) passed in API call to initiate search')
return {}
# initialize variables
user_query = q.strip()
results_count = n
results = {}
query_start, query_end, collate_start, collate_end = None, None, None, None
# return cached results, if available
query_cache_key = f'{user_query}-{n}-{t}-{r}'
if query_cache_key in state.query_cache:
logger.info(f'Return response from query cache')
return state.query_cache[query_cache_key]
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
# query org-mode notes
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Music or t == None) and state.model.music_search:
# query music library
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
# query markdown files
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
# query transactions
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Image or t == None) and state.model.image_search:
# query images
query_start = time.time()
hits = image_search.query(user_query, results_count, state.model.image_search)
output_directory = constants.web_directory / 'images'
query_end = time.time()
# collate and return results
collate_start = time.time()
results = image_search.collate_results(
hits,
image_names=state.model.image_search.image_names,
output_directory=output_directory,
image_files_url='/static/images',
count=results_count)
collate_end = time.time()
# Cache results
state.query_cache[query_cache_key] = results
if query_start and query_end:
logger.debug(f"Query took {query_end - query_start:.3f} seconds")
if collate_start and collate_end:
logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds")
return results
@router.get('/reload')
def reload(t: Optional[SearchType] = None):
state.model = configure_search(state.model, state.config, regenerate=False, t=t)
return {'status': 'ok', 'message': 'reload completed'}
@router.get('/regenerate')
def regenerate(t: Optional[SearchType] = None):
state.model = configure_search(state.model, state.config, regenerate=True, t=t)
return {'status': 'ok', 'message': 'regeneration completed'}
@router.get('/beta/search')
def search_beta(q: str, n: Optional[int] = 1):
# Extract Search Type using GPT
metadata = extract_search_type(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose)
search_type = get_from_dict(metadata, "search-type")
# Search
search_results = search(q, n=n, t=SearchType(search_type))
# Return response
return {'status': 'ok', 'result': search_results, 'type': search_type}
@router.get('/beta/chat')
def chat(q: str):
# Load Conversation History
chat_session = state.processor_config.conversation.chat_session
meta_log = state.processor_config.conversation.meta_log
# Converse with OpenAI GPT
metadata = understand(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose)
if state.verbose > 1:
print(f'Understood: {get_from_dict(metadata, "intent")}')
if get_from_dict(metadata, "intent", "memory-type") == "notes":
query = get_from_dict(metadata, "intent", "query")
result_list = search(query, n=1, t=SearchType.Org)
collated_result = "\n".join([item["entry"] for item in result_list])
if state.verbose > 1:
print(f'Semantically Similar Notes:\n{collated_result}')
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, api_key=state.processor_config.conversation.openai_api_key)
else:
gpt_response = converse(q, chat_session, api_key=state.processor_config.conversation.openai_api_key)
# Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log['chat'] = message_to_log(q, metadata, gpt_response, meta_log.get('chat', []))
return {'status': 'ok', 'response': gpt_response}
@router.on_event('shutdown')
def shutdown_event():
# No need to create empty log file
if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log):
return
elif state.processor_config.conversation.verbose:
print('INFO:\tSaving conversation logs to disk...')
# Summarize Conversation Logs for this Session
chat_session = state.processor_config.conversation.chat_session
openai_api_key = state.processor_config.conversation.openai_api_key
conversation_log = state.processor_config.conversation.meta_log
session = {
"summary": summarize(chat_session, summary_type="chat", api_key=openai_api_key),
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"])
}
if 'session' in conversation_log:
conversation_log['session'].append(session)
else:
conversation_log['session'] = [session]
# Save Conversation Metadata Logs to Disk
conversation_logfile = get_absolute_path(state.processor_config.conversation.conversation_logfile)
with open(conversation_logfile, "w+", encoding='utf-8') as logfile:
json.dump(conversation_log, logfile)
print('INFO:\tConversation logs saved to disk.')

129
src/routers/api.py Normal file
View File

@@ -0,0 +1,129 @@
# Standard Packages
import yaml
import time
import logging
from typing import Optional
# External Packages
from fastapi import APIRouter
# Internal Packages
from src.configure import configure_search
from src.search_type import image_search, text_search
from src.utils.rawconfig import FullConfig, SearchResponse
from src.utils.config import SearchType
from src.utils import state, constants
# Initialize Router
api = APIRouter()
logger = logging.getLogger(__name__)
# Create Routes
@api.get('/config/data', response_model=FullConfig)
def get_config_data():
return state.config
@api.post('/config/data')
async def set_config_data(updated_config: FullConfig):
state.config = updated_config
with open(state.config_file, 'w') as outfile:
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
outfile.close()
return state.config
@api.get('/search', response_model=list[SearchResponse])
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
results: list[SearchResponse] = []
if q is None or q == '':
logger.info(f'No query param (q) passed in API call to initiate search')
return results
# initialize variables
user_query = q.strip()
results_count = n
query_start, query_end, collate_start, collate_end = None, None, None, None
# return cached results, if available
query_cache_key = f'{user_query}-{n}-{t}-{r}'
if query_cache_key in state.query_cache:
logger.info(f'Return response from query cache')
return state.query_cache[query_cache_key]
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
# query org-mode notes
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Music or t == None) and state.model.music_search:
# query music library
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
# query markdown files
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
# query transactions
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
query_end = time.time()
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Image or t == None) and state.model.image_search:
# query images
query_start = time.time()
hits = image_search.query(user_query, results_count, state.model.image_search)
output_directory = constants.web_directory / 'images'
query_end = time.time()
# collate and return results
collate_start = time.time()
results = image_search.collate_results(
hits,
image_names=state.model.image_search.image_names,
output_directory=output_directory,
image_files_url='/static/images',
count=results_count)
collate_end = time.time()
# Cache results
state.query_cache[query_cache_key] = results
if query_start and query_end:
logger.debug(f"Query took {query_end - query_start:.3f} seconds")
if collate_start and collate_end:
logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds")
return results
@api.get('/update')
def update(t: Optional[SearchType] = None, force: Optional[bool] = False):
state.model = configure_search(state.model, state.config, regenerate=force, t=t)
return {'status': 'ok', 'message': 'index updated'}

88
src/routers/api_beta.py Normal file
View File

@@ -0,0 +1,88 @@
# Standard Packages
import json
import logging
from typing import Optional
# External Packages
from fastapi import APIRouter
# Internal Packages
from src.routers.api import search
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from src.utils.config import SearchType
from src.utils.helpers import get_absolute_path, get_from_dict
from src.utils import state
# Initialize Router
api_beta = APIRouter()
logger = logging.getLogger(__name__)
# Create Routes
@api_beta.get('/search')
def search_beta(q: str, n: Optional[int] = 1):
# Extract Search Type using GPT
metadata = extract_search_type(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose)
search_type = get_from_dict(metadata, "search-type")
# Search
search_results = search(q, n=n, t=SearchType(search_type))
# Return response
return {'status': 'ok', 'result': search_results, 'type': search_type}
@api_beta.get('/chat')
def chat(q: str):
# Load Conversation History
chat_session = state.processor_config.conversation.chat_session
meta_log = state.processor_config.conversation.meta_log
# Converse with OpenAI GPT
metadata = understand(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose)
logger.debug(f'Understood: {get_from_dict(metadata, "intent")}')
if get_from_dict(metadata, "intent", "memory-type") == "notes":
query = get_from_dict(metadata, "intent", "query")
result_list = search(query, n=1, t=SearchType.Org)
collated_result = "\n".join([item["entry"] for item in result_list])
logger.debug(f'Semantically Similar Notes:\n{collated_result}')
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, api_key=state.processor_config.conversation.openai_api_key)
else:
gpt_response = converse(q, chat_session, api_key=state.processor_config.conversation.openai_api_key)
# Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log['chat'] = message_to_log(q, metadata, gpt_response, meta_log.get('chat', []))
return {'status': 'ok', 'response': gpt_response}
@api_beta.on_event('shutdown')
def shutdown_event():
# No need to create empty log file
if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log):
return
logger.debug('INFO:\tSaving conversation logs to disk...')
# Summarize Conversation Logs for this Session
chat_session = state.processor_config.conversation.chat_session
openai_api_key = state.processor_config.conversation.openai_api_key
conversation_log = state.processor_config.conversation.meta_log
session = {
"summary": summarize(chat_session, summary_type="chat", api_key=openai_api_key),
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"])
}
if 'session' in conversation_log:
conversation_log['session'].append(session)
else:
conversation_log['session'] = [session]
# Save Conversation Metadata Logs to Disk
conversation_logfile = get_absolute_path(state.processor_config.conversation.conversation_logfile)
with open(conversation_logfile, "w+", encoding='utf-8') as logfile:
json.dump(conversation_log, logfile)
logger.info('INFO:\tConversation logs saved to disk.')

23
src/routers/web_client.py Normal file
View File

@@ -0,0 +1,23 @@
# External Packages
from fastapi import APIRouter
from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.templating import Jinja2Templates
# Internal Packages
from src.utils import constants
# Initialize Router
web_client = APIRouter()
templates = Jinja2Templates(directory=constants.web_directory)
# Create Routes
@web_client.get("/", response_class=FileResponse)
def index():
return FileResponse(constants.web_directory / "index.html")
@web_client.get('/config', response_class=HTMLResponse)
def config_page(request: Request):
return templates.TemplateResponse("config.html", context={'request': request})

View File

@@ -37,7 +37,7 @@ class DateFilter(BaseFilter):
start = time.time()
for id, entry in enumerate(entries):
# Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]):
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
# Convert date string in entry to unix timestamp
try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()

View File

@@ -24,7 +24,7 @@ class FileFilter(BaseFilter):
def load(self, entries, *args, **kwargs):
start = time.time()
for id, entry in enumerate(entries):
self.file_to_entry_map[entry[self.entry_key]].add(id)
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
end = time.time()
logger.debug(f"Created file filter index: {end - start} seconds")

View File

@@ -29,7 +29,7 @@ class WordFilter(BaseFilter):
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
# Create map of words to entries they exist in
for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, entry[self.entry_key].lower()):
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
if word == '':
continue
self.word_to_entry_index[word].add(entry_index)

View File

@@ -15,7 +15,7 @@ import torch
# Internal Packages
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model
from src.utils.config import ImageSearchModel
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
# Create Logger
@@ -203,8 +203,8 @@ def render_results(hits, image_names, image_directory, count):
img.show()
def collate_results(hits, image_names, output_directory, image_files_url, count=5):
results = []
def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> list[SearchResponse]:
results: list[SearchResponse] = []
for index, hit in enumerate(hits[:count]):
source_path = image_names[hit['corpus_id']]
@@ -220,12 +220,17 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
shutil.copy(source_path, target_path)
# Add the image metadata to the results
results += [{
"entry": f'{image_files_url}/{target_image_name}',
"score": f"{hit['score']:.9f}",
"image_score": f"{hit['image_score']:.9f}",
"metadata_score": f"{hit['metadata_score']:.9f}",
}]
results += [SearchResponse.parse_obj(
{
"entry": f'{image_files_url}/{target_image_name}',
"score": f"{hit['score']:.9f}",
"additional":
{
"image_score": f"{hit['image_score']:.9f}",
"metadata_score": f"{hit['metadata_score']:.9f}",
}
}
)]
return results

View File

@@ -1,17 +1,19 @@
# Standard Packages
import logging
import time
from typing import Type
# External Packages
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from src.processor.text_to_jsonl import TextToJsonl
from src.search_filter.base_filter import BaseFilter
# Internal Packages
from src.utils import state
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
from src.utils.jsonl import load_jsonl
@@ -48,12 +50,12 @@ def initialize_model(search_config: TextSearchConfig):
return bi_encoder, cross_encoder, top_k
def extract_entries(jsonl_file):
def extract_entries(jsonl_file) -> list[Entry]:
"Load entries from compressed jsonl"
return load_jsonl(jsonl_file)
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False):
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, embeddings_file, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = []
# Load pre-computed embeddings from file if exists and update them if required
@@ -62,15 +64,15 @@ def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate
logger.info(f"Loaded embeddings from {embeddings_file}")
# Encode any new entries in the corpus and update corpus embeddings
new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None]
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
if new_entries:
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
existing_entry_ids = [id for id, _ in entries_with_ids if id is not None]
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor()
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
# Else compute the corpus embeddings from scratch
else:
new_entries = [entry['compiled'] for _, entry in entries_with_ids]
new_entries = [entry.compiled for _, entry in entries_with_ids]
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
# Save regenerated or updated embeddings to file
@@ -131,7 +133,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
# Score all retrieved entries using the cross-encoder
if rank_results:
start = time.time()
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time()
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
@@ -151,7 +153,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
return hits, entries
def render_results(hits, entries, count=5, display_biencoder_results=False):
def render_results(hits, entries: list[Entry], count=5, display_biencoder_results=False):
"Render the Results returned by Search for the Query"
if display_biencoder_results:
# Output of top hits from bi-encoder
@@ -159,34 +161,34 @@ def render_results(hits, entries, count=5, display_biencoder_results=False):
print(f"Top-{count} Bi-Encoder Retrieval hits")
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
for hit in hits[0:count]:
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['compiled']}")
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']].compiled}")
# Output of top hits from re-ranker
print("\n-------------------------\n")
print(f"Top-{count} Cross-Encoder Re-ranker hits")
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
for hit in hits[0:count]:
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}")
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']].compiled}")
def collate_results(hits, entries, count=5):
return [
def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse]:
return [SearchResponse.parse_obj(
{
"entry": entries[hit['corpus_id']]['raw'],
"entry": entries[hit['corpus_id']].raw,
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
}
})
for hit
in hits[0:count]]
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel:
def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
entries_with_indices = text_to_jsonl(config, previous_entries)
entries_with_indices = text_to_jsonl(config).process(previous_entries)
# Extract Updated Entries
entries = extract_entries(config.compressed_jsonl)

View File

@@ -1,8 +1,6 @@
# Standard Packages
from pathlib import Path
import sys
import time
import hashlib
from os.path import join
from collections import OrderedDict
from typing import Optional, Union
@@ -83,38 +81,3 @@ class LRU(OrderedDict):
oldest = next(iter(self))
del self[oldest]
def mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=None):
# Hash all current and previous entries to identify new entries
start = time.time()
current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), current_entries))
previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), previous_entries))
end = time.time()
logger.debug(f"Hash previous, current entries: {end - start} seconds")
start = time.time()
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
# All entries that did not exist in the previous set are to be added
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
# All entries that exist in both current and previous sets are kept
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# Mark new entries with no ids for later embeddings generation
new_entries = [
(None, hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes
]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
for entry_hash in existing_entry_hashes
]
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
entries_with_ids = existing_entries_sorted + new_entries
end = time.time()
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
return entries_with_ids

View File

@@ -1,4 +1,5 @@
# System Packages
import json
from pathlib import Path
from typing import List, Optional
@@ -71,3 +72,32 @@ class FullConfig(ConfigBase):
content_type: Optional[ContentConfig]
search_type: Optional[SearchConfig]
processor: Optional[ProcessorConfig]
class SearchResponse(ConfigBase):
entry: str
score: str
additional: Optional[dict]
class Entry():
raw: str
compiled: str
file: Optional[str]
def __init__(self, raw: str = None, compiled: str = None, file: Optional[str] = None):
self.raw = raw
self.compiled = compiled
self.file = file
def to_json(self) -> str:
return json.dumps(self.__dict__, ensure_ascii=False)
def __repr__(self) -> str:
return self.__dict__.__repr__()
@classmethod
def from_dict(cls, dictionary: dict):
return cls(
raw=dictionary['raw'],
compiled=dictionary['compiled'],
file=dictionary.get('file', None)
)

View File

@@ -9,7 +9,7 @@ from src.utils.rawconfig import FullConfig
# Do not emit tags when dumping to YAML
yaml.emitter.Emitter.process_tag = lambda self, *args, **kwargs: None
yaml.emitter.Emitter.process_tag = lambda self, *args, **kwargs: None # type: ignore[assignment]
def save_config_to_file(yaml_config: dict, yaml_config_file: Path):

View File

@@ -6,7 +6,7 @@ from src.search_type import image_search, text_search
from src.utils.config import SearchType
from src.utils.helpers import resolve_absolute_path
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.processor.org_mode.org_to_jsonl import OrgToJsonl
from src.search_filter.date_filter import DateFilter
from src.search_filter.word_filter import WordFilter
from src.search_filter.file_filter import FileFilter
@@ -60,6 +60,6 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
embeddings_file = content_dir.joinpath('note_embeddings.pt'))
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
return content_config

View File

@@ -43,9 +43,8 @@ just generating embeddings*
- **Khoj via API**
- See [Khoj API Docs](http://localhost:8000/docs)
- [Query](http://localhost:8000/search?q=%22what%20is%20the%20meaning%20of%20life%22)
- [Regenerate
Embeddings](http://localhost:8000/regenerate?t=ledger)
- [Query](http://localhost:8000/api/search?q=%22what%20is%20the%20meaning%20of%20life%22)
- [Update Index](http://localhost:8000/api/update?t=ledger)
- [Configure Application](https://localhost:8000/ui)
- **Khoj via Emacs**
- [Install](https://github.com/debanjum/khoj/tree/master/src/interface/emacs#installation)

View File

@@ -27,8 +27,8 @@
- Run ~M-x khoj <user-query>~ or Call ~C-c C-s~
- *Khoj via API*
- Query: ~GET~ [[http://localhost:8000/search?q=%22what%20is%20the%20meaning%20of%20life%22][http://localhost:8000/search?q="What is the meaning of life"]]
- Regenerate Embeddings: ~GET~ [[http://localhost:8000/regenerate][http://localhost:8000/regenerate]]
- Query: ~GET~ [[http://localhost:8000/api/search?q=%22what%20is%20the%20meaning%20of%20life%22][http://localhost:8000/api/search?q="What is the meaning of life"]]
- Update Index: ~GET~ [[http://localhost:8000/api/update][http://localhost:8000/api/update]]
- [[http://localhost:8000/docs][Khoj API Docs]]
- *Call Khoj via Python Script Directly*

View File

@@ -2,7 +2,7 @@
import json
# Internal Packages
from src.processor.ledger.beancount_to_jsonl import extract_beancount_transactions, convert_transactions_to_maps, convert_transaction_maps_to_jsonl, get_beancount_files
from src.processor.ledger.beancount_to_jsonl import BeancountToJsonl
def test_no_transactions_in_file(tmp_path):
@@ -16,10 +16,11 @@ def test_no_transactions_in_file(tmp_path):
# Act
# Extract Entries from specified Beancount files
entry_nodes, file_to_entries = extract_beancount_transactions(beancount_files=[beancount_file])
entry_nodes, file_to_entries = BeancountToJsonl.extract_beancount_transactions(beancount_files=[beancount_file])
# Process Each Entry from All Beancount Files
jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entry_nodes, file_to_entries))
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
BeancountToJsonl.convert_transactions_to_maps(entry_nodes, file_to_entries))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@@ -38,10 +39,11 @@ Assets:Test:Test -1.00 KES
# Act
# Extract Entries from specified Beancount files
entries, entry_to_file_map = extract_beancount_transactions(beancount_files=[beancount_file])
entries, entry_to_file_map = BeancountToJsonl.extract_beancount_transactions(beancount_files=[beancount_file])
# Process Each Entry from All Beancount Files
jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entries, entry_to_file_map))
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@@ -65,10 +67,11 @@ Assets:Test:Test -1.00 KES
# Act
# Extract Entries from specified Beancount files
entries, entry_to_file_map = extract_beancount_transactions(beancount_files=[beancount_file])
entries, entry_to_file_map = BeancountToJsonl.extract_beancount_transactions(beancount_files=[beancount_file])
# Process Each Entry from All Beancount Files
jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entries, entry_to_file_map))
jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl(
BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@@ -96,7 +99,7 @@ def test_get_beancount_files(tmp_path):
input_filter = [tmp_path / 'group1*.bean', tmp_path / 'group2*.beancount']
# Act
extracted_org_files = get_beancount_files(input_files, input_filter)
extracted_org_files = BeancountToJsonl.get_beancount_files(input_files, input_filter)
# Assert
assert len(extracted_org_files) == 5

View File

@@ -12,7 +12,7 @@ from src.main import app
from src.utils.state import model, config
from src.search_type import text_search, image_search
from src.utils.rawconfig import ContentConfig, SearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.processor.org_mode.org_to_jsonl import OrgToJsonl
from src.search_filter.word_filter import WordFilter
from src.search_filter.file_filter import FileFilter
@@ -28,7 +28,7 @@ def test_search_with_invalid_content_type():
user_query = quote("How to call Khoj from Emacs?")
# Act
response = client.get(f"/search?q={user_query}&t=invalid_content_type")
response = client.get(f"/api/search?q={user_query}&t=invalid_content_type")
# Assert
assert response.status_code == 422
@@ -43,29 +43,29 @@ def test_search_with_valid_content_type(content_config: ContentConfig, search_co
# config.content_type.image = search_config.image
for content_type in ["org", "markdown", "ledger", "music"]:
# Act
response = client.get(f"/search?q=random&t={content_type}")
response = client.get(f"/api/search?q=random&t={content_type}")
# Assert
assert response.status_code == 200
# ----------------------------------------------------------------------------------------------------
def test_reload_with_invalid_content_type():
def test_update_with_invalid_content_type():
# Act
response = client.get(f"/reload?t=invalid_content_type")
response = client.get(f"/api/update?t=invalid_content_type")
# Assert
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
def test_reload_with_valid_content_type(content_config: ContentConfig, search_config: SearchConfig):
def test_update_with_valid_content_type(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
config.content_type = content_config
config.search_type = search_config
for content_type in ["org", "markdown", "ledger", "music"]:
# Act
response = client.get(f"/reload?t={content_type}")
response = client.get(f"/api/update?t={content_type}")
# Assert
assert response.status_code == 200
@@ -73,7 +73,7 @@ def test_reload_with_valid_content_type(content_config: ContentConfig, search_co
# ----------------------------------------------------------------------------------------------------
def test_regenerate_with_invalid_content_type():
# Act
response = client.get(f"/regenerate?t=invalid_content_type")
response = client.get(f"/api/update?force=true&t=invalid_content_type")
# Assert
assert response.status_code == 422
@@ -87,7 +87,7 @@ def test_regenerate_with_valid_content_type(content_config: ContentConfig, searc
for content_type in ["org", "markdown", "ledger", "music", "image"]:
# Act
response = client.get(f"/regenerate?t={content_type}")
response = client.get(f"/api/update?force=true&t={content_type}")
# Assert
assert response.status_code == 200
@@ -104,7 +104,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
for query, expected_image_name in query_expected_image_pairs:
# Act
response = client.get(f"/search?q={query}&n=1&t=image")
response = client.get(f"/api/search?q={query}&n=1&t=image")
# Assert
assert response.status_code == 200
@@ -118,11 +118,11 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
# ----------------------------------------------------------------------------------------------------
def test_notes_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
user_query = quote("How to git install application?")
# Act
response = client.get(f"/search?q={user_query}&n=1&t=org&r=true")
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true")
# Assert
assert response.status_code == 200
@@ -135,11 +135,11 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
filters = [WordFilter(), FileFilter()]
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
user_query = quote('+"Emacs" file:"*.org"')
# Act
response = client.get(f"/search?q={user_query}&n=1&t=org")
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
# Assert
assert response.status_code == 200
@@ -152,11 +152,11 @@ def test_notes_search_with_only_filters(content_config: ContentConfig, search_co
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
filters = [WordFilter()]
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
user_query = quote('How to git install application? +"Emacs"')
# Act
response = client.get(f"/search?q={user_query}&n=1&t=org")
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
# Assert
assert response.status_code == 200
@@ -169,11 +169,11 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
filters = [WordFilter()]
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
user_query = quote('How to git install application? -"clone"')
# Act
response = client.get(f"/search?q={user_query}&n=1&t=org")
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
# Assert
assert response.status_code == 200

View File

@@ -3,19 +3,17 @@ import re
from datetime import datetime
from math import inf
# External Packages
import torch
# Application Packages
from src.search_filter.date_filter import DateFilter
from src.utils.rawconfig import Entry
def test_date_filter():
embeddings = torch.randn(3, 10)
entries = [
{'compiled': '', 'raw': 'Entry with no date'},
{'compiled': '', 'raw': 'April Fools entry: 1984-04-01'},
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
Entry(compiled='', raw='Entry with no date'),
Entry(compiled='', raw='April Fools entry: 1984-04-01'),
Entry(compiled='', raw='Entry with date:1984-04-02')
]
q_with_no_date_filter = 'head tail'
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)

View File

@@ -1,14 +1,12 @@
# External Packages
import torch
# Application Packages
from src.search_filter.file_filter import FileFilter
from src.utils.rawconfig import Entry
def test_no_file_filter():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
entries = arrange_content()
q_with_no_filter = 'head tail'
# Act
@@ -24,7 +22,7 @@ def test_no_file_filter():
def test_file_filter_with_non_existent_file():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
entries = arrange_content()
q_with_no_filter = 'head file:"nonexistent.org" tail'
# Act
@@ -40,7 +38,7 @@ def test_file_filter_with_non_existent_file():
def test_single_file_filter():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
entries = arrange_content()
q_with_no_filter = 'head file:"file 1.org" tail'
# Act
@@ -56,7 +54,7 @@ def test_single_file_filter():
def test_file_filter_with_partial_match():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
entries = arrange_content()
q_with_no_filter = 'head file:"1.org" tail'
# Act
@@ -72,7 +70,7 @@ def test_file_filter_with_partial_match():
def test_file_filter_with_regex_match():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
entries = arrange_content()
q_with_no_filter = 'head file:"*.org" tail'
# Act
@@ -88,7 +86,7 @@ def test_file_filter_with_regex_match():
def test_multiple_file_filter():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
entries = arrange_content()
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
# Act
@@ -102,11 +100,11 @@ def test_multiple_file_filter():
def arrange_content():
embeddings = torch.randn(4, 10)
entries = [
{'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'},
{'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'},
{'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'},
{'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}]
Entry(compiled='', raw='First Entry', file= 'file 1.org'),
Entry(compiled='', raw='Second Entry', file= 'file2.org'),
Entry(compiled='', raw='Third Entry', file= 'file 1.org'),
Entry(compiled='', raw='Fourth Entry', file= 'file2.org')
]
return embeddings, entries
return entries

View File

@@ -2,9 +2,6 @@
from pathlib import Path
from PIL import Image
# External Packages
import pytest
# Internal Packages
from src.utils.state import model
from src.utils.constants import web_directory
@@ -70,7 +67,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
image_files_url='/static/images',
count=1)
actual_image_path = output_directory.joinpath(Path(results[0]["entry"]).name)
actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
actual_image = Image.open(actual_image_path)
expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name))

View File

@@ -2,7 +2,7 @@
import json
# Internal Packages
from src.processor.markdown.markdown_to_jsonl import extract_markdown_entries, convert_markdown_maps_to_jsonl, convert_markdown_entries_to_maps, get_markdown_files
from src.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
@@ -16,10 +16,11 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Markdown files
entry_nodes, file_to_entries = extract_markdown_entries(markdown_files=[markdownfile])
entry_nodes, file_to_entries = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
# Process Each Entry from All Notes Files
jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entry_nodes, file_to_entries))
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entry_nodes, file_to_entries))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@@ -37,10 +38,11 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Markdown files
entries, entry_to_file_map = extract_markdown_entries(markdown_files=[markdownfile])
entries, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
# Process Each Entry from All Notes Files
jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entries, entry_to_file_map))
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@@ -62,10 +64,11 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Markdown files
entries, entry_to_file_map = extract_markdown_entries(markdown_files=[markdownfile])
entries, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile])
# Process Each Entry from All Notes Files
jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entries, entry_to_file_map))
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@@ -93,7 +96,7 @@ def test_get_markdown_files(tmp_path):
input_filter = [tmp_path / 'group1*.md', tmp_path / 'group2*.markdown']
# Act
extracted_org_files = get_markdown_files(input_files, input_filter)
extracted_org_files = MarkdownToJsonl.get_markdown_files(input_files, input_filter)
# Assert
assert len(extracted_org_files) == 5

View File

@@ -2,7 +2,7 @@
import json
# Internal Packages
from src.processor.org_mode.org_to_jsonl import convert_org_entries_to_jsonl, convert_org_nodes_to_entries, extract_org_entries, get_org_files
from src.processor.org_mode.org_to_jsonl import OrgToJsonl
from src.utils.helpers import is_none_or_empty
@@ -21,8 +21,8 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
for index_heading_entries in [True, False]:
# Act
# Extract entries into jsonl from specified Org files
jsonl_string = convert_org_entries_to_jsonl(convert_org_nodes_to_entries(
*extract_org_entries(org_files=[orgfile]),
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries(
*OrgToJsonl.extract_org_entries(org_files=[orgfile]),
index_heading_entries=index_heading_entries))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@@ -49,10 +49,10 @@ def test_entry_with_body_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Org files
entries, entry_to_file_map = extract_org_entries(org_files=[orgfile])
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile])
# Process Each Entry from All Notes Files
jsonl_string = convert_org_entries_to_jsonl(convert_org_nodes_to_entries(entries, entry_to_file_map))
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map))
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@@ -70,11 +70,11 @@ def test_file_with_no_headings_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Org files
entry_nodes, file_to_entries = extract_org_entries(org_files=[orgfile])
entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=[orgfile])
# Process Each Entry from All Notes Files
entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries)
jsonl_string = convert_org_entries_to_jsonl(entries)
entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@@ -102,7 +102,7 @@ def test_get_org_files(tmp_path):
input_filter = [tmp_path / 'group1*.org', tmp_path / 'group2*.org']
# Act
extracted_org_files = get_org_files(input_files, input_filter)
extracted_org_files = OrgToJsonl.get_org_files(input_files, input_filter)
# Assert
assert len(extracted_org_files) == 5

View File

@@ -9,7 +9,7 @@ import pytest
from src.utils.state import model
from src.search_type import text_search
from src.utils.rawconfig import ContentConfig, SearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.processor.org_mode.org_to_jsonl import OrgToJsonl
# Test
@@ -24,7 +24,7 @@ def test_asymmetric_setup_with_missing_file_raises_error(content_config: Content
# Act
# Generate notes embeddings during asymmetric setup
with pytest.raises(FileNotFoundError):
text_search.setup(org_to_jsonl, new_org_content_config, search_config.asymmetric, regenerate=True)
text_search.setup(OrgToJsonl, new_org_content_config, search_config.asymmetric, regenerate=True)
# ----------------------------------------------------------------------------------------------------
@@ -39,7 +39,7 @@ def test_asymmetric_setup_with_empty_file_raises_error(content_config: ContentCo
# Act
# Generate notes embeddings during asymmetric setup
with pytest.raises(ValueError, match=r'^No valid entries found*'):
text_search.setup(org_to_jsonl, new_org_content_config, search_config.asymmetric, regenerate=True)
text_search.setup(OrgToJsonl, new_org_content_config, search_config.asymmetric, regenerate=True)
# Cleanup
# delete created test file
@@ -50,7 +50,7 @@ def test_asymmetric_setup_with_empty_file_raises_error(content_config: ContentCo
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
# Act
# Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
# Assert
assert len(notes_model.entries) == 10
@@ -60,7 +60,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
query = "How to git install application?"
# Act
@@ -76,14 +76,14 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
# Assert
# Actual_data should contain "Khoj via Emacs" entry
search_result = results[0]["entry"]
search_result = results[0].entry
assert "git clone" in search_result
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
@@ -96,11 +96,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
# regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
regenerated_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
# Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
# Assert
assert len(regenerated_notes_model.entries) == 11
@@ -119,7 +119,7 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
# ----------------------------------------------------------------------------------------------------
def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
@@ -133,7 +133,7 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# Act
# update embeddings, entries with the newly added note
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
# verify new entry added in updated embeddings, entries
assert len(initial_notes_model.entries) == 11

View File

@@ -1,6 +1,6 @@
# Application Packages
from src.search_filter.word_filter import WordFilter
from src.utils.config import SearchType
from src.utils.rawconfig import Entry
def test_no_word_filter():
@@ -69,9 +69,10 @@ def test_word_include_and_exclude_filter():
def arrange_content():
entries = [
{'compiled': '', 'raw': 'Minimal Entry'},
{'compiled': '', 'raw': 'Entry with exclude_word'},
{'compiled': '', 'raw': 'Entry with include_word'},
{'compiled': '', 'raw': 'Entry with include_word and exclude_word'}]
Entry(compiled='', raw='Minimal Entry'),
Entry(compiled='', raw='Entry with exclude_word'),
Entry(compiled='', raw='Entry with include_word'),
Entry(compiled='', raw='Entry with include_word and exclude_word')
]
return entries