Put global state variables into separate state module

- Variables storing app, device state aren't constants.
  Do not mix with actual constants like empty_escape_sequence, web_directory
This commit is contained in:
Debanjum Singh Solanky
2022-08-06 03:05:35 +03:00
parent bc423d8f76
commit 7b04978f52
9 changed files with 64 additions and 64 deletions

View File

@@ -11,7 +11,7 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.search_type import image_search, text_search from src.search_type import image_search, text_search
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from src.utils.cli import cli from src.utils.cli import cli
from src.utils import constants from src.utils import state
from src.utils.helpers import get_absolute_path from src.utils.helpers import get_absolute_path
from src.utils.rawconfig import FullConfig from src.utils.rawconfig import FullConfig
@@ -21,19 +21,19 @@ def initialize_server(cmd_args):
args = cli(cmd_args) args = cli(cmd_args)
# Stores the file path to the config file. # Stores the file path to the config file.
constants.config_file = args.config_file state.config_file = args.config_file
# Store the raw config data. # Store the raw config data.
constants.config = args.config state.config = args.config
# Store the verbose flag # Store the verbose flag
constants.verbose = args.verbose state.verbose = args.verbose
# Initialize the search model from Config # Initialize the search model from Config
constants.model = initialize_search(constants.model, args.config, args.regenerate, device=constants.device, verbose=constants.verbose) state.model = initialize_search(state.model, args.config, args.regenerate, device=state.device, verbose=state.verbose)
# Initialize Processor from Config # Initialize Processor from Config
constants.processor_config = initialize_processor(args.config, verbose=constants.verbose) state.processor_config = initialize_processor(args.config, verbose=state.verbose)
return args.host, args.port, args.socket return args.host, args.port, args.socket

View File

@@ -9,9 +9,9 @@ from fastapi.staticfiles import StaticFiles
from PyQt6 import QtCore, QtGui, QtWidgets from PyQt6 import QtCore, QtGui, QtWidgets
# Internal Packages # Internal Packages
from src.utils import constants
from src.configure import initialize_server from src.configure import initialize_server
from src.router import router from src.router import router
from src.utils import constants
# Initialize the Application Server # Initialize the Application Server

View File

