Use python standard logging framework for app logs

- Stop passing verbose flag around app methods
- Minor remap of verbosity levels to match python logging framework levels
  - verbose = 0 maps to logging.WARN
  - verbose = 1 maps to logging.INFO
  - verbose >=2 maps to logging.DEBUG
- Minor clean-up of app: unused modules, conversation file opening
This commit is contained in:
Debanjum Singh Solanky
2022-09-03 14:43:32 +03:00
parent d0531c3064
commit 094bd18e57
10 changed files with 184 additions and 155 deletions

View File

@@ -1,8 +1,8 @@
# System Packages # System Packages
import sys import sys
import logging
# External Packages # External Packages
import torch
import json import json
# Internal Packages # Internal Packages
@@ -12,10 +12,13 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.search_type import image_search, text_search from src.search_type import image_search, text_search
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from src.utils import state from src.utils import state
from src.utils.helpers import get_absolute_path from src.utils.helpers import resolve_absolute_path
from src.utils.rawconfig import FullConfig, ProcessorConfig from src.utils.rawconfig import FullConfig, ProcessorConfig
logger = logging.getLogger(__name__)
def configure_server(args, required=False): def configure_server(args, required=False):
if args.config is None: if args.config is None:
if required: if required:
@@ -27,42 +30,42 @@ def configure_server(args, required=False):
state.config = args.config state.config = args.config
# Initialize the search model from Config # Initialize the search model from Config
state.model = configure_search(state.model, state.config, args.regenerate, verbose=state.verbose) state.model = configure_search(state.model, state.config, args.regenerate)
# Initialize Processor from Config # Initialize Processor from Config
state.processor_config = configure_processor(args.config.processor, verbose=state.verbose) state.processor_config = configure_processor(args.config.processor)
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, verbose: int = 0): def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None):
# Initialize Org Notes Search # Initialize Org Notes Search
if (t == SearchType.Org or t == None) and config.content_type.org: if (t == SearchType.Org or t == None) and config.content_type.org:
# Extract Entries, Generate Notes Embeddings # Extract Entries, Generate Notes Embeddings
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate)
# Initialize Org Music Search # Initialize Org Music Search
if (t == SearchType.Music or t == None) and config.content_type.music: if (t == SearchType.Music or t == None) and config.content_type.music:
# Extract Entries, Generate Music Embeddings # Extract Entries, Generate Music Embeddings
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate)
# Initialize Markdown Search # Initialize Markdown Search
if (t == SearchType.Markdown or t == None) and config.content_type.markdown: if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
# Extract Entries, Generate Markdown Embeddings # Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose) model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate)
# Initialize Ledger Search # Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger: if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
# Extract Entries, Generate Ledger Embeddings # Extract Entries, Generate Ledger Embeddings
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose) model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate)
# Initialize Image Search # Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image: if (t == SearchType.Image or t == None) and config.content_type.image:
# Extract Entries, Generate Image Embeddings # Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup(config.content_type.image, search_config=config.search_type.image, regenerate=regenerate, verbose=verbose) model.image_search = image_search.setup(config.content_type.image, search_config=config.search_type.image, regenerate=regenerate)
return model return model
def configure_processor(processor_config: ProcessorConfig, verbose: int): def configure_processor(processor_config: ProcessorConfig):
if not processor_config: if not processor_config:
return return
@@ -70,27 +73,23 @@ def configure_processor(processor_config: ProcessorConfig, verbose: int):
# Initialize Conversation Processor # Initialize Conversation Processor
if processor_config.conversation: if processor_config.conversation:
processor.conversation = configure_conversation_processor(processor_config.conversation, verbose) processor.conversation = configure_conversation_processor(processor_config.conversation)
return processor return processor
def configure_conversation_processor(conversation_processor_config, verbose: int): def configure_conversation_processor(conversation_processor_config):
conversation_processor = ConversationProcessorConfigModel(conversation_processor_config, verbose) conversation_processor = ConversationProcessorConfigModel(conversation_processor_config)
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
conversation_logfile = conversation_processor.conversation_logfile if conversation_logfile.is_file():
if conversation_processor.verbose:
print('INFO:\tLoading conversation logs from disk...')
if conversation_logfile.expanduser().absolute().is_file():
# Load Metadata Logs from Conversation Logfile # Load Metadata Logs from Conversation Logfile
with open(get_absolute_path(conversation_logfile), 'r') as f: with conversation_logfile.open('r') as f:
conversation_processor.meta_log = json.load(f) conversation_processor.meta_log = json.load(f)
logger.info('Conversation logs loaded from disk.')
print('INFO:\tConversation logs loaded from disk.')
else: else:
# Initialize Conversation Logs # Initialize Conversation Logs
conversation_processor.meta_log = {} conversation_processor.meta_log = {}
conversation_processor.chat_session = "" conversation_processor.chat_session = ""
return conversation_processor return conversation_processor

