diff --git a/src/processor/ledger/beancount_to_jsonl.py b/src/processor/ledger/beancount_to_jsonl.py index 74b1ed33..c88b66c2 100644 --- a/src/processor/ledger/beancount_to_jsonl.py +++ b/src/processor/ledger/beancount_to_jsonl.py @@ -2,7 +2,7 @@ import glob import re import logging -import time +from typing import List # Internal Packages from src.processor.text_to_jsonl import TextToJsonl @@ -109,7 +109,7 @@ class BeancountToJsonl(TextToJsonl): return entries, dict(transaction_to_file_map) @staticmethod - def convert_transactions_to_maps(parsed_entries: list[str], transaction_to_file_map) -> list[Entry]: + def convert_transactions_to_maps(parsed_entries: List[str], transaction_to_file_map) -> List[Entry]: "Convert each parsed Beancount transaction into a Entry" entries = [] for parsed_entry in parsed_entries: @@ -120,6 +120,6 @@ class BeancountToJsonl(TextToJsonl): return entries @staticmethod - def convert_transaction_maps_to_jsonl(entries: list[Entry]) -> str: + def convert_transaction_maps_to_jsonl(entries: List[Entry]) -> str: "Convert each Beancount transaction entry to JSON and collate as JSONL" return ''.join([f'{entry.to_json()}\n' for entry in entries]) diff --git a/src/processor/markdown/markdown_to_jsonl.py b/src/processor/markdown/markdown_to_jsonl.py index 773d6dfc..9b326884 100644 --- a/src/processor/markdown/markdown_to_jsonl.py +++ b/src/processor/markdown/markdown_to_jsonl.py @@ -3,6 +3,7 @@ import glob import re import logging import time +from typing import List # Internal Packages from src.processor.text_to_jsonl import TextToJsonl @@ -110,7 +111,7 @@ class MarkdownToJsonl(TextToJsonl): return entries, dict(entry_to_file_map) @staticmethod - def convert_markdown_entries_to_maps(parsed_entries: list[str], entry_to_file_map) -> list[Entry]: + def convert_markdown_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]: "Convert each Markdown entries into a dictionary" entries = [] for parsed_entry in parsed_entries: @@ -121,6 +122,6 @@ class MarkdownToJsonl(TextToJsonl): return entries @staticmethod - def convert_markdown_maps_to_jsonl(entries: list[Entry]): + def convert_markdown_maps_to_jsonl(entries: List[Entry]): "Convert each Markdown entry to JSON and collate as JSONL" return ''.join([f'{entry.to_json()}\n' for entry in entries]) diff --git a/src/processor/org_mode/org_to_jsonl.py b/src/processor/org_mode/org_to_jsonl.py index f2c301cd..2e227de5 100644 --- a/src/processor/org_mode/org_to_jsonl.py +++ b/src/processor/org_mode/org_to_jsonl.py @@ -2,7 +2,7 @@ import glob import logging import time -from typing import Iterable +from typing import Iterable, List # Internal Packages from src.processor.org_mode import orgnode @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) class OrgToJsonl(TextToJsonl): # Define Functions - def process(self, previous_entries: list[Entry]=None): + def process(self, previous_entries: List[Entry]=None): # Extract required fields from config org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl index_heading_entries = self.config.index_heading_entries @@ -101,9 +101,9 @@ class OrgToJsonl(TextToJsonl): return entries, dict(entry_to_file_map) @staticmethod - def convert_org_nodes_to_entries(parsed_entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[Entry]: + def convert_org_nodes_to_entries(parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> List[Entry]: "Convert Org-Mode nodes into list of Entry objects" - entries: list[Entry] = [] + entries: List[Entry] = [] for parsed_entry in parsed_entries: if not parsed_entry.hasBody and not index_heading_entries: # Ignore title notes i.e notes with just headings and empty body diff --git a/src/processor/org_mode/orgnode.py b/src/processor/org_mode/orgnode.py index a5f4cd43..7675650b 100644 --- a/src/processor/org_mode/orgnode.py +++ b/src/processor/org_mode/orgnode.py @@ -37,6 +37,7 @@ import re import datetime from pathlib import Path from os.path import relpath +from typing import List indent_regex = re.compile(r'^ *') @@ -69,7 +70,7 @@ def makelist(filename): sched_date = '' deadline_date = '' logbook = list() - nodelist: list[Orgnode] = list() + nodelist: List[Orgnode] = list() property_map = dict() in_properties_drawer = False in_logbook_drawer = False diff --git a/src/processor/text_to_jsonl.py b/src/processor/text_to_jsonl.py index 3c198784..33a6a515 100644 --- a/src/processor/text_to_jsonl.py +++ b/src/processor/text_to_jsonl.py @@ -1,9 +1,8 @@ # Standard Packages from abc import ABC, abstractmethod import hashlib -import time import logging -from typing import Callable +from typing import Callable, List, Tuple from src.utils.helpers import timer # Internal Packages @@ -18,16 +17,16 @@ class TextToJsonl(ABC): self.config = config @abstractmethod - def process(self, previous_entries: list[Entry]=None) -> list[tuple[int, Entry]]: ... + def process(self, previous_entries: List[Entry]=None) -> List[Tuple[int, Entry]]: ... @staticmethod def hash_func(key: str) -> Callable: return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest() @staticmethod - def split_entries_by_max_tokens(entries: list[Entry], max_tokens: int=256, max_word_length: int=500) -> list[Entry]: + def split_entries_by_max_tokens(entries: List[Entry], max_tokens: int=256, max_word_length: int=500) -> List[Entry]: "Split entries if compiled entry length exceeds the max tokens supported by the ML model." - chunked_entries: list[Entry] = [] + chunked_entries: List[Entry] = [] for entry in entries: compiled_entry_words = entry.compiled.split() # Drop long words instead of having entry truncated to maintain quality of entry processed by models @@ -39,7 +38,7 @@ class TextToJsonl(ABC): chunked_entries.append(entry_chunk) return chunked_entries - def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]: + def mark_entries_for_update(self, current_entries: List[Entry], previous_entries: List[Entry], key='compiled', logger=None) -> List[Tuple[int, Entry]]: # Hash all current and previous entries to identify new entries with timer("Hash previous, current entries", logger): current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries)) diff --git a/src/routers/api.py b/src/routers/api.py index 9480c914..c505c133 100644 --- a/src/routers/api.py +++ b/src/routers/api.py @@ -1,7 +1,7 @@ # Standard Packages import yaml import logging -from typing import Optional +from typing import List, Optional # External Packages from fastapi import APIRouter @@ -38,9 +38,9 @@ async def set_config_data(updated_config: FullConfig): outfile.close() return state.config -@api.get('/search', response_model=list[SearchResponse]) +@api.get('/search', response_model=List[SearchResponse]) def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False): - results: list[SearchResponse] = [] + results: List[SearchResponse] = [] if q is None or q == '': logger.info(f'No query param (q) passed in API call to initiate search') return results diff --git a/src/search_filter/base_filter.py b/src/search_filter/base_filter.py index a1b56492..5b7c7f60 100644 --- a/src/search_filter/base_filter.py +++ b/src/search_filter/base_filter.py @@ -1,5 +1,6 @@ # Standard Packages from abc import ABC, abstractmethod +from typing import List, Set, Tuple # Internal Packages from src.utils.rawconfig import Entry @@ -7,10 +8,10 @@ from src.utils.rawconfig import Entry class BaseFilter(ABC): @abstractmethod - def load(self, entries: list[Entry], *args, **kwargs): ... + def load(self, entries: List[Entry], *args, **kwargs): ... @abstractmethod def can_filter(self, raw_query:str) -> bool: ... @abstractmethod - def apply(self, query:str, entries: list[Entry]) -> tuple[str, set[int]]: ... + def apply(self, query:str, entries: List[Entry]) -> Tuple[str, Set[int]]: ... diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index 6dc57b6e..db358988 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -5,6 +5,7 @@ import copy import shutil import time import logging +from typing import List # External Packages from sentence_transformers import SentenceTransformer, util @@ -189,8 +190,8 @@ def query(raw_query, count, model: ImageSearchModel): return sorted(hits, key=lambda hit: hit["score"], reverse=True) -def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> list[SearchResponse]: - results: list[SearchResponse] = [] +def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> List[SearchResponse]: + results: List[SearchResponse] = [] for index, hit in enumerate(hits[:count]): source_path = image_names[hit['corpus_id']] diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 0aba5e2f..46f83638 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -2,7 +2,7 @@ import logging from pathlib import Path import time -from typing import Type +from typing import List, Tuple, Type # External Packages import torch @@ -53,12 +53,12 @@ def initialize_model(search_config: TextSearchConfig): return bi_encoder, cross_encoder, top_k -def extract_entries(jsonl_file) -> list[Entry]: +def extract_entries(jsonl_file) -> List[Entry]: "Load entries from compressed jsonl" return list(map(Entry.from_dict, load_jsonl(jsonl_file))) -def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False): +def compute_embeddings(entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" new_entries = [] # Load pre-computed embeddings from file if exists and update them if required @@ -90,7 +90,7 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: Ba return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> tuple[list[dict], list[Entry]]: +def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> Tuple[List[dict], List[Entry]]: "Search for entries that answer the query" query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings @@ -127,7 +127,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> return hits, entries -def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse]: +def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]: return [SearchResponse.parse_obj( { "entry": entries[hit['corpus_id']].raw, @@ -141,7 +141,7 @@ def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse] in hits[0:count]] -def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel: +def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: List[BaseFilter] = []) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) @@ -166,7 +166,7 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k) -def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]: +def apply_filters(query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]) -> Tuple[str, List[Entry], torch.Tensor]: '''Filter query, entries and embeddings before semantic search''' with timer("Total Filter Time", logger, state.device): @@ -186,7 +186,7 @@ def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Ten return query, entries, corpus_embeddings -def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]: +def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]: '''Score all retrieved entries using the cross-encoder''' with timer("Cross-Encoder Predict Time", logger, state.device): cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits] @@ -199,7 +199,7 @@ def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[E return hits -def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]: +def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]: '''Order results by cross-encoder score followed by bi-encoder score''' with timer("Rank Time", logger, state.device): hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score @@ -208,7 +208,7 @@ def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]: return hits -def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]: +def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]: '''Deduplicate entries by raw entry text before showing to users Compiled entries are split by max tokens supported by ML models. This can result in duplicate hits, entries shown to user.''' diff --git a/src/utils/config.py b/src/utils/config.py index 118f766f..e1f3ddb4 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints from enum import Enum from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List # External Packages import torch @@ -29,7 +29,7 @@ class ProcessorType(str, Enum): class TextSearchModel(): - def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder: CrossEncoder, filters: list[BaseFilter], top_k): + def __init__(self, entries: List[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder: CrossEncoder, filters: List[BaseFilter], top_k): self.entries = entries self.corpus_embeddings = corpus_embeddings self.bi_encoder = bi_encoder diff --git a/src/utils/models.py b/src/utils/models.py index 43e02639..53525566 100644 --- a/src/utils/models.py +++ b/src/utils/models.py @@ -1,5 +1,6 @@ # Standard Packages from abc import ABC, abstractmethod +from typing import List # External Packages import openai @@ -15,7 +16,7 @@ class BaseEncoder(ABC): def __init__(self, model_name: str, device: torch.device=None, **kwargs): ... @abstractmethod - def encode(self, entries: list[str], device:torch.device=None, **kwargs) -> torch.Tensor: ... + def encode(self, entries: List[str], device:torch.device=None, **kwargs) -> torch.Tensor: ... class OpenAI(BaseEncoder): diff --git a/src/utils/state.py b/src/utils/state.py index 0e323b89..8574c92d 100644 --- a/src/utils/state.py +++ b/src/utils/state.py @@ -1,5 +1,6 @@ # Standard Packages import threading +from typing import List from packaging import version # External Packages @@ -19,7 +20,7 @@ config_file: Path = None verbose: int = 0 host: str = None port: int = None -cli_args: list[str] = None +cli_args: List[str] = None query_cache = LRU() search_index_lock = threading.Lock()