Use python standard logging framework for app logs

- Stop passing verbose flag around app methods
- Minor remap of verbosity levels to match python logging framework levels
  - verbose = 0 maps to logging.WARN
  - verbose = 1 maps to logging.INFO
  - verbose >=2 maps to logging.DEBUG
- Minor clean-up of app: unused modules, conversation file opening
This commit is contained in:
Debanjum Singh Solanky
2022-09-03 14:43:32 +03:00
parent d0531c3064
commit 094bd18e57
10 changed files with 184 additions and 155 deletions

View File

@@ -1,8 +1,8 @@
# System Packages
import sys
import logging
# External Packages
import torch
import json
# Internal Packages
@@ -12,10 +12,13 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.search_type import image_search, text_search
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from src.utils import state
from src.utils.helpers import get_absolute_path
from src.utils.helpers import resolve_absolute_path
from src.utils.rawconfig import FullConfig, ProcessorConfig
logger = logging.getLogger(__name__)
def configure_server(args, required=False):
if args.config is None:
if required:
@@ -27,42 +30,42 @@ def configure_server(args, required=False):
state.config = args.config
# Initialize the search model from Config
state.model = configure_search(state.model, state.config, args.regenerate, verbose=state.verbose)
state.model = configure_search(state.model, state.config, args.regenerate)
# Initialize Processor from Config
state.processor_config = configure_processor(args.config.processor, verbose=state.verbose)
state.processor_config = configure_processor(args.config.processor)
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, verbose: int = 0):
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None):
# Initialize Org Notes Search
if (t == SearchType.Org or t == None) and config.content_type.org:
# Extract Entries, Generate Notes Embeddings
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate)
# Initialize Org Music Search
if (t == SearchType.Music or t == None) and config.content_type.music:
# Extract Entries, Generate Music Embeddings
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate)
# Initialize Markdown Search
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
# Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate)
# Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
# Extract Entries, Generate Ledger Embeddings
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose)
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate)
# Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image:
# Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup(config.content_type.image, search_config=config.search_type.image, regenerate=regenerate, verbose=verbose)
model.image_search = image_search.setup(config.content_type.image, search_config=config.search_type.image, regenerate=regenerate)
return model
def configure_processor(processor_config: ProcessorConfig, verbose: int):
def configure_processor(processor_config: ProcessorConfig):
if not processor_config:
return
@@ -70,27 +73,23 @@ def configure_processor(processor_config: ProcessorConfig, verbose: int):
# Initialize Conversation Processor
if processor_config.conversation:
processor.conversation = configure_conversation_processor(processor_config.conversation, verbose)
processor.conversation = configure_conversation_processor(processor_config.conversation)
return processor
def configure_conversation_processor(conversation_processor_config, verbose: int):
conversation_processor = ConversationProcessorConfigModel(conversation_processor_config, verbose)
def configure_conversation_processor(conversation_processor_config):
conversation_processor = ConversationProcessorConfigModel(conversation_processor_config)
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
conversation_logfile = conversation_processor.conversation_logfile
if conversation_processor.verbose:
print('INFO:\tLoading conversation logs from disk...')
if conversation_logfile.expanduser().absolute().is_file():
if conversation_logfile.is_file():
# Load Metadata Logs from Conversation Logfile
with open(get_absolute_path(conversation_logfile), 'r') as f:
with conversation_logfile.open('r') as f:
conversation_processor.meta_log = json.load(f)
print('INFO:\tConversation logs loaded from disk.')
logger.info('Conversation logs loaded from disk.')
else:
# Initialize Conversation Logs
conversation_processor.meta_log = {}
conversation_processor.chat_session = ""
return conversation_processor
return conversation_processor

View File