View File

@@ -2,6 +2,7 @@
import os import os
import signal import signal
import sys import sys
import logging
from platform import system from platform import system
# External Packages # External Packages
@@ -25,6 +26,33 @@ app = FastAPI()
app.mount("/static", StaticFiles(directory=constants.web_directory), name="static") app.mount("/static", StaticFiles(directory=constants.web_directory), name="static")
app.include_router(router) app.include_router(router)
logger = logging.getLogger('src')
class CustomFormatter(logging.Formatter):
blue = "\x1b[1;34m"
green = "\x1b[1;32m"
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
bold_red = "\x1b[31;1m"
reset = "\x1b[0m"
format = "%(levelname)s: %(asctime)s: %(name)s | %(message)s"
FORMATS = {
logging.DEBUG: blue + format + reset,
logging.INFO: green + format + reset,
logging.WARNING: yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
def run(): def run():
# Turn Tokenizers Parallelism Off. App does not support it. # Turn Tokenizers Parallelism Off. App does not support it.
@@ -33,6 +61,20 @@ def run():
# Load config from CLI # Load config from CLI
state.cli_args = sys.argv[1:] state.cli_args = sys.argv[1:]
args = cli(state.cli_args) args = cli(state.cli_args)
# Setup Logger
# logging.basicConfig(format='%(levelname)s: %(asctime)s : %(module)s | %(message)s', datefmt='%d-%m-%Y %H:%M:%S')
ch = logging.StreamHandler()
ch.setFormatter(CustomFormatter())
logger.addHandler(ch)
if args.verbose == 0:
logger.setLevel(logging.WARN)
elif args.verbose == 1:
logger.setLevel(logging.INFO)
elif args.verbose >= 2:
logger.setLevel(logging.DEBUG)
logger.info("Starting Khoj...")
set_state(args) set_state(args)
if args.no_gui: if args.no_gui:

View File

@@ -6,6 +6,7 @@ import argparse
import pathlib import pathlib
import glob import glob
import re import re
import logging
# Internal Packages # Internal Packages
from src.utils.helpers import get_absolute_path, is_none_or_empty from src.utils.helpers import get_absolute_path, is_none_or_empty
@@ -13,32 +14,35 @@ from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.jsonl import dump_jsonl, compress_jsonl_data
logger = logging.getLogger(__name__)
# Define Functions # Define Functions
def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file, verbose=0): def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file):
# Input Validation # Input Validation
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter): if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
print("At least one of beancount-files or beancount-file-filter is required to be specified") print("At least one of beancount-files or beancount-file-filter is required to be specified")
exit(1) exit(1)
# Get Beancount Files to Process # Get Beancount Files to Process
beancount_files = get_beancount_files(beancount_files, beancount_file_filter, verbose) beancount_files = get_beancount_files(beancount_files, beancount_file_filter)
# Extract Entries from specified Beancount files # Extract Entries from specified Beancount files
entries = extract_beancount_entries(beancount_files) entries = extract_beancount_entries(beancount_files)
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
jsonl_data = convert_beancount_entries_to_jsonl(entries, verbose=verbose) jsonl_data = convert_beancount_entries_to_jsonl(entries)
# Compress JSONL formatted Data # Compress JSONL formatted Data
if output_file.suffix == ".gz": if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file, verbose=verbose) compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl": elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file, verbose=verbose) dump_jsonl(jsonl_data, output_file)
return entries return entries
def get_beancount_files(beancount_files=None, beancount_file_filter=None, verbose=0): def get_beancount_files(beancount_files=None, beancount_file_filter=None):
"Get Beancount files to process" "Get Beancount files to process"
absolute_beancount_files, filtered_beancount_files = set(), set() absolute_beancount_files, filtered_beancount_files = set(), set()
if beancount_files: if beancount_files:
@@ -57,8 +61,7 @@ def get_beancount_files(beancount_files=None, beancount_file_filter=None, verbos
if any(files_with_non_beancount_extensions): if any(files_with_non_beancount_extensions):
print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}") print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}")
if verbose > 0: logger.info(f'Processing files: {all_beancount_files}')
print(f'Processing files: {all_beancount_files}')
return all_beancount_files return all_beancount_files
@@ -82,7 +85,7 @@ def extract_beancount_entries(beancount_files):
return entries return entries
def convert_beancount_entries_to_jsonl(entries, verbose=0): def convert_beancount_entries_to_jsonl(entries):
"Convert each Beancount transaction to JSON and collate as JSONL" "Convert each Beancount transaction to JSON and collate as JSONL"
jsonl = '' jsonl = ''
for entry in entries: for entry in entries:
@@ -90,8 +93,7 @@ def convert_beancount_entries_to_jsonl(entries, verbose=0):
# Convert Dictionary to JSON and Append to JSONL string # Convert Dictionary to JSON and Append to JSONL string
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n' jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
if verbose > 0: logger.info(f"Converted {len(entries)} to jsonl format")
print(f"Converted {len(entries)} to jsonl format")
return jsonl return jsonl

