diff --git a/README.org b/README.org index bcc14ab5..f4c73726 100644 --- a/README.org +++ b/README.org @@ -18,7 +18,7 @@ Load ML model, generate embeddings and expose API to query specified org-mode files #+begin_src shell - python3 src/main.py --input-files ~/Notes/Schedule.org ~/Notes/Incoming.org --verbose + python3 src/main.py --org-files ~/Notes/Schedule.org ~/Notes/Incoming.org -c sample_config.yml --verbose #+end_src ** Use diff --git a/environment.yml b/environment.yml index bca8ab1e..bd7f9f59 100644 --- a/environment.yml +++ b/environment.yml @@ -9,4 +9,5 @@ dependencies: - sentence-transformers - fastapi - uvicorn + - pyyaml - pytest \ No newline at end of file diff --git a/sample_config.yml b/sample_config.yml new file mode 100644 index 00000000..4604207b --- /dev/null +++ b/sample_config.yml @@ -0,0 +1,11 @@ +content-type: + org: + input-files: ["src/tests/data/main_readme.org", "src/tests/data/interface_emacs_readme.org"] + input-filter: null + compressed-jsonl: ".notes.json.gz" + embeddings-file: ".note_embeddings.pt" + +search-type: + asymmetric: + encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3" + cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2" diff --git a/src/main.py b/src/main.py index 9f61a480..dd815f62 100644 --- a/src/main.py +++ b/src/main.py @@ -6,12 +6,13 @@ from typing import Optional # External Packages import uvicorn +import yaml from fastapi import FastAPI # Internal Packages from search_type import asymmetric from processor.org_mode.org_to_jsonl import org_to_jsonl -from utils.helpers import is_none_or_empty +from utils.helpers import is_none_or_empty, get_absolute_path, get_from_dict, merge_dicts app = FastAPI() @@ -26,7 +27,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None): user_query = q results_count = n - if t == 'notes' or t == None: + if (t == 'notes' or t == None) and notes_search_enabled: # query notes hits = asymmetric.query_notes( user_query, @@ -45,35 +46,90 @@ def search(q: str, n: Optional[int] = 5, t: Optional[str] = None): @app.get('/regenerate') def regenerate(t: Optional[str] = None): - if t == 'notes' or t == None: + if (t == 'notes' or t == None) and notes_search_enabled: # Extract Entries, Generate Embeddings global corpus_embeddings global entries - entries, corpus_embeddings, _, _, _ = asymmetric.setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, regenerate=True, verbose=args.verbose) + entries, corpus_embeddings, _, _, _ = asymmetric.setup( + org_config['input-files'], + org_config['input-filter'], + pathlib.Path(org_config['compressed-jsonl']), + pathlib.Path(org_config['embeddings-file']), + regenerate=True, + verbose=args.verbose) + return {'status': 'ok', 'message': 'regeneration completed'} def cli(args=None): - if not args: + if is_none_or_empty(args): args = sys.argv[1:] # Setup Argument Parser for the Commandline Interface parser = argparse.ArgumentParser(description="Expose API for Semantic Search") - parser.add_argument('--input-files', '-i', nargs='*', help="List of org-mode files to process") - parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for org-mode files to process") - parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path(".notes.jsonl.gz"), help="Compressed JSONL formatted notes file to compute embeddings from") - parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path(".notes_embeddings.pt"), help="File to save/load model embeddings to/from") - parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from org-mode files. Default: false") + parser.add_argument('--org-files', '-i', nargs='*', help="List of org-mode files to process") + parser.add_argument('--org-filter', type=str, default=None, help="Regex filter for org-mode files to process") + parser.add_argument('--config-file', '-c', type=pathlib.Path, help="YAML file with user configuration") + parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate model embeddings from source files. Default: false") parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs. Default: 0") + args = parser.parse_args(args) - return parser.parse_args(args) + if not (args.config_file or args.org_files): + print(f"Require at least 1 of --org-file, --org-filter or --config-file flags to be passed from commandline") + exit(1) + + # Config Priority: Cmd Args > Config File > Default Config + args.config = default_config + if args.config_file and args.config_file.exists(): + with open(get_absolute_path(args.config_file), 'r', encoding='utf-8') as config_file: + config_from_file = yaml.safe_load(config_file) + args.config = merge_dicts(priority_dict=config_from_file, default_dict=args.config) + + if args.org_files: + args.config['content-type']['org']['input-files'] = args.org_files + + if args.org_filter: + args.config['content-type']['org']['input-filter'] = args.org_filter + + return args + + +default_config = { + 'content-type': + { + 'org': + { + 'compressed-jsonl': '.notes.jsonl.gz', + 'embeddings-file': '.note_embeddings.pt' + } + }, + 'search-type': + { + 'asymmetric': + { + 'encoder': "sentence-transformers/msmarco-MiniLM-L-6-v3", + 'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2" + } + } +} if __name__ == '__main__': args = cli() + org_config = get_from_dict(args.config, 'content-type', 'org') + + notes_search_enabled = False + if 'input-files' in org_config or 'input-filter' in org_config: + notes_search_enabled = True + entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = asymmetric.setup( + org_config['input-files'], + org_config['input-filter'], + pathlib.Path(org_config['compressed-jsonl']), + pathlib.Path(org_config['embeddings-file']), + args.regenerate, + args.verbose) - entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = asymmetric.setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate, args.verbose) # Start Application Server uvicorn.run(app) diff --git a/src/tests/data/config.yml b/src/tests/data/config.yml new file mode 100644 index 00000000..b002b32a --- /dev/null +++ b/src/tests/data/config.yml @@ -0,0 +1,11 @@ +content-type: + org: + input-files: [ "~/first_from_config.org", "~/second_from_config.org" ] + input-filter: "*.org" + compressed-jsonl: ".notes.json.gz" + embeddings-file: ".note_embeddings.pt" + +search-type: + asymmetric: + encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3" + cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2" diff --git a/src/tests/test_main.py b/src/tests/test_main.py index ddb33483..b6cdf224 100644 --- a/src/tests/test_main.py +++ b/src/tests/test_main.py @@ -32,34 +32,64 @@ def test_asymmetric_setup(): assert len(corpus_embeddings) == 10 -# ---------------------------------------------------------------------------------------------------- -def test_cli_default(): +def test_cli_minimal_default(): # Act - args = cli(['--input-files=tests/data/test.org']) + actual_args = cli(['--config-file=tests/data/config.yml']) # Assert - assert args.input_files == ['tests/data/test.org'] - assert args.input_filter == None - assert args.compressed_jsonl == Path('.notes.jsonl.gz') - assert args.embeddings == Path('.notes_embeddings.pt') - assert args.regenerate == False - assert args.verbose == 0 - + assert actual_args.config_file == Path('tests/data/config.yml') + assert actual_args.regenerate == False + assert actual_args.verbose == 0 # ---------------------------------------------------------------------------------------------------- -def test_cli_set_by_user(): +def test_cli_flags(): # Act - actual_args = cli(['--input-files=tests/data/test.org', - '--input-filter=tests/data/*.org', - '--compressed-jsonl=tests/data/.test.jsonl.gz', - '--embeddings=tests/data/.test_embeddings.pt', + actual_args = cli(['--config-file=tests/data/config.yml', '--regenerate', '-vvv']) # Assert - assert actual_args.input_files == ['tests/data/test.org'] - assert actual_args.input_filter == 'tests/data/*.org' - assert actual_args.compressed_jsonl == Path('tests/data/.test.jsonl.gz') - assert actual_args.embeddings == Path('tests/data/.test_embeddings.pt') + assert actual_args.config_file == Path('tests/data/config.yml') assert actual_args.regenerate == True assert actual_args.verbose == 3 + + +# ---------------------------------------------------------------------------------------------------- +def test_cli_config_from_file(): + # Act + actual_args = cli(['--config-file=tests/data/config.yml', + '--regenerate', + '-vvv']) + + # Assert + assert actual_args.config_file == Path('tests/data/config.yml') + assert actual_args.regenerate == True + assert actual_args.config is not None + assert actual_args.config['content-type']['org']['input-files'] == ['~/first_from_config.org', '~/second_from_config.org'] + assert actual_args.verbose == 3 + + +# ---------------------------------------------------------------------------------------------------- +def test_cli_config_from_cmd_args(): + "" + # Act + actual_args = cli(['--org-files=first.org']) + + # Assert + assert actual_args.org_files == ['first.org'] + assert actual_args.config_file is None + assert actual_args.config is not None + assert actual_args.config['content-type']['org']['input-files'] == ['first.org'] + + +# ---------------------------------------------------------------------------------------------------- +def test_cli_config_from_cmd_args_override_config_file(): + # Act + actual_args = cli(['--config-file=tests/data/config.yml', + '--org-files=first.org']) + + # Assert + assert actual_args.org_files == ['first.org'] + assert actual_args.config_file == Path('tests/data/config.yml') + assert actual_args.config is not None + assert actual_args.config['content-type']['org']['input-files'] == ['first.org']