@@ -2,6 +2,7 @@
import os
import signal
import sys
import logging
from platform import system
# External Packages
@@ -25,6 +26,33 @@ app = FastAPI()
app.mount("/static", StaticFiles(directory=constants.web_directory), name="static")
app.include_router(router)
logger = logging.getLogger('src')
class CustomFormatter(logging.Formatter):
blue = "\x1b[1;34m"
green = "\x1b[1;32m"
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
bold_red = "\x1b[31;1m"
reset = "\x1b[0m"
format = "%(levelname)s: %(asctime)s: %(name)s | %(message)s"
FORMATS = {
logging.DEBUG: blue + format + reset,
logging.INFO: green + format + reset,
logging.WARNING: yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
def run():
# Turn Tokenizers Parallelism Off. App does not support it.
@@ -33,6 +61,20 @@ def run():
# Load config from CLI
state.cli_args = sys.argv[1:]
args = cli(state.cli_args)
# Setup Logger
# logging.basicConfig(format='%(levelname)s: %(asctime)s : %(module)s | %(message)s', datefmt='%d-%m-%Y %H:%M:%S')
ch = logging.StreamHandler()
ch.setFormatter(CustomFormatter())
logger.addHandler(ch)
if args.verbose == 0:
logger.setLevel(logging.WARN)
elif args.verbose == 1:
logger.setLevel(logging.INFO)
elif args.verbose >= 2:
logger.setLevel(logging.DEBUG)
logger.info("Starting Khoj...")
set_state(args)
if args.no_gui:

View File

@@ -6,6 +6,7 @@ import argparse
import pathlib
import glob
import re
import logging
# Internal Packages
from src.utils.helpers import get_absolute_path, is_none_or_empty
@@ -13,32 +14,35 @@ from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
logger = logging.getLogger(__name__)
# Define Functions
def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file, verbose=0):
def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file):
# Input Validation
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
print("At least one of beancount-files or beancount-file-filter is required to be specified")
exit(1)
# Get Beancount Files to Process
beancount_files = get_beancount_files(beancount_files, beancount_file_filter, verbose)
beancount_files = get_beancount_files(beancount_files, beancount_file_filter)
# Extract Entries from specified Beancount files
entries = extract_beancount_entries(beancount_files)
# Process Each Entry from All Notes Files
jsonl_data = convert_beancount_entries_to_jsonl(entries, verbose=verbose)
jsonl_data = convert_beancount_entries_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file, verbose=verbose)
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file, verbose=verbose)
dump_jsonl(jsonl_data, output_file)
return entries
def get_beancount_files(beancount_files=None, beancount_file_filter=None, verbose=0):
def get_beancount_files(beancount_files=None, beancount_file_filter=None):
"Get Beancount files to process"
absolute_beancount_files, filtered_beancount_files = set(), set()
if beancount_files:
@@ -57,8 +61,7 @@ def get_beancount_files(beancount_files=None, beancount_file_filter=None, verbos
if any(files_with_non_beancount_extensions):
print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}")
if verbose > 0:
print(f'Processing files: {all_beancount_files}')
logger.info(f'Processing files: {all_beancount_files}')
return all_beancount_files
@@ -82,7 +85,7 @@ def extract_beancount_entries(beancount_files):
return entries
def convert_beancount_entries_to_jsonl(entries, verbose=0):
def convert_beancount_entries_to_jsonl(entries):
"Convert each Beancount transaction to JSON and collate as JSONL"
jsonl = ''
for entry in entries:
@@ -90,8 +93,7 @@ def convert_beancount_entries_to_jsonl(entries, verbose=0):
# Convert Dictionary to JSON and Append to JSONL string
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
if verbose > 0:
print(f"Converted {len(entries)} to jsonl format")
logger.info(f"Converted {len(entries)} to jsonl format")
return jsonl

View File