View File

@@ -6,6 +6,7 @@ import argparse
import pathlib import pathlib
import glob import glob
import re import re
import logging
# Internal Packages # Internal Packages
from src.utils.helpers import get_absolute_path, is_none_or_empty from src.utils.helpers import get_absolute_path, is_none_or_empty
@@ -13,32 +14,35 @@ from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.jsonl import dump_jsonl, compress_jsonl_data
logger = logging.getLogger(__name__)
# Define Functions # Define Functions
def markdown_to_jsonl(markdown_files, markdown_file_filter, output_file, verbose=0): def markdown_to_jsonl(markdown_files, markdown_file_filter, output_file):
# Input Validation # Input Validation
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter): if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
print("At least one of markdown-files or markdown-file-filter is required to be specified") print("At least one of markdown-files or markdown-file-filter is required to be specified")
exit(1) exit(1)
# Get Markdown Files to Process # Get Markdown Files to Process
markdown_files = get_markdown_files(markdown_files, markdown_file_filter, verbose) markdown_files = get_markdown_files(markdown_files, markdown_file_filter)
# Extract Entries from specified Markdown files # Extract Entries from specified Markdown files
entries = extract_markdown_entries(markdown_files) entries = extract_markdown_entries(markdown_files)
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
jsonl_data = convert_markdown_entries_to_jsonl(entries, verbose=verbose) jsonl_data = convert_markdown_entries_to_jsonl(entries)
# Compress JSONL formatted Data # Compress JSONL formatted Data
if output_file.suffix == ".gz": if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file, verbose=verbose) compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl": elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file, verbose=verbose) dump_jsonl(jsonl_data, output_file)
return entries return entries
def get_markdown_files(markdown_files=None, markdown_file_filter=None, verbose=0): def get_markdown_files(markdown_files=None, markdown_file_filter=None):
"Get Markdown files to process" "Get Markdown files to process"
absolute_markdown_files, filtered_markdown_files = set(), set() absolute_markdown_files, filtered_markdown_files = set(), set()
if markdown_files: if markdown_files:
@@ -56,10 +60,9 @@ def get_markdown_files(markdown_files=None, markdown_file_filter=None, verbose=0
} }
if any(files_with_non_markdown_extensions): if any(files_with_non_markdown_extensions):
print(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}") logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
if verbose > 0: logger.info(f'Processing files: {all_markdown_files}')
print(f'Processing files: {all_markdown_files}')
return all_markdown_files return all_markdown_files
@@ -81,7 +84,7 @@ def extract_markdown_entries(markdown_files):
return entries return entries
def convert_markdown_entries_to_jsonl(entries, verbose=0): def convert_markdown_entries_to_jsonl(entries):
"Convert each Markdown entries to JSON and collate as JSONL" "Convert each Markdown entries to JSON and collate as JSONL"
jsonl = '' jsonl = ''
for entry in entries: for entry in entries:
@@ -89,8 +92,7 @@ def convert_markdown_entries_to_jsonl(entries, verbose=0):
# Convert Dictionary to JSON and Append to JSONL string # Convert Dictionary to JSON and Append to JSONL string
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n' jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
if verbose > 0: logger.info(f"Converted {len(entries)} to jsonl format")
print(f"Converted {len(entries)} to jsonl format")
return jsonl return jsonl

View File

