From 02d944030f04c0dc07f855340179ed87b33294b0 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 14 Sep 2022 10:53:43 +0300 Subject: [PATCH 01/10] Use Base TextToJsonl class to standardize _to_jsonl processors - Start standardizing implementation of the `text_to_jsonl' processors - `text_to_jsonl; scripts already had a shared structure - This change starts to codify that implicit structure - Benefits - Ease adding more `text_to_jsonl; processors - Allow merging shared functionality - Help with type hinting - Drawbacks - Lower agility to change. But this was already an implicit issue as the text_to_jsonl processors got more deeply wired into the app --- src/configure.py | 14 +- src/processor/ledger/beancount_to_jsonl.py | 185 ++++++++-------- src/processor/markdown/markdown_to_jsonl.py | 179 ++++++++-------- src/processor/org_mode/org_to_jsonl.py | 221 ++++++++++---------- src/processor/text_to_jsonl.py | 14 ++ src/search_type/text_search.py | 6 +- tests/conftest.py | 4 +- tests/test_beancount_to_jsonl.py | 19 +- tests/test_client.py | 10 +- tests/test_markdown_to_jsonl.py | 19 +- tests/test_org_to_jsonl.py | 18 +- tests/test_text_search.py | 20 +- 12 files changed, 364 insertions(+), 345 deletions(-) create mode 100644 src/processor/text_to_jsonl.py diff --git a/src/configure.py b/src/configure.py index d920a614..495ac313 100644 --- a/src/configure.py +++ b/src/configure.py @@ -6,9 +6,9 @@ import logging import json # Internal Packages -from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl -from src.processor.markdown.markdown_to_jsonl import markdown_to_jsonl -from src.processor.org_mode.org_to_jsonl import org_to_jsonl +from src.processor.ledger.beancount_to_jsonl import BeancountToJsonl +from src.processor.markdown.markdown_to_jsonl import MarkdownToJsonl +from src.processor.org_mode.org_to_jsonl import OrgToJsonl from src.search_type import image_search, text_search from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from src.utils import state @@ -44,7 +44,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, if (t == SearchType.Org or t == None) and config.content_type.org: # Extract Entries, Generate Notes Embeddings model.orgmode_search = text_search.setup( - org_to_jsonl, + OrgToJsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, @@ -54,7 +54,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, if (t == SearchType.Music or t == None) and config.content_type.music: # Extract Entries, Generate Music Embeddings model.music_search = text_search.setup( - org_to_jsonl, + OrgToJsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, @@ -64,7 +64,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, if (t == SearchType.Markdown or t == None) and config.content_type.markdown: # Extract Entries, Generate Markdown Embeddings model.markdown_search = text_search.setup( - markdown_to_jsonl, + MarkdownToJsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, @@ -74,7 +74,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, if (t == SearchType.Ledger or t == None) and config.content_type.ledger: # Extract Entries, Generate Ledger Embeddings model.ledger_search = text_search.setup( - beancount_to_jsonl, + BeancountToJsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, diff --git a/src/processor/ledger/beancount_to_jsonl.py b/src/processor/ledger/beancount_to_jsonl.py index 7b8b9bba..d54b7e1b 100644 --- a/src/processor/ledger/beancount_to_jsonl.py +++ b/src/processor/ledger/beancount_to_jsonl.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - # Standard Packages import json import glob @@ -8,121 +6,122 @@ import logging import time # Internal Packages +from src.processor.text_to_jsonl import TextToJsonl from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update from src.utils.constants import empty_escape_sequences from src.utils.jsonl import dump_jsonl, compress_jsonl_data -from src.utils.rawconfig import TextContentConfig logger = logging.getLogger(__name__) -# Define Functions -def beancount_to_jsonl(config: TextContentConfig, previous_entries=None): - # Extract required fields from config - beancount_files, beancount_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl +class BeancountToJsonl(TextToJsonl): + # Define Functions + def process(self, previous_entries=None): + # Extract required fields from config + beancount_files, beancount_file_filter, output_file = self.config.input_files, self.config.input_filter,self.config.compressed_jsonl - # Input Validation - if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter): - print("At least one of beancount-files or beancount-file-filter is required to be specified") - exit(1) + # Input Validation + if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter): + print("At least one of beancount-files or beancount-file-filter is required to be specified") + exit(1) - # Get Beancount Files to Process - beancount_files = get_beancount_files(beancount_files, beancount_file_filter) + # Get Beancount Files to Process + beancount_files = BeancountToJsonl.get_beancount_files(beancount_files, beancount_file_filter) - # Extract Entries from specified Beancount files - start = time.time() - current_entries = convert_transactions_to_maps(*extract_beancount_transactions(beancount_files)) - end = time.time() - logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds") + # Extract Entries from specified Beancount files + start = time.time() + current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files)) + end = time.time() + logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds") - # Identify, mark and merge any new entries with previous entries - start = time.time() - if not previous_entries: - entries_with_ids = list(enumerate(current_entries)) - else: - entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) - end = time.time() - logger.debug(f"Identify new or updated transaction: {end - start} seconds") + # Identify, mark and merge any new entries with previous entries + start = time.time() + if not previous_entries: + entries_with_ids = list(enumerate(current_entries)) + else: + entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + end = time.time() + logger.debug(f"Identify new or updated transaction: {end - start} seconds") - # Process Each Entry from All Notes Files - start = time.time() - entries = list(map(lambda entry: entry[1], entries_with_ids)) - jsonl_data = convert_transaction_maps_to_jsonl(entries) + # Process Each Entry from All Notes Files + start = time.time() + entries = list(map(lambda entry: entry[1], entries_with_ids)) + jsonl_data = BeancountToJsonl.convert_transaction_maps_to_jsonl(entries) - # Compress JSONL formatted Data - if output_file.suffix == ".gz": - compress_jsonl_data(jsonl_data, output_file) - elif output_file.suffix == ".jsonl": - dump_jsonl(jsonl_data, output_file) - end = time.time() - logger.debug(f"Write transactions to JSONL file: {end - start} seconds") + # Compress JSONL formatted Data + if output_file.suffix == ".gz": + compress_jsonl_data(jsonl_data, output_file) + elif output_file.suffix == ".jsonl": + dump_jsonl(jsonl_data, output_file) + end = time.time() + logger.debug(f"Write transactions to JSONL file: {end - start} seconds") - return entries_with_ids + return entries_with_ids + @staticmethod + def get_beancount_files(beancount_files=None, beancount_file_filters=None): + "Get Beancount files to process" + absolute_beancount_files, filtered_beancount_files = set(), set() + if beancount_files: + absolute_beancount_files = {get_absolute_path(beancount_file) + for beancount_file + in beancount_files} + if beancount_file_filters: + filtered_beancount_files = { + filtered_file + for beancount_file_filter in beancount_file_filters + for filtered_file in glob.glob(get_absolute_path(beancount_file_filter)) + } -def get_beancount_files(beancount_files=None, beancount_file_filters=None): - "Get Beancount files to process" - absolute_beancount_files, filtered_beancount_files = set(), set() - if beancount_files: - absolute_beancount_files = {get_absolute_path(beancount_file) - for beancount_file - in beancount_files} - if beancount_file_filters: - filtered_beancount_files = { - filtered_file - for beancount_file_filter in beancount_file_filters - for filtered_file in glob.glob(get_absolute_path(beancount_file_filter)) + all_beancount_files = sorted(absolute_beancount_files | filtered_beancount_files) + + files_with_non_beancount_extensions = { + beancount_file + for beancount_file + in all_beancount_files + if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount") } + if any(files_with_non_beancount_extensions): + print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}") - all_beancount_files = sorted(absolute_beancount_files | filtered_beancount_files) + logger.info(f'Processing files: {all_beancount_files}') - files_with_non_beancount_extensions = { - beancount_file - for beancount_file - in all_beancount_files - if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount") - } - if any(files_with_non_beancount_extensions): - print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}") + return all_beancount_files - logger.info(f'Processing files: {all_beancount_files}') + @staticmethod + def extract_beancount_transactions(beancount_files): + "Extract entries from specified Beancount files" - return all_beancount_files + # Initialize Regex for extracting Beancount Entries + transaction_regex = r'^\n?\d{4}-\d{2}-\d{2} [\*|\!] ' + empty_newline = f'^[\n\r\t\ ]*$' + entries = [] + transaction_to_file_map = [] + for beancount_file in beancount_files: + with open(beancount_file) as f: + ledger_content = f.read() + transactions_per_file = [entry.strip(empty_escape_sequences) + for entry + in re.split(empty_newline, ledger_content, flags=re.MULTILINE) + if re.match(transaction_regex, entry)] + transaction_to_file_map += zip(transactions_per_file, [beancount_file]*len(transactions_per_file)) + entries.extend(transactions_per_file) + return entries, dict(transaction_to_file_map) -def extract_beancount_transactions(beancount_files): - "Extract entries from specified Beancount files" + @staticmethod + def convert_transactions_to_maps(entries: list[str], transaction_to_file_map) -> list[dict]: + "Convert each Beancount transaction into a dictionary" + entry_maps = [] + for entry in entries: + entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{transaction_to_file_map[entry]}'}) - # Initialize Regex for extracting Beancount Entries - transaction_regex = r'^\n?\d{4}-\d{2}-\d{2} [\*|\!] ' - empty_newline = f'^[\n\r\t\ ]*$' + logger.info(f"Converted {len(entries)} transactions to dictionaries") - entries = [] - transaction_to_file_map = [] - for beancount_file in beancount_files: - with open(beancount_file) as f: - ledger_content = f.read() - transactions_per_file = [entry.strip(empty_escape_sequences) - for entry - in re.split(empty_newline, ledger_content, flags=re.MULTILINE) - if re.match(transaction_regex, entry)] - transaction_to_file_map += zip(transactions_per_file, [beancount_file]*len(transactions_per_file)) - entries.extend(transactions_per_file) - return entries, dict(transaction_to_file_map) + return entry_maps - -def convert_transactions_to_maps(entries: list[str], transaction_to_file_map) -> list[dict]: - "Convert each Beancount transaction into a dictionary" - entry_maps = [] - for entry in entries: - entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{transaction_to_file_map[entry]}'}) - - logger.info(f"Converted {len(entries)} transactions to dictionaries") - - return entry_maps - - -def convert_transaction_maps_to_jsonl(entries: list[dict]) -> str: - "Convert each Beancount transaction dictionary to JSON and collate as JSONL" - return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) + @staticmethod + def convert_transaction_maps_to_jsonl(entries: list[dict]) -> str: + "Convert each Beancount transaction dictionary to JSON and collate as JSONL" + return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) diff --git a/src/processor/markdown/markdown_to_jsonl.py b/src/processor/markdown/markdown_to_jsonl.py index 22f5ea17..48fbbdf9 100644 --- a/src/processor/markdown/markdown_to_jsonl.py +++ b/src/processor/markdown/markdown_to_jsonl.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - # Standard Packages import json import glob @@ -8,120 +6,121 @@ import logging import time # Internal Packages +from src.processor.text_to_jsonl import TextToJsonl from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update from src.utils.constants import empty_escape_sequences from src.utils.jsonl import dump_jsonl, compress_jsonl_data -from src.utils.rawconfig import TextContentConfig logger = logging.getLogger(__name__) -# Define Functions -def markdown_to_jsonl(config: TextContentConfig, previous_entries=None): - # Extract required fields from config - markdown_files, markdown_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl +class MarkdownToJsonl(TextToJsonl): + # Define Functions + def process(self, previous_entries=None): + # Extract required fields from config + markdown_files, markdown_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl - # Input Validation - if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter): - print("At least one of markdown-files or markdown-file-filter is required to be specified") - exit(1) + # Input Validation + if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter): + print("At least one of markdown-files or markdown-file-filter is required to be specified") + exit(1) - # Get Markdown Files to Process - markdown_files = get_markdown_files(markdown_files, markdown_file_filter) + # Get Markdown Files to Process + markdown_files = MarkdownToJsonl.get_markdown_files(markdown_files, markdown_file_filter) - # Extract Entries from specified Markdown files - start = time.time() - current_entries = convert_markdown_entries_to_maps(*extract_markdown_entries(markdown_files)) - end = time.time() - logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds") + # Extract Entries from specified Markdown files + start = time.time() + current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files)) + end = time.time() + logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds") - # Identify, mark and merge any new entries with previous entries - start = time.time() - if not previous_entries: - entries_with_ids = list(enumerate(current_entries)) - else: - entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) - end = time.time() - logger.debug(f"Identify new or updated entries: {end - start} seconds") + # Identify, mark and merge any new entries with previous entries + start = time.time() + if not previous_entries: + entries_with_ids = list(enumerate(current_entries)) + else: + entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + end = time.time() + logger.debug(f"Identify new or updated entries: {end - start} seconds") - # Process Each Entry from All Notes Files - start = time.time() - entries = list(map(lambda entry: entry[1], entries_with_ids)) - jsonl_data = convert_markdown_maps_to_jsonl(entries) + # Process Each Entry from All Notes Files + start = time.time() + entries = list(map(lambda entry: entry[1], entries_with_ids)) + jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries) - # Compress JSONL formatted Data - if output_file.suffix == ".gz": - compress_jsonl_data(jsonl_data, output_file) - elif output_file.suffix == ".jsonl": - dump_jsonl(jsonl_data, output_file) - end = time.time() - logger.debug(f"Write markdown entries to JSONL file: {end - start} seconds") + # Compress JSONL formatted Data + if output_file.suffix == ".gz": + compress_jsonl_data(jsonl_data, output_file) + elif output_file.suffix == ".jsonl": + dump_jsonl(jsonl_data, output_file) + end = time.time() + logger.debug(f"Write markdown entries to JSONL file: {end - start} seconds") - return entries_with_ids + return entries_with_ids + @staticmethod + def get_markdown_files(markdown_files=None, markdown_file_filters=None): + "Get Markdown files to process" + absolute_markdown_files, filtered_markdown_files = set(), set() + if markdown_files: + absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files} + if markdown_file_filters: + filtered_markdown_files = { + filtered_file + for markdown_file_filter in markdown_file_filters + for filtered_file in glob.glob(get_absolute_path(markdown_file_filter)) + } -def get_markdown_files(markdown_files=None, markdown_file_filters=None): - "Get Markdown files to process" - absolute_markdown_files, filtered_markdown_files = set(), set() - if markdown_files: - absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files} - if markdown_file_filters: - filtered_markdown_files = { - filtered_file - for markdown_file_filter in markdown_file_filters - for filtered_file in glob.glob(get_absolute_path(markdown_file_filter)) + all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files) + + files_with_non_markdown_extensions = { + md_file + for md_file + in all_markdown_files + if not md_file.endswith(".md") and not md_file.endswith('.markdown') } - all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files) + if any(files_with_non_markdown_extensions): + logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}") - files_with_non_markdown_extensions = { - md_file - for md_file - in all_markdown_files - if not md_file.endswith(".md") and not md_file.endswith('.markdown') - } + logger.info(f'Processing files: {all_markdown_files}') - if any(files_with_non_markdown_extensions): - logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}") + return all_markdown_files - logger.info(f'Processing files: {all_markdown_files}') + @staticmethod + def extract_markdown_entries(markdown_files): + "Extract entries by heading from specified Markdown files" - return all_markdown_files + # Regex to extract Markdown Entries by Heading + markdown_heading_regex = r'^#' + entries = [] + entry_to_file_map = [] + for markdown_file in markdown_files: + with open(markdown_file) as f: + markdown_content = f.read() + markdown_entries_per_file = [f'#{entry.strip(empty_escape_sequences)}' + for entry + in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE) + if entry.strip(empty_escape_sequences) != ''] + entry_to_file_map += zip(markdown_entries_per_file, [markdown_file]*len(markdown_entries_per_file)) + entries.extend(markdown_entries_per_file) -def extract_markdown_entries(markdown_files): - "Extract entries by heading from specified Markdown files" + return entries, dict(entry_to_file_map) - # Regex to extract Markdown Entries by Heading - markdown_heading_regex = r'^#' + @staticmethod + def convert_markdown_entries_to_maps(entries: list[str], entry_to_file_map) -> list[dict]: + "Convert each Markdown entries into a dictionary" + entry_maps = [] + for entry in entries: + entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{entry_to_file_map[entry]}'}) - entries = [] - entry_to_file_map = [] - for markdown_file in markdown_files: - with open(markdown_file) as f: - markdown_content = f.read() - markdown_entries_per_file = [f'#{entry.strip(empty_escape_sequences)}' - for entry - in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE) - if entry.strip(empty_escape_sequences) != ''] - entry_to_file_map += zip(markdown_entries_per_file, [markdown_file]*len(markdown_entries_per_file)) - entries.extend(markdown_entries_per_file) + logger.info(f"Converted {len(entries)} markdown entries to dictionaries") - return entries, dict(entry_to_file_map) + return entry_maps - -def convert_markdown_entries_to_maps(entries: list[str], entry_to_file_map) -> list[dict]: - "Convert each Markdown entries into a dictionary" - entry_maps = [] - for entry in entries: - entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{entry_to_file_map[entry]}'}) - - logger.info(f"Converted {len(entries)} markdown entries to dictionaries") - - return entry_maps - - -def convert_markdown_maps_to_jsonl(entries): - "Convert each Markdown entries to JSON and collate as JSONL" - return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) + @staticmethod + def convert_markdown_maps_to_jsonl(entries): + "Convert each Markdown entries to JSON and collate as JSONL" + return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) diff --git a/src/processor/org_mode/org_to_jsonl.py b/src/processor/org_mode/org_to_jsonl.py index 43f4acef..c4c18ce9 100644 --- a/src/processor/org_mode/org_to_jsonl.py +++ b/src/processor/org_mode/org_to_jsonl.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - # Standard Packages import json import glob @@ -9,147 +7,148 @@ from typing import Iterable # Internal Packages from src.processor.org_mode import orgnode +from src.processor.text_to_jsonl import TextToJsonl from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils import state -from src.utils.rawconfig import TextContentConfig logger = logging.getLogger(__name__) -# Define Functions -def org_to_jsonl(config: TextContentConfig, previous_entries=None): - # Extract required fields from config - org_files, org_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl - index_heading_entries = config.index_heading_entries +class OrgToJsonl(TextToJsonl): + # Define Functions + def process(self, previous_entries=None): + # Extract required fields from config + org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl + index_heading_entries = self.config.index_heading_entries - # Input Validation - if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter): - print("At least one of org-files or org-file-filter is required to be specified") - exit(1) + # Input Validation + if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter): + print("At least one of org-files or org-file-filter is required to be specified") + exit(1) - # Get Org Files to Process - start = time.time() - org_files = get_org_files(org_files, org_file_filter) + # Get Org Files to Process + start = time.time() + org_files = OrgToJsonl.get_org_files(org_files, org_file_filter) - # Extract Entries from specified Org files - start = time.time() - entry_nodes, file_to_entries = extract_org_entries(org_files) - end = time.time() - logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds") + # Extract Entries from specified Org files + start = time.time() + entry_nodes, file_to_entries = self.extract_org_entries(org_files) + end = time.time() + logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds") - start = time.time() - current_entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries) - end = time.time() - logger.debug(f"Convert OrgNodes into entry dictionaries: {end - start} seconds") + start = time.time() + current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries) + end = time.time() + logger.debug(f"Convert OrgNodes into entry dictionaries: {end - start} seconds") - # Identify, mark and merge any new entries with previous entries - if not previous_entries: - entries_with_ids = list(enumerate(current_entries)) - else: - entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + # Identify, mark and merge any new entries with previous entries + if not previous_entries: + entries_with_ids = list(enumerate(current_entries)) + else: + entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) - # Process Each Entry from All Notes Files - start = time.time() - entries = map(lambda entry: entry[1], entries_with_ids) - jsonl_data = convert_org_entries_to_jsonl(entries) + # Process Each Entry from All Notes Files + start = time.time() + entries = map(lambda entry: entry[1], entries_with_ids) + jsonl_data = self.convert_org_entries_to_jsonl(entries) - # Compress JSONL formatted Data - if output_file.suffix == ".gz": - compress_jsonl_data(jsonl_data, output_file) - elif output_file.suffix == ".jsonl": - dump_jsonl(jsonl_data, output_file) - end = time.time() - logger.debug(f"Write org entries to JSONL file: {end - start} seconds") + # Compress JSONL formatted Data + if output_file.suffix == ".gz": + compress_jsonl_data(jsonl_data, output_file) + elif output_file.suffix == ".jsonl": + dump_jsonl(jsonl_data, output_file) + end = time.time() + logger.debug(f"Write org entries to JSONL file: {end - start} seconds") - return entries_with_ids + return entries_with_ids + @staticmethod + def get_org_files(org_files=None, org_file_filters=None): + "Get Org files to process" + absolute_org_files, filtered_org_files = set(), set() + if org_files: + absolute_org_files = { + get_absolute_path(org_file) + for org_file + in org_files + } + if org_file_filters: + filtered_org_files = { + filtered_file + for org_file_filter in org_file_filters + for filtered_file in glob.glob(get_absolute_path(org_file_filter)) + } -def get_org_files(org_files=None, org_file_filters=None): - "Get Org files to process" - absolute_org_files, filtered_org_files = set(), set() - if org_files: - absolute_org_files = { - get_absolute_path(org_file) - for org_file - in org_files - } - if org_file_filters: - filtered_org_files = { - filtered_file - for org_file_filter in org_file_filters - for filtered_file in glob.glob(get_absolute_path(org_file_filter)) - } + all_org_files = sorted(absolute_org_files | filtered_org_files) - all_org_files = sorted(absolute_org_files | filtered_org_files) + files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")} + if any(files_with_non_org_extensions): + logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}") - files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")} - if any(files_with_non_org_extensions): - logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}") + logger.info(f'Processing files: {all_org_files}') - logger.info(f'Processing files: {all_org_files}') + return all_org_files - return all_org_files + @staticmethod + def extract_org_entries(org_files): + "Extract entries from specified Org files" + entries = [] + entry_to_file_map = [] + for org_file in org_files: + org_file_entries = orgnode.makelist(str(org_file)) + entry_to_file_map += zip(org_file_entries, [org_file]*len(org_file_entries)) + entries.extend(org_file_entries) + return entries, dict(entry_to_file_map) -def extract_org_entries(org_files): - "Extract entries from specified Org files" - entries = [] - entry_to_file_map = [] - for org_file in org_files: - org_file_entries = orgnode.makelist(str(org_file)) - entry_to_file_map += zip(org_file_entries, [org_file]*len(org_file_entries)) - entries.extend(org_file_entries) + @staticmethod + def convert_org_nodes_to_entries(entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[dict]: + "Convert Org-Mode entries into list of dictionary" + entry_maps = [] + for entry in entries: + entry_dict = dict() - return entries, dict(entry_to_file_map) + if not entry.hasBody and not index_heading_entries: + # Ignore title notes i.e notes with just headings and empty body + continue - -def convert_org_nodes_to_entries(entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[dict]: - "Convert Org-Mode entries into list of dictionary" - entry_maps = [] - for entry in entries: - entry_dict = dict() - - if not entry.hasBody and not index_heading_entries: - # Ignore title notes i.e notes with just headings and empty body - continue - - entry_dict["compiled"] = f'{entry.heading}.' - if state.verbose > 2: - logger.debug(f"Title: {entry.heading}") - - if entry.tags: - tags_str = " ".join(entry.tags) - entry_dict["compiled"] += f'\t {tags_str}.' + entry_dict["compiled"] = f'{entry.heading}.' if state.verbose > 2: - logger.debug(f"Tags: {tags_str}") + logger.debug(f"Title: {entry.heading}") - if entry.closed: - entry_dict["compiled"] += f'\n Closed on {entry.closed.strftime("%Y-%m-%d")}.' - if state.verbose > 2: - logger.debug(f'Closed: {entry.closed.strftime("%Y-%m-%d")}') + if entry.tags: + tags_str = " ".join(entry.tags) + entry_dict["compiled"] += f'\t {tags_str}.' + if state.verbose > 2: + logger.debug(f"Tags: {tags_str}") - if entry.scheduled: - entry_dict["compiled"] += f'\n Scheduled for {entry.scheduled.strftime("%Y-%m-%d")}.' - if state.verbose > 2: - logger.debug(f'Scheduled: {entry.scheduled.strftime("%Y-%m-%d")}') + if entry.closed: + entry_dict["compiled"] += f'\n Closed on {entry.closed.strftime("%Y-%m-%d")}.' + if state.verbose > 2: + logger.debug(f'Closed: {entry.closed.strftime("%Y-%m-%d")}') - if entry.hasBody: - entry_dict["compiled"] += f'\n {entry.body}' - if state.verbose > 2: - logger.debug(f"Body: {entry.body}") + if entry.scheduled: + entry_dict["compiled"] += f'\n Scheduled for {entry.scheduled.strftime("%Y-%m-%d")}.' + if state.verbose > 2: + logger.debug(f'Scheduled: {entry.scheduled.strftime("%Y-%m-%d")}') - if entry_dict: - entry_dict["raw"] = f'{entry}' - entry_dict["file"] = f'{entry_to_file_map[entry]}' + if entry.hasBody: + entry_dict["compiled"] += f'\n {entry.body}' + if state.verbose > 2: + logger.debug(f"Body: {entry.body}") - # Convert Dictionary to JSON and Append to JSONL string - entry_maps.append(entry_dict) + if entry_dict: + entry_dict["raw"] = f'{entry}' + entry_dict["file"] = f'{entry_to_file_map[entry]}' - return entry_maps + # Convert Dictionary to JSON and Append to JSONL string + entry_maps.append(entry_dict) + return entry_maps -def convert_org_entries_to_jsonl(entries: Iterable[dict]) -> str: - "Convert each Org-Mode entry to JSON and collate as JSONL" - return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) + @staticmethod + def convert_org_entries_to_jsonl(entries: Iterable[dict]) -> str: + "Convert each Org-Mode entry to JSON and collate as JSONL" + return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) diff --git a/src/processor/text_to_jsonl.py b/src/processor/text_to_jsonl.py new file mode 100644 index 00000000..e59c5fb1 --- /dev/null +++ b/src/processor/text_to_jsonl.py @@ -0,0 +1,14 @@ +# Standard Packages +from abc import ABC, abstractmethod +from typing import Iterable + +# Internal Packages +from src.utils.rawconfig import TextContentConfig + + +class TextToJsonl(ABC): + def __init__(self, config: TextContentConfig): + self.config = config + + @abstractmethod + def process(self, previous_entries: Iterable[tuple[int, dict]]=None) -> list[tuple[int, dict]]: ... diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index d4d8a9d4..ff7d9c43 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -1,10 +1,12 @@ # Standard Packages import logging import time +from typing import Type # External Packages import torch from sentence_transformers import SentenceTransformer, CrossEncoder, util +from src.processor.text_to_jsonl import TextToJsonl from src.search_filter.base_filter import BaseFilter # Internal Packages @@ -179,14 +181,14 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel: +def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) # Map notes in text files to (compressed) JSONL formatted file config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None - entries_with_indices = text_to_jsonl(config, previous_entries) + entries_with_indices = text_to_jsonl(config).process(previous_entries) # Extract Updated Entries entries = extract_entries(config.compressed_jsonl) diff --git a/tests/conftest.py b/tests/conftest.py index f6c0a7ea..103a28e8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ from src.search_type import image_search, text_search from src.utils.config import SearchType from src.utils.helpers import resolve_absolute_path from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig -from src.processor.org_mode.org_to_jsonl import org_to_jsonl +from src.processor.org_mode.org_to_jsonl import OrgToJsonl from src.search_filter.date_filter import DateFilter from src.search_filter.word_filter import WordFilter from src.search_filter.file_filter import FileFilter @@ -60,6 +60,6 @@ def content_config(tmp_path_factory, search_config: SearchConfig): embeddings_file = content_dir.joinpath('note_embeddings.pt')) filters = [DateFilter(), WordFilter(), FileFilter()] - text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) + text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) return content_config diff --git a/tests/test_beancount_to_jsonl.py b/tests/test_beancount_to_jsonl.py index 51a4dffd..2c1cb9e6 100644 --- a/tests/test_beancount_to_jsonl.py +++ b/tests/test_beancount_to_jsonl.py @@ -2,7 +2,7 @@ import json # Internal Packages -from src.processor.ledger.beancount_to_jsonl import extract_beancount_transactions, convert_transactions_to_maps, convert_transaction_maps_to_jsonl, get_beancount_files +from src.processor.ledger.beancount_to_jsonl import BeancountToJsonl def test_no_transactions_in_file(tmp_path): @@ -16,10 +16,11 @@ def test_no_transactions_in_file(tmp_path): # Act # Extract Entries from specified Beancount files - entry_nodes, file_to_entries = extract_beancount_transactions(beancount_files=[beancount_file]) + entry_nodes, file_to_entries = BeancountToJsonl.extract_beancount_transactions(beancount_files=[beancount_file]) # Process Each Entry from All Beancount Files - jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entry_nodes, file_to_entries)) + jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl( + BeancountToJsonl.convert_transactions_to_maps(entry_nodes, file_to_entries)) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -38,10 +39,11 @@ Assets:Test:Test -1.00 KES # Act # Extract Entries from specified Beancount files - entries, entry_to_file_map = extract_beancount_transactions(beancount_files=[beancount_file]) + entries, entry_to_file_map = BeancountToJsonl.extract_beancount_transactions(beancount_files=[beancount_file]) # Process Each Entry from All Beancount Files - jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entries, entry_to_file_map)) + jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl( + BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map)) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -65,10 +67,11 @@ Assets:Test:Test -1.00 KES # Act # Extract Entries from specified Beancount files - entries, entry_to_file_map = extract_beancount_transactions(beancount_files=[beancount_file]) + entries, entry_to_file_map = BeancountToJsonl.extract_beancount_transactions(beancount_files=[beancount_file]) # Process Each Entry from All Beancount Files - jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entries, entry_to_file_map)) + jsonl_string = BeancountToJsonl.convert_transaction_maps_to_jsonl( + BeancountToJsonl.convert_transactions_to_maps(entries, entry_to_file_map)) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -96,7 +99,7 @@ def test_get_beancount_files(tmp_path): input_filter = [tmp_path / 'group1*.bean', tmp_path / 'group2*.beancount'] # Act - extracted_org_files = get_beancount_files(input_files, input_filter) + extracted_org_files = BeancountToJsonl.get_beancount_files(input_files, input_filter) # Assert assert len(extracted_org_files) == 5 diff --git a/tests/test_client.py b/tests/test_client.py index d405a044..96fa2c01 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,7 +12,7 @@ from src.main import app from src.utils.state import model, config from src.search_type import text_search, image_search from src.utils.rawconfig import ContentConfig, SearchConfig -from src.processor.org_mode.org_to_jsonl import org_to_jsonl +from src.processor.org_mode.org_to_jsonl import OrgToJsonl from src.search_filter.word_filter import WordFilter from src.search_filter.file_filter import FileFilter @@ -118,7 +118,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig # ---------------------------------------------------------------------------------------------------- def test_notes_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) user_query = quote("How to git install application?") # Act @@ -135,7 +135,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter(), FileFilter()] - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) + model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) user_query = quote('+"Emacs" file:"*.org"') # Act @@ -152,7 +152,7 @@ def test_notes_search_with_only_filters(content_config: ContentConfig, search_co def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter()] - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) + model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) user_query = quote('How to git install application? +"Emacs"') # Act @@ -169,7 +169,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter()] - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) + model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) user_query = quote('How to git install application? -"clone"') # Act diff --git a/tests/test_markdown_to_jsonl.py b/tests/test_markdown_to_jsonl.py index 89c471d8..c4c72688 100644 --- a/tests/test_markdown_to_jsonl.py +++ b/tests/test_markdown_to_jsonl.py @@ -2,7 +2,7 @@ import json # Internal Packages -from src.processor.markdown.markdown_to_jsonl import extract_markdown_entries, convert_markdown_maps_to_jsonl, convert_markdown_entries_to_maps, get_markdown_files +from src.processor.markdown.markdown_to_jsonl import MarkdownToJsonl def test_markdown_file_with_no_headings_to_jsonl(tmp_path): @@ -16,10 +16,11 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path): # Act # Extract Entries from specified Markdown files - entry_nodes, file_to_entries = extract_markdown_entries(markdown_files=[markdownfile]) + entry_nodes, file_to_entries = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile]) # Process Each Entry from All Notes Files - jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entry_nodes, file_to_entries)) + jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl( + MarkdownToJsonl.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -37,10 +38,11 @@ def test_single_markdown_entry_to_jsonl(tmp_path): # Act # Extract Entries from specified Markdown files - entries, entry_to_file_map = extract_markdown_entries(markdown_files=[markdownfile]) + entries, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile]) # Process Each Entry from All Notes Files - jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entries, entry_to_file_map)) + jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl( + MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -62,10 +64,11 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path): # Act # Extract Entries from specified Markdown files - entries, entry_to_file_map = extract_markdown_entries(markdown_files=[markdownfile]) + entries, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=[markdownfile]) # Process Each Entry from All Notes Files - jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entries, entry_to_file_map)) + jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl( + MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -93,7 +96,7 @@ def test_get_markdown_files(tmp_path): input_filter = [tmp_path / 'group1*.md', tmp_path / 'group2*.markdown'] # Act - extracted_org_files = get_markdown_files(input_files, input_filter) + extracted_org_files = MarkdownToJsonl.get_markdown_files(input_files, input_filter) # Assert assert len(extracted_org_files) == 5 diff --git a/tests/test_org_to_jsonl.py b/tests/test_org_to_jsonl.py index 8a2f58ba..2dbedcd0 100644 --- a/tests/test_org_to_jsonl.py +++ b/tests/test_org_to_jsonl.py @@ -2,7 +2,7 @@ import json # Internal Packages -from src.processor.org_mode.org_to_jsonl import convert_org_entries_to_jsonl, convert_org_nodes_to_entries, extract_org_entries, get_org_files +from src.processor.org_mode.org_to_jsonl import OrgToJsonl from src.utils.helpers import is_none_or_empty @@ -21,8 +21,8 @@ def test_configure_heading_entry_to_jsonl(tmp_path): for index_heading_entries in [True, False]: # Act # Extract entries into jsonl from specified Org files - jsonl_string = convert_org_entries_to_jsonl(convert_org_nodes_to_entries( - *extract_org_entries(org_files=[orgfile]), + jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries( + *OrgToJsonl.extract_org_entries(org_files=[orgfile]), index_heading_entries=index_heading_entries)) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] @@ -49,10 +49,10 @@ def test_entry_with_body_to_jsonl(tmp_path): # Act # Extract Entries from specified Org files - entries, entry_to_file_map = extract_org_entries(org_files=[orgfile]) + entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=[orgfile]) # Process Each Entry from All Notes Files - jsonl_string = convert_org_entries_to_jsonl(convert_org_nodes_to_entries(entries, entry_to_file_map)) + jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map)) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -70,11 +70,11 @@ def test_file_with_no_headings_to_jsonl(tmp_path): # Act # Extract Entries from specified Org files - entry_nodes, file_to_entries = extract_org_entries(org_files=[orgfile]) + entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=[orgfile]) # Process Each Entry from All Notes Files - entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries) - jsonl_string = convert_org_entries_to_jsonl(entries) + entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries) + jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(entries) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert @@ -102,7 +102,7 @@ def test_get_org_files(tmp_path): input_filter = [tmp_path / 'group1*.org', tmp_path / 'group2*.org'] # Act - extracted_org_files = get_org_files(input_files, input_filter) + extracted_org_files = OrgToJsonl.get_org_files(input_files, input_filter) # Assert assert len(extracted_org_files) == 5 diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 6744566d..584c07b9 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -9,7 +9,7 @@ import pytest from src.utils.state import model from src.search_type import text_search from src.utils.rawconfig import ContentConfig, SearchConfig -from src.processor.org_mode.org_to_jsonl import org_to_jsonl +from src.processor.org_mode.org_to_jsonl import OrgToJsonl # Test @@ -24,7 +24,7 @@ def test_asymmetric_setup_with_missing_file_raises_error(content_config: Content # Act # Generate notes embeddings during asymmetric setup with pytest.raises(FileNotFoundError): - text_search.setup(org_to_jsonl, new_org_content_config, search_config.asymmetric, regenerate=True) + text_search.setup(OrgToJsonl, new_org_content_config, search_config.asymmetric, regenerate=True) # ---------------------------------------------------------------------------------------------------- @@ -39,7 +39,7 @@ def test_asymmetric_setup_with_empty_file_raises_error(content_config: ContentCo # Act # Generate notes embeddings during asymmetric setup with pytest.raises(ValueError, match=r'^No valid entries found*'): - text_search.setup(org_to_jsonl, new_org_content_config, search_config.asymmetric, regenerate=True) + text_search.setup(OrgToJsonl, new_org_content_config, search_config.asymmetric, regenerate=True) # Cleanup # delete created test file @@ -50,7 +50,7 @@ def test_asymmetric_setup_with_empty_file_raises_error(content_config: ContentCo def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): # Act # Regenerate notes embeddings during asymmetric setup - notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) + notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) # Assert assert len(notes_model.entries) == 10 @@ -60,7 +60,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo # ---------------------------------------------------------------------------------------------------- def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) + model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) query = "How to git install application?" # Act @@ -83,7 +83,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC # ---------------------------------------------------------------------------------------------------- def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig): # Arrange - initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model= text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10 @@ -96,11 +96,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") # regenerate notes jsonl, model embeddings and model to include entry from new file - regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) + regenerated_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) # Act # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files - initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) # Assert assert len(regenerated_notes_model.entries) == 11 @@ -119,7 +119,7 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC # ---------------------------------------------------------------------------------------------------- def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig): # Arrange - initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) + initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10 @@ -133,7 +133,7 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search # Act # update embeddings, entries with the newly added note - initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) # verify new entry added in updated embeddings, entries assert len(initial_notes_model.entries) == 11 From ee65a4f2c79291c60d88c3d6160be97bb6ab1ac1 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 14 Sep 2022 14:01:09 +0300 Subject: [PATCH 02/10] Merge /reload, /regenerate into single /update API endpoint - Pass force=true to /update API to force regenerating index from scratch - Otherwise calls to the /update API endpoint will result in an incremental update to index --- src/router.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/router.py b/src/router.py index d73ca331..27b237dd 100644 --- a/src/router.py +++ b/src/router.py @@ -137,16 +137,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti return results -@router.get('/reload') -def reload(t: Optional[SearchType] = None): - state.model = configure_search(state.model, state.config, regenerate=False, t=t) - return {'status': 'ok', 'message': 'reload completed'} - - -@router.get('/regenerate') -def regenerate(t: Optional[SearchType] = None): - state.model = configure_search(state.model, state.config, regenerate=True, t=t) - return {'status': 'ok', 'message': 'regeneration completed'} +@router.get('/update') +def update(t: Optional[SearchType] = None, force: Optional[bool] = False): + state.model = configure_search(state.model, state.config, regenerate=force, t=t) + return {'status': 'ok', 'message': 'index updated completed'} @router.get('/beta/search') From e42a38e82501e624df1d175899465c10837b0bff Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 14 Sep 2022 21:22:20 +0300 Subject: [PATCH 03/10] Version Khoj API, Update frontends, tests and docs to reflect it - Split router.py into v1.0, beta and frontend (no-prefix) api modules under new router package. Version tag in main.py via prefix - Update frontends to use the versioned api endpoints - Update tests to work with versioned api endpoints - Update docs to mentioned, reference only versioned api endpoints --- Readme.md | 4 +- src/interface/emacs/khoj.el | 4 +- src/interface/web/assets/config.js | 6 +- src/interface/web/index.html | 8 +-- src/main.py | 8 ++- src/routers/api_beta.py | 89 +++++++++++++++++++++++ src/{router.py => routers/api_v1_0.py} | 99 ++------------------------ src/routers/frontend.py | 25 +++++++ tests/data/markdown/main_readme.md | 5 +- tests/data/org/main_readme.org | 4 +- tests/test_client.py | 26 +++---- 11 files changed, 154 insertions(+), 124 deletions(-) create mode 100644 src/routers/api_beta.py rename src/{router.py => routers/api_v1_0.py} (51%) create mode 100644 src/routers/frontend.py diff --git a/Readme.md b/Readme.md index da09b2ce..14794746 100644 --- a/Readme.md +++ b/Readme.md @@ -85,7 +85,7 @@ khoj ### 3. Configure 1. Enable content types and point to files to search in the First Run Screen that pops up on app start -2. Click configure and wait. The app will load ML model, generates embeddings and expose the search API +2. Click `Configure` and wait. The app will download ML models and index the content for search ## Use @@ -113,7 +113,7 @@ pip install --upgrade khoj-assistant ## Miscellaneous -- The beta [chat](http://localhost:8000/beta/chat) and [search](http://localhost:8000/beta/search) API endpoints use [OpenAI API](https://openai.com/api/) +- The beta [chat](http://localhost:8000/api/beta/chat) and [search](http://localhost:8000/api/beta/search) API endpoints use [OpenAI API](https://openai.com/api/) - It is disabled by default - To use it add your `openai-api-key` via the app configure screen - Warning: *If you use the above beta APIs, your query and top result(s) will be sent to OpenAI for processing* diff --git a/src/interface/emacs/khoj.el b/src/interface/emacs/khoj.el index c34fb88e..c05f98c6 100644 --- a/src/interface/emacs/khoj.el +++ b/src/interface/emacs/khoj.el @@ -226,7 +226,7 @@ Use `which-key` if available, else display simple message in echo area" (defun khoj--get-enabled-content-types () "Get content types enabled for search from API." - (let ((config-url (format "%s/config/data" khoj-server-url))) + (let ((config-url (format "%s/api/v1.0/config/data" khoj-server-url))) (with-temp-buffer (erase-buffer) (url-insert-file-contents config-url) @@ -243,7 +243,7 @@ Use `which-key` if available, else display simple message in echo area" "Construct API Query from QUERY, SEARCH-TYPE and (optional) RERANK params." (let ((rerank (or rerank "false")) (encoded-query (url-hexify-string query))) - (format "%s/search?q=%s&t=%s&r=%s&n=%s" khoj-server-url encoded-query search-type rerank khoj-results-count))) + (format "%s/api/v1.0/search?q=%s&t=%s&r=%s&n=%s" khoj-server-url encoded-query search-type rerank khoj-results-count))) (defun khoj--query-api-and-render-results (query search-type query-url buffer-name) "Query Khoj API using QUERY, SEARCH-TYPE, QUERY-URL. diff --git a/src/interface/web/assets/config.js b/src/interface/web/assets/config.js index 30ab6858..90412e1c 100644 --- a/src/interface/web/assets/config.js +++ b/src/interface/web/assets/config.js @@ -10,7 +10,7 @@ var emptyValueDefault = "🖊️"; /** * Fetch the existing config file. */ -fetch("/config/data") +fetch("/api/v1.0/config/data") .then(response => response.json()) .then(data => { rawConfig = data; @@ -26,7 +26,7 @@ fetch("/config/data") configForm.addEventListener("submit", (event) => { event.preventDefault(); console.log(rawConfig); - fetch("/config/data", { + fetch("/api/v1.0/config/data", { method: "POST", credentials: "same-origin", headers: { @@ -46,7 +46,7 @@ regenerateButton.addEventListener("click", (event) => { event.preventDefault(); regenerateButton.style.cursor = "progress"; regenerateButton.disabled = true; - fetch("/regenerate") + fetch("/api/v1.0/update?force=true") .then(response => response.json()) .then(data => { regenerateButton.style.cursor = "pointer"; diff --git a/src/interface/web/index.html b/src/interface/web/index.html index a74041fa..5dc76b87 100644 --- a/src/interface/web/index.html +++ b/src/interface/web/index.html @@ -77,8 +77,8 @@ // Generate Backend API URL to execute Search url = type === "image" - ? `/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}` - : `/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}`; + ? `/api/v1.0/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}` + : `/api/v1.0/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}`; // Execute Search and Render Results fetch(url) @@ -94,7 +94,7 @@ function updateIndex() { type = document.getElementById("type").value; - fetch(`/reload?t=${type}`) + fetch(`/api/v1.0/update?t=${type}`) .then(response => response.json()) .then(data => { console.log(data); @@ -118,7 +118,7 @@ function populate_type_dropdown() { // Populate type dropdown field with enabled search types only var possible_search_types = ["org", "markdown", "ledger", "music", "image"]; - fetch("/config/data") + fetch("/api/v1.0/config/data") .then(response => response.json()) .then(data => { document.getElementById("type").innerHTML = diff --git a/src/main.py b/src/main.py index 378758b2..13d15674 100644 --- a/src/main.py +++ b/src/main.py @@ -19,7 +19,9 @@ from PyQt6.QtCore import QThread, QTimer # Internal Packages from src.configure import configure_server -from src.router import router +from src.routers.api_v1_0 import api_v1_0 +from src.routers.api_beta import api_beta +from src.routers.frontend import frontend_router from src.utils import constants, state from src.utils.cli import cli from src.interface.desktop.main_window import MainWindow @@ -29,7 +31,9 @@ from src.interface.desktop.system_tray import create_system_tray # Initialize the Application Server app = FastAPI() app.mount("/static", StaticFiles(directory=constants.web_directory), name="static") -app.include_router(router) +app.include_router(api_v1_0, prefix="/api/v1.0") +app.include_router(api_beta, prefix="/api/beta") +app.include_router(frontend_router) logger = logging.getLogger('src') diff --git a/src/routers/api_beta.py b/src/routers/api_beta.py new file mode 100644 index 00000000..389025b9 --- /dev/null +++ b/src/routers/api_beta.py @@ -0,0 +1,89 @@ +# Standard Packages +import json +import logging +from typing import Optional + +# External Packages +from fastapi import APIRouter + +# Internal Packages +from src.routers.api_v1_0 import search +from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize +from src.utils.config import SearchType +from src.utils.helpers import get_absolute_path, get_from_dict +from src.utils import state + + +api_beta = APIRouter() +logger = logging.getLogger(__name__) + + +@api_beta.get('/search') +def search_beta(q: str, n: Optional[int] = 1): + # Extract Search Type using GPT + metadata = extract_search_type(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose) + search_type = get_from_dict(metadata, "search-type") + + # Search + search_results = search(q, n=n, t=SearchType(search_type)) + + # Return response + return {'status': 'ok', 'result': search_results, 'type': search_type} + + +@api_beta.get('/chat') +def chat(q: str): + # Load Conversation History + chat_session = state.processor_config.conversation.chat_session + meta_log = state.processor_config.conversation.meta_log + + # Converse with OpenAI GPT + metadata = understand(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose) + if state.verbose > 1: + print(f'Understood: {get_from_dict(metadata, "intent")}') + + if get_from_dict(metadata, "intent", "memory-type") == "notes": + query = get_from_dict(metadata, "intent", "query") + result_list = search(query, n=1, t=SearchType.Org) + collated_result = "\n".join([item["entry"] for item in result_list]) + if state.verbose > 1: + print(f'Semantically Similar Notes:\n{collated_result}') + gpt_response = summarize(collated_result, summary_type="notes", user_query=q, api_key=state.processor_config.conversation.openai_api_key) + else: + gpt_response = converse(q, chat_session, api_key=state.processor_config.conversation.openai_api_key) + + # Update Conversation History + state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) + state.processor_config.conversation.meta_log['chat'] = message_to_log(q, metadata, gpt_response, meta_log.get('chat', [])) + + return {'status': 'ok', 'response': gpt_response} + + +@api_beta.on_event('shutdown') +def shutdown_event(): + # No need to create empty log file + if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log): + return + elif state.processor_config.conversation.verbose: + print('INFO:\tSaving conversation logs to disk...') + + # Summarize Conversation Logs for this Session + chat_session = state.processor_config.conversation.chat_session + openai_api_key = state.processor_config.conversation.openai_api_key + conversation_log = state.processor_config.conversation.meta_log + session = { + "summary": summarize(chat_session, summary_type="chat", api_key=openai_api_key), + "session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"], + "session-end": len(conversation_log["chat"]) + } + if 'session' in conversation_log: + conversation_log['session'].append(session) + else: + conversation_log['session'] = [session] + + # Save Conversation Metadata Logs to Disk + conversation_logfile = get_absolute_path(state.processor_config.conversation.conversation_logfile) + with open(conversation_logfile, "w+", encoding='utf-8') as logfile: + json.dump(conversation_log, logfile) + + print('INFO:\tConversation logs saved to disk.') diff --git a/src/router.py b/src/routers/api_v1_0.py similarity index 51% rename from src/router.py rename to src/routers/api_v1_0.py index 27b237dd..616dbc17 100644 --- a/src/router.py +++ b/src/routers/api_v1_0.py @@ -1,45 +1,29 @@ # Standard Packages import yaml -import json import time import logging from typing import Optional # External Packages from fastapi import APIRouter -from fastapi import Request -from fastapi.responses import HTMLResponse, FileResponse -from fastapi.templating import Jinja2Templates # Internal Packages from src.configure import configure_search from src.search_type import image_search, text_search -from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize from src.utils.rawconfig import FullConfig from src.utils.config import SearchType -from src.utils.helpers import LRU, get_absolute_path, get_from_dict from src.utils import state, constants -router = APIRouter() -templates = Jinja2Templates(directory=constants.web_directory) +api_v1_0 = APIRouter() logger = logging.getLogger(__name__) -query_cache = LRU() -@router.get("/", response_class=FileResponse) -def index(): - return FileResponse(constants.web_directory / "index.html") - -@router.get('/config', response_class=HTMLResponse) -def config_page(request: Request): - return templates.TemplateResponse("config.html", context={'request': request}) - -@router.get('/config/data', response_model=FullConfig) +@api_v1_0.get('/config/data', response_model=FullConfig) def config_data(): return state.config -@router.post('/config/data') +@api_v1_0.post('/config/data') async def config_data(updated_config: FullConfig): state.config = updated_config with open(state.config_file, 'w') as outfile: @@ -47,7 +31,7 @@ async def config_data(updated_config: FullConfig): outfile.close() return state.config -@router.get('/search') +@api_v1_0.get('/search') def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False): if q is None or q == '': logger.info(f'No query param (q) passed in API call to initiate search') @@ -137,78 +121,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti return results -@router.get('/update') +@api_v1_0.get('/update') def update(t: Optional[SearchType] = None, force: Optional[bool] = False): state.model = configure_search(state.model, state.config, regenerate=force, t=t) - return {'status': 'ok', 'message': 'index updated completed'} - - -@router.get('/beta/search') -def search_beta(q: str, n: Optional[int] = 1): - # Extract Search Type using GPT - metadata = extract_search_type(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose) - search_type = get_from_dict(metadata, "search-type") - - # Search - search_results = search(q, n=n, t=SearchType(search_type)) - - # Return response - return {'status': 'ok', 'result': search_results, 'type': search_type} - - -@router.get('/beta/chat') -def chat(q: str): - # Load Conversation History - chat_session = state.processor_config.conversation.chat_session - meta_log = state.processor_config.conversation.meta_log - - # Converse with OpenAI GPT - metadata = understand(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose) - if state.verbose > 1: - print(f'Understood: {get_from_dict(metadata, "intent")}') - - if get_from_dict(metadata, "intent", "memory-type") == "notes": - query = get_from_dict(metadata, "intent", "query") - result_list = search(query, n=1, t=SearchType.Org) - collated_result = "\n".join([item["entry"] for item in result_list]) - if state.verbose > 1: - print(f'Semantically Similar Notes:\n{collated_result}') - gpt_response = summarize(collated_result, summary_type="notes", user_query=q, api_key=state.processor_config.conversation.openai_api_key) - else: - gpt_response = converse(q, chat_session, api_key=state.processor_config.conversation.openai_api_key) - - # Update Conversation History - state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) - state.processor_config.conversation.meta_log['chat'] = message_to_log(q, metadata, gpt_response, meta_log.get('chat', [])) - - return {'status': 'ok', 'response': gpt_response} - - -@router.on_event('shutdown') -def shutdown_event(): - # No need to create empty log file - if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log): - return - elif state.processor_config.conversation.verbose: - print('INFO:\tSaving conversation logs to disk...') - - # Summarize Conversation Logs for this Session - chat_session = state.processor_config.conversation.chat_session - openai_api_key = state.processor_config.conversation.openai_api_key - conversation_log = state.processor_config.conversation.meta_log - session = { - "summary": summarize(chat_session, summary_type="chat", api_key=openai_api_key), - "session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"], - "session-end": len(conversation_log["chat"]) - } - if 'session' in conversation_log: - conversation_log['session'].append(session) - else: - conversation_log['session'] = [session] - - # Save Conversation Metadata Logs to Disk - conversation_logfile = get_absolute_path(state.processor_config.conversation.conversation_logfile) - with open(conversation_logfile, "w+", encoding='utf-8') as logfile: - json.dump(conversation_log, logfile) - - print('INFO:\tConversation logs saved to disk.') + return {'status': 'ok', 'message': 'index updated'} diff --git a/src/routers/frontend.py b/src/routers/frontend.py new file mode 100644 index 00000000..8ed5d6ee --- /dev/null +++ b/src/routers/frontend.py @@ -0,0 +1,25 @@ +# Standard Packages +import logging + +# External Packages +from fastapi import APIRouter +from fastapi import Request +from fastapi.responses import HTMLResponse, FileResponse +from fastapi.templating import Jinja2Templates + +# Internal Packages +from src.utils import constants + + +frontend_router = APIRouter() +templates = Jinja2Templates(directory=constants.web_directory) +logger = logging.getLogger(__name__) + + +@frontend_router.get("/", response_class=FileResponse) +def index(): + return FileResponse(constants.web_directory / "index.html") + +@frontend_router.get('/config', response_class=HTMLResponse) +def config_page(request: Request): + return templates.TemplateResponse("config.html", context={'request': request}) diff --git a/tests/data/markdown/main_readme.md b/tests/data/markdown/main_readme.md index 45e289b1..7f626319 100644 --- a/tests/data/markdown/main_readme.md +++ b/tests/data/markdown/main_readme.md @@ -43,9 +43,8 @@ just generating embeddings* - **Khoj via API** - See [Khoj API Docs](http://localhost:8000/docs) - - [Query](http://localhost:8000/search?q=%22what%20is%20the%20meaning%20of%20life%22) - - [Regenerate - Embeddings](http://localhost:8000/regenerate?t=ledger) + - [Query](http://localhost:8000/api/v1.0/search?q=%22what%20is%20the%20meaning%20of%20life%22) + - [Update Index](http://localhost:8000/api/v1.0/update?t=ledger) - [Configure Application](https://localhost:8000/ui) - **Khoj via Emacs** - [Install](https://github.com/debanjum/khoj/tree/master/src/interface/emacs#installation) diff --git a/tests/data/org/main_readme.org b/tests/data/org/main_readme.org index 917562e2..4f63801a 100644 --- a/tests/data/org/main_readme.org +++ b/tests/data/org/main_readme.org @@ -27,8 +27,8 @@ - Run ~M-x khoj ~ or Call ~C-c C-s~ - *Khoj via API* - - Query: ~GET~ [[http://localhost:8000/search?q=%22what%20is%20the%20meaning%20of%20life%22][http://localhost:8000/search?q="What is the meaning of life"]] - - Regenerate Embeddings: ~GET~ [[http://localhost:8000/regenerate][http://localhost:8000/regenerate]] + - Query: ~GET~ [[http://localhost:8000/api/v1.0/search?q=%22what%20is%20the%20meaning%20of%20life%22][http://localhost:8000/api/v1.0/search?q="What is the meaning of life"]] + - Update Index: ~GET~ [[http://localhost:8000/api/v1.0/update][http://localhost:8000/api/v1.0/update]] - [[http://localhost:8000/docs][Khoj API Docs]] - *Call Khoj via Python Script Directly* diff --git a/tests/test_client.py b/tests/test_client.py index 96fa2c01..c17e7edd 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -28,7 +28,7 @@ def test_search_with_invalid_content_type(): user_query = quote("How to call Khoj from Emacs?") # Act - response = client.get(f"/search?q={user_query}&t=invalid_content_type") + response = client.get(f"/api/v1.0/search?q={user_query}&t=invalid_content_type") # Assert assert response.status_code == 422 @@ -43,29 +43,29 @@ def test_search_with_valid_content_type(content_config: ContentConfig, search_co # config.content_type.image = search_config.image for content_type in ["org", "markdown", "ledger", "music"]: # Act - response = client.get(f"/search?q=random&t={content_type}") + response = client.get(f"/api/v1.0/search?q=random&t={content_type}") # Assert assert response.status_code == 200 # ---------------------------------------------------------------------------------------------------- -def test_reload_with_invalid_content_type(): +def test_update_with_invalid_content_type(): # Act - response = client.get(f"/reload?t=invalid_content_type") + response = client.get(f"/api/v1.0/update?t=invalid_content_type") # Assert assert response.status_code == 422 # ---------------------------------------------------------------------------------------------------- -def test_reload_with_valid_content_type(content_config: ContentConfig, search_config: SearchConfig): +def test_update_with_valid_content_type(content_config: ContentConfig, search_config: SearchConfig): # Arrange config.content_type = content_config config.search_type = search_config for content_type in ["org", "markdown", "ledger", "music"]: # Act - response = client.get(f"/reload?t={content_type}") + response = client.get(f"/api/v1.0/update?t={content_type}") # Assert assert response.status_code == 200 @@ -73,7 +73,7 @@ def test_reload_with_valid_content_type(content_config: ContentConfig, search_co # ---------------------------------------------------------------------------------------------------- def test_regenerate_with_invalid_content_type(): # Act - response = client.get(f"/regenerate?t=invalid_content_type") + response = client.get(f"/api/v1.0/update?force=true&t=invalid_content_type") # Assert assert response.status_code == 422 @@ -87,7 +87,7 @@ def test_regenerate_with_valid_content_type(content_config: ContentConfig, searc for content_type in ["org", "markdown", "ledger", "music", "image"]: # Act - response = client.get(f"/regenerate?t={content_type}") + response = client.get(f"/api/v1.0/update?force=true&t={content_type}") # Assert assert response.status_code == 200 @@ -104,7 +104,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig for query, expected_image_name in query_expected_image_pairs: # Act - response = client.get(f"/search?q={query}&n=1&t=image") + response = client.get(f"/api/v1.0/search?q={query}&n=1&t=image") # Assert assert response.status_code == 200 @@ -122,7 +122,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig user_query = quote("How to git install application?") # Act - response = client.get(f"/search?q={user_query}&n=1&t=org&r=true") + response = client.get(f"/api/v1.0/search?q={user_query}&n=1&t=org&r=true") # Assert assert response.status_code == 200 @@ -139,7 +139,7 @@ def test_notes_search_with_only_filters(content_config: ContentConfig, search_co user_query = quote('+"Emacs" file:"*.org"') # Act - response = client.get(f"/search?q={user_query}&n=1&t=org") + response = client.get(f"/api/v1.0/search?q={user_query}&n=1&t=org") # Assert assert response.status_code == 200 @@ -156,7 +156,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ user_query = quote('How to git install application? +"Emacs"') # Act - response = client.get(f"/search?q={user_query}&n=1&t=org") + response = client.get(f"/api/v1.0/search?q={user_query}&n=1&t=org") # Assert assert response.status_code == 200 @@ -173,7 +173,7 @@ def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_ user_query = quote('How to git install application? -"clone"') # Act - response = client.get(f"/search?q={user_query}&n=1&t=org") + response = client.get(f"/api/v1.0/search?q={user_query}&n=1&t=org") # Assert assert response.status_code == 200 From 0521ea10d66e1b6f0f90b13450361ada446717a9 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 15 Sep 2022 13:44:00 +0300 Subject: [PATCH 04/10] Put image score breakdown under `additional' field in search response - Update web, emacs interfaces to consume the scores from new schema --- src/interface/emacs/khoj.el | 4 ++-- src/interface/web/index.html | 2 +- src/search_type/image_search.py | 17 +++++++++++------ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/interface/emacs/khoj.el b/src/interface/emacs/khoj.el index c05f98c6..e5b05605 100644 --- a/src/interface/emacs/khoj.el +++ b/src/interface/emacs/khoj.el @@ -187,8 +187,8 @@ Use `which-key` if available, else display simple message in echo area" (lambda (args) (format "\n\n