@@ -20,7 +20,7 @@ from src.search_filter.date_filter import DateFilter
from src.utils.rawconfig import FullConfig from src.utils.rawconfig import FullConfig
from src.utils.config import SearchType from src.utils.config import SearchType
from src.utils.helpers import get_absolute_path, get_from_dict from src.utils.helpers import get_absolute_path, get_from_dict
from src.utils import constants from src.utils import state, constants
router = APIRouter() router = APIRouter()
@@ -36,15 +36,15 @@ def config_page(request: Request):
@router.get('/config/data', response_model=FullConfig) @router.get('/config/data', response_model=FullConfig)
def config_data(): def config_data():
return constants.config return state.config
@router.post('/config/data') @router.post('/config/data')
async def config_data(updated_config: FullConfig): async def config_data(updated_config: FullConfig):
constants.config = updated_config state.config = updated_config
with open(constants.config_file, 'w') as outfile: with open(state.config_file, 'w') as outfile:
yaml.dump(yaml.safe_load(constants.config.json(by_alias=True)), outfile) yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
outfile.close() outfile.close()
return constants.config return state.config
@router.get('/search') @router.get('/search')
@lru_cache(maxsize=100) @lru_cache(maxsize=100)
@@ -57,10 +57,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
results_count = n results_count = n
results = {} results = {}
if (t == SearchType.Org or t == None) and constants.model.orgmode_search: if (t == SearchType.Org or t == None) and state.model.orgmode_search:
# query org-mode notes # query org-mode notes
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, constants.model.orgmode_search, rank_results=r, device=constants.device, filters=[DateFilter(), ExplicitFilter()], verbose=constants.verbose) hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, device=state.device, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results
@@ -68,10 +68,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
results = text_search.collate_results(hits, entries, results_count) results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time() collate_end = time.time()
if (t == SearchType.Music or t == None) and constants.model.music_search: if (t == SearchType.Music or t == None) and state.model.music_search:
# query music library # query music library
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, constants.model.music_search, rank_results=r, device=constants.device, filters=[DateFilter(), ExplicitFilter()], verbose=constants.verbose) hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, device=state.device, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results
@@ -79,10 +79,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
results = text_search.collate_results(hits, entries, results_count) results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time() collate_end = time.time()
if (t == SearchType.Markdown or t == None) and constants.model.orgmode_search: if (t == SearchType.Markdown or t == None) and state.model.orgmode_search:
# query markdown files # query markdown files
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, constants.model.markdown_search, rank_results=r, device=constants.device, filters=[ExplicitFilter(), DateFilter()], verbose=constants.verbose) hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, device=state.device, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results
@@ -90,10 +90,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
results = text_search.collate_results(hits, entries, results_count) results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time() collate_end = time.time()
if (t == SearchType.Ledger or t == None) and constants.model.ledger_search: if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
# query transactions # query transactions
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, constants.model.ledger_search, rank_results=r, device=constants.device, filters=[ExplicitFilter(), DateFilter()], verbose=constants.verbose) hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, device=state.device, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results
@@ -101,10 +101,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
results = text_search.collate_results(hits, entries, results_count) results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time() collate_end = time.time()
if (t == SearchType.Image or t == None) and constants.model.image_search: if (t == SearchType.Image or t == None) and state.model.image_search:
# query images # query images
query_start = time.time() query_start = time.time()
hits = image_search.query(user_query, results_count, constants.model.image_search) hits = image_search.query(user_query, results_count, state.model.image_search)
output_directory = constants.web_directory / 'images' output_directory = constants.web_directory / 'images'
query_end = time.time() query_end = time.time()
@@ -112,13 +112,13 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
collate_start = time.time() collate_start = time.time()
results = image_search.collate_results( results = image_search.collate_results(
hits, hits,
image_names=constants.model.image_search.image_names, image_names=state.model.image_search.image_names,
output_directory=output_directory, output_directory=output_directory,
image_files_url='/static/images', image_files_url='/static/images',
count=results_count) count=results_count)
collate_end = time.time() collate_end = time.time()
if constants.verbose > 1: if state.verbose > 1:
print(f"Query took {query_end - query_start:.3f} seconds") print(f"Query took {query_end - query_start:.3f} seconds")
print(f"Collating results took {collate_end - collate_start:.3f} seconds") print(f"Collating results took {collate_end - collate_start:.3f} seconds")
@@ -127,20 +127,20 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
@router.get('/reload') @router.get('/reload')
def reload(t: Optional[SearchType] = None): def reload(t: Optional[SearchType] = None):
constants.model = initialize_search(constants.model, constants.config, regenerate=False, t=t, device=constants.device) state.model = initialize_search(state.model, state.config, regenerate=False, t=t, device=state.device)
return {'status': 'ok', 'message': 'reload completed'} return {'status': 'ok', 'message': 'reload completed'}
@router.get('/regenerate') @router.get('/regenerate')
def regenerate(t: Optional[SearchType] = None): def regenerate(t: Optional[SearchType] = None):
constants.model = initialize_search(constants.model, constants.config, regenerate=True, t=t, device=constants.device) state.model = initialize_search(state.model, state.config, regenerate=True, t=t, device=state.device)
return {'status': 'ok', 'message': 'regeneration completed'} return {'status': 'ok', 'message': 'regeneration completed'}
@router.get('/beta/search') @router.get('/beta/search')
def search_beta(q: str, n: Optional[int] = 1): def search_beta(q: str, n: Optional[int] = 1):
# Extract Search Type using GPT # Extract Search Type using GPT
metadata = extract_search_type(q, api_key=constants.processor_config.conversation.openai_api_key, verbose=constants.verbose) metadata = extract_search_type(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose)
search_type = get_from_dict(metadata, "search-type") search_type = get_from_dict(metadata, "search-type")
# Search # Search
@@ -153,27 +153,27 @@ def search_beta(q: str, n: Optional[int] = 1):
@router.get('/chat') @router.get('/chat')
def chat(q: str): def chat(q: str):
# Load Conversation History # Load Conversation History
chat_session = constants.processor_config.conversation.chat_session chat_session = state.processor_config.conversation.chat_session
meta_log = constants.processor_config.conversation.meta_log meta_log = state.processor_config.conversation.meta_log
# Converse with OpenAI GPT # Converse with OpenAI GPT
metadata = understand(q, api_key=constants.processor_config.conversation.openai_api_key, verbose=constants.verbose) metadata = understand(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose)
if constants.verbose > 1: if state.verbose > 1:
print(f'Understood: {get_from_dict(metadata, "intent")}') print(f'Understood: {get_from_dict(metadata, "intent")}')
if get_from_dict(metadata, "intent", "memory-type") == "notes": if get_from_dict(metadata, "intent", "memory-type") == "notes":
query = get_from_dict(metadata, "intent", "query") query = get_from_dict(metadata, "intent", "query")
result_list = search(query, n=1, t=SearchType.Org) result_list = search(query, n=1, t=SearchType.Org)
collated_result = "\n".join([item["entry"] for item in result_list]) collated_result = "\n".join([item["entry"] for item in result_list])
if constants.verbose > 1: if state.verbose > 1:
print(f'Semantically Similar Notes:\n{collated_result}') print(f'Semantically Similar Notes:\n{collated_result}')
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, api_key=constants.processor_config.conversation.openai_api_key) gpt_response = summarize(collated_result, summary_type="notes", user_query=q, api_key=state.processor_config.conversation.openai_api_key)
else: else:
gpt_response = converse(q, chat_session, api_key=constants.processor_config.conversation.openai_api_key) gpt_response = converse(q, chat_session, api_key=state.processor_config.conversation.openai_api_key)
# Update Conversation History # Update Conversation History
constants.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
constants.processor_config.conversation.meta_log['chat'] = message_to_log(q, metadata, gpt_response, meta_log.get('chat', [])) state.processor_config.conversation.meta_log['chat'] = message_to_log(q, metadata, gpt_response, meta_log.get('chat', []))
return {'status': 'ok', 'response': gpt_response} return {'status': 'ok', 'response': gpt_response}
@@ -181,15 +181,15 @@ def chat(q: str):
@router.on_event('shutdown') @router.on_event('shutdown')
def shutdown_event(): def shutdown_event():
# No need to create empty log file # No need to create empty log file
if not (constants.processor_config and constants.processor_config.conversation and constants.processor_config.conversation.meta_log): if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log):
return return
elif constants.processor_config.conversation.verbose: elif state.processor_config.conversation.verbose:
print('INFO:\tSaving conversation logs to disk...') print('INFO:\tSaving conversation logs to disk...')
# Summarize Conversation Logs for this Session # Summarize Conversation Logs for this Session
chat_session = constants.processor_config.conversation.chat_session chat_session = state.processor_config.conversation.chat_session
openai_api_key = constants.processor_config.conversation.openai_api_key openai_api_key = state.processor_config.conversation.openai_api_key
conversation_log = constants.processor_config.conversation.meta_log conversation_log = state.processor_config.conversation.meta_log
session = { session = {
"summary": summarize(chat_session, summary_type="chat", api_key=openai_api_key), "summary": summarize(chat_session, summary_type="chat", api_key=openai_api_key),
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"], "session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
@@ -201,7 +201,7 @@ def shutdown_event():
conversation_log['session'] = [session] conversation_log['session'] = [session]
# Save Conversation Metadata Logs to Disk # Save Conversation Metadata Logs to Disk
conversation_logfile = get_absolute_path(constants.processor_config.conversation.conversation_logfile) conversation_logfile = get_absolute_path(state.processor_config.conversation.conversation_logfile)
with open(conversation_logfile, "w+", encoding='utf-8') as logfile: with open(conversation_logfile, "w+", encoding='utf-8') as logfile:
json.dump(conversation_log, logfile) json.dump(conversation_log, logfile)

View File

@@ -1,19 +1,4 @@
# External Packages
import torch
from pathlib import Path from pathlib import Path
# Internal Packages
from src.utils.config import SearchModels, ProcessorConfigModel
from src.utils.rawconfig import FullConfig
# Application Global State
config = FullConfig()
model = SearchModels()
processor_config = ProcessorConfigModel()
config_file: Path = ""
verbose: int = 0
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # Set device to GPU if available
# Other Constants
web_directory = Path(__file__).parent.parent / 'interface/web/' web_directory = Path(__file__).parent.parent / 'interface/web/'
empty_escape_sequences = r'\n|\r\t ' empty_escape_sequences = r'\n|\r\t '

15
src/utils/state.py Normal file
View File

@@ -0,0 +1,15 @@
# External Packages
import torch
from pathlib import Path
# Internal Packages
from src.utils.config import SearchModels, ProcessorConfigModel
from src.utils.rawconfig import FullConfig
# Application Global State
config = FullConfig()
model = SearchModels()
processor_config = ProcessorConfigModel()
config_file: Path = ""
verbose: int = 0
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # Set device to GPU if available

View File

@@ -1,12 +1,11 @@
# Standard Packages # Standard Packages
import pytest import pytest
import torch
# Internal Packages # Internal Packages
from src.search_type import image_search, text_search from src.search_type import image_search, text_search
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.utils import constants from src.utils import state
@pytest.fixture(scope='session') @pytest.fixture(scope='session')
@@ -56,7 +55,7 @@ def model_dir(search_config):
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
embeddings_file = model_dir.joinpath('note_embeddings.pt')) embeddings_file = model_dir.joinpath('note_embeddings.pt'))
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, device=constants.device, verbose=True) text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, device=state.device, verbose=True)
return model_dir return model_dir