@@ -6,40 +6,43 @@ import json
import argparse import argparse
import pathlib import pathlib
import glob import glob
import logging
# Internal Packages # Internal Packages
from src.processor.org_mode import orgnode from src.processor.org_mode import orgnode
from src.utils.helpers import get_absolute_path, is_none_or_empty from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.jsonl import dump_jsonl, compress_jsonl_data
logger = logging.getLogger(__name__)
# Define Functions # Define Functions
def org_to_jsonl(org_files, org_file_filter, output_file, verbose=0): def org_to_jsonl(org_files, org_file_filter, output_file):
# Input Validation # Input Validation
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter): if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter):
print("At least one of org-files or org-file-filter is required to be specified") print("At least one of org-files or org-file-filter is required to be specified")
exit(1) exit(1)
# Get Org Files to Process # Get Org Files to Process
org_files = get_org_files(org_files, org_file_filter, verbose) org_files = get_org_files(org_files, org_file_filter)
# Extract Entries from specified Org files # Extract Entries from specified Org files
entries = extract_org_entries(org_files) entries = extract_org_entries(org_files)
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
jsonl_data = convert_org_entries_to_jsonl(entries, verbose=verbose) jsonl_data = convert_org_entries_to_jsonl(entries)
# Compress JSONL formatted Data # Compress JSONL formatted Data
if output_file.suffix == ".gz": if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file, verbose=verbose) compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl": elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file, verbose=verbose) dump_jsonl(jsonl_data, output_file)
return entries return entries
def get_org_files(org_files=None, org_file_filter=None, verbose=0): def get_org_files(org_files=None, org_file_filter=None):
"Get Org files to process" "Get Org files to process"
absolute_org_files, filtered_org_files = set(), set() absolute_org_files, filtered_org_files = set(), set()
if org_files: if org_files:
@@ -53,10 +56,9 @@ def get_org_files(org_files=None, org_file_filter=None, verbose=0):
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")} files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
if any(files_with_non_org_extensions): if any(files_with_non_org_extensions):
print(f"[Warning] There maybe non org-mode files in the input set: {files_with_non_org_extensions}") logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
if verbose > 0: logger.info(f'Processing files: {all_org_files}')
print(f'Processing files: {all_org_files}')
return all_org_files return all_org_files
@@ -72,7 +74,7 @@ def extract_org_entries(org_files):
return entries return entries
def convert_org_entries_to_jsonl(entries, verbose=0) -> str: def convert_org_entries_to_jsonl(entries) -> str:
"Convert each Org-Mode entries to JSON and collate as JSONL" "Convert each Org-Mode entries to JSON and collate as JSONL"
jsonl = '' jsonl = ''
for entry in entries: for entry in entries:
@@ -83,29 +85,24 @@ def convert_org_entries_to_jsonl(entries, verbose=0) -> str:
continue continue
entry_dict["compiled"] = f'{entry.Heading()}.' entry_dict["compiled"] = f'{entry.Heading()}.'
if verbose > 2: logger.debug(f"Title: {entry.Heading()}")
print(f"Title: {entry.Heading()}")
if entry.Tags(): if entry.Tags():
tags_str = " ".join(entry.Tags()) tags_str = " ".join(entry.Tags())
entry_dict["compiled"] += f'\t {tags_str}.' entry_dict["compiled"] += f'\t {tags_str}.'
if verbose > 2: logger.debug(f"Tags: {tags_str}")
print(f"Tags: {tags_str}")
if entry.Closed(): if entry.Closed():
entry_dict["compiled"] += f'\n Closed on {entry.Closed().strftime("%Y-%m-%d")}.' entry_dict["compiled"] += f'\n Closed on {entry.Closed().strftime("%Y-%m-%d")}.'
if verbose > 2: logger.debug(f'Closed: {entry.Closed().strftime("%Y-%m-%d")}')
print(f'Closed: {entry.Closed().strftime("%Y-%m-%d")}')
if entry.Scheduled(): if entry.Scheduled():
entry_dict["compiled"] += f'\n Scheduled for {entry.Scheduled().strftime("%Y-%m-%d")}.' entry_dict["compiled"] += f'\n Scheduled for {entry.Scheduled().strftime("%Y-%m-%d")}.'
if verbose > 2: logger.debug(f'Scheduled: {entry.Scheduled().strftime("%Y-%m-%d")}')
print(f'Scheduled: {entry.Scheduled().strftime("%Y-%m-%d")}')
if entry.Body(): if entry.Body():
entry_dict["compiled"] += f'\n {entry.Body()}' entry_dict["compiled"] += f'\n {entry.Body()}'
if verbose > 2: logger.debug(f"Body: {entry.Body()}")
print(f"Body: {entry.Body()}")
if entry_dict: if entry_dict:
entry_dict["raw"] = f'{entry}' entry_dict["raw"] = f'{entry}'
@@ -113,8 +110,7 @@ def convert_org_entries_to_jsonl(entries, verbose=0) -> str:
# Convert Dictionary to JSON and Append to JSONL string # Convert Dictionary to JSON and Append to JSONL string
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n' jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
if verbose > 0: logger.info(f"Converted {len(entries)} to jsonl format")
print(f"Converted {len(entries)} to jsonl format")
return jsonl return jsonl

View File