@@ -6,6 +6,7 @@ import argparse
import pathlib
import glob
import re
import logging
# Internal Packages
from src.utils.helpers import get_absolute_path, is_none_or_empty
@@ -13,32 +14,35 @@ from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
logger = logging.getLogger(__name__)
# Define Functions
def markdown_to_jsonl(markdown_files, markdown_file_filter, output_file, verbose=0):
def markdown_to_jsonl(markdown_files, markdown_file_filter, output_file):
# Input Validation
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
print("At least one of markdown-files or markdown-file-filter is required to be specified")
exit(1)
# Get Markdown Files to Process
markdown_files = get_markdown_files(markdown_files, markdown_file_filter, verbose)
markdown_files = get_markdown_files(markdown_files, markdown_file_filter)
# Extract Entries from specified Markdown files
entries = extract_markdown_entries(markdown_files)
# Process Each Entry from All Notes Files
jsonl_data = convert_markdown_entries_to_jsonl(entries, verbose=verbose)
jsonl_data = convert_markdown_entries_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file, verbose=verbose)
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file, verbose=verbose)
dump_jsonl(jsonl_data, output_file)
return entries
def get_markdown_files(markdown_files=None, markdown_file_filter=None, verbose=0):
def get_markdown_files(markdown_files=None, markdown_file_filter=None):
"Get Markdown files to process"
absolute_markdown_files, filtered_markdown_files = set(), set()
if markdown_files:
@@ -56,10 +60,9 @@ def get_markdown_files(markdown_files=None, markdown_file_filter=None, verbose=0
}
if any(files_with_non_markdown_extensions):
print(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
if verbose > 0:
print(f'Processing files: {all_markdown_files}')
logger.info(f'Processing files: {all_markdown_files}')
return all_markdown_files
@@ -81,7 +84,7 @@ def extract_markdown_entries(markdown_files):
return entries
def convert_markdown_entries_to_jsonl(entries, verbose=0):
def convert_markdown_entries_to_jsonl(entries):
"Convert each Markdown entries to JSON and collate as JSONL"
jsonl = ''
for entry in entries:
@@ -89,8 +92,7 @@ def convert_markdown_entries_to_jsonl(entries, verbose=0):
# Convert Dictionary to JSON and Append to JSONL string
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
if verbose > 0:
print(f"Converted {len(entries)} to jsonl format")
logger.info(f"Converted {len(entries)} to jsonl format")
return jsonl

View File

@@ -6,40 +6,43 @@ import json
import argparse
import pathlib
import glob
import logging
# Internal Packages
from src.processor.org_mode import orgnode
from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
logger = logging.getLogger(__name__)
# Define Functions
def org_to_jsonl(org_files, org_file_filter, output_file, verbose=0):
def org_to_jsonl(org_files, org_file_filter, output_file):
# Input Validation
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter):
print("At least one of org-files or org-file-filter is required to be specified")
exit(1)
# Get Org Files to Process
org_files = get_org_files(org_files, org_file_filter, verbose)
org_files = get_org_files(org_files, org_file_filter)
# Extract Entries from specified Org files
entries = extract_org_entries(org_files)
# Process Each Entry from All Notes Files
jsonl_data = convert_org_entries_to_jsonl(entries, verbose=verbose)
jsonl_data = convert_org_entries_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file, verbose=verbose)
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file, verbose=verbose)
dump_jsonl(jsonl_data, output_file)
return entries
def get_org_files(org_files=None, org_file_filter=None, verbose=0):
def get_org_files(org_files=None, org_file_filter=None):
"Get Org files to process"
absolute_org_files, filtered_org_files = set(), set()
if org_files:
@@ -53,10 +56,9 @@ def get_org_files(org_files=None, org_file_filter=None, verbose=0):
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
if any(files_with_non_org_extensions):
print(f"[Warning] There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
if verbose > 0:
print(f'Processing files: {all_org_files}')
logger.info(f'Processing files: {all_org_files}')
return all_org_files
@@ -72,7 +74,7 @@ def extract_org_entries(org_files):
return entries
def convert_org_entries_to_jsonl(entries, verbose=0) -> str:
def convert_org_entries_to_jsonl(entries) -> str:
"Convert each Org-Mode entries to JSON and collate as JSONL"
jsonl = ''
for entry in entries:
@@ -83,29 +85,24 @@ def convert_org_entries_to_jsonl(entries, verbose=0) -> str:
continue
entry_dict["compiled"] = f'{entry.Heading()}.'
if verbose > 2:
print(f"Title: {entry.Heading()}")
logger.debug(f"Title: {entry.Heading()}")
if entry.Tags():
tags_str = " ".join(entry.Tags())
entry_dict["compiled"] += f'\t {tags_str}.'
if verbose > 2:
print(f"Tags: {tags_str}")
logger.debug(f"Tags: {tags_str}")
if entry.Closed():
entry_dict["compiled"] += f'\n Closed on {entry.Closed().strftime("%Y-%m-%d")}.'
if verbose > 2:
print(f'Closed: {entry.Closed().strftime("%Y-%m-%d")}')
logger.debug(f'Closed: {entry.Closed().strftime("%Y-%m-%d")}')
if entry.Scheduled():
entry_dict["compiled"] += f'\n Scheduled for {entry.Scheduled().strftime("%Y-%m-%d")}.'
if verbose > 2:
print(f'Scheduled: {entry.Scheduled().strftime("%Y-%m-%d")}')
logger.debug(f'Scheduled: {entry.Scheduled().strftime("%Y-%m-%d")}')
if entry.Body():
entry_dict["compiled"] += f'\n {entry.Body()}'
if verbose > 2:
print(f"Body: {entry.Body()}")
logger.debug(f"Body: {entry.Body()}")
if entry_dict:
entry_dict["raw"] = f'{entry}'
@@ -113,8 +110,7 @@ def convert_org_entries_to_jsonl(entries, verbose=0) -> str:
# Convert Dictionary to JSON and Append to JSONL string
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
if verbose > 0:
print(f"Converted {len(entries)} to jsonl format")
logger.info(f"Converted {len(entries)} to jsonl format")
return jsonl

View File

@@ -2,6 +2,7 @@
import yaml
import json
import time
import logging
from typing import Optional
from functools import lru_cache
@@ -22,9 +23,11 @@ from src.utils.config import SearchType
from src.utils.helpers import get_absolute_path, get_from_dict
from src.utils import state, constants
router = APIRouter()
router = APIRouter()
templates = Jinja2Templates(directory=constants.web_directory)
logger = logging.getLogger(__name__)
@router.get("/", response_class=FileResponse)
def index():
@@ -50,7 +53,7 @@ async def config_data(updated_config: FullConfig):
@lru_cache(maxsize=100)
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
if q is None or q == '':
print(f'No query param (q) passed in API call to initiate search')
logger.info(f'No query param (q) passed in API call to initiate search')
return {}
# initialize variables
@@ -120,11 +123,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
count=results_count)
collate_end = time.time()
if state.verbose > 1:
if query_start and query_end:
print(f"Query took {query_end - query_start:.3f} seconds")
if collate_start and collate_end:
print(f"Collating results took {collate_end - collate_start:.3f} seconds")
if query_start and query_end:
logger.debug(f"Query took {query_end - query_start:.3f} seconds")
if collate_start and collate_end:
logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds")
return results