Score: %s Meta: %s Image: %s

\n\n\n\n" (cdr (assoc 'score args)) - (cdr (assoc 'metadata_score args)) - (cdr (assoc 'image_score args)) + (cdr (assoc 'metadata_score (assoc 'additional args))) + (cdr (assoc 'image_score (assoc 'additional args))) khoj-server-url (cdr (assoc 'entry args)) khoj-server-url diff --git a/src/interface/web/index.html b/src/interface/web/index.html index 5dc76b87..3d940ee3 100644 --- a/src/interface/web/index.html +++ b/src/interface/web/index.html @@ -16,7 +16,7 @@ return ` ` } diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index a86cc42c..d19a063a 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -220,12 +220,17 @@ def collate_results(hits, image_names, output_directory, image_files_url, count= shutil.copy(source_path, target_path) # Add the image metadata to the results - results += [{ - "entry": f'{image_files_url}/{target_image_name}', - "score": f"{hit['score']:.9f}", - "image_score": f"{hit['image_score']:.9f}", - "metadata_score": f"{hit['metadata_score']:.9f}", - }] + results += [ + { + "entry": f'{image_files_url}/{target_image_name}', + "score": f"{hit['score']:.9f}", + "additional": + { + "image_score": f"{hit['image_score']:.9f}", + "metadata_score": f"{hit['metadata_score']:.9f}", + } + } + ] return results From 99754970abba637e32e5f781861e3d23ac17eb83 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 15 Sep 2022 13:57:20 +0300 Subject: [PATCH 05/10] Type the /search API response to better document the response schema - Both Text, Image Search were already giving list of entry, score - This change just concretizes this change and exposes this in the API documentation (i.e OpenAPI, Swagger, Redocs) --- src/routers/api_v1_0.py | 8 ++++---- src/search_type/image_search.py | 10 +++++----- src/search_type/text_search.py | 8 ++++---- src/utils/rawconfig.py | 5 +++++ 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/routers/api_v1_0.py b/src/routers/api_v1_0.py index 616dbc17..b6dea695 100644 --- a/src/routers/api_v1_0.py +++ b/src/routers/api_v1_0.py @@ -10,7 +10,7 @@ from fastapi import APIRouter # Internal Packages from src.configure import configure_search from src.search_type import image_search, text_search -from src.utils.rawconfig import FullConfig +from src.utils.rawconfig import FullConfig, SearchResponse from src.utils.config import SearchType from src.utils import state, constants @@ -31,16 +31,16 @@ async def config_data(updated_config: FullConfig): outfile.close() return state.config -@api_v1_0.get('/search') +@api_v1_0.get('/search', response_model=list[SearchResponse]) def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False): + results: list[SearchResponse] = [] if q is None or q == '': logger.info(f'No query param (q) passed in API call to initiate search') - return {} + return results # initialize variables user_query = q.strip() results_count = n - results = {} query_start, query_end, collate_start, collate_end = None, None, None, None # return cached results, if available diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index d19a063a..e04bbe49 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -15,7 +15,7 @@ import torch # Internal Packages from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model from src.utils.config import ImageSearchModel -from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig +from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse # Create Logger @@ -203,8 +203,8 @@ def render_results(hits, image_names, image_directory, count): img.show() -def collate_results(hits, image_names, output_directory, image_files_url, count=5): - results = [] +def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> list[SearchResponse]: + results: list[SearchResponse] = [] for index, hit in enumerate(hits[:count]): source_path = image_names[hit['corpus_id']] @@ -220,7 +220,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count= shutil.copy(source_path, target_path) # Add the image metadata to the results - results += [ + results += [SearchResponse.parse_obj( { "entry": f'{image_files_url}/{target_image_name}', "score": f"{hit['score']:.9f}", @@ -230,7 +230,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count= "metadata_score": f"{hit['metadata_score']:.9f}", } } - ] + )] return results diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index ff7d9c43..009f39b9 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -13,7 +13,7 @@ from src.search_filter.base_filter import BaseFilter from src.utils import state from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model from src.utils.config import TextSearchModel -from src.utils.rawconfig import TextSearchConfig, TextContentConfig +from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig from src.utils.jsonl import load_jsonl @@ -171,12 +171,12 @@ def render_results(hits, entries, count=5, display_biencoder_results=False): print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}") -def collate_results(hits, entries, count=5): - return [ +def collate_results(hits, entries, count=5) -> list[SearchResponse]: + return [SearchResponse.parse_obj( { "entry": entries[hit['corpus_id']]['raw'], "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}" - } + }) for hit in hits[0:count]] diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index 2c708569..84aadc0a 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -71,3 +71,8 @@ class FullConfig(ConfigBase): content_type: Optional[ContentConfig] search_type: Optional[SearchConfig] processor: Optional[ProcessorConfig] + +class SearchResponse(ConfigBase): + entry: str + score: str + additional: Optional[dict] \ No newline at end of file From 7e9298f31576ebb2b18f85369c8b0e8d88a2f0c7 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 15 Sep 2022 23:34:43 +0300 Subject: [PATCH 06/10] Use new Text Entry class to track text entries in Intermediate Format - Context - The app maintains all text content in a standard, intermediate format - The intermediate format was loaded, passed around as a dictionary for easier, faster updates to the intermediate format schema initially - The intermediate format is reasonably stable now, given it's usage by all 3 text content types currently implemented - Changes - Concretize text entries into `Entries' class instead of using dictionaries - Code is updated to load, pass around entries as `Entries' objects instead of as dictionaries - `text_search' and `text_to_jsonl' methods are annotated with type hints for the new `Entries' type - Code and Tests referencing entries are updated to use class style access patterns instead of the previous dictionary access patterns - Move `mark_entries_for_update' method into `TextToJsonl' base class - This is a more natural location for the method as it is only (to be) used by `text_to_jsonl' classes - Avoid circular reference issues on importing `Entries' class --- src/processor/ledger/beancount_to_jsonl.py | 26 ++++----- src/processor/markdown/markdown_to_jsonl.py | 24 ++++---- src/processor/org_mode/org_to_jsonl.py | 65 ++++++++++----------- src/processor/text_to_jsonl.py | 46 ++++++++++++++- src/search_filter/date_filter.py | 2 +- src/search_filter/file_filter.py | 2 +- src/search_filter/word_filter.py | 2 +- src/search_type/text_search.py | 26 ++++----- src/utils/helpers.py | 37 ------------ src/utils/rawconfig.py | 27 ++++++++- tests/test_date_filter.py | 9 +-- tests/test_file_filter.py | 12 ++-- tests/test_image_search.py | 2 +- tests/test_text_search.py | 2 +- tests/test_word_filter.py | 10 ++-- 15 files changed, 161 insertions(+), 131 deletions(-) diff --git a/src/processor/ledger/beancount_to_jsonl.py b/src/processor/ledger/beancount_to_jsonl.py index d54b7e1b..ccad97da 100644 --- a/src/processor/ledger/beancount_to_jsonl.py +++ b/src/processor/ledger/beancount_to_jsonl.py @@ -1,5 +1,4 @@ # Standard Packages -import json import glob import re import logging @@ -7,9 +6,10 @@ import time # Internal Packages from src.processor.text_to_jsonl import TextToJsonl -from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update +from src.utils.helpers import get_absolute_path, is_none_or_empty from src.utils.constants import empty_escape_sequences from src.utils.jsonl import dump_jsonl, compress_jsonl_data +from src.utils.rawconfig import Entry logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ class BeancountToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) end = time.time() logger.debug(f"Identify new or updated transaction: {end - start} seconds") @@ -111,17 +111,17 @@ class BeancountToJsonl(TextToJsonl): return entries, dict(transaction_to_file_map) @staticmethod - def convert_transactions_to_maps(entries: list[str], transaction_to_file_map) -> list[dict]: - "Convert each Beancount transaction into a dictionary" - entry_maps = [] - for entry in entries: - entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{transaction_to_file_map[entry]}'}) + def convert_transactions_to_maps(parsed_entries: list[str], transaction_to_file_map) -> list[Entry]: + "Convert each parsed Beancount transaction into a Entry" + entries = [] + for parsed_entry in parsed_entries: + entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{transaction_to_file_map[parsed_entry]}')) - logger.info(f"Converted {len(entries)} transactions to dictionaries") + logger.info(f"Converted {len(parsed_entries)} transactions to dictionaries") - return entry_maps + return entries @staticmethod - def convert_transaction_maps_to_jsonl(entries: list[dict]) -> str: - "Convert each Beancount transaction dictionary to JSON and collate as JSONL" - return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) + def convert_transaction_maps_to_jsonl(entries: list[Entry]) -> str: + "Convert each Beancount transaction entry to JSON and collate as JSONL" + return ''.join([f'{entry.to_json()}\n' for entry in entries]) diff --git a/src/processor/markdown/markdown_to_jsonl.py b/src/processor/markdown/markdown_to_jsonl.py index 48fbbdf9..5c4d660d 100644 --- a/src/processor/markdown/markdown_to_jsonl.py +++ b/src/processor/markdown/markdown_to_jsonl.py @@ -1,5 +1,4 @@ # Standard Packages -import json import glob import re import logging @@ -7,9 +6,10 @@ import time # Internal Packages from src.processor.text_to_jsonl import TextToJsonl -from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update +from src.utils.helpers import get_absolute_path, is_none_or_empty from src.utils.constants import empty_escape_sequences from src.utils.jsonl import dump_jsonl, compress_jsonl_data +from src.utils.rawconfig import Entry logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ class MarkdownToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) end = time.time() logger.debug(f"Identify new or updated entries: {end - start} seconds") @@ -110,17 +110,17 @@ class MarkdownToJsonl(TextToJsonl): return entries, dict(entry_to_file_map) @staticmethod - def convert_markdown_entries_to_maps(entries: list[str], entry_to_file_map) -> list[dict]: + def convert_markdown_entries_to_maps(parsed_entries: list[str], entry_to_file_map) -> list[Entry]: "Convert each Markdown entries into a dictionary" - entry_maps = [] - for entry in entries: - entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{entry_to_file_map[entry]}'}) + entries = [] + for parsed_entry in parsed_entries: + entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{entry_to_file_map[parsed_entry]}')) - logger.info(f"Converted {len(entries)} markdown entries to dictionaries") + logger.info(f"Converted {len(parsed_entries)} markdown entries to dictionaries") - return entry_maps + return entries @staticmethod - def convert_markdown_maps_to_jsonl(entries): - "Convert each Markdown entries to JSON and collate as JSONL" - return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) + def convert_markdown_maps_to_jsonl(entries: list[Entry]): + "Convert each Markdown entry to JSON and collate as JSONL" + return ''.join([f'{entry.to_json()}\n' for entry in entries]) diff --git a/src/processor/org_mode/org_to_jsonl.py b/src/processor/org_mode/org_to_jsonl.py index c4c18ce9..52441a99 100644 --- a/src/processor/org_mode/org_to_jsonl.py +++ b/src/processor/org_mode/org_to_jsonl.py @@ -1,5 +1,4 @@ # Standard Packages -import json import glob import logging import time @@ -8,8 +7,9 @@ from typing import Iterable # Internal Packages from src.processor.org_mode import orgnode from src.processor.text_to_jsonl import TextToJsonl -from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update +from src.utils.helpers import get_absolute_path, is_none_or_empty from src.utils.jsonl import dump_jsonl, compress_jsonl_data +from src.utils.rawconfig import Entry from src.utils import state @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) class OrgToJsonl(TextToJsonl): # Define Functions - def process(self, previous_entries=None): + def process(self, previous_entries: list[Entry]=None): # Extract required fields from config org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl index_heading_entries = self.config.index_heading_entries @@ -47,7 +47,7 @@ class OrgToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) # Process Each Entry from All Notes Files start = time.time() @@ -104,51 +104,48 @@ class OrgToJsonl(TextToJsonl): return entries, dict(entry_to_file_map) @staticmethod - def convert_org_nodes_to_entries(entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[dict]: - "Convert Org-Mode entries into list of dictionary" - entry_maps = [] - for entry in entries: - entry_dict = dict() - - if not entry.hasBody and not index_heading_entries: + def convert_org_nodes_to_entries(parsed_entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[Entry]: + "Convert Org-Mode nodes into list of Entry objects" + entries: list[Entry] = [] + for parsed_entry in parsed_entries: + if not parsed_entry.hasBody and not index_heading_entries: # Ignore title notes i.e notes with just headings and empty body continue - entry_dict["compiled"] = f'{entry.heading}.' + compiled = f'{parsed_entry.heading}.' if state.verbose > 2: - logger.debug(f"Title: {entry.heading}") + logger.debug(f"Title: {parsed_entry.heading}") - if entry.tags: - tags_str = " ".join(entry.tags) - entry_dict["compiled"] += f'\t {tags_str}.' + if parsed_entry.tags: + tags_str = " ".join(parsed_entry.tags) + compiled += f'\t {tags_str}.' if state.verbose > 2: logger.debug(f"Tags: {tags_str}") - if entry.closed: - entry_dict["compiled"] += f'\n Closed on {entry.closed.strftime("%Y-%m-%d")}.' + if parsed_entry.closed: + compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.' if state.verbose > 2: - logger.debug(f'Closed: {entry.closed.strftime("%Y-%m-%d")}') + logger.debug(f'Closed: {parsed_entry.closed.strftime("%Y-%m-%d")}') - if entry.scheduled: - entry_dict["compiled"] += f'\n Scheduled for {entry.scheduled.strftime("%Y-%m-%d")}.' + if parsed_entry.scheduled: + compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.' if state.verbose > 2: - logger.debug(f'Scheduled: {entry.scheduled.strftime("%Y-%m-%d")}') + logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}') - if entry.hasBody: - entry_dict["compiled"] += f'\n {entry.body}' + if parsed_entry.hasBody: + compiled += f'\n {parsed_entry.body}' if state.verbose > 2: - logger.debug(f"Body: {entry.body}") + logger.debug(f"Body: {parsed_entry.body}") - if entry_dict: - entry_dict["raw"] = f'{entry}' - entry_dict["file"] = f'{entry_to_file_map[entry]}' + if compiled: + entries += [Entry( + compiled=compiled, + raw=f'{parsed_entry}', + file=f'{entry_to_file_map[parsed_entry]}')] - # Convert Dictionary to JSON and Append to JSONL string - entry_maps.append(entry_dict) - - return entry_maps + return entries @staticmethod - def convert_org_entries_to_jsonl(entries: Iterable[dict]) -> str: + def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str: "Convert each Org-Mode entry to JSON and collate as JSONL" - return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) + return ''.join([f'{entry_dict.to_json()}\n' for entry_dict in entries]) diff --git a/src/processor/text_to_jsonl.py b/src/processor/text_to_jsonl.py index e59c5fb1..a8153f52 100644 --- a/src/processor/text_to_jsonl.py +++ b/src/processor/text_to_jsonl.py @@ -1,9 +1,14 @@ # Standard Packages from abc import ABC, abstractmethod -from typing import Iterable +import hashlib +import time +import logging # Internal Packages -from src.utils.rawconfig import TextContentConfig +from src.utils.rawconfig import Entry, TextContentConfig + + +logger = logging.getLogger(__name__) class TextToJsonl(ABC): @@ -11,4 +16,39 @@ class TextToJsonl(ABC): self.config = config @abstractmethod - def process(self, previous_entries: Iterable[tuple[int, dict]]=None) -> list[tuple[int, dict]]: ... + def process(self, previous_entries: list[Entry]=None) -> list[tuple[int, Entry]]: ... + + def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]: + # Hash all current and previous entries to identify new entries + start = time.time() + current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(getattr(e, key), encoding='utf-8')).hexdigest(), current_entries)) + previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(getattr(e, key), encoding='utf-8')).hexdigest(), previous_entries)) + end = time.time() + logger.debug(f"Hash previous, current entries: {end - start} seconds") + + start = time.time() + hash_to_current_entries = dict(zip(current_entry_hashes, current_entries)) + hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries)) + + # All entries that did not exist in the previous set are to be added + new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes) + # All entries that exist in both current and previous sets are kept + existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes) + + # Mark new entries with -1 id to flag for later embeddings generation + new_entries = [ + (-1, hash_to_current_entries[entry_hash]) + for entry_hash in new_entry_hashes + ] + # Set id of existing entries to their previous ids to reuse their existing encoded embeddings + existing_entries = [ + (previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash]) + for entry_hash in existing_entry_hashes + ] + + existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0]) + entries_with_ids = existing_entries_sorted + new_entries + end = time.time() + logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds") + + return entries_with_ids \ No newline at end of file diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index 22a66068..00b829ac 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -37,7 +37,7 @@ class DateFilter(BaseFilter): start = time.time() for id, entry in enumerate(entries): # Extract dates from entry - for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): + for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)): # Convert date string in entry to unix timestamp try: date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 41f80274..84b520c0 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -24,7 +24,7 @@ class FileFilter(BaseFilter): def load(self, entries, *args, **kwargs): start = time.time() for id, entry in enumerate(entries): - self.file_to_entry_map[entry[self.entry_key]].add(id) + self.file_to_entry_map[getattr(entry, self.entry_key)].add(id) end = time.time() logger.debug(f"Created file filter index: {end - start} seconds") diff --git a/src/search_filter/word_filter.py b/src/search_filter/word_filter.py index e040ceee..ff9f9ee5 100644 --- a/src/search_filter/word_filter.py +++ b/src/search_filter/word_filter.py @@ -29,7 +29,7 @@ class WordFilter(BaseFilter): entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'' # Create map of words to entries they exist in for entry_index, entry in enumerate(entries): - for word in re.split(entry_splitter, entry[self.entry_key].lower()): + for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()): if word == '': continue self.word_to_entry_index[word].add(entry_index) diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 009f39b9..8b29c517 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -13,7 +13,7 @@ from src.search_filter.base_filter import BaseFilter from src.utils import state from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model from src.utils.config import TextSearchModel -from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig +from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry from src.utils.jsonl import load_jsonl @@ -50,12 +50,12 @@ def initialize_model(search_config: TextSearchConfig): return bi_encoder, cross_encoder, top_k -def extract_entries(jsonl_file): +def extract_entries(jsonl_file) -> list[Entry]: "Load entries from compressed jsonl" - return load_jsonl(jsonl_file) + return list(map(Entry.from_dict, load_jsonl(jsonl_file))) -def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False): +def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, embeddings_file, regenerate=False): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" new_entries = [] # Load pre-computed embeddings from file if exists and update them if required @@ -64,15 +64,15 @@ def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate logger.info(f"Loaded embeddings from {embeddings_file}") # Encode any new entries in the corpus and update corpus embeddings - new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None] + new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1] if new_entries: new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) - existing_entry_ids = [id for id, _ in entries_with_ids if id is not None] + existing_entry_ids = [id for id, _ in entries_with_ids if id != -1] existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor() corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) # Else compute the corpus embeddings from scratch else: - new_entries = [entry['compiled'] for _, entry in entries_with_ids] + new_entries = [entry.compiled for _, entry in entries_with_ids] corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) # Save regenerated or updated embeddings to file @@ -133,7 +133,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): # Score all retrieved entries using the cross-encoder if rank_results: start = time.time() - cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits] + cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits] cross_scores = model.cross_encoder.predict(cross_inp) end = time.time() logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}") @@ -153,7 +153,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): return hits, entries -def render_results(hits, entries, count=5, display_biencoder_results=False): +def render_results(hits, entries: list[Entry], count=5, display_biencoder_results=False): "Render the Results returned by Search for the Query" if display_biencoder_results: # Output of top hits from bi-encoder @@ -161,20 +161,20 @@ def render_results(hits, entries, count=5, display_biencoder_results=False): print(f"Top-{count} Bi-Encoder Retrieval hits") hits = sorted(hits, key=lambda x: x['score'], reverse=True) for hit in hits[0:count]: - print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['compiled']}") + print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']].compiled}") # Output of top hits from re-ranker print("\n-------------------------\n") print(f"Top-{count} Cross-Encoder Re-ranker hits") hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) for hit in hits[0:count]: - print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}") + print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']].compiled}") -def collate_results(hits, entries, count=5) -> list[SearchResponse]: +def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse]: return [SearchResponse.parse_obj( { - "entry": entries[hit['corpus_id']]['raw'], + "entry": entries[hit['corpus_id']].raw, "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}" }) for hit diff --git a/src/utils/helpers.py b/src/utils/helpers.py index df1899f9..8425a8fa 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -1,8 +1,6 @@ # Standard Packages from pathlib import Path import sys -import time -import hashlib from os.path import join from collections import OrderedDict from typing import Optional, Union @@ -83,38 +81,3 @@ class LRU(OrderedDict): oldest = next(iter(self)) del self[oldest] - -def mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=None): - # Hash all current and previous entries to identify new entries - start = time.time() - current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), current_entries)) - previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), previous_entries)) - end = time.time() - logger.debug(f"Hash previous, current entries: {end - start} seconds") - - start = time.time() - hash_to_current_entries = dict(zip(current_entry_hashes, current_entries)) - hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries)) - - # All entries that did not exist in the previous set are to be added - new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes) - # All entries that exist in both current and previous sets are kept - existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes) - - # Mark new entries with no ids for later embeddings generation - new_entries = [ - (None, hash_to_current_entries[entry_hash]) - for entry_hash in new_entry_hashes - ] - # Set id of existing entries to their previous ids to reuse their existing encoded embeddings - existing_entries = [ - (previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash]) - for entry_hash in existing_entry_hashes - ] - - existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0]) - entries_with_ids = existing_entries_sorted + new_entries - end = time.time() - logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds") - - return entries_with_ids \ No newline at end of file diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index 84aadc0a..165be0d1 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -1,4 +1,5 @@ # System Packages +import json from pathlib import Path from typing import List, Optional @@ -75,4 +76,28 @@ class FullConfig(ConfigBase): class SearchResponse(ConfigBase): entry: str score: str - additional: Optional[dict] \ No newline at end of file + additional: Optional[dict] + +class Entry(): + raw: str + compiled: str + file: Optional[str] + + def __init__(self, raw: str = None, compiled: str = None, file: Optional[str] = None): + self.raw = raw + self.compiled = compiled + self.file = file + + def to_json(self) -> str: + return json.dumps(self.__dict__, ensure_ascii=False) + + def __repr__(self) -> str: + return self.__dict__.__repr__() + + @classmethod + def from_dict(cls, dictionary: dict): + return cls( + raw=dictionary['raw'], + compiled=dictionary['compiled'], + file=dictionary.get('file', None) + ) \ No newline at end of file diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 345c5c4f..59ef697c 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -8,14 +8,15 @@ import torch # Application Packages from src.search_filter.date_filter import DateFilter +from src.utils.rawconfig import Entry def test_date_filter(): - embeddings = torch.randn(3, 10) entries = [ - {'compiled': '', 'raw': 'Entry with no date'}, - {'compiled': '', 'raw': 'April Fools entry: 1984-04-01'}, - {'compiled': '', 'raw': 'Entry with date:1984-04-02'}] + Entry(compiled='', raw='Entry with no date'), + Entry(compiled='', raw='April Fools entry: 1984-04-01'), + Entry(compiled='', raw='Entry with date:1984-04-02') + ] q_with_no_date_filter = 'head tail' ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries) diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index 3f9c22b3..e6c17299 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -3,6 +3,7 @@ import torch # Application Packages from src.search_filter.file_filter import FileFilter +from src.utils.rawconfig import Entry def test_no_file_filter(): @@ -104,9 +105,10 @@ def test_multiple_file_filter(): def arrange_content(): embeddings = torch.randn(4, 10) entries = [ - {'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'}, - {'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'}, - {'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'}, - {'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}] + Entry(compiled='', raw='First Entry', file= 'file 1.org'), + Entry(compiled='', raw='Second Entry', file= 'file2.org'), + Entry(compiled='', raw='Third Entry', file= 'file 1.org'), + Entry(compiled='', raw='Fourth Entry', file= 'file2.org') + ] - return embeddings, entries + return entries diff --git a/tests/test_image_search.py b/tests/test_image_search.py index e1a56b44..97911164 100644 --- a/tests/test_image_search.py +++ b/tests/test_image_search.py @@ -70,7 +70,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig image_files_url='/static/images', count=1) - actual_image_path = output_directory.joinpath(Path(results[0]["entry"]).name) + actual_image_path = output_directory.joinpath(Path(results[0].entry).name) actual_image = Image.open(actual_image_path) expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name)) diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 584c07b9..e05831a1 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -76,7 +76,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC # Assert # Actual_data should contain "Khoj via Emacs" entry - search_result = results[0]["entry"] + search_result = results[0].entry assert "git clone" in search_result diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index db23c2c6..58069b24 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -1,6 +1,7 @@ # Application Packages from src.search_filter.word_filter import WordFilter from src.utils.config import SearchType +from src.utils.rawconfig import Entry def test_no_word_filter(): @@ -69,9 +70,10 @@ def test_word_include_and_exclude_filter(): def arrange_content(): entries = [ - {'compiled': '', 'raw': 'Minimal Entry'}, - {'compiled': '', 'raw': 'Entry with exclude_word'}, - {'compiled': '', 'raw': 'Entry with include_word'}, - {'compiled': '', 'raw': 'Entry with include_word and exclude_word'}] + Entry(compiled='', raw='Minimal Entry'), + Entry(compiled='', raw='Entry with exclude_word'), + Entry(compiled='', raw='Entry with include_word'), + Entry(compiled='', raw='Entry with include_word and exclude_word') + ] return entries From 2c548133f3aecb8f74d5fe2c8a18422695a9dd43 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 16 Sep 2022 00:06:21 +0300 Subject: [PATCH 07/10] Remove unused imports, `embeddings' variable from text search tests --- tests/test_date_filter.py | 3 --- tests/test_file_filter.py | 16 ++++++---------- tests/test_image_search.py | 3 --- tests/test_word_filter.py | 1 - 4 files changed, 6 insertions(+), 17 deletions(-) diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 59ef697c..bc656701 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -3,9 +3,6 @@ import re from datetime import datetime from math import inf -# External Packages -import torch - # Application Packages from src.search_filter.date_filter import DateFilter from src.utils.rawconfig import Entry diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index e6c17299..28b0367f 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -1,6 +1,3 @@ -# External Packages -import torch - # Application Packages from src.search_filter.file_filter import FileFilter from src.utils.rawconfig import Entry @@ -9,7 +6,7 @@ from src.utils.rawconfig import Entry def test_no_file_filter(): # Arrange file_filter = FileFilter() - embeddings, entries = arrange_content() + entries = arrange_content() q_with_no_filter = 'head tail' # Act @@ -25,7 +22,7 @@ def test_no_file_filter(): def test_file_filter_with_non_existent_file(): # Arrange file_filter = FileFilter() - embeddings, entries = arrange_content() + entries = arrange_content() q_with_no_filter = 'head file:"nonexistent.org" tail' # Act @@ -41,7 +38,7 @@ def test_file_filter_with_non_existent_file(): def test_single_file_filter(): # Arrange file_filter = FileFilter() - embeddings, entries = arrange_content() + entries = arrange_content() q_with_no_filter = 'head file:"file 1.org" tail' # Act @@ -57,7 +54,7 @@ def test_single_file_filter(): def test_file_filter_with_partial_match(): # Arrange file_filter = FileFilter() - embeddings, entries = arrange_content() + entries = arrange_content() q_with_no_filter = 'head file:"1.org" tail' # Act @@ -73,7 +70,7 @@ def test_file_filter_with_partial_match(): def test_file_filter_with_regex_match(): # Arrange file_filter = FileFilter() - embeddings, entries = arrange_content() + entries = arrange_content() q_with_no_filter = 'head file:"*.org" tail' # Act @@ -89,7 +86,7 @@ def test_file_filter_with_regex_match(): def test_multiple_file_filter(): # Arrange file_filter = FileFilter() - embeddings, entries = arrange_content() + entries = arrange_content() q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"' # Act @@ -103,7 +100,6 @@ def test_multiple_file_filter(): def arrange_content(): - embeddings = torch.randn(4, 10) entries = [ Entry(compiled='', raw='First Entry', file= 'file 1.org'), Entry(compiled='', raw='Second Entry', file= 'file2.org'), diff --git a/tests/test_image_search.py b/tests/test_image_search.py index 97911164..374168f5 100644 --- a/tests/test_image_search.py +++ b/tests/test_image_search.py @@ -2,9 +2,6 @@ from pathlib import Path from PIL import Image -# External Packages -import pytest - # Internal Packages from src.utils.state import model from src.utils.constants import web_directory diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index 58069b24..2e662fd0 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -1,6 +1,5 @@ # Application Packages from src.search_filter.word_filter import WordFilter -from src.utils.config import SearchType from src.utils.rawconfig import Entry From d292bdcc119e1d8a939235ede15eab7331e21eb0 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 8 Oct 2022 13:16:08 +0300 Subject: [PATCH 08/10] Do not version API. Premature given current state of the codebase - Reason - All clients that currently consume the API are part of Khoj - Any breaking API changes will be fixed in clients immediately - So decoupling client from API is not required - This removes the burden of maintaining muliple versions of the API --- src/interface/emacs/khoj.el | 4 ++-- src/interface/web/assets/config.js | 6 +++--- src/interface/web/index.html | 8 ++++---- src/main.py | 4 ++-- src/routers/{api_v1_0.py => api.py} | 14 +++++++------- src/routers/api_beta.py | 2 +- tests/data/markdown/main_readme.md | 4 ++-- tests/data/org/main_readme.org | 4 ++-- tests/test_client.py | 22 +++++++++++----------- 9 files changed, 34 insertions(+), 34 deletions(-) rename src/routers/{api_v1_0.py => api.py} (94%) diff --git a/src/interface/emacs/khoj.el b/src/interface/emacs/khoj.el index e5b05605..d3d6c6f9 100644 --- a/src/interface/emacs/khoj.el +++ b/src/interface/emacs/khoj.el @@ -226,7 +226,7 @@ Use `which-key` if available, else display simple message in echo area" (defun khoj--get-enabled-content-types () "Get content types enabled for search from API." - (let ((config-url (format "%s/api/v1.0/config/data" khoj-server-url))) + (let ((config-url (format "%s/api/config/data" khoj-server-url))) (with-temp-buffer (erase-buffer) (url-insert-file-contents config-url) @@ -243,7 +243,7 @@ Use `which-key` if available, else display simple message in echo area" "Construct API Query from QUERY, SEARCH-TYPE and (optional) RERANK params." (let ((rerank (or rerank "false")) (encoded-query (url-hexify-string query))) - (format "%s/api/v1.0/search?q=%s&t=%s&r=%s&n=%s" khoj-server-url encoded-query search-type rerank khoj-results-count))) + (format "%s/api/search?q=%s&t=%s&r=%s&n=%s" khoj-server-url encoded-query search-type rerank khoj-results-count))) (defun khoj--query-api-and-render-results (query search-type query-url buffer-name) "Query Khoj API using QUERY, SEARCH-TYPE, QUERY-URL. diff --git a/src/interface/web/assets/config.js b/src/interface/web/assets/config.js index 90412e1c..965df9bc 100644 --- a/src/interface/web/assets/config.js +++ b/src/interface/web/assets/config.js @@ -10,7 +10,7 @@ var emptyValueDefault = "🖊️"; /** * Fetch the existing config file. */ -fetch("/api/v1.0/config/data") +fetch("/api/config/data") .then(response => response.json()) .then(data => { rawConfig = data; @@ -26,7 +26,7 @@ fetch("/api/v1.0/config/data") configForm.addEventListener("submit", (event) => { event.preventDefault(); console.log(rawConfig); - fetch("/api/v1.0/config/data", { + fetch("/api/config/data", { method: "POST", credentials: "same-origin", headers: { @@ -46,7 +46,7 @@ regenerateButton.addEventListener("click", (event) => { event.preventDefault(); regenerateButton.style.cursor = "progress"; regenerateButton.disabled = true; - fetch("/api/v1.0/update?force=true") + fetch("/api/update?force=true") .then(response => response.json()) .then(data => { regenerateButton.style.cursor = "pointer"; diff --git a/src/interface/web/index.html b/src/interface/web/index.html index 3d940ee3..c3360d8f 100644 --- a/src/interface/web/index.html +++ b/src/interface/web/index.html @@ -77,8 +77,8 @@ // Generate Backend API URL to execute Search url = type === "image" - ? `/api/v1.0/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}` - : `/api/v1.0/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}`; + ? `/api/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}` + : `/api/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}`; // Execute Search and Render Results fetch(url) @@ -94,7 +94,7 @@ function updateIndex() { type = document.getElementById("type").value; - fetch(`/api/v1.0/update?t=${type}`) + fetch(`/api/update?t=${type}`) .then(response => response.json()) .then(data => { console.log(data); @@ -118,7 +118,7 @@ function populate_type_dropdown() { // Populate type dropdown field with enabled search types only var possible_search_types = ["org", "markdown", "ledger", "music", "image"]; - fetch("/api/v1.0/config/data") + fetch("/api/config/data") .then(response => response.json()) .then(data => { document.getElementById("type").innerHTML = diff --git a/src/main.py b/src/main.py index 13d15674..33320285 100644 --- a/src/main.py +++ b/src/main.py @@ -19,7 +19,7 @@ from PyQt6.QtCore import QThread, QTimer # Internal Packages from src.configure import configure_server -from src.routers.api_v1_0 import api_v1_0 +from src.routers.api import api from src.routers.api_beta import api_beta from src.routers.frontend import frontend_router from src.utils import constants, state @@ -31,7 +31,7 @@ from src.interface.desktop.system_tray import create_system_tray # Initialize the Application Server app = FastAPI() app.mount("/static", StaticFiles(directory=constants.web_directory), name="static") -app.include_router(api_v1_0, prefix="/api/v1.0") +app.include_router(api, prefix="/api") app.include_router(api_beta, prefix="/api/beta") app.include_router(frontend_router) diff --git a/src/routers/api_v1_0.py b/src/routers/api.py similarity index 94% rename from src/routers/api_v1_0.py rename to src/routers/api.py index b6dea695..c8d4fa3a 100644 --- a/src/routers/api_v1_0.py +++ b/src/routers/api.py @@ -15,23 +15,23 @@ from src.utils.config import SearchType from src.utils import state, constants -api_v1_0 = APIRouter() +api = APIRouter() logger = logging.getLogger(__name__) -@api_v1_0.get('/config/data', response_model=FullConfig) -def config_data(): +@api.get('/config/data', response_model=FullConfig) +def get_config_data(): return state.config -@api_v1_0.post('/config/data') -async def config_data(updated_config: FullConfig): +@api.post('/config/data') +async def set_config_data(updated_config: FullConfig): state.config = updated_config with open(state.config_file, 'w') as outfile: yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile) outfile.close() return state.config -@api_v1_0.get('/search', response_model=list[SearchResponse]) +@api.get('/search', response_model=list[SearchResponse]) def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False): results: list[SearchResponse] = [] if q is None or q == '': @@ -121,7 +121,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti return results -@api_v1_0.get('/update') +@api.get('/update') def update(t: Optional[SearchType] = None, force: Optional[bool] = False): state.model = configure_search(state.model, state.config, regenerate=force, t=t) return {'status': 'ok', 'message': 'index updated'} diff --git a/src/routers/api_beta.py b/src/routers/api_beta.py index 389025b9..28d9cc2b 100644 --- a/src/routers/api_beta.py +++ b/src/routers/api_beta.py @@ -7,7 +7,7 @@ from typing import Optional from fastapi import APIRouter # Internal Packages -from src.routers.api_v1_0 import search +from src.routers.api import search from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize from src.utils.config import SearchType from src.utils.helpers import get_absolute_path, get_from_dict diff --git a/tests/data/markdown/main_readme.md b/tests/data/markdown/main_readme.md index 7f626319..14d97a97 100644 --- a/tests/data/markdown/main_readme.md +++ b/tests/data/markdown/main_readme.md @@ -43,8 +43,8 @@ just generating embeddings* - **Khoj via API** - See [Khoj API Docs](http://localhost:8000/docs) - - [Query](http://localhost:8000/api/v1.0/search?q=%22what%20is%20the%20meaning%20of%20life%22) - - [Update Index](http://localhost:8000/api/v1.0/update?t=ledger) + - [Query](http://localhost:8000/api/search?q=%22what%20is%20the%20meaning%20of%20life%22) + - [Update Index](http://localhost:8000/api/update?t=ledger) - [Configure Application](https://localhost:8000/ui) - **Khoj via Emacs** - [Install](https://github.com/debanjum/khoj/tree/master/src/interface/emacs#installation) diff --git a/tests/data/org/main_readme.org b/tests/data/org/main_readme.org index 4f63801a..48c1bfd5 100644 --- a/tests/data/org/main_readme.org +++ b/tests/data/org/main_readme.org @@ -27,8 +27,8 @@ - Run ~M-x khoj ~ or Call ~C-c C-s~ - *Khoj via API* - - Query: ~GET~ [[http://localhost:8000/api/v1.0/search?q=%22what%20is%20the%20meaning%20of%20life%22][http://localhost:8000/api/v1.0/search?q="What is the meaning of life"]] - - Update Index: ~GET~ [[http://localhost:8000/api/v1.0/update][http://localhost:8000/api/v1.0/update]] + - Query: ~GET~ [[http://localhost:8000/api/search?q=%22what%20is%20the%20meaning%20of%20life%22][http://localhost:8000/api/search?q="What is the meaning of life"]] + - Update Index: ~GET~ [[http://localhost:8000/api/update][http://localhost:8000/api/update]] - [[http://localhost:8000/docs][Khoj API Docs]] - *Call Khoj via Python Script Directly* diff --git a/tests/test_client.py b/tests/test_client.py index c17e7edd..d3dde245 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -28,7 +28,7 @@ def test_search_with_invalid_content_type(): user_query = quote("How to call Khoj from Emacs?") # Act - response = client.get(f"/api/v1.0/search?q={user_query}&t=invalid_content_type") + response = client.get(f"/api/search?q={user_query}&t=invalid_content_type") # Assert assert response.status_code == 422 @@ -43,7 +43,7 @@ def test_search_with_valid_content_type(content_config: ContentConfig, search_co # config.content_type.image = search_config.image for content_type in ["org", "markdown", "ledger", "music"]: # Act - response = client.get(f"/api/v1.0/search?q=random&t={content_type}") + response = client.get(f"/api/search?q=random&t={content_type}") # Assert assert response.status_code == 200 @@ -51,7 +51,7 @@ def test_search_with_valid_content_type(content_config: ContentConfig, search_co # ---------------------------------------------------------------------------------------------------- def test_update_with_invalid_content_type(): # Act - response = client.get(f"/api/v1.0/update?t=invalid_content_type") + response = client.get(f"/api/update?t=invalid_content_type") # Assert assert response.status_code == 422 @@ -65,7 +65,7 @@ def test_update_with_valid_content_type(content_config: ContentConfig, search_co for content_type in ["org", "markdown", "ledger", "music"]: # Act - response = client.get(f"/api/v1.0/update?t={content_type}") + response = client.get(f"/api/update?t={content_type}") # Assert assert response.status_code == 200 @@ -73,7 +73,7 @@ def test_update_with_valid_content_type(content_config: ContentConfig, search_co # ---------------------------------------------------------------------------------------------------- def test_regenerate_with_invalid_content_type(): # Act - response = client.get(f"/api/v1.0/update?force=true&t=invalid_content_type") + response = client.get(f"/api/update?force=true&t=invalid_content_type") # Assert assert response.status_code == 422 @@ -87,7 +87,7 @@ def test_regenerate_with_valid_content_type(content_config: ContentConfig, searc for content_type in ["org", "markdown", "ledger", "music", "image"]: # Act - response = client.get(f"/api/v1.0/update?force=true&t={content_type}") + response = client.get(f"/api/update?force=true&t={content_type}") # Assert assert response.status_code == 200 @@ -104,7 +104,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig for query, expected_image_name in query_expected_image_pairs: # Act - response = client.get(f"/api/v1.0/search?q={query}&n=1&t=image") + response = client.get(f"/api/search?q={query}&n=1&t=image") # Assert assert response.status_code == 200 @@ -122,7 +122,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig user_query = quote("How to git install application?") # Act - response = client.get(f"/api/v1.0/search?q={user_query}&n=1&t=org&r=true") + response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true") # Assert assert response.status_code == 200 @@ -139,7 +139,7 @@ def test_notes_search_with_only_filters(content_config: ContentConfig, search_co user_query = quote('+"Emacs" file:"*.org"') # Act - response = client.get(f"/api/v1.0/search?q={user_query}&n=1&t=org") + response = client.get(f"/api/search?q={user_query}&n=1&t=org") # Assert assert response.status_code == 200 @@ -156,7 +156,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ user_query = quote('How to git install application? +"Emacs"') # Act - response = client.get(f"/api/v1.0/search?q={user_query}&n=1&t=org") + response = client.get(f"/api/search?q={user_query}&n=1&t=org") # Assert assert response.status_code == 200 @@ -173,7 +173,7 @@ def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_ user_query = quote('How to git install application? -"clone"') # Act - response = client.get(f"/api/v1.0/search?q={user_query}&n=1&t=org") + response = client.get(f"/api/search?q={user_query}&n=1&t=org") # Assert assert response.status_code == 200 From c467df8fa39fba0e4a68c60ef7cd707373c19719 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 8 Oct 2022 17:33:13 +0300 Subject: [PATCH 09/10] Setup `mypy' for static type checking --- .mypy.ini | 13 +++++++++++++ src/utils/yaml.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 .mypy.ini diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 00000000..205d50d6 --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,13 @@ +[mypy] +strict_optional = False +ignore_missing_imports = True +install_types = True +non_interactive = True +show_error_codes = True +exclude = (?x)( + src/interface/desktop/main_window.py + | src/interface/desktop/file_browser.py + | src/interface/desktop/system_tray.py + | build/* + | tests/* + ) diff --git a/src/utils/yaml.py b/src/utils/yaml.py index a70c6f76..07f7cd87 100644 --- a/src/utils/yaml.py +++ b/src/utils/yaml.py @@ -9,7 +9,7 @@ from src.utils.rawconfig import FullConfig # Do not emit tags when dumping to YAML -yaml.emitter.Emitter.process_tag = lambda self, *args, **kwargs: None +yaml.emitter.Emitter.process_tag = lambda self, *args, **kwargs: None # type: ignore[assignment] def save_config_to_file(yaml_config: dict, yaml_config_file: Path): From e1b5a8792056d676eb95e96a3db73b00552a53ad Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 19 Oct 2022 16:15:57 +0530 Subject: [PATCH 10/10] Rename Frontend Router to Web Client. Fix logger usage in routers - Use logger in api_beta router instead of print statements - Remove unused logger in web client router --- src/main.py | 4 ++-- src/routers/api.py | 2 ++ src/routers/api_beta.py | 13 ++++++------- src/routers/{frontend.py => web_client.py} | 12 +++++------- 4 files changed, 15 insertions(+), 16 deletions(-) rename src/routers/{frontend.py => web_client.py} (69%) diff --git a/src/main.py b/src/main.py index 33320285..4f522572 100644 --- a/src/main.py +++ b/src/main.py @@ -21,7 +21,7 @@ from PyQt6.QtCore import QThread, QTimer from src.configure import configure_server from src.routers.api import api from src.routers.api_beta import api_beta -from src.routers.frontend import frontend_router +from src.routers.web_client import web_client from src.utils import constants, state from src.utils.cli import cli from src.interface.desktop.main_window import MainWindow @@ -33,7 +33,7 @@ app = FastAPI() app.mount("/static", StaticFiles(directory=constants.web_directory), name="static") app.include_router(api, prefix="/api") app.include_router(api_beta, prefix="/api/beta") -app.include_router(frontend_router) +app.include_router(web_client) logger = logging.getLogger('src') diff --git a/src/routers/api.py b/src/routers/api.py index c8d4fa3a..c8347f03 100644 --- a/src/routers/api.py +++ b/src/routers/api.py @@ -15,10 +15,12 @@ from src.utils.config import SearchType from src.utils import state, constants +# Initialize Router api = APIRouter() logger = logging.getLogger(__name__) +# Create Routes @api.get('/config/data', response_model=FullConfig) def get_config_data(): return state.config diff --git a/src/routers/api_beta.py b/src/routers/api_beta.py index 28d9cc2b..6425630d 100644 --- a/src/routers/api_beta.py +++ b/src/routers/api_beta.py @@ -14,10 +14,12 @@ from src.utils.helpers import get_absolute_path, get_from_dict from src.utils import state +# Initialize Router api_beta = APIRouter() logger = logging.getLogger(__name__) +# Create Routes @api_beta.get('/search') def search_beta(q: str, n: Optional[int] = 1): # Extract Search Type using GPT @@ -39,15 +41,13 @@ def chat(q: str): # Converse with OpenAI GPT metadata = understand(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose) - if state.verbose > 1: - print(f'Understood: {get_from_dict(metadata, "intent")}') + logger.debug(f'Understood: {get_from_dict(metadata, "intent")}') if get_from_dict(metadata, "intent", "memory-type") == "notes": query = get_from_dict(metadata, "intent", "query") result_list = search(query, n=1, t=SearchType.Org) collated_result = "\n".join([item["entry"] for item in result_list]) - if state.verbose > 1: - print(f'Semantically Similar Notes:\n{collated_result}') + logger.debug(f'Semantically Similar Notes:\n{collated_result}') gpt_response = summarize(collated_result, summary_type="notes", user_query=q, api_key=state.processor_config.conversation.openai_api_key) else: gpt_response = converse(q, chat_session, api_key=state.processor_config.conversation.openai_api_key) @@ -64,8 +64,7 @@ def shutdown_event(): # No need to create empty log file if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log): return - elif state.processor_config.conversation.verbose: - print('INFO:\tSaving conversation logs to disk...') + logger.debug('INFO:\tSaving conversation logs to disk...') # Summarize Conversation Logs for this Session chat_session = state.processor_config.conversation.chat_session @@ -86,4 +85,4 @@ def shutdown_event(): with open(conversation_logfile, "w+", encoding='utf-8') as logfile: json.dump(conversation_log, logfile) - print('INFO:\tConversation logs saved to disk.') + logger.info('INFO:\tConversation logs saved to disk.') diff --git a/src/routers/frontend.py b/src/routers/web_client.py similarity index 69% rename from src/routers/frontend.py rename to src/routers/web_client.py index 8ed5d6ee..0c2b8628 100644 --- a/src/routers/frontend.py +++ b/src/routers/web_client.py @@ -1,6 +1,3 @@ -# Standard Packages -import logging - # External Packages from fastapi import APIRouter from fastapi import Request @@ -11,15 +8,16 @@ from fastapi.templating import Jinja2Templates from src.utils import constants -frontend_router = APIRouter() +# Initialize Router +web_client = APIRouter() templates = Jinja2Templates(directory=constants.web_directory) -logger = logging.getLogger(__name__) -@frontend_router.get("/", response_class=FileResponse) +# Create Routes +@web_client.get("/", response_class=FileResponse) def index(): return FileResponse(constants.web_directory / "index.html") -@frontend_router.get('/config', response_class=HTMLResponse) +@web_client.get('/config', response_class=HTMLResponse) def config_page(request: Request): return templates.TemplateResponse("config.html", context={'request': request})