@@ -2,6 +2,7 @@
import yaml import yaml
import json import json
import time import time
import logging
from typing import Optional from typing import Optional
from functools import lru_cache from functools import lru_cache
@@ -22,9 +23,11 @@ from src.utils.config import SearchType
from src.utils.helpers import get_absolute_path, get_from_dict from src.utils.helpers import get_absolute_path, get_from_dict
from src.utils import state, constants from src.utils import state, constants
router = APIRouter()
router = APIRouter()
templates = Jinja2Templates(directory=constants.web_directory) templates = Jinja2Templates(directory=constants.web_directory)
logger = logging.getLogger(__name__)
@router.get("/", response_class=FileResponse) @router.get("/", response_class=FileResponse)
def index(): def index():
@@ -50,7 +53,7 @@ async def config_data(updated_config: FullConfig):
@lru_cache(maxsize=100) @lru_cache(maxsize=100)
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False): def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
if q is None or q == '': if q is None or q == '':
print(f'No query param (q) passed in API call to initiate search') logger.info(f'No query param (q) passed in API call to initiate search')
return {} return {}
# initialize variables # initialize variables
@@ -120,11 +123,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
count=results_count) count=results_count)
collate_end = time.time() collate_end = time.time()
if state.verbose > 1: if query_start and query_end:
if query_start and query_end: logger.debug(f"Query took {query_end - query_start:.3f} seconds")
print(f"Query took {query_end - query_start:.3f} seconds") if collate_start and collate_end:
if collate_start and collate_end: logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds")
print(f"Collating results took {collate_end - collate_start:.3f} seconds")
return results return results

View File

