diff --git a/src/configure.py b/src/configure.py index 564b7280..8c42fbd8 100644 --- a/src/configure.py +++ b/src/configure.py @@ -11,7 +11,7 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.search_type import image_search, text_search from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from src.utils.cli import cli -from src.utils import constants +from src.utils import state from src.utils.helpers import get_absolute_path from src.utils.rawconfig import FullConfig @@ -21,19 +21,19 @@ def initialize_server(cmd_args): args = cli(cmd_args) # Stores the file path to the config file. - constants.config_file = args.config_file + state.config_file = args.config_file # Store the raw config data. - constants.config = args.config + state.config = args.config # Store the verbose flag - constants.verbose = args.verbose + state.verbose = args.verbose # Initialize the search model from Config - constants.model = initialize_search(constants.model, args.config, args.regenerate, device=constants.device, verbose=constants.verbose) + state.model = initialize_search(state.model, args.config, args.regenerate, device=state.device, verbose=state.verbose) # Initialize Processor from Config - constants.processor_config = initialize_processor(args.config, verbose=constants.verbose) + state.processor_config = initialize_processor(args.config, verbose=state.verbose) return args.host, args.port, args.socket diff --git a/src/main.py b/src/main.py index 1d4c5ad7..76f68089 100644 --- a/src/main.py +++ b/src/main.py @@ -9,9 +9,9 @@ from fastapi.staticfiles import StaticFiles from PyQt6 import QtCore, QtGui, QtWidgets # Internal Packages -from src.utils import constants from src.configure import initialize_server from src.router import router +from src.utils import constants # Initialize the Application Server diff --git a/src/router.py b/src/router.py index 28880eb7..c5bb93e6 100644 --- a/src/router.py +++ b/src/router.py @@ -20,7 +20,7 @@ from src.search_filter.date_filter import DateFilter from src.utils.rawconfig import FullConfig from src.utils.config import SearchType from src.utils.helpers import get_absolute_path, get_from_dict -from src.utils import constants +from src.utils import state, constants router = APIRouter() @@ -36,15 +36,15 @@ def config_page(request: Request): @router.get('/config/data', response_model=FullConfig) def config_data(): - return constants.config + return state.config @router.post('/config/data') async def config_data(updated_config: FullConfig): - constants.config = updated_config - with open(constants.config_file, 'w') as outfile: - yaml.dump(yaml.safe_load(constants.config.json(by_alias=True)), outfile) + state.config = updated_config + with open(state.config_file, 'w') as outfile: + yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile) outfile.close() - return constants.config + return state.config @router.get('/search') @lru_cache(maxsize=100) @@ -57,10 +57,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti results_count = n results = {} - if (t == SearchType.Org or t == None) and constants.model.orgmode_search: + if (t == SearchType.Org or t == None) and state.model.orgmode_search: # query org-mode notes query_start = time.time() - hits, entries = text_search.query(user_query, constants.model.orgmode_search, rank_results=r, device=constants.device, filters=[DateFilter(), ExplicitFilter()], verbose=constants.verbose) + hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, device=state.device, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) query_end = time.time() # collate and return results @@ -68,10 +68,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti results = text_search.collate_results(hits, entries, results_count) collate_end = time.time() - if (t == SearchType.Music or t == None) and constants.model.music_search: + if (t == SearchType.Music or t == None) and state.model.music_search: # query music library query_start = time.time() - hits, entries = text_search.query(user_query, constants.model.music_search, rank_results=r, device=constants.device, filters=[DateFilter(), ExplicitFilter()], verbose=constants.verbose) + hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, device=state.device, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) query_end = time.time() # collate and return results @@ -79,10 +79,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti results = text_search.collate_results(hits, entries, results_count) collate_end = time.time() - if (t == SearchType.Markdown or t == None) and constants.model.orgmode_search: + if (t == SearchType.Markdown or t == None) and state.model.orgmode_search: # query markdown files query_start = time.time() - hits, entries = text_search.query(user_query, constants.model.markdown_search, rank_results=r, device=constants.device, filters=[ExplicitFilter(), DateFilter()], verbose=constants.verbose) + hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, device=state.device, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) query_end = time.time() # collate and return results @@ -90,10 +90,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti results = text_search.collate_results(hits, entries, results_count) collate_end = time.time() - if (t == SearchType.Ledger or t == None) and constants.model.ledger_search: + if (t == SearchType.Ledger or t == None) and state.model.ledger_search: # query transactions query_start = time.time() - hits, entries = text_search.query(user_query, constants.model.ledger_search, rank_results=r, device=constants.device, filters=[ExplicitFilter(), DateFilter()], verbose=constants.verbose) + hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, device=state.device, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) query_end = time.time() # collate and return results @@ -101,10 +101,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti results = text_search.collate_results(hits, entries, results_count) collate_end = time.time() - if (t == SearchType.Image or t == None) and constants.model.image_search: + if (t == SearchType.Image or t == None) and state.model.image_search: # query images query_start = time.time() - hits = image_search.query(user_query, results_count, constants.model.image_search) + hits = image_search.query(user_query, results_count, state.model.image_search) output_directory = constants.web_directory / 'images' query_end = time.time() @@ -112,13 +112,13 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti collate_start = time.time() results = image_search.collate_results( hits, - image_names=constants.model.image_search.image_names, + image_names=state.model.image_search.image_names, output_directory=output_directory, image_files_url='/static/images', count=results_count) collate_end = time.time() - if constants.verbose > 1: + if state.verbose > 1: print(f"Query took {query_end - query_start:.3f} seconds") print(f"Collating results took {collate_end - collate_start:.3f} seconds") @@ -127,20 +127,20 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti @router.get('/reload') def reload(t: Optional[SearchType] = None): - constants.model = initialize_search(constants.model, constants.config, regenerate=False, t=t, device=constants.device) + state.model = initialize_search(state.model, state.config, regenerate=False, t=t, device=state.device) return {'status': 'ok', 'message': 'reload completed'} @router.get('/regenerate') def regenerate(t: Optional[SearchType] = None): - constants.model = initialize_search(constants.model, constants.config, regenerate=True, t=t, device=constants.device) + state.model = initialize_search(state.model, state.config, regenerate=True, t=t, device=state.device) return {'status': 'ok', 'message': 'regeneration completed'} @router.get('/beta/search') def search_beta(q: str, n: Optional[int] = 1): # Extract Search Type using GPT - metadata = extract_search_type(q, api_key=constants.processor_config.conversation.openai_api_key, verbose=constants.verbose) + metadata = extract_search_type(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose) search_type = get_from_dict(metadata, "search-type") # Search @@ -153,27 +153,27 @@ def search_beta(q: str, n: Optional[int] = 1): @router.get('/chat') def chat(q: str): # Load Conversation History - chat_session = constants.processor_config.conversation.chat_session - meta_log = constants.processor_config.conversation.meta_log + chat_session = state.processor_config.conversation.chat_session + meta_log = state.processor_config.conversation.meta_log # Converse with OpenAI GPT - metadata = understand(q, api_key=constants.processor_config.conversation.openai_api_key, verbose=constants.verbose) - if constants.verbose > 1: + metadata = understand(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose) + if state.verbose > 1: print(f'Understood: {get_from_dict(metadata, "intent")}') if get_from_dict(metadata, "intent", "memory-type") == "notes": query = get_from_dict(metadata, "intent", "query") result_list = search(query, n=1, t=SearchType.Org) collated_result = "\n".join([item["entry"] for item in result_list]) - if constants.verbose > 1: + if state.verbose > 1: print(f'Semantically Similar Notes:\n{collated_result}') - gpt_response = summarize(collated_result, summary_type="notes", user_query=q, api_key=constants.processor_config.conversation.openai_api_key) + gpt_response = summarize(collated_result, summary_type="notes", user_query=q, api_key=state.processor_config.conversation.openai_api_key) else: - gpt_response = converse(q, chat_session, api_key=constants.processor_config.conversation.openai_api_key) + gpt_response = converse(q, chat_session, api_key=state.processor_config.conversation.openai_api_key) # Update Conversation History - constants.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) - constants.processor_config.conversation.meta_log['chat'] = message_to_log(q, metadata, gpt_response, meta_log.get('chat', [])) + state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) + state.processor_config.conversation.meta_log['chat'] = message_to_log(q, metadata, gpt_response, meta_log.get('chat', [])) return {'status': 'ok', 'response': gpt_response} @@ -181,15 +181,15 @@ def chat(q: str): @router.on_event('shutdown') def shutdown_event(): # No need to create empty log file - if not (constants.processor_config and constants.processor_config.conversation and constants.processor_config.conversation.meta_log): + if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log): return - elif constants.processor_config.conversation.verbose: + elif state.processor_config.conversation.verbose: print('INFO:\tSaving conversation logs to disk...') # Summarize Conversation Logs for this Session - chat_session = constants.processor_config.conversation.chat_session - openai_api_key = constants.processor_config.conversation.openai_api_key - conversation_log = constants.processor_config.conversation.meta_log + chat_session = state.processor_config.conversation.chat_session + openai_api_key = state.processor_config.conversation.openai_api_key + conversation_log = state.processor_config.conversation.meta_log session = { "summary": summarize(chat_session, summary_type="chat", api_key=openai_api_key), "session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"], @@ -201,7 +201,7 @@ def shutdown_event(): conversation_log['session'] = [session] # Save Conversation Metadata Logs to Disk - conversation_logfile = get_absolute_path(constants.processor_config.conversation.conversation_logfile) + conversation_logfile = get_absolute_path(state.processor_config.conversation.conversation_logfile) with open(conversation_logfile, "w+", encoding='utf-8') as logfile: json.dump(conversation_log, logfile) diff --git a/src/utils/constants.py b/src/utils/constants.py index be5d7464..bfb307f0 100644 --- a/src/utils/constants.py +++ b/src/utils/constants.py @@ -1,19 +1,4 @@ -# External Packages -import torch from pathlib import Path -# Internal Packages -from src.utils.config import SearchModels, ProcessorConfigModel -from src.utils.rawconfig import FullConfig - -# Application Global State -config = FullConfig() -model = SearchModels() -processor_config = ProcessorConfigModel() -config_file: Path = "" -verbose: int = 0 -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # Set device to GPU if available - -# Other Constants web_directory = Path(__file__).parent.parent / 'interface/web/' empty_escape_sequences = r'\n|\r\t ' diff --git a/src/utils/state.py b/src/utils/state.py new file mode 100644 index 00000000..964fa458 --- /dev/null +++ b/src/utils/state.py @@ -0,0 +1,15 @@ +# External Packages +import torch +from pathlib import Path + +# Internal Packages +from src.utils.config import SearchModels, ProcessorConfigModel +from src.utils.rawconfig import FullConfig + +# Application Global State +config = FullConfig() +model = SearchModels() +processor_config = ProcessorConfigModel() +config_file: Path = "" +verbose: int = 0 +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # Set device to GPU if available \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 68f12634..56610d45 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,11 @@ # Standard Packages import pytest -import torch # Internal Packages from src.search_type import image_search, text_search from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig from src.processor.org_mode.org_to_jsonl import org_to_jsonl -from src.utils import constants +from src.utils import state @pytest.fixture(scope='session') @@ -56,7 +55,7 @@ def model_dir(search_config): compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), embeddings_file = model_dir.joinpath('note_embeddings.pt')) - text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, device=constants.device, verbose=True) + text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, device=state.device, verbose=True) return model_dir diff --git a/tests/test_asymmetric_search.py b/tests/test_asymmetric_search.py index cf56b449..39fed92e 100644 --- a/tests/test_asymmetric_search.py +++ b/tests/test_asymmetric_search.py @@ -2,7 +2,7 @@ from pathlib import Path # Internal Packages -from src.utils.constants import model +from src.utils.state import model from src.search_type import text_search from src.utils.rawconfig import ContentConfig, SearchConfig from src.processor.org_mode.org_to_jsonl import org_to_jsonl diff --git a/tests/test_client.py b/tests/test_client.py index 779cc8ab..85aad8d7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,7 +8,7 @@ import pytest # Internal Packages from src.main import app -from src.utils.constants import model, config +from src.utils.state import model, config from src.search_type import text_search, image_search from src.utils.rawconfig import ContentConfig, SearchConfig from src.processor.org_mode import org_to_jsonl diff --git a/tests/test_image_search.py b/tests/test_image_search.py index 21b72937..4eb52048 100644 --- a/tests/test_image_search.py +++ b/tests/test_image_search.py @@ -6,7 +6,8 @@ from PIL import Image import pytest # Internal Packages -from src.utils.constants import model, web_directory +from src.utils.state import model +from src.utils.constants import web_directory from src.search_type import image_search from src.utils.helpers import resolve_absolute_path from src.utils.rawconfig import ContentConfig, SearchConfig