View File

@@ -5,6 +5,7 @@ import pathlib
import copy
import shutil
import time
import logging
# External Packages
from sentence_transformers import SentenceTransformer, util
@@ -17,7 +18,10 @@ from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_mod
import src.utils.exiftool as exiftool
from src.utils.config import ImageSearchModel
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig
from src.utils import state
# Create Logger
logger = logging.getLogger(__name__)
def initialize_model(search_config: ImageSearchConfig):
@@ -39,34 +43,33 @@ def initialize_model(search_config: ImageSearchConfig):
return encoder
def extract_entries(image_directories, verbose=0):
def extract_entries(image_directories):
image_names = []
for image_directory in image_directories:
image_directory = resolve_absolute_path(image_directory, strict=True)
image_names.extend(list(image_directory.glob('*.jpg')))
image_names.extend(list(image_directory.glob('*.jpeg')))
if verbose > 0:
if logger.level >= logging.INFO:
image_directory_names = ', '.join([str(image_directory) for image_directory in image_directories])
print(f'Found {len(image_names)} images in {image_directory_names}')
logger.info(f'Found {len(image_names)} images in {image_directory_names}')
return sorted(image_names)
def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0):
def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate, verbose)
image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate, verbose)
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate)
image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate)
return image_embeddings, image_metadata_embeddings
def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False, verbose=0):
def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False):
# Load pre-computed image embeddings from file if exists
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
image_embeddings = torch.load(embeddings_file)
if verbose > 0:
print(f"Loaded pre-computed embeddings from {embeddings_file}")
logger.info(f"Loaded {len(image_embeddings)} image embeddings from {embeddings_file}")
# Else compute the image embeddings from scratch, which can take a while
else:
image_embeddings = []
@@ -87,8 +90,7 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
# Save computed image embeddings to file
torch.save(image_embeddings, embeddings_file)
if verbose > 0:
print(f"Saved computed embeddings to {embeddings_file}")
logger.info(f"Saved computed embeddings to {embeddings_file}")
return image_embeddings
@@ -99,8 +101,7 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
# Load pre-computed image metadata embedding file if exists
if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate:
image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata")
if verbose > 0:
print(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
logger.info(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
# Else compute the image metadata embeddings from scratch, which can take a while
if use_xmp_metadata and image_metadata_embeddings is None:
@@ -113,16 +114,15 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
convert_to_tensor=True,
batch_size=min(len(image_metadata), batch_size))
except RuntimeError as e:
print(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
logger.error(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
continue
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
if verbose > 0:
print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
logger.info(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
return image_metadata_embeddings
def extract_metadata(image_name, verbose=0):
def extract_metadata(image_name):
with exiftool.ExifTool() as et:
image_metadata = et.get_tags(["XMP:Subject", "XMP:Description"], str(image_name))
image_metadata_subjects = set([subject.split(":")[1] for subject in image_metadata.get("XMP:Subject", "") if ":" in subject])
@@ -131,8 +131,7 @@ def extract_metadata(image_name, verbose=0):
if len(image_metadata_subjects) > 0:
image_processed_metadata += ". " + ", ".join(image_metadata_subjects)
if verbose > 2:
print(f"{image_name}:\t{image_processed_metadata}")
logger.debug(f"{image_name}:\t{image_processed_metadata}")
return image_processed_metadata
@@ -143,19 +142,16 @@ def query(raw_query, count, model: ImageSearchModel):
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query), strict=True)
query = copy.deepcopy(Image.open(query_imagepath))
query.thumbnail((640, query.height)) # scale down image for faster processing
if model.verbose > 0:
print(f"Find Images similar to Image at {query_imagepath}")
logger.info(f"Find Images similar to Image at {query_imagepath}")
else:
query = raw_query
if state.verbose > 0:
print(f"Find Images by Text: {query}")
logger.info(f"Find Images by Text: {query}")
# Now we encode the query (which can either be an image or a text string)
start = time.time()
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
end = time.time()
if state.verbose > 1:
print(f"Query Encode Time: {end - start:.3f} seconds")
logger.debug(f"Query Encode Time: {end - start:.3f} seconds")
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
start = time.time()
@@ -163,8 +159,7 @@ def query(raw_query, count, model: ImageSearchModel):
for result
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
end = time.time()
if state.verbose > 1:
print(f"Search Time: {end - start:.3f} seconds")
logger.debug(f"Search Time: {end - start:.3f} seconds")
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings:
@@ -173,8 +168,7 @@ def query(raw_query, count, model: ImageSearchModel):
for result
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
end = time.time()
if state.verbose > 1:
print(f"Metadata Search Time: {end - start:.3f} seconds")
logger.debug(f"Metadata Search Time: {end - start:.3f} seconds")
# Sum metadata, image scores of the highest ranked images
for corpus_id, score in metadata_hits.items():
@@ -237,7 +231,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
return results
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel:
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
# Initialize Model
encoder = initialize_model(search_config)
@@ -245,7 +239,7 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
absolute_image_files, filtered_image_files = set(), set()
if config.input_directories:
image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories]
absolute_image_files = set(extract_entries(image_directories, verbose))
absolute_image_files = set(extract_entries(image_directories))
if config.input_filter:
filtered_image_files = set(glob.glob(get_absolute_path(config.input_filter)))
@@ -259,14 +253,12 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
embeddings_file,
batch_size=config.batch_size,
regenerate=regenerate,
use_xmp_metadata=config.use_xmp_metadata,
verbose=verbose)
use_xmp_metadata=config.use_xmp_metadata)
return ImageSearchModel(all_image_files,
image_embeddings,
image_metadata_embeddings,
encoder,
verbose)
encoder)
if __name__ == '__main__':