@@ -5,6 +5,7 @@ import pathlib
import copy import copy
import shutil import shutil
import time import time
import logging
# External Packages # External Packages
from sentence_transformers import SentenceTransformer, util from sentence_transformers import SentenceTransformer, util
@@ -17,7 +18,10 @@ from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_mod
import src.utils.exiftool as exiftool import src.utils.exiftool as exiftool
from src.utils.config import ImageSearchModel from src.utils.config import ImageSearchModel
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig
from src.utils import state
# Create Logger
logger = logging.getLogger(__name__)
def initialize_model(search_config: ImageSearchConfig): def initialize_model(search_config: ImageSearchConfig):
@@ -39,34 +43,33 @@ def initialize_model(search_config: ImageSearchConfig):
return encoder return encoder
def extract_entries(image_directories, verbose=0): def extract_entries(image_directories):
image_names = [] image_names = []
for image_directory in image_directories: for image_directory in image_directories:
image_directory = resolve_absolute_path(image_directory, strict=True) image_directory = resolve_absolute_path(image_directory, strict=True)
image_names.extend(list(image_directory.glob('*.jpg'))) image_names.extend(list(image_directory.glob('*.jpg')))
image_names.extend(list(image_directory.glob('*.jpeg'))) image_names.extend(list(image_directory.glob('*.jpeg')))
if verbose > 0: if logger.level >= logging.INFO:
image_directory_names = ', '.join([str(image_directory) for image_directory in image_directories]) image_directory_names = ', '.join([str(image_directory) for image_directory in image_directories])
print(f'Found {len(image_names)} images in {image_directory_names}') logger.info(f'Found {len(image_names)} images in {image_directory_names}')
return sorted(image_names) return sorted(image_names)
def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0): def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings" "Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate, verbose) image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate)
image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate, verbose) image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate)
return image_embeddings, image_metadata_embeddings return image_embeddings, image_metadata_embeddings
def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False, verbose=0): def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False):
# Load pre-computed image embeddings from file if exists # Load pre-computed image embeddings from file if exists
if resolve_absolute_path(embeddings_file).exists() and not regenerate: if resolve_absolute_path(embeddings_file).exists() and not regenerate:
image_embeddings = torch.load(embeddings_file) image_embeddings = torch.load(embeddings_file)
if verbose > 0: logger.info(f"Loaded {len(image_embeddings)} image embeddings from {embeddings_file}")
print(f"Loaded pre-computed embeddings from {embeddings_file}")
# Else compute the image embeddings from scratch, which can take a while # Else compute the image embeddings from scratch, which can take a while
else: else:
image_embeddings = [] image_embeddings = []
@@ -87,8 +90,7 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
# Save computed image embeddings to file # Save computed image embeddings to file
torch.save(image_embeddings, embeddings_file) torch.save(image_embeddings, embeddings_file)
if verbose > 0: logger.info(f"Saved computed embeddings to {embeddings_file}")
print(f"Saved computed embeddings to {embeddings_file}")
return image_embeddings return image_embeddings
@@ -99,8 +101,7 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
# Load pre-computed image metadata embedding file if exists # Load pre-computed image metadata embedding file if exists
if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate: if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate:
image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata") image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata")
if verbose > 0: logger.info(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
print(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
# Else compute the image metadata embeddings from scratch, which can take a while # Else compute the image metadata embeddings from scratch, which can take a while
if use_xmp_metadata and image_metadata_embeddings is None: if use_xmp_metadata and image_metadata_embeddings is None:
@@ -113,16 +114,15 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
convert_to_tensor=True, convert_to_tensor=True,
batch_size=min(len(image_metadata), batch_size)) batch_size=min(len(image_metadata), batch_size))
except RuntimeError as e: except RuntimeError as e:
print(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}") logger.error(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
continue continue
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata") torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
if verbose > 0: logger.info(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
return image_metadata_embeddings return image_metadata_embeddings
def extract_metadata(image_name, verbose=0): def extract_metadata(image_name):
with exiftool.ExifTool() as et: with exiftool.ExifTool() as et:
image_metadata = et.get_tags(["XMP:Subject", "XMP:Description"], str(image_name)) image_metadata = et.get_tags(["XMP:Subject", "XMP:Description"], str(image_name))
image_metadata_subjects = set([subject.split(":")[1] for subject in image_metadata.get("XMP:Subject", "") if ":" in subject]) image_metadata_subjects = set([subject.split(":")[1] for subject in image_metadata.get("XMP:Subject", "") if ":" in subject])
@@ -131,8 +131,7 @@ def extract_metadata(image_name, verbose=0):
if len(image_metadata_subjects) > 0: if len(image_metadata_subjects) > 0:
image_processed_metadata += ". " + ", ".join(image_metadata_subjects) image_processed_metadata += ". " + ", ".join(image_metadata_subjects)
if verbose > 2: logger.debug(f"{image_name}:\t{image_processed_metadata}")
print(f"{image_name}:\t{image_processed_metadata}")
return image_processed_metadata return image_processed_metadata
@@ -143,19 +142,16 @@ def query(raw_query, count, model: ImageSearchModel):
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query), strict=True) query_imagepath = resolve_absolute_path(pathlib.Path(raw_query), strict=True)
query = copy.deepcopy(Image.open(query_imagepath)) query = copy.deepcopy(Image.open(query_imagepath))
query.thumbnail((640, query.height)) # scale down image for faster processing query.thumbnail((640, query.height)) # scale down image for faster processing
if model.verbose > 0: logger.info(f"Find Images similar to Image at {query_imagepath}")
print(f"Find Images similar to Image at {query_imagepath}")
else: else:
query = raw_query query = raw_query
if state.verbose > 0: logger.info(f"Find Images by Text: {query}")
print(f"Find Images by Text: {query}")
# Now we encode the query (which can either be an image or a text string) # Now we encode the query (which can either be an image or a text string)
start = time.time() start = time.time()
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False) query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
end = time.time() end = time.time()
if state.verbose > 1: logger.debug(f"Query Encode Time: {end - start:.3f} seconds")
print(f"Query Encode Time: {end - start:.3f} seconds")
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
start = time.time() start = time.time()
@@ -163,8 +159,7 @@ def query(raw_query, count, model: ImageSearchModel):
for result for result
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]} in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
end = time.time() end = time.time()
if state.verbose > 1: logger.debug(f"Search Time: {end - start:.3f} seconds")
print(f"Search Time: {end - start:.3f} seconds")
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings. # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings: if model.image_metadata_embeddings:
@@ -173,8 +168,7 @@ def query(raw_query, count, model: ImageSearchModel):
for result for result
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]} in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
end = time.time() end = time.time()
if state.verbose > 1: logger.debug(f"Metadata Search Time: {end - start:.3f} seconds")
print(f"Metadata Search Time: {end - start:.3f} seconds")
# Sum metadata, image scores of the highest ranked images # Sum metadata, image scores of the highest ranked images
for corpus_id, score in metadata_hits.items(): for corpus_id, score in metadata_hits.items():
@@ -237,7 +231,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
return results return results
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel: def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
# Initialize Model # Initialize Model
encoder = initialize_model(search_config) encoder = initialize_model(search_config)
@@ -245,7 +239,7 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
absolute_image_files, filtered_image_files = set(), set() absolute_image_files, filtered_image_files = set(), set()
if config.input_directories: if config.input_directories:
image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories] image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories]
absolute_image_files = set(extract_entries(image_directories, verbose)) absolute_image_files = set(extract_entries(image_directories))
if config.input_filter: if config.input_filter:
filtered_image_files = set(glob.glob(get_absolute_path(config.input_filter))) filtered_image_files = set(glob.glob(get_absolute_path(config.input_filter)))
@@ -259,14 +253,12 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
embeddings_file, embeddings_file,
batch_size=config.batch_size, batch_size=config.batch_size,
regenerate=regenerate, regenerate=regenerate,
use_xmp_metadata=config.use_xmp_metadata, use_xmp_metadata=config.use_xmp_metadata)
verbose=verbose)
return ImageSearchModel(all_image_files, return ImageSearchModel(all_image_files,
image_embeddings, image_embeddings,
image_metadata_embeddings, image_metadata_embeddings,
encoder, encoder)
verbose)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -1,8 +1,9 @@
# Standard Packages # Standard Packages
import argparse import argparse
import pathlib import pathlib
from copy import deepcopy import logging
import time import time
from copy import deepcopy
# External Packages # External Packages
import torch import torch
@@ -16,6 +17,9 @@ from src.utils.rawconfig import TextSearchConfig, TextContentConfig
from src.utils.jsonl import load_jsonl from src.utils.jsonl import load_jsonl
logger = logging.getLogger(__name__)
def initialize_model(search_config: TextSearchConfig): def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text" "Initialize model for semantic search on text"
torch.set_num_threads(4) torch.set_num_threads(4)
@@ -46,32 +50,30 @@ def initialize_model(search_config: TextSearchConfig):
return bi_encoder, cross_encoder, top_k return bi_encoder, cross_encoder, top_k
def extract_entries(jsonl_file, verbose=0): def extract_entries(jsonl_file):
"Load entries from compressed jsonl" "Load entries from compressed jsonl"
return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'} return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'}
for entry for entry
in load_jsonl(jsonl_file, verbose=verbose)] in load_jsonl(jsonl_file)]
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0): def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings" "Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
# Load pre-computed embeddings from file if exists # Load pre-computed embeddings from file if exists
if embeddings_file.exists() and not regenerate: if embeddings_file.exists() and not regenerate:
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device) corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
if verbose > 0: logger.info(f"Loaded embeddings from {embeddings_file}")
print(f"Loaded embeddings from {embeddings_file}")
else: # Else compute the corpus_embeddings from scratch, which can take a while else: # Else compute the corpus_embeddings from scratch, which can take a while
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True) corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True)
corpus_embeddings = util.normalize_embeddings(corpus_embeddings) corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
torch.save(corpus_embeddings, embeddings_file) torch.save(corpus_embeddings, embeddings_file)
if verbose > 0: logger.info(f"Computed embeddings and saved them to {embeddings_file}")
print(f"Computed embeddings and saved them to {embeddings_file}")
return corpus_embeddings return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = [], verbose=0): def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = []):
"Search for entries that answer the query" "Search for entries that answer the query"
query = raw_query query = raw_query
@@ -85,16 +87,14 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
corpus_embeddings = model.corpus_embeddings corpus_embeddings = model.corpus_embeddings
entries = model.entries entries = model.entries
end = time.time() end = time.time()
if verbose > 1: logger.debug(f"Copy Time: {end - start:.3f} seconds")
print(f"Copy Time: {end - start:.3f} seconds")
# Filter query, entries and embeddings before semantic search # Filter query, entries and embeddings before semantic search
start = time.time() start = time.time()
for filter in filters_in_query: for filter in filters_in_query:
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings) query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
end = time.time() end = time.time()
if verbose > 1: logger.debug(f"Filter Time: {end - start:.3f} seconds")
print(f"Filter Time: {end - start:.3f} seconds")
if entries is None or len(entries) == 0: if entries is None or len(entries) == 0:
return [], [] return [], []
@@ -104,15 +104,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding) question_embedding = util.normalize_embeddings(question_embedding)
end = time.time() end = time.time()
if verbose > 1: logger.debug(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
print(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
# Find relevant entries for the query # Find relevant entries for the query
start = time.time() start = time.time()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
end = time.time() end = time.time()
if verbose > 1: logger.debug(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
print(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
if rank_results: if rank_results:
@@ -120,8 +118,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits] cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp) cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time() end = time.time()
if verbose > 1: logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
# Store cross-encoder scores in results dictionary for ranking # Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)): for idx in range(len(cross_scores)):
@@ -133,8 +130,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
if rank_results: if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time() end = time.time()
if verbose > 1: logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
print(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
return hits, entries return hits, entries
@@ -167,24 +163,24 @@ def collate_results(hits, entries, count=5):
in hits[0:count]] in hits[0:count]]
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel: def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
# Initialize Model # Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config) bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in text files to (compressed) JSONL formatted file # Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
if not config.compressed_jsonl.exists() or regenerate: if not config.compressed_jsonl.exists() or regenerate:
text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose) text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl)
# Extract Entries # Extract Entries
entries = extract_entries(config.compressed_jsonl, verbose) entries = extract_entries(config.compressed_jsonl)
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
# Compute or Load Embeddings # Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file) config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose) corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose) return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k)
if __name__ == '__main__': if __name__ == '__main__':
@@ -200,7 +196,7 @@ if __name__ == '__main__':
parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0") parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0")
args = parser.parse_args() args = parser.parse_args()
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate, args.verbose) entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate)
# Run User Queries on Entries in Interactive Mode # Run User Queries on Entries in Interactive Mode
while args.interactive: while args.interactive:
@@ -213,4 +209,4 @@ if __name__ == '__main__':
hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k) hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
# render results # render results
render_results(hits, entries, count=args.results_count) render_results(hits, entries, count=args.results_count)

