mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Use python standard logging framework for app logs
- Stop passing verbose flag around app methods - Minor remap of verbosity levels to match python logging framework levels - verbose = 0 maps to logging.WARN - verbose = 1 maps to logging.INFO - verbose >=2 maps to logging.DEBUG - Minor clean-up of app: unused modules, conversation file opening
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
# System Packages
|
||||
import sys
|
||||
import logging
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
import json
|
||||
|
||||
# Internal Packages
|
||||
@@ -12,10 +12,13 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||
from src.search_type import image_search, text_search
|
||||
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
||||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path
|
||||
from src.utils.helpers import resolve_absolute_path
|
||||
from src.utils.rawconfig import FullConfig, ProcessorConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def configure_server(args, required=False):
|
||||
if args.config is None:
|
||||
if required:
|
||||
@@ -27,42 +30,42 @@ def configure_server(args, required=False):
|
||||
state.config = args.config
|
||||
|
||||
# Initialize the search model from Config
|
||||
state.model = configure_search(state.model, state.config, args.regenerate, verbose=state.verbose)
|
||||
state.model = configure_search(state.model, state.config, args.regenerate)
|
||||
|
||||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(args.config.processor, verbose=state.verbose)
|
||||
state.processor_config = configure_processor(args.config.processor)
|
||||
|
||||
|
||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, verbose: int = 0):
|
||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None):
|
||||
# Initialize Org Notes Search
|
||||
if (t == SearchType.Org or t == None) and config.content_type.org:
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
||||
|
||||
# Initialize Org Music Search
|
||||
if (t == SearchType.Music or t == None) and config.content_type.music:
|
||||
# Extract Entries, Generate Music Embeddings
|
||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
||||
|
||||
# Initialize Markdown Search
|
||||
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
||||
|
||||
# Initialize Ledger Search
|
||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||
# Extract Entries, Generate Ledger Embeddings
|
||||
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose)
|
||||
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate)
|
||||
|
||||
# Initialize Image Search
|
||||
if (t == SearchType.Image or t == None) and config.content_type.image:
|
||||
# Extract Entries, Generate Image Embeddings
|
||||
model.image_search = image_search.setup(config.content_type.image, search_config=config.search_type.image, regenerate=regenerate, verbose=verbose)
|
||||
model.image_search = image_search.setup(config.content_type.image, search_config=config.search_type.image, regenerate=regenerate)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def configure_processor(processor_config: ProcessorConfig, verbose: int):
|
||||
def configure_processor(processor_config: ProcessorConfig):
|
||||
if not processor_config:
|
||||
return
|
||||
|
||||
@@ -70,27 +73,23 @@ def configure_processor(processor_config: ProcessorConfig, verbose: int):
|
||||
|
||||
# Initialize Conversation Processor
|
||||
if processor_config.conversation:
|
||||
processor.conversation = configure_conversation_processor(processor_config.conversation, verbose)
|
||||
processor.conversation = configure_conversation_processor(processor_config.conversation)
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
def configure_conversation_processor(conversation_processor_config, verbose: int):
|
||||
conversation_processor = ConversationProcessorConfigModel(conversation_processor_config, verbose)
|
||||
def configure_conversation_processor(conversation_processor_config):
|
||||
conversation_processor = ConversationProcessorConfigModel(conversation_processor_config)
|
||||
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
|
||||
|
||||
conversation_logfile = conversation_processor.conversation_logfile
|
||||
if conversation_processor.verbose:
|
||||
print('INFO:\tLoading conversation logs from disk...')
|
||||
|
||||
if conversation_logfile.expanduser().absolute().is_file():
|
||||
if conversation_logfile.is_file():
|
||||
# Load Metadata Logs from Conversation Logfile
|
||||
with open(get_absolute_path(conversation_logfile), 'r') as f:
|
||||
with conversation_logfile.open('r') as f:
|
||||
conversation_processor.meta_log = json.load(f)
|
||||
|
||||
print('INFO:\tConversation logs loaded from disk.')
|
||||
logger.info('Conversation logs loaded from disk.')
|
||||
else:
|
||||
# Initialize Conversation Logs
|
||||
conversation_processor.meta_log = {}
|
||||
conversation_processor.chat_session = ""
|
||||
|
||||
return conversation_processor
|
||||
return conversation_processor
|
||||
|
||||
42
src/main.py
42
src/main.py
@@ -2,6 +2,7 @@
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import logging
|
||||
from platform import system
|
||||
|
||||
# External Packages
|
||||
@@ -25,6 +26,33 @@ app = FastAPI()
|
||||
app.mount("/static", StaticFiles(directory=constants.web_directory), name="static")
|
||||
app.include_router(router)
|
||||
|
||||
logger = logging.getLogger('src')
|
||||
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
|
||||
blue = "\x1b[1;34m"
|
||||
green = "\x1b[1;32m"
|
||||
grey = "\x1b[38;20m"
|
||||
yellow = "\x1b[33;20m"
|
||||
red = "\x1b[31;20m"
|
||||
bold_red = "\x1b[31;1m"
|
||||
reset = "\x1b[0m"
|
||||
format = "%(levelname)s: %(asctime)s: %(name)s | %(message)s"
|
||||
|
||||
FORMATS = {
|
||||
logging.DEBUG: blue + format + reset,
|
||||
logging.INFO: green + format + reset,
|
||||
logging.WARNING: yellow + format + reset,
|
||||
logging.ERROR: red + format + reset,
|
||||
logging.CRITICAL: bold_red + format + reset
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt)
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
def run():
|
||||
# Turn Tokenizers Parallelism Off. App does not support it.
|
||||
@@ -33,6 +61,20 @@ def run():
|
||||
# Load config from CLI
|
||||
state.cli_args = sys.argv[1:]
|
||||
args = cli(state.cli_args)
|
||||
|
||||
# Setup Logger
|
||||
# logging.basicConfig(format='%(levelname)s: %(asctime)s : %(module)s | %(message)s', datefmt='%d-%m-%Y %H:%M:%S')
|
||||
ch = logging.StreamHandler()
|
||||
ch.setFormatter(CustomFormatter())
|
||||
logger.addHandler(ch)
|
||||
if args.verbose == 0:
|
||||
logger.setLevel(logging.WARN)
|
||||
elif args.verbose == 1:
|
||||
logger.setLevel(logging.INFO)
|
||||
elif args.verbose >= 2:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
logger.info("Starting Khoj...")
|
||||
set_state(args)
|
||||
|
||||
if args.no_gui:
|
||||
|
||||
@@ -6,6 +6,7 @@ import argparse
|
||||
import pathlib
|
||||
import glob
|
||||
import re
|
||||
import logging
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
@@ -13,32 +14,35 @@ from src.utils.constants import empty_escape_sequences
|
||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define Functions
|
||||
def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file, verbose=0):
|
||||
def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file):
|
||||
# Input Validation
|
||||
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
|
||||
print("At least one of beancount-files or beancount-file-filter is required to be specified")
|
||||
exit(1)
|
||||
|
||||
# Get Beancount Files to Process
|
||||
beancount_files = get_beancount_files(beancount_files, beancount_file_filter, verbose)
|
||||
beancount_files = get_beancount_files(beancount_files, beancount_file_filter)
|
||||
|
||||
# Extract Entries from specified Beancount files
|
||||
entries = extract_beancount_entries(beancount_files)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_data = convert_beancount_entries_to_jsonl(entries, verbose=verbose)
|
||||
jsonl_data = convert_beancount_entries_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file, verbose=verbose)
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file, verbose=verbose)
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def get_beancount_files(beancount_files=None, beancount_file_filter=None, verbose=0):
|
||||
def get_beancount_files(beancount_files=None, beancount_file_filter=None):
|
||||
"Get Beancount files to process"
|
||||
absolute_beancount_files, filtered_beancount_files = set(), set()
|
||||
if beancount_files:
|
||||
@@ -57,8 +61,7 @@ def get_beancount_files(beancount_files=None, beancount_file_filter=None, verbos
|
||||
if any(files_with_non_beancount_extensions):
|
||||
print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}")
|
||||
|
||||
if verbose > 0:
|
||||
print(f'Processing files: {all_beancount_files}')
|
||||
logger.info(f'Processing files: {all_beancount_files}')
|
||||
|
||||
return all_beancount_files
|
||||
|
||||
@@ -82,7 +85,7 @@ def extract_beancount_entries(beancount_files):
|
||||
return entries
|
||||
|
||||
|
||||
def convert_beancount_entries_to_jsonl(entries, verbose=0):
|
||||
def convert_beancount_entries_to_jsonl(entries):
|
||||
"Convert each Beancount transaction to JSON and collate as JSONL"
|
||||
jsonl = ''
|
||||
for entry in entries:
|
||||
@@ -90,8 +93,7 @@ def convert_beancount_entries_to_jsonl(entries, verbose=0):
|
||||
# Convert Dictionary to JSON and Append to JSONL string
|
||||
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
|
||||
|
||||
if verbose > 0:
|
||||
print(f"Converted {len(entries)} to jsonl format")
|
||||
logger.info(f"Converted {len(entries)} to jsonl format")
|
||||
|
||||
return jsonl
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import argparse
|
||||
import pathlib
|
||||
import glob
|
||||
import re
|
||||
import logging
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
@@ -13,32 +14,35 @@ from src.utils.constants import empty_escape_sequences
|
||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define Functions
|
||||
def markdown_to_jsonl(markdown_files, markdown_file_filter, output_file, verbose=0):
|
||||
def markdown_to_jsonl(markdown_files, markdown_file_filter, output_file):
|
||||
# Input Validation
|
||||
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
|
||||
print("At least one of markdown-files or markdown-file-filter is required to be specified")
|
||||
exit(1)
|
||||
|
||||
# Get Markdown Files to Process
|
||||
markdown_files = get_markdown_files(markdown_files, markdown_file_filter, verbose)
|
||||
markdown_files = get_markdown_files(markdown_files, markdown_file_filter)
|
||||
|
||||
# Extract Entries from specified Markdown files
|
||||
entries = extract_markdown_entries(markdown_files)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_data = convert_markdown_entries_to_jsonl(entries, verbose=verbose)
|
||||
jsonl_data = convert_markdown_entries_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file, verbose=verbose)
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file, verbose=verbose)
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def get_markdown_files(markdown_files=None, markdown_file_filter=None, verbose=0):
|
||||
def get_markdown_files(markdown_files=None, markdown_file_filter=None):
|
||||
"Get Markdown files to process"
|
||||
absolute_markdown_files, filtered_markdown_files = set(), set()
|
||||
if markdown_files:
|
||||
@@ -56,10 +60,9 @@ def get_markdown_files(markdown_files=None, markdown_file_filter=None, verbose=0
|
||||
}
|
||||
|
||||
if any(files_with_non_markdown_extensions):
|
||||
print(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
|
||||
logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
|
||||
|
||||
if verbose > 0:
|
||||
print(f'Processing files: {all_markdown_files}')
|
||||
logger.info(f'Processing files: {all_markdown_files}')
|
||||
|
||||
return all_markdown_files
|
||||
|
||||
@@ -81,7 +84,7 @@ def extract_markdown_entries(markdown_files):
|
||||
return entries
|
||||
|
||||
|
||||
def convert_markdown_entries_to_jsonl(entries, verbose=0):
|
||||
def convert_markdown_entries_to_jsonl(entries):
|
||||
"Convert each Markdown entries to JSON and collate as JSONL"
|
||||
jsonl = ''
|
||||
for entry in entries:
|
||||
@@ -89,8 +92,7 @@ def convert_markdown_entries_to_jsonl(entries, verbose=0):
|
||||
# Convert Dictionary to JSON and Append to JSONL string
|
||||
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
|
||||
|
||||
if verbose > 0:
|
||||
print(f"Converted {len(entries)} to jsonl format")
|
||||
logger.info(f"Converted {len(entries)} to jsonl format")
|
||||
|
||||
return jsonl
|
||||
|
||||
|
||||
@@ -6,40 +6,43 @@ import json
|
||||
import argparse
|
||||
import pathlib
|
||||
import glob
|
||||
import logging
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.org_mode import orgnode
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
from src.utils.constants import empty_escape_sequences
|
||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define Functions
|
||||
def org_to_jsonl(org_files, org_file_filter, output_file, verbose=0):
|
||||
def org_to_jsonl(org_files, org_file_filter, output_file):
|
||||
# Input Validation
|
||||
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter):
|
||||
print("At least one of org-files or org-file-filter is required to be specified")
|
||||
exit(1)
|
||||
|
||||
# Get Org Files to Process
|
||||
org_files = get_org_files(org_files, org_file_filter, verbose)
|
||||
org_files = get_org_files(org_files, org_file_filter)
|
||||
|
||||
# Extract Entries from specified Org files
|
||||
entries = extract_org_entries(org_files)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_data = convert_org_entries_to_jsonl(entries, verbose=verbose)
|
||||
jsonl_data = convert_org_entries_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file, verbose=verbose)
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file, verbose=verbose)
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def get_org_files(org_files=None, org_file_filter=None, verbose=0):
|
||||
def get_org_files(org_files=None, org_file_filter=None):
|
||||
"Get Org files to process"
|
||||
absolute_org_files, filtered_org_files = set(), set()
|
||||
if org_files:
|
||||
@@ -53,10 +56,9 @@ def get_org_files(org_files=None, org_file_filter=None, verbose=0):
|
||||
|
||||
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
|
||||
if any(files_with_non_org_extensions):
|
||||
print(f"[Warning] There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
|
||||
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
|
||||
|
||||
if verbose > 0:
|
||||
print(f'Processing files: {all_org_files}')
|
||||
logger.info(f'Processing files: {all_org_files}')
|
||||
|
||||
return all_org_files
|
||||
|
||||
@@ -72,7 +74,7 @@ def extract_org_entries(org_files):
|
||||
return entries
|
||||
|
||||
|
||||
def convert_org_entries_to_jsonl(entries, verbose=0) -> str:
|
||||
def convert_org_entries_to_jsonl(entries) -> str:
|
||||
"Convert each Org-Mode entries to JSON and collate as JSONL"
|
||||
jsonl = ''
|
||||
for entry in entries:
|
||||
@@ -83,29 +85,24 @@ def convert_org_entries_to_jsonl(entries, verbose=0) -> str:
|
||||
continue
|
||||
|
||||
entry_dict["compiled"] = f'{entry.Heading()}.'
|
||||
if verbose > 2:
|
||||
print(f"Title: {entry.Heading()}")
|
||||
logger.debug(f"Title: {entry.Heading()}")
|
||||
|
||||
if entry.Tags():
|
||||
tags_str = " ".join(entry.Tags())
|
||||
entry_dict["compiled"] += f'\t {tags_str}.'
|
||||
if verbose > 2:
|
||||
print(f"Tags: {tags_str}")
|
||||
logger.debug(f"Tags: {tags_str}")
|
||||
|
||||
if entry.Closed():
|
||||
entry_dict["compiled"] += f'\n Closed on {entry.Closed().strftime("%Y-%m-%d")}.'
|
||||
if verbose > 2:
|
||||
print(f'Closed: {entry.Closed().strftime("%Y-%m-%d")}')
|
||||
logger.debug(f'Closed: {entry.Closed().strftime("%Y-%m-%d")}')
|
||||
|
||||
if entry.Scheduled():
|
||||
entry_dict["compiled"] += f'\n Scheduled for {entry.Scheduled().strftime("%Y-%m-%d")}.'
|
||||
if verbose > 2:
|
||||
print(f'Scheduled: {entry.Scheduled().strftime("%Y-%m-%d")}')
|
||||
logger.debug(f'Scheduled: {entry.Scheduled().strftime("%Y-%m-%d")}')
|
||||
|
||||
if entry.Body():
|
||||
entry_dict["compiled"] += f'\n {entry.Body()}'
|
||||
if verbose > 2:
|
||||
print(f"Body: {entry.Body()}")
|
||||
logger.debug(f"Body: {entry.Body()}")
|
||||
|
||||
if entry_dict:
|
||||
entry_dict["raw"] = f'{entry}'
|
||||
@@ -113,8 +110,7 @@ def convert_org_entries_to_jsonl(entries, verbose=0) -> str:
|
||||
# Convert Dictionary to JSON and Append to JSONL string
|
||||
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
|
||||
|
||||
if verbose > 0:
|
||||
print(f"Converted {len(entries)} to jsonl format")
|
||||
logger.info(f"Converted {len(entries)} to jsonl format")
|
||||
|
||||
return jsonl
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import yaml
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional
|
||||
from functools import lru_cache
|
||||
|
||||
@@ -22,9 +23,11 @@ from src.utils.config import SearchType
|
||||
from src.utils.helpers import get_absolute_path, get_from_dict
|
||||
from src.utils import state, constants
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory=constants.web_directory)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/", response_class=FileResponse)
|
||||
def index():
|
||||
@@ -50,7 +53,7 @@ async def config_data(updated_config: FullConfig):
|
||||
@lru_cache(maxsize=100)
|
||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
||||
if q is None or q == '':
|
||||
print(f'No query param (q) passed in API call to initiate search')
|
||||
logger.info(f'No query param (q) passed in API call to initiate search')
|
||||
return {}
|
||||
|
||||
# initialize variables
|
||||
@@ -120,11 +123,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||
count=results_count)
|
||||
collate_end = time.time()
|
||||
|
||||
if state.verbose > 1:
|
||||
if query_start and query_end:
|
||||
print(f"Query took {query_end - query_start:.3f} seconds")
|
||||
if collate_start and collate_end:
|
||||
print(f"Collating results took {collate_end - collate_start:.3f} seconds")
|
||||
if query_start and query_end:
|
||||
logger.debug(f"Query took {query_end - query_start:.3f} seconds")
|
||||
if collate_start and collate_end:
|
||||
logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import pathlib
|
||||
import copy
|
||||
import shutil
|
||||
import time
|
||||
import logging
|
||||
|
||||
# External Packages
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
@@ -17,7 +18,10 @@ from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_mod
|
||||
import src.utils.exiftool as exiftool
|
||||
from src.utils.config import ImageSearchModel
|
||||
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig
|
||||
from src.utils import state
|
||||
|
||||
|
||||
# Create Logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_model(search_config: ImageSearchConfig):
|
||||
@@ -39,34 +43,33 @@ def initialize_model(search_config: ImageSearchConfig):
|
||||
return encoder
|
||||
|
||||
|
||||
def extract_entries(image_directories, verbose=0):
|
||||
def extract_entries(image_directories):
|
||||
image_names = []
|
||||
for image_directory in image_directories:
|
||||
image_directory = resolve_absolute_path(image_directory, strict=True)
|
||||
image_names.extend(list(image_directory.glob('*.jpg')))
|
||||
image_names.extend(list(image_directory.glob('*.jpeg')))
|
||||
|
||||
if verbose > 0:
|
||||
if logger.level >= logging.INFO:
|
||||
image_directory_names = ', '.join([str(image_directory) for image_directory in image_directories])
|
||||
print(f'Found {len(image_names)} images in {image_directory_names}')
|
||||
logger.info(f'Found {len(image_names)} images in {image_directory_names}')
|
||||
return sorted(image_names)
|
||||
|
||||
|
||||
def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0):
|
||||
def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
|
||||
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate, verbose)
|
||||
image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate, verbose)
|
||||
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate)
|
||||
image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate)
|
||||
|
||||
return image_embeddings, image_metadata_embeddings
|
||||
|
||||
|
||||
def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False, verbose=0):
|
||||
def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False):
|
||||
# Load pre-computed image embeddings from file if exists
|
||||
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
|
||||
image_embeddings = torch.load(embeddings_file)
|
||||
if verbose > 0:
|
||||
print(f"Loaded pre-computed embeddings from {embeddings_file}")
|
||||
logger.info(f"Loaded {len(image_embeddings)} image embeddings from {embeddings_file}")
|
||||
# Else compute the image embeddings from scratch, which can take a while
|
||||
else:
|
||||
image_embeddings = []
|
||||
@@ -87,8 +90,7 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
|
||||
|
||||
# Save computed image embeddings to file
|
||||
torch.save(image_embeddings, embeddings_file)
|
||||
if verbose > 0:
|
||||
print(f"Saved computed embeddings to {embeddings_file}")
|
||||
logger.info(f"Saved computed embeddings to {embeddings_file}")
|
||||
|
||||
return image_embeddings
|
||||
|
||||
@@ -99,8 +101,7 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
|
||||
# Load pre-computed image metadata embedding file if exists
|
||||
if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate:
|
||||
image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata")
|
||||
if verbose > 0:
|
||||
print(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
|
||||
logger.info(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
|
||||
|
||||
# Else compute the image metadata embeddings from scratch, which can take a while
|
||||
if use_xmp_metadata and image_metadata_embeddings is None:
|
||||
@@ -113,16 +114,15 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
|
||||
convert_to_tensor=True,
|
||||
batch_size=min(len(image_metadata), batch_size))
|
||||
except RuntimeError as e:
|
||||
print(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
|
||||
logger.error(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
|
||||
continue
|
||||
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
|
||||
if verbose > 0:
|
||||
print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
|
||||
logger.info(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
|
||||
|
||||
return image_metadata_embeddings
|
||||
|
||||
|
||||
def extract_metadata(image_name, verbose=0):
|
||||
def extract_metadata(image_name):
|
||||
with exiftool.ExifTool() as et:
|
||||
image_metadata = et.get_tags(["XMP:Subject", "XMP:Description"], str(image_name))
|
||||
image_metadata_subjects = set([subject.split(":")[1] for subject in image_metadata.get("XMP:Subject", "") if ":" in subject])
|
||||
@@ -131,8 +131,7 @@ def extract_metadata(image_name, verbose=0):
|
||||
if len(image_metadata_subjects) > 0:
|
||||
image_processed_metadata += ". " + ", ".join(image_metadata_subjects)
|
||||
|
||||
if verbose > 2:
|
||||
print(f"{image_name}:\t{image_processed_metadata}")
|
||||
logger.debug(f"{image_name}:\t{image_processed_metadata}")
|
||||
|
||||
return image_processed_metadata
|
||||
|
||||
@@ -143,19 +142,16 @@ def query(raw_query, count, model: ImageSearchModel):
|
||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query), strict=True)
|
||||
query = copy.deepcopy(Image.open(query_imagepath))
|
||||
query.thumbnail((640, query.height)) # scale down image for faster processing
|
||||
if model.verbose > 0:
|
||||
print(f"Find Images similar to Image at {query_imagepath}")
|
||||
logger.info(f"Find Images similar to Image at {query_imagepath}")
|
||||
else:
|
||||
query = raw_query
|
||||
if state.verbose > 0:
|
||||
print(f"Find Images by Text: {query}")
|
||||
logger.info(f"Find Images by Text: {query}")
|
||||
|
||||
# Now we encode the query (which can either be an image or a text string)
|
||||
start = time.time()
|
||||
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||
end = time.time()
|
||||
if state.verbose > 1:
|
||||
print(f"Query Encode Time: {end - start:.3f} seconds")
|
||||
logger.debug(f"Query Encode Time: {end - start:.3f} seconds")
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||
start = time.time()
|
||||
@@ -163,8 +159,7 @@ def query(raw_query, count, model: ImageSearchModel):
|
||||
for result
|
||||
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
|
||||
end = time.time()
|
||||
if state.verbose > 1:
|
||||
print(f"Search Time: {end - start:.3f} seconds")
|
||||
logger.debug(f"Search Time: {end - start:.3f} seconds")
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
||||
if model.image_metadata_embeddings:
|
||||
@@ -173,8 +168,7 @@ def query(raw_query, count, model: ImageSearchModel):
|
||||
for result
|
||||
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
|
||||
end = time.time()
|
||||
if state.verbose > 1:
|
||||
print(f"Metadata Search Time: {end - start:.3f} seconds")
|
||||
logger.debug(f"Metadata Search Time: {end - start:.3f} seconds")
|
||||
|
||||
# Sum metadata, image scores of the highest ranked images
|
||||
for corpus_id, score in metadata_hits.items():
|
||||
@@ -237,7 +231,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
||||
return results
|
||||
|
||||
|
||||
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel:
|
||||
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
|
||||
# Initialize Model
|
||||
encoder = initialize_model(search_config)
|
||||
|
||||
@@ -245,7 +239,7 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
|
||||
absolute_image_files, filtered_image_files = set(), set()
|
||||
if config.input_directories:
|
||||
image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories]
|
||||
absolute_image_files = set(extract_entries(image_directories, verbose))
|
||||
absolute_image_files = set(extract_entries(image_directories))
|
||||
if config.input_filter:
|
||||
filtered_image_files = set(glob.glob(get_absolute_path(config.input_filter)))
|
||||
|
||||
@@ -259,14 +253,12 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
|
||||
embeddings_file,
|
||||
batch_size=config.batch_size,
|
||||
regenerate=regenerate,
|
||||
use_xmp_metadata=config.use_xmp_metadata,
|
||||
verbose=verbose)
|
||||
use_xmp_metadata=config.use_xmp_metadata)
|
||||
|
||||
return ImageSearchModel(all_image_files,
|
||||
image_embeddings,
|
||||
image_metadata_embeddings,
|
||||
encoder,
|
||||
verbose)
|
||||
encoder)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# Standard Packages
|
||||
import argparse
|
||||
import pathlib
|
||||
from copy import deepcopy
|
||||
import logging
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
@@ -16,6 +17,9 @@ from src.utils.rawconfig import TextSearchConfig, TextContentConfig
|
||||
from src.utils.jsonl import load_jsonl
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_model(search_config: TextSearchConfig):
|
||||
"Initialize model for semantic search on text"
|
||||
torch.set_num_threads(4)
|
||||
@@ -46,32 +50,30 @@ def initialize_model(search_config: TextSearchConfig):
|
||||
return bi_encoder, cross_encoder, top_k
|
||||
|
||||
|
||||
def extract_entries(jsonl_file, verbose=0):
|
||||
def extract_entries(jsonl_file):
|
||||
"Load entries from compressed jsonl"
|
||||
return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'}
|
||||
for entry
|
||||
in load_jsonl(jsonl_file, verbose=verbose)]
|
||||
in load_jsonl(jsonl_file)]
|
||||
|
||||
|
||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0):
|
||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
# Load pre-computed embeddings from file if exists
|
||||
if embeddings_file.exists() and not regenerate:
|
||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||
if verbose > 0:
|
||||
print(f"Loaded embeddings from {embeddings_file}")
|
||||
logger.info(f"Loaded embeddings from {embeddings_file}")
|
||||
|
||||
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
||||
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||
torch.save(corpus_embeddings, embeddings_file)
|
||||
if verbose > 0:
|
||||
print(f"Computed embeddings and saved them to {embeddings_file}")
|
||||
logger.info(f"Computed embeddings and saved them to {embeddings_file}")
|
||||
|
||||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = [], verbose=0):
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = []):
|
||||
"Search for entries that answer the query"
|
||||
query = raw_query
|
||||
|
||||
@@ -85,16 +87,14 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
|
||||
corpus_embeddings = model.corpus_embeddings
|
||||
entries = model.entries
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Copy Time: {end - start:.3f} seconds")
|
||||
logger.debug(f"Copy Time: {end - start:.3f} seconds")
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
start = time.time()
|
||||
for filter in filters_in_query:
|
||||
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Filter Time: {end - start:.3f} seconds")
|
||||
logger.debug(f"Filter Time: {end - start:.3f} seconds")
|
||||
|
||||
if entries is None or len(entries) == 0:
|
||||
return [], []
|
||||
@@ -104,15 +104,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
logger.debug(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
# Find relevant entries for the query
|
||||
start = time.time()
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
logger.debug(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
@@ -120,8 +118,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
|
||||
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
@@ -133,8 +130,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
|
||||
if rank_results:
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
return hits, entries
|
||||
|
||||
@@ -167,24 +163,24 @@ def collate_results(hits, entries, count=5):
|
||||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel:
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
# Map notes in text files to (compressed) JSONL formatted file
|
||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||
if not config.compressed_jsonl.exists() or regenerate:
|
||||
text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose)
|
||||
text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl)
|
||||
|
||||
# Extract Entries
|
||||
entries = extract_entries(config.compressed_jsonl, verbose)
|
||||
entries = extract_entries(config.compressed_jsonl)
|
||||
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
||||
|
||||
# Compute or Load Embeddings
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -200,7 +196,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0")
|
||||
args = parser.parse_args()
|
||||
|
||||
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate, args.verbose)
|
||||
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate)
|
||||
|
||||
# Run User Queries on Entries in Interactive Mode
|
||||
while args.interactive:
|
||||
@@ -213,4 +209,4 @@ if __name__ == '__main__':
|
||||
hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
|
||||
|
||||
# render results
|
||||
render_results(hits, entries, count=args.results_count)
|
||||
render_results(hits, entries, count=args.results_count)
|
||||
|
||||
@@ -20,23 +20,21 @@ class ProcessorType(str, Enum):
|
||||
|
||||
|
||||
class TextSearchModel():
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose):
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k):
|
||||
self.entries = entries
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
self.bi_encoder = bi_encoder
|
||||
self.cross_encoder = cross_encoder
|
||||
self.top_k = top_k
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
class ImageSearchModel():
|
||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder, verbose):
|
||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
|
||||
self.image_encoder = image_encoder
|
||||
self.image_names = image_names
|
||||
self.image_embeddings = image_embeddings
|
||||
self.image_metadata_embeddings = image_metadata_embeddings
|
||||
self.image_encoder = image_encoder
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -49,12 +47,11 @@ class SearchModels():
|
||||
|
||||
|
||||
class ConversationProcessorConfigModel():
|
||||
def __init__(self, processor_config: ConversationProcessorConfig, verbose: bool):
|
||||
def __init__(self, processor_config: ConversationProcessorConfig):
|
||||
self.openai_api_key = processor_config.openai_api_key
|
||||
self.conversation_logfile = Path(processor_config.conversation_logfile)
|
||||
self.chat_session = ''
|
||||
self.meta_log = []
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
# Standard Packages
|
||||
import json
|
||||
import gzip
|
||||
import logging
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.constants import empty_escape_sequences
|
||||
from src.utils.helpers import get_absolute_path
|
||||
|
||||
|
||||
def load_jsonl(input_path, verbose=0):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_jsonl(input_path):
|
||||
"Read List of JSON objects from JSON line file"
|
||||
# Initialize Variables
|
||||
data = []
|
||||
@@ -27,13 +31,12 @@ def load_jsonl(input_path, verbose=0):
|
||||
jsonl_file.close()
|
||||
|
||||
# Log JSONL entries loaded
|
||||
if verbose > 0:
|
||||
print(f'Loaded {len(data)} records from {input_path}')
|
||||
logger.info(f'Loaded {len(data)} records from {input_path}')
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def dump_jsonl(jsonl_data, output_path, verbose=0):
|
||||
def dump_jsonl(jsonl_data, output_path):
|
||||
"Write List of JSON objects to JSON line file"
|
||||
# Create output directory, if it doesn't exist
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -41,16 +44,14 @@ def dump_jsonl(jsonl_data, output_path, verbose=0):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(jsonl_data)
|
||||
|
||||
if verbose > 0:
|
||||
print(f'Wrote {len(jsonl_data)} lines to jsonl at {output_path}')
|
||||
logger.info(f'Wrote {len(jsonl_data)} lines to jsonl at {output_path}')
|
||||
|
||||
|
||||
def compress_jsonl_data(jsonl_data, output_path, verbose=0):
|
||||
def compress_jsonl_data(jsonl_data, output_path):
|
||||
# Create output directory, if it doesn't exist
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with gzip.open(output_path, 'wt') as gzip_file:
|
||||
gzip_file.write(jsonl_data)
|
||||
|
||||
if verbose > 0:
|
||||
print(f'Wrote {len(jsonl_data)} lines to gzip compressed jsonl at {output_path}')
|
||||
logger.info(f'Wrote {len(jsonl_data)} lines to gzip compressed jsonl at {output_path}')
|
||||
Reference in New Issue
Block a user