View File

@@ -1,8 +1,9 @@
# Standard Packages
import argparse
import pathlib
from copy import deepcopy
import logging
import time
from copy import deepcopy
# External Packages
import torch
@@ -16,6 +17,9 @@ from src.utils.rawconfig import TextSearchConfig, TextContentConfig
from src.utils.jsonl import load_jsonl
logger = logging.getLogger(__name__)
def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text"
torch.set_num_threads(4)
@@ -46,32 +50,30 @@ def initialize_model(search_config: TextSearchConfig):
return bi_encoder, cross_encoder, top_k
def extract_entries(jsonl_file, verbose=0):
def extract_entries(jsonl_file):
"Load entries from compressed jsonl"
return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'}
for entry
in load_jsonl(jsonl_file, verbose=verbose)]
in load_jsonl(jsonl_file)]
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0):
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
# Load pre-computed embeddings from file if exists
if embeddings_file.exists() and not regenerate:
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
if verbose > 0:
print(f"Loaded embeddings from {embeddings_file}")
logger.info(f"Loaded embeddings from {embeddings_file}")
else: # Else compute the corpus_embeddings from scratch, which can take a while
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True)
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
torch.save(corpus_embeddings, embeddings_file)
if verbose > 0:
print(f"Computed embeddings and saved them to {embeddings_file}")
logger.info(f"Computed embeddings and saved them to {embeddings_file}")
return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = [], verbose=0):
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = []):
"Search for entries that answer the query"
query = raw_query
@@ -85,16 +87,14 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
corpus_embeddings = model.corpus_embeddings
entries = model.entries
end = time.time()
if verbose > 1:
print(f"Copy Time: {end - start:.3f} seconds")
logger.debug(f"Copy Time: {end - start:.3f} seconds")
# Filter query, entries and embeddings before semantic search
start = time.time()
for filter in filters_in_query:
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
end = time.time()
if verbose > 1:
print(f"Filter Time: {end - start:.3f} seconds")
logger.debug(f"Filter Time: {end - start:.3f} seconds")
if entries is None or len(entries) == 0:
return [], []
@@ -104,15 +104,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
end = time.time()
if verbose > 1:
print(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
logger.debug(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
# Find relevant entries for the query
start = time.time()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
end = time.time()
if verbose > 1:
print(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
logger.debug(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
# Score all retrieved entries using the cross-encoder
if rank_results:
@@ -120,8 +118,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time()
if verbose > 1:
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
@@ -133,8 +130,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time()
if verbose > 1:
print(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
return hits, entries
@@ -167,24 +163,24 @@ def collate_results(hits, entries, count=5):
in hits[0:count]]
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel:
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
if not config.compressed_jsonl.exists() or regenerate:
text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose)
text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl)
# Extract Entries
entries = extract_entries(config.compressed_jsonl, verbose)
entries = extract_entries(config.compressed_jsonl)
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k)
if __name__ == '__main__':
@@ -200,7 +196,7 @@ if __name__ == '__main__':
parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0")
args = parser.parse_args()
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate, args.verbose)
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate)
# Run User Queries on Entries in Interactive Mode
while args.interactive:
@@ -213,4 +209,4 @@ if __name__ == '__main__':
hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
# render results
render_results(hits, entries, count=args.results_count)
render_results(hits, entries, count=args.results_count)