View File

@@ -2,7 +2,7 @@
from pathlib import Path from pathlib import Path
# Internal Packages # Internal Packages
from src.utils.constants import model from src.utils.state import model
from src.search_type import text_search from src.search_type import text_search
from src.utils.rawconfig import ContentConfig, SearchConfig from src.utils.rawconfig import ContentConfig, SearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.processor.org_mode.org_to_jsonl import org_to_jsonl

View File

@@ -8,7 +8,7 @@ import pytest
# Internal Packages # Internal Packages
from src.main import app from src.main import app
from src.utils.constants import model, config from src.utils.state import model, config
from src.search_type import text_search, image_search from src.search_type import text_search, image_search
from src.utils.rawconfig import ContentConfig, SearchConfig from src.utils.rawconfig import ContentConfig, SearchConfig
from src.processor.org_mode import org_to_jsonl from src.processor.org_mode import org_to_jsonl

View File

@@ -6,7 +6,8 @@ from PIL import Image
import pytest import pytest
# Internal Packages # Internal Packages
from src.utils.constants import model, web_directory from src.utils.state import model
from src.utils.constants import web_directory
from src.search_type import image_search from src.search_type import image_search
from src.utils.helpers import resolve_absolute_path from src.utils.helpers import resolve_absolute_path
from src.utils.rawconfig import ContentConfig, SearchConfig from src.utils.rawconfig import ContentConfig, SearchConfig