View File

@@ -20,23 +20,21 @@ class ProcessorType(str, Enum):
class TextSearchModel(): class TextSearchModel():
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose): def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k):
self.entries = entries self.entries = entries
self.corpus_embeddings = corpus_embeddings self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder self.bi_encoder = bi_encoder
self.cross_encoder = cross_encoder self.cross_encoder = cross_encoder
self.top_k = top_k self.top_k = top_k
self.verbose = verbose
class ImageSearchModel(): class ImageSearchModel():
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder, verbose): def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
self.image_encoder = image_encoder self.image_encoder = image_encoder
self.image_names = image_names self.image_names = image_names
self.image_embeddings = image_embeddings self.image_embeddings = image_embeddings
self.image_metadata_embeddings = image_metadata_embeddings self.image_metadata_embeddings = image_metadata_embeddings
self.image_encoder = image_encoder self.image_encoder = image_encoder
self.verbose = verbose
@dataclass @dataclass
@@ -49,12 +47,11 @@ class SearchModels():
class ConversationProcessorConfigModel(): class ConversationProcessorConfigModel():
def __init__(self, processor_config: ConversationProcessorConfig, verbose: bool): def __init__(self, processor_config: ConversationProcessorConfig):
self.openai_api_key = processor_config.openai_api_key self.openai_api_key = processor_config.openai_api_key
self.conversation_logfile = Path(processor_config.conversation_logfile) self.conversation_logfile = Path(processor_config.conversation_logfile)
self.chat_session = '' self.chat_session = ''
self.meta_log = [] self.meta_log = []
self.verbose = verbose
@dataclass @dataclass