View File

@@ -20,23 +20,21 @@ class ProcessorType(str, Enum):
class TextSearchModel():
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose):
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder
self.cross_encoder = cross_encoder
self.top_k = top_k
self.verbose = verbose
class ImageSearchModel():
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder, verbose):
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
self.image_encoder = image_encoder
self.image_names = image_names
self.image_embeddings = image_embeddings
self.image_metadata_embeddings = image_metadata_embeddings
self.image_encoder = image_encoder
self.verbose = verbose
@dataclass
@@ -49,12 +47,11 @@ class SearchModels():
class ConversationProcessorConfigModel():
def __init__(self, processor_config: ConversationProcessorConfig, verbose: bool):
def __init__(self, processor_config: ConversationProcessorConfig):
self.openai_api_key = processor_config.openai_api_key
self.conversation_logfile = Path(processor_config.conversation_logfile)
self.chat_session = ''
self.meta_log = []
self.verbose = verbose
@dataclass

View File

@@ -1,13 +1,17 @@
# Standard Packages
import json
import gzip
import logging
# Internal Packages
from src.utils.constants import empty_escape_sequences
from src.utils.helpers import get_absolute_path
def load_jsonl(input_path, verbose=0):
logger = logging.getLogger(__name__)
def load_jsonl(input_path):
"Read List of JSON objects from JSON line file"
# Initialize Variables
data = []
@@ -27,13 +31,12 @@ def load_jsonl(input_path, verbose=0):
jsonl_file.close()
# Log JSONL entries loaded
if verbose > 0:
print(f'Loaded {len(data)} records from {input_path}')
logger.info(f'Loaded {len(data)} records from {input_path}')
return data
def dump_jsonl(jsonl_data, output_path, verbose=0):
def dump_jsonl(jsonl_data, output_path):
"Write List of JSON objects to JSON line file"
# Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
@@ -41,16 +44,14 @@ def dump_jsonl(jsonl_data, output_path, verbose=0):
with open(output_path, 'w', encoding='utf-8') as f:
f.write(jsonl_data)
if verbose > 0:
print(f'Wrote {len(jsonl_data)} lines to jsonl at {output_path}')
logger.info(f'Wrote {len(jsonl_data)} lines to jsonl at {output_path}')
def compress_jsonl_data(jsonl_data, output_path, verbose=0):
def compress_jsonl_data(jsonl_data, output_path):
# Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
with gzip.open(output_path, 'wt') as gzip_file:
gzip_file.write(jsonl_data)
if verbose > 0:
print(f'Wrote {len(jsonl_data)} lines to gzip compressed jsonl at {output_path}')
logger.info(f'Wrote {len(jsonl_data)} lines to gzip compressed jsonl at {output_path}')