View File

@@ -1,13 +1,17 @@
# Standard Packages # Standard Packages
import json import json
import gzip import gzip
import logging
# Internal Packages # Internal Packages
from src.utils.constants import empty_escape_sequences from src.utils.constants import empty_escape_sequences
from src.utils.helpers import get_absolute_path from src.utils.helpers import get_absolute_path
def load_jsonl(input_path, verbose=0): logger = logging.getLogger(__name__)
def load_jsonl(input_path):
"Read List of JSON objects from JSON line file" "Read List of JSON objects from JSON line file"
# Initialize Variables # Initialize Variables
data = [] data = []
@@ -27,13 +31,12 @@ def load_jsonl(input_path, verbose=0):
jsonl_file.close() jsonl_file.close()
# Log JSONL entries loaded # Log JSONL entries loaded
if verbose > 0: logger.info(f'Loaded {len(data)} records from {input_path}')
print(f'Loaded {len(data)} records from {input_path}')
return data return data
def dump_jsonl(jsonl_data, output_path, verbose=0): def dump_jsonl(jsonl_data, output_path):
"Write List of JSON objects to JSON line file" "Write List of JSON objects to JSON line file"
# Create output directory, if it doesn't exist # Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
@@ -41,16 +44,14 @@ def dump_jsonl(jsonl_data, output_path, verbose=0):
with open(output_path, 'w', encoding='utf-8') as f: with open(output_path, 'w', encoding='utf-8') as f:
f.write(jsonl_data) f.write(jsonl_data)
if verbose > 0: logger.info(f'Wrote {len(jsonl_data)} lines to jsonl at {output_path}')
print(f'Wrote {len(jsonl_data)} lines to jsonl at {output_path}')
def compress_jsonl_data(jsonl_data, output_path, verbose=0): def compress_jsonl_data(jsonl_data, output_path):
# Create output directory, if it doesn't exist # Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
with gzip.open(output_path, 'wt') as gzip_file: with gzip.open(output_path, 'wt') as gzip_file:
gzip_file.write(jsonl_data) gzip_file.write(jsonl_data)
if verbose > 0: logger.info(f'Wrote {len(jsonl_data)} lines to gzip compressed jsonl at {output_path}')
print(f'Wrote {len(jsonl_data)} lines to gzip compressed jsonl at {output_path}')