Use Black to format Khoj server code and tests

This commit is contained in:
Debanjum Singh Solanky
2023-02-17 10:04:26 -06:00
parent 6130fddf45
commit 5e83baab21
44 changed files with 1167 additions and 915 deletions

View File

@@ -26,10 +26,12 @@ logger = logging.getLogger(__name__)
def configure_server(args, required=False):
if args.config is None:
if required:
logger.error(f'Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.')
logger.error(f"Exiting as Khoj is not configured.\nConfigure it via GUI or by editing {state.config_file}.")
sys.exit(1)
else:
logger.warn(f'Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}.')
logger.warn(
f"Khoj is not configured.\nConfigure it via khoj GUI, plugins or by editing {state.config_file}."
)
return
else:
state.config = args.config
@@ -60,7 +62,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
config.content_type.org,
search_config=config.search_type.asymmetric,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()])
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Org Music Search
if (t == SearchType.Music or t == None) and config.content_type.music:
@@ -70,7 +73,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
config.content_type.music,
search_config=config.search_type.asymmetric,
regenerate=regenerate,
filters=[DateFilter(), WordFilter()])
filters=[DateFilter(), WordFilter()],
)
# Initialize Markdown Search
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
@@ -80,7 +84,8 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
config.content_type.markdown,
search_config=config.search_type.asymmetric,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()])
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
@@ -90,15 +95,15 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
config.content_type.ledger,
search_config=config.search_type.symmetric,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()])
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image:
# Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup(
config.content_type.image,
search_config=config.search_type.image,
regenerate=regenerate)
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
)
# Invalidate Query Cache
state.query_cache = LRU()
@@ -125,9 +130,9 @@ def configure_conversation_processor(conversation_processor_config):
if conversation_logfile.is_file():
# Load Metadata Logs from Conversation Logfile
with conversation_logfile.open('r') as f:
with conversation_logfile.open("r") as f:
conversation_processor.meta_log = json.load(f)
logger.info('Conversation logs loaded from disk.')
logger.info("Conversation logs loaded from disk.")
else:
# Initialize Conversation Logs
conversation_processor.meta_log = {}

View File

@@ -8,7 +8,7 @@ from khoj.utils.helpers import is_none_or_empty
class FileBrowser(QtWidgets.QWidget):
def __init__(self, title, search_type: SearchType=None, default_files:list=[]):
def __init__(self, title, search_type: SearchType = None, default_files: list = []):
QtWidgets.QWidget.__init__(self)
layout = QtWidgets.QHBoxLayout()
self.setLayout(layout)
@@ -22,51 +22,54 @@ class FileBrowser(QtWidgets.QWidget):
self.label.setFixedWidth(95)
self.label.setWordWrap(True)
layout.addWidget(self.label)
self.lineEdit = QtWidgets.QPlainTextEdit(self)
self.lineEdit.setFixedWidth(330)
self.setFiles(default_files)
self.lineEdit.setFixedHeight(min(7+20*len(self.lineEdit.toPlainText().split('\n')),90))
self.lineEdit.setFixedHeight(min(7 + 20 * len(self.lineEdit.toPlainText().split("\n")), 90))
self.lineEdit.textChanged.connect(self.updateFieldHeight)
layout.addWidget(self.lineEdit)
self.button = QtWidgets.QPushButton('Add')
self.button = QtWidgets.QPushButton("Add")
self.button.clicked.connect(self.storeFilesSelectedInFileDialog)
layout.addWidget(self.button)
layout.addStretch()
def getFileFilter(self, search_type):
if search_type == SearchType.Org:
return 'Org-Mode Files (*.org)'
return "Org-Mode Files (*.org)"
elif search_type == SearchType.Ledger:
return 'Beancount Files (*.bean *.beancount)'
return "Beancount Files (*.bean *.beancount)"
elif search_type == SearchType.Markdown:
return 'Markdown Files (*.md *.markdown)'
return "Markdown Files (*.md *.markdown)"
elif search_type == SearchType.Music:
return 'Org-Music Files (*.org)'
return "Org-Music Files (*.org)"
elif search_type == SearchType.Image:
return 'Images (*.jp[e]g)'
return "Images (*.jp[e]g)"
def storeFilesSelectedInFileDialog(self):
filepaths = self.getPaths()
if self.search_type == SearchType.Image:
filepaths.append(QtWidgets.QFileDialog.getExistingDirectory(self, caption='Choose Folder',
directory=self.dirpath))
filepaths.append(
QtWidgets.QFileDialog.getExistingDirectory(self, caption="Choose Folder", directory=self.dirpath)
)
else:
filepaths.extend(QtWidgets.QFileDialog.getOpenFileNames(self, caption='Choose Files',
directory=self.dirpath,
filter=self.filter_name)[0])
filepaths.extend(
QtWidgets.QFileDialog.getOpenFileNames(
self, caption="Choose Files", directory=self.dirpath, filter=self.filter_name
)[0]
)
self.setFiles(filepaths)
def setFiles(self, paths:list):
def setFiles(self, paths: list):
self.filepaths = [path for path in paths if not is_none_or_empty(path)]
self.lineEdit.setPlainText("\n".join(self.filepaths))
def getPaths(self) -> list:
if self.lineEdit.toPlainText() == '':
if self.lineEdit.toPlainText() == "":
return []
else:
return self.lineEdit.toPlainText().split('\n')
return self.lineEdit.toPlainText().split("\n")
def updateFieldHeight(self):
self.lineEdit.setFixedHeight(min(7+20*len(self.lineEdit.toPlainText().split('\n')),90))
self.lineEdit.setFixedHeight(min(7 + 20 * len(self.lineEdit.toPlainText().split("\n")), 90))

View File

@@ -6,7 +6,7 @@ from khoj.utils.config import ProcessorType
class LabelledTextField(QtWidgets.QWidget):
def __init__(self, title, processor_type: ProcessorType=None, default_value: str=None):
def __init__(self, title, processor_type: ProcessorType = None, default_value: str = None):
QtWidgets.QWidget.__init__(self)
layout = QtWidgets.QHBoxLayout()
self.setLayout(layout)

View File

@@ -31,9 +31,9 @@ class MainWindow(QtWidgets.QMainWindow):
self.config_file = config_file
# Set regenerate flag to regenerate embeddings everytime user clicks configure
if state.cli_args:
state.cli_args += ['--regenerate']
state.cli_args += ["--regenerate"]
else:
state.cli_args = ['--regenerate']
state.cli_args = ["--regenerate"]
# Load config from existing config, if exists, else load from default config
if resolve_absolute_path(self.config_file).exists():
@@ -49,8 +49,8 @@ class MainWindow(QtWidgets.QMainWindow):
self.setFixedWidth(600)
# Set Window Icon
icon_path = constants.web_directory / 'assets/icons/favicon-144x144.png'
self.setWindowIcon(QtGui.QIcon(f'{icon_path.absolute()}'))
icon_path = constants.web_directory / "assets/icons/favicon-144x144.png"
self.setWindowIcon(QtGui.QIcon(f"{icon_path.absolute()}"))
# Initialize Configure Window Layout
self.layout = QtWidgets.QVBoxLayout()
@@ -58,13 +58,13 @@ class MainWindow(QtWidgets.QMainWindow):
# Add Settings Panels for each Search Type to Configure Window Layout
self.search_settings_panels = []
for search_type in SearchType:
current_content_config = self.current_config['content-type'].get(search_type, {})
current_content_config = self.current_config["content-type"].get(search_type, {})
self.search_settings_panels += [self.add_settings_panel(current_content_config, search_type)]
# Add Conversation Processor Panel to Configure Screen
self.processor_settings_panels = []
conversation_type = ProcessorType.Conversation
current_conversation_config = self.current_config['processor'].get(conversation_type, {})
current_conversation_config = self.current_config["processor"].get(conversation_type, {})
self.processor_settings_panels += [self.add_processor_panel(current_conversation_config, conversation_type)]
# Add Action Buttons Panel
@@ -81,11 +81,11 @@ class MainWindow(QtWidgets.QMainWindow):
"Add Settings Panel for specified Search Type. Toggle Editable Search Types"
# Get current files from config for given search type
if search_type == SearchType.Image:
current_content_files = current_content_config.get('input-directories', [])
file_input_text = f'{search_type.name} Folders'
current_content_files = current_content_config.get("input-directories", [])
file_input_text = f"{search_type.name} Folders"
else:
current_content_files = current_content_config.get('input-files', [])
file_input_text = f'{search_type.name} Files'
current_content_files = current_content_config.get("input-files", [])
file_input_text = f"{search_type.name} Files"
# Create widgets to display settings for given search type
search_type_settings = QtWidgets.QWidget()
@@ -109,7 +109,7 @@ class MainWindow(QtWidgets.QMainWindow):
def add_processor_panel(self, current_conversation_config: dict, processor_type: ProcessorType):
"Add Conversation Processor Panel"
# Get current settings from config for given processor type
current_openai_api_key = current_conversation_config.get('openai-api-key', None)
current_openai_api_key = current_conversation_config.get("openai-api-key", None)
# Create widgets to display settings for given processor type
processor_type_settings = QtWidgets.QWidget()
@@ -137,20 +137,22 @@ class MainWindow(QtWidgets.QMainWindow):
action_bar_layout = QtWidgets.QHBoxLayout(action_bar)
self.configure_button = QtWidgets.QPushButton("Configure", clicked=self.configure_app)
self.search_button = QtWidgets.QPushButton("Search", clicked=lambda: webbrowser.open(f'http://{state.host}:{state.port}/'))
self.search_button = QtWidgets.QPushButton(
"Search", clicked=lambda: webbrowser.open(f"http://{state.host}:{state.port}/")
)
self.search_button.setEnabled(not self.first_run)
action_bar_layout.addWidget(self.configure_button)
action_bar_layout.addWidget(self.search_button)
self.layout.addWidget(action_bar)
def get_default_config(self, search_type:SearchType=None, processor_type:ProcessorType=None):
def get_default_config(self, search_type: SearchType = None, processor_type: ProcessorType = None):
"Get default config"
config = constants.default_config
if search_type:
return config['content-type'][search_type]
return config["content-type"][search_type]
elif processor_type:
return config['processor'][processor_type]
return config["processor"][processor_type]
else:
return config
@@ -160,7 +162,9 @@ class MainWindow(QtWidgets.QMainWindow):
for message_prefix in ErrorType:
for i in reversed(range(self.layout.count())):
current_widget = self.layout.itemAt(i).widget()
if isinstance(current_widget, QtWidgets.QLabel) and current_widget.text().startswith(message_prefix.value):
if isinstance(current_widget, QtWidgets.QLabel) and current_widget.text().startswith(
message_prefix.value
):
self.layout.removeWidget(current_widget)
current_widget.deleteLater()
@@ -180,18 +184,24 @@ class MainWindow(QtWidgets.QMainWindow):
continue
if isinstance(child, SearchCheckBox):
# Search Type Disabled
if not child.isChecked() and child.search_type in self.new_config['content-type']:
del self.new_config['content-type'][child.search_type]
if not child.isChecked() and child.search_type in self.new_config["content-type"]:
del self.new_config["content-type"][child.search_type]
# Search Type (re)-Enabled
if child.isChecked():
current_search_config = self.current_config['content-type'].get(child.search_type, {})
default_search_config = self.get_default_config(search_type = child.search_type)
self.new_config['content-type'][child.search_type.value] = merge_dicts(current_search_config, default_search_config)
elif isinstance(child, FileBrowser) and child.search_type in self.new_config['content-type']:
current_search_config = self.current_config["content-type"].get(child.search_type, {})
default_search_config = self.get_default_config(search_type=child.search_type)
self.new_config["content-type"][child.search_type.value] = merge_dicts(
current_search_config, default_search_config
)
elif isinstance(child, FileBrowser) and child.search_type in self.new_config["content-type"]:
if child.search_type.value == SearchType.Image:
self.new_config['content-type'][child.search_type.value]['input-directories'] = child.getPaths() if child.getPaths() != [] else None
self.new_config["content-type"][child.search_type.value]["input-directories"] = (
child.getPaths() if child.getPaths() != [] else None
)
else:
self.new_config['content-type'][child.search_type.value]['input-files'] = child.getPaths() if child.getPaths() != [] else None
self.new_config["content-type"][child.search_type.value]["input-files"] = (
child.getPaths() if child.getPaths() != [] else None
)
def update_processor_settings(self):
"Update config with conversation settings from UI"
@@ -201,16 +211,20 @@ class MainWindow(QtWidgets.QMainWindow):
continue
if isinstance(child, ProcessorCheckBox):
# Processor Type Disabled
if not child.isChecked() and child.processor_type in self.new_config['processor']:
del self.new_config['processor'][child.processor_type]
if not child.isChecked() and child.processor_type in self.new_config["processor"]:
del self.new_config["processor"][child.processor_type]
# Processor Type (re)-Enabled
if child.isChecked():
current_processor_config = self.current_config['processor'].get(child.processor_type, {})
default_processor_config = self.get_default_config(processor_type = child.processor_type)
self.new_config['processor'][child.processor_type.value] = merge_dicts(current_processor_config, default_processor_config)
elif isinstance(child, LabelledTextField) and child.processor_type in self.new_config['processor']:
current_processor_config = self.current_config["processor"].get(child.processor_type, {})
default_processor_config = self.get_default_config(processor_type=child.processor_type)
self.new_config["processor"][child.processor_type.value] = merge_dicts(
current_processor_config, default_processor_config
)
elif isinstance(child, LabelledTextField) and child.processor_type in self.new_config["processor"]:
if child.processor_type == ProcessorType.Conversation:
self.new_config['processor'][child.processor_type.value]['openai-api-key'] = child.input_field.toPlainText() if child.input_field.toPlainText() != '' else None
self.new_config["processor"][child.processor_type.value]["openai-api-key"] = (
child.input_field.toPlainText() if child.input_field.toPlainText() != "" else None
)
def save_settings_to_file(self) -> bool:
"Save validated settings to file"
@@ -278,7 +292,7 @@ class MainWindow(QtWidgets.QMainWindow):
self.show()
self.setWindowState(Qt.WindowState.WindowActive)
self.activateWindow() # For Bringing to Top on Windows
self.raise_() # For Bringing to Top from Minimized State on OSX
self.raise_() # For Bringing to Top from Minimized State on OSX
class SettingsLoader(QObject):
@@ -312,6 +326,7 @@ class ProcessorCheckBox(QtWidgets.QCheckBox):
self.processor_type = processor_type
super(ProcessorCheckBox, self).__init__(text, parent=parent)
class ErrorType(Enum):
"Error Types"
ConfigLoadingError = "Config Loading Error"

View File

@@ -17,17 +17,17 @@ def create_system_tray(gui: QtWidgets.QApplication, main_window: MainWindow):
"""
# Create the system tray with icon
icon_path = constants.web_directory / 'assets/icons/favicon-144x144.png'
icon = QtGui.QIcon(f'{icon_path.absolute()}')
icon_path = constants.web_directory / "assets/icons/favicon-144x144.png"
icon = QtGui.QIcon(f"{icon_path.absolute()}")
tray = QtWidgets.QSystemTrayIcon(icon)
tray.setVisible(True)
# Create the menu and menu actions
menu = QtWidgets.QMenu()
menu_actions = [
('Search', lambda: webbrowser.open(f'http://{state.host}:{state.port}/')),
('Configure', main_window.show_on_top),
('Quit', gui.quit),
("Search", lambda: webbrowser.open(f"http://{state.host}:{state.port}/")),
("Configure", main_window.show_on_top),
("Quit", gui.quit),
]
# Add the menu actions to the menu

View File

@@ -8,8 +8,8 @@ import warnings
from platform import system
# Ignore non-actionable warnings
warnings.filterwarnings("ignore", message=r'snapshot_download.py has been made private', category=FutureWarning)
warnings.filterwarnings("ignore", message=r'legacy way to download files from the HF hub,', category=FutureWarning)
warnings.filterwarnings("ignore", message=r"snapshot_download.py has been made private", category=FutureWarning)
warnings.filterwarnings("ignore", message=r"legacy way to download files from the HF hub,", category=FutureWarning)
# External Packages
import uvicorn
@@ -43,11 +43,12 @@ rich_handler = RichHandler(rich_tracebacks=True)
rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
logging.basicConfig(handlers=[rich_handler])
logger = logging.getLogger('khoj')
logger = logging.getLogger("khoj")
def run():
# Turn Tokenizers Parallelism Off. App does not support it.
os.environ["TOKENIZERS_PARALLELISM"] = 'false'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Load config from CLI
state.cli_args = sys.argv[1:]
@@ -66,7 +67,7 @@ def run():
logger.setLevel(logging.DEBUG)
# Set Log File
fh = logging.FileHandler(state.config_file.parent / 'khoj.log')
fh = logging.FileHandler(state.config_file.parent / "khoj.log")
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)
@@ -87,7 +88,7 @@ def run():
# On Linux (Gnome) the System tray is not supported.
# Since only the Main Window is available
# Quitting it should quit the application
if system() in ['Windows', 'Darwin']:
if system() in ["Windows", "Darwin"]:
gui.setQuitOnLastWindowClosed(False)
tray = create_system_tray(gui, main_window)
tray.show()
@@ -97,7 +98,7 @@ def run():
server = ServerThread(app, args.host, args.port, args.socket)
# Show Main Window on First Run Experience or if on Linux
if args.config is None or system() not in ['Windows', 'Darwin']:
if args.config is None or system() not in ["Windows", "Darwin"]:
main_window.show()
# Setup Signal Handlers
@@ -112,9 +113,10 @@ def run():
gui.aboutToQuit.connect(server.terminate)
# Close Splash Screen if still open
if system() != 'Darwin':
if system() != "Darwin":
try:
import pyi_splash
# Update the text on the splash screen
pyi_splash.update_text("Khoj setup complete")
# Close Splash Screen
@@ -167,5 +169,5 @@ class ServerThread(QThread):
start_server(self.app, self.host, self.port, self.socket)
if __name__ == '__main__':
if __name__ == "__main__":
run()

View File

@@ -19,31 +19,27 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
# Setup Prompt based on Summary Type
if summary_type == "chat":
prompt = f'''
prompt = f"""
You are an AI. Summarize the conversation below from your perspective:
{text}
Summarize the conversation from the AI's first-person perspective:'''
Summarize the conversation from the AI's first-person perspective:"""
elif summary_type == "notes":
prompt = f'''
prompt = f"""
Summarize the below notes about {user_query}:
{text}
Summarize the notes in second person perspective:'''
Summarize the notes in second person perspective:"""
# Get Response from GPT
response = openai.Completion.create(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
stop="\"\"\"")
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop='"""'
)
# Extract, Clean Message from GPT's Response
story = response['choices'][0]['text']
story = response["choices"][0]["text"]
return str(story).replace("\n\n", "")
@@ -53,7 +49,7 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
"""
# Initialize Variables
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
understand_primer = '''
understand_primer = """
Objective: Extract search type from user query and return information as JSON
Allowed search types are listed below:
@@ -73,7 +69,7 @@ A:{ "search-type": "notes" }
Q: When did I buy Groceries last?
A:{ "search-type": "ledger" }
Q:When did I go surfing last?
A:{ "search-type": "notes" }'''
A:{ "search-type": "notes" }"""
# Setup Prompt with Understand Primer
prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
@@ -82,15 +78,11 @@ A:{ "search-type": "notes" }'''
# Get Response from GPT
response = openai.Completion.create(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
stop=["\n"])
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
)
# Extract, Clean Message from GPT's Response
story = str(response['choices'][0]['text'])
story = str(response["choices"][0]["text"])
return json.loads(story.strip(empty_escape_sequences))
@@ -100,7 +92,7 @@ def understand(text, model, api_key=None, temperature=0.5, max_tokens=100, verbo
"""
# Initialize Variables
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
understand_primer = '''
understand_primer = """
Objective: Extract intent and trigger emotion information as JSON from each chat message
Potential intent types and valid argument values are listed below:
@@ -142,7 +134,7 @@ A: { "intent": {"type": "remember", "memory-type": "notes", "query": "recommend
Q: When did I go surfing last?
A: { "intent": {"type": "remember", "memory-type": "notes", "query": "When did I go surfing last"}, "trigger-emotion": "calm" }
Q: Can you dance for me?
A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance for me?"}, "trigger-emotion": "sad" }'''
A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance for me?"}, "trigger-emotion": "sad" }"""
# Setup Prompt with Understand Primer
prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
@@ -151,15 +143,11 @@ A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance
# Get Response from GPT
response = openai.Completion.create(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
stop=["\n"])
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
)
# Extract, Clean Message from GPT's Response
story = str(response['choices'][0]['text'])
story = str(response["choices"][0]["text"])
return json.loads(story.strip(empty_escape_sequences))
@@ -171,15 +159,15 @@ def converse(text, model, conversation_history=None, api_key=None, temperature=0
max_words = 500
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
conversation_primer = f'''
conversation_primer = f"""
The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and a very friendly companion.
Human: Hello, who are you?
AI: Hi, I am an AI conversational companion created by OpenAI. How can I help you today?'''
AI: Hi, I am an AI conversational companion created by OpenAI. How can I help you today?"""
# Setup Prompt with Primer or Conversation History
prompt = message_to_prompt(text, conversation_history or conversation_primer)
prompt = ' '.join(prompt.split()[:max_words])
prompt = " ".join(prompt.split()[:max_words])
# Get Response from GPT
response = openai.Completion.create(
@@ -188,14 +176,17 @@ AI: Hi, I am an AI conversational companion created by OpenAI. How can I help yo
temperature=temperature,
max_tokens=max_tokens,
presence_penalty=0.6,
stop=["\n", "Human:", "AI:"])
stop=["\n", "Human:", "AI:"],
)
# Extract, Clean Message from GPT's Response
story = str(response['choices'][0]['text'])
story = str(response["choices"][0]["text"])
return story.strip(empty_escape_sequences)
def message_to_prompt(user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"):
def message_to_prompt(
user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"
):
"""Create prompt for GPT from messages and conversation history"""
gpt_message = f" {gpt_message}" if gpt_message else ""
@@ -205,12 +196,8 @@ def message_to_prompt(user_message, conversation_history="", gpt_message=None, s
def message_to_log(user_message, gpt_message, user_message_metadata={}, conversation_log=[]):
"""Create json logs from messages, metadata for conversation log"""
default_user_message_metadata = {
"intent": {
"type": "remember",
"memory-type": "notes",
"query": user_message
},
"trigger-emotion": "calm"
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
"trigger-emotion": "calm",
}
current_dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -229,5 +216,4 @@ def message_to_log(user_message, gpt_message, user_message_metadata={}, conversa
def extract_summaries(metadata):
"""Extract summaries from metadata"""
return ''.join(
[f'\n{session["summary"]}' for session in metadata])
return "".join([f'\n{session["summary"]}' for session in metadata])

View File

@@ -19,7 +19,11 @@ class BeancountToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=None):
# Extract required fields from config
beancount_files, beancount_file_filter, output_file = self.config.input_files, self.config.input_filter,self.config.compressed_jsonl
beancount_files, beancount_file_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
# Input Validation
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
@@ -31,7 +35,9 @@ class BeancountToJsonl(TextToJsonl):
# Extract Entries from specified Beancount files
with timer("Parse transactions from Beancount files into dictionaries", logger):
current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files))
current_entries = BeancountToJsonl.convert_transactions_to_maps(
*BeancountToJsonl.extract_beancount_transactions(beancount_files)
)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@@ -42,7 +48,9 @@ class BeancountToJsonl(TextToJsonl):
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
entries_with_ids = self.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
with timer("Write transactions to JSONL file", logger):
# Process Each Entry from All Notes Files
@@ -62,9 +70,7 @@ class BeancountToJsonl(TextToJsonl):
"Get Beancount files to process"
absolute_beancount_files, filtered_beancount_files = set(), set()
if beancount_files:
absolute_beancount_files = {get_absolute_path(beancount_file)
for beancount_file
in beancount_files}
absolute_beancount_files = {get_absolute_path(beancount_file) for beancount_file in beancount_files}
if beancount_file_filters:
filtered_beancount_files = {
filtered_file
@@ -76,14 +82,13 @@ class BeancountToJsonl(TextToJsonl):
files_with_non_beancount_extensions = {
beancount_file
for beancount_file
in all_beancount_files
for beancount_file in all_beancount_files
if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount")
}
if any(files_with_non_beancount_extensions):
print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}")
logger.info(f'Processing files: {all_beancount_files}')
logger.info(f"Processing files: {all_beancount_files}")
return all_beancount_files
@@ -92,19 +97,20 @@ class BeancountToJsonl(TextToJsonl):
"Extract entries from specified Beancount files"
# Initialize Regex for extracting Beancount Entries
transaction_regex = r'^\n?\d{4}-\d{2}-\d{2} [\*|\!] '
empty_newline = f'^[\n\r\t\ ]*$'
transaction_regex = r"^\n?\d{4}-\d{2}-\d{2} [\*|\!] "
empty_newline = f"^[\n\r\t\ ]*$"
entries = []
transaction_to_file_map = []
for beancount_file in beancount_files:
with open(beancount_file) as f:
ledger_content = f.read()
transactions_per_file = [entry.strip(empty_escape_sequences)
for entry
in re.split(empty_newline, ledger_content, flags=re.MULTILINE)
if re.match(transaction_regex, entry)]
transaction_to_file_map += zip(transactions_per_file, [beancount_file]*len(transactions_per_file))
transactions_per_file = [
entry.strip(empty_escape_sequences)
for entry in re.split(empty_newline, ledger_content, flags=re.MULTILINE)
if re.match(transaction_regex, entry)
]
transaction_to_file_map += zip(transactions_per_file, [beancount_file] * len(transactions_per_file))
entries.extend(transactions_per_file)
return entries, dict(transaction_to_file_map)
@@ -113,7 +119,9 @@ class BeancountToJsonl(TextToJsonl):
"Convert each parsed Beancount transaction into a Entry"
entries = []
for parsed_entry in parsed_entries:
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{transaction_to_file_map[parsed_entry]}'))
entries.append(
Entry(compiled=parsed_entry, raw=parsed_entry, file=f"{transaction_to_file_map[parsed_entry]}")
)
logger.info(f"Converted {len(parsed_entries)} transactions to dictionaries")
@@ -122,4 +130,4 @@ class BeancountToJsonl(TextToJsonl):
@staticmethod
def convert_transaction_maps_to_jsonl(entries: List[Entry]) -> str:
"Convert each Beancount transaction entry to JSON and collate as JSONL"
return ''.join([f'{entry.to_json()}\n' for entry in entries])
return "".join([f"{entry.to_json()}\n" for entry in entries])

View File

@@ -20,7 +20,11 @@ class MarkdownToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=None):
# Extract required fields from config
markdown_files, markdown_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
markdown_files, markdown_file_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
# Input Validation
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
@@ -32,7 +36,9 @@ class MarkdownToJsonl(TextToJsonl):
# Extract Entries from specified Markdown files
with timer("Parse entries from Markdown files into dictionaries", logger):
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files))
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(
*MarkdownToJsonl.extract_markdown_entries(markdown_files)
)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
@@ -43,7 +49,9 @@ class MarkdownToJsonl(TextToJsonl):
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
entries_with_ids = self.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
with timer("Write markdown entries to JSONL file", logger):
# Process Each Entry from All Notes Files
@@ -75,15 +83,16 @@ class MarkdownToJsonl(TextToJsonl):
files_with_non_markdown_extensions = {
md_file
for md_file
in all_markdown_files
if not md_file.endswith(".md") and not md_file.endswith('.markdown')
for md_file in all_markdown_files
if not md_file.endswith(".md") and not md_file.endswith(".markdown")
}
if any(files_with_non_markdown_extensions):
logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
logger.warn(
f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}"
)
logger.info(f'Processing files: {all_markdown_files}')
logger.info(f"Processing files: {all_markdown_files}")
return all_markdown_files
@@ -92,20 +101,20 @@ class MarkdownToJsonl(TextToJsonl):
"Extract entries by heading from specified Markdown files"
# Regex to extract Markdown Entries by Heading
markdown_heading_regex = r'^#'
markdown_heading_regex = r"^#"
entries = []
entry_to_file_map = []
for markdown_file in markdown_files:
with open(markdown_file, 'r', encoding='utf8') as f:
with open(markdown_file, "r", encoding="utf8") as f:
markdown_content = f.read()
markdown_entries_per_file = []
for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE):
prefix = '#' if entry.startswith('#') else '# '
if entry.strip(empty_escape_sequences) != '':
markdown_entries_per_file.append(f'{prefix}{entry.strip(empty_escape_sequences)}')
prefix = "#" if entry.startswith("#") else "# "
if entry.strip(empty_escape_sequences) != "":
markdown_entries_per_file.append(f"{prefix}{entry.strip(empty_escape_sequences)}")
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file]*len(markdown_entries_per_file))
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file] * len(markdown_entries_per_file))
entries.extend(markdown_entries_per_file)
return entries, dict(entry_to_file_map)
@@ -115,7 +124,7 @@ class MarkdownToJsonl(TextToJsonl):
"Convert each Markdown entries into a dictionary"
entries = []
for parsed_entry in parsed_entries:
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{entry_to_file_map[parsed_entry]}'))
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f"{entry_to_file_map[parsed_entry]}"))
logger.info(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
@@ -124,4 +133,4 @@ class MarkdownToJsonl(TextToJsonl):
@staticmethod
def convert_markdown_maps_to_jsonl(entries: List[Entry]):
"Convert each Markdown entry to JSON and collate as JSONL"
return ''.join([f'{entry.to_json()}\n' for entry in entries])
return "".join([f"{entry.to_json()}\n" for entry in entries])

View File

@@ -18,9 +18,13 @@ logger = logging.getLogger(__name__)
class OrgToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries: List[Entry]=None):
def process(self, previous_entries: List[Entry] = None):
# Extract required fields from config
org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
org_files, org_file_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
index_heading_entries = self.config.index_heading_entries
# Input Validation
@@ -46,7 +50,9 @@ class OrgToJsonl(TextToJsonl):
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
entries_with_ids = self.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
# Process Each Entry from All Notes Files
with timer("Write org entries to JSONL file", logger):
@@ -66,11 +72,7 @@ class OrgToJsonl(TextToJsonl):
"Get Org files to process"
absolute_org_files, filtered_org_files = set(), set()
if org_files:
absolute_org_files = {
get_absolute_path(org_file)
for org_file
in org_files
}
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
if org_file_filters:
filtered_org_files = {
filtered_file
@@ -84,7 +86,7 @@ class OrgToJsonl(TextToJsonl):
if any(files_with_non_org_extensions):
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
logger.info(f'Processing files: {all_org_files}')
logger.info(f"Processing files: {all_org_files}")
return all_org_files
@@ -95,13 +97,15 @@ class OrgToJsonl(TextToJsonl):
entry_to_file_map = []
for org_file in org_files:
org_file_entries = orgnode.makelist(str(org_file))
entry_to_file_map += zip(org_file_entries, [org_file]*len(org_file_entries))
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
entries.extend(org_file_entries)
return entries, dict(entry_to_file_map)
@staticmethod
def convert_org_nodes_to_entries(parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> List[Entry]:
def convert_org_nodes_to_entries(
parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False
) -> List[Entry]:
"Convert Org-Mode nodes into list of Entry objects"
entries: List[Entry] = []
for parsed_entry in parsed_entries:
@@ -109,13 +113,13 @@ class OrgToJsonl(TextToJsonl):
# Ignore title notes i.e notes with just headings and empty body
continue
compiled = f'{parsed_entry.heading}.'
compiled = f"{parsed_entry.heading}."
if state.verbose > 2:
logger.debug(f"Title: {parsed_entry.heading}")
if parsed_entry.tags:
tags_str = " ".join(parsed_entry.tags)
compiled += f'\t {tags_str}.'
compiled += f"\t {tags_str}."
if state.verbose > 2:
logger.debug(f"Tags: {tags_str}")
@@ -130,19 +134,16 @@ class OrgToJsonl(TextToJsonl):
logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}')
if parsed_entry.hasBody:
compiled += f'\n {parsed_entry.body}'
compiled += f"\n {parsed_entry.body}"
if state.verbose > 2:
logger.debug(f"Body: {parsed_entry.body}")
if compiled:
entries += [Entry(
compiled=compiled,
raw=f'{parsed_entry}',
file=f'{entry_to_file_map[parsed_entry]}')]
entries += [Entry(compiled=compiled, raw=f"{parsed_entry}", file=f"{entry_to_file_map[parsed_entry]}")]
return entries
@staticmethod
def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str:
"Convert each Org-Mode entry to JSON and collate as JSONL"
return ''.join([f'{entry_dict.to_json()}\n' for entry_dict in entries])
return "".join([f"{entry_dict.to_json()}\n" for entry_dict in entries])

View File

@@ -39,182 +39,197 @@ from pathlib import Path
from os.path import relpath
from typing import List
indent_regex = re.compile(r'^ *')
indent_regex = re.compile(r"^ *")
def normalize_filename(filename):
"Normalize and escape filename for rendering"
if not Path(filename).is_absolute():
# Normalize relative filename to be relative to current directory
normalized_filename = f'~/{relpath(filename, start=Path.home())}'
else:
normalized_filename = filename
escaped_filename = f'{normalized_filename}'.replace("[","\[").replace("]","\]")
return escaped_filename
"Normalize and escape filename for rendering"
if not Path(filename).is_absolute():
# Normalize relative filename to be relative to current directory
normalized_filename = f"~/{relpath(filename, start=Path.home())}"
else:
normalized_filename = filename
escaped_filename = f"{normalized_filename}".replace("[", "\[").replace("]", "\]")
return escaped_filename
def makelist(filename):
"""
Read an org-mode file and return a list of Orgnode objects
created from this file.
"""
ctr = 0
"""
Read an org-mode file and return a list of Orgnode objects
created from this file.
"""
ctr = 0
f = open(filename, 'r')
f = open(filename, "r")
todos = { "TODO": "", "WAITING": "", "ACTIVE": "",
"DONE": "", "CANCELLED": "", "FAILED": ""} # populated from #+SEQ_TODO line
level = ""
heading = ""
bodytext = ""
tags = list() # set of all tags in headline
closed_date = ''
sched_date = ''
deadline_date = ''
logbook = list()
nodelist: List[Orgnode] = list()
property_map = dict()
in_properties_drawer = False
in_logbook_drawer = False
file_title = f'{filename}'
todos = {
"TODO": "",
"WAITING": "",
"ACTIVE": "",
"DONE": "",
"CANCELLED": "",
"FAILED": "",
} # populated from #+SEQ_TODO line
level = ""
heading = ""
bodytext = ""
tags = list() # set of all tags in headline
closed_date = ""
sched_date = ""
deadline_date = ""
logbook = list()
nodelist: List[Orgnode] = list()
property_map = dict()
in_properties_drawer = False
in_logbook_drawer = False
file_title = f"{filename}"
for line in f:
ctr += 1
heading_search = re.search(r'^(\*+)\s(.*?)\s*$', line)
if heading_search: # we are processing a heading line
if heading: # if we have are on second heading, append first heading to headings list
thisNode = Orgnode(level, heading, bodytext, tags)
if closed_date:
thisNode.closed = closed_date
closed_date = ''
if sched_date:
thisNode.scheduled = sched_date
sched_date = ""
if deadline_date:
thisNode.deadline = deadline_date
deadline_date = ''
if logbook:
thisNode.logbook = logbook
logbook = list()
thisNode.properties = property_map
nodelist.append( thisNode )
property_map = {'LINE': f'file:{normalize_filename(filename)}::{ctr}'}
level = heading_search.group(1)
heading = heading_search.group(2)
bodytext = ""
tags = list() # set of all tags in headline
tag_search = re.search(r'(.*?)\s*:([a-zA-Z0-9].*?):$',heading)
if tag_search:
heading = tag_search.group(1)
parsedtags = tag_search.group(2)
if parsedtags:
for parsedtag in parsedtags.split(':'):
if parsedtag != '': tags.append(parsedtag)
else: # we are processing a non-heading line
if line[:10] == '#+SEQ_TODO':
kwlist = re.findall(r'([A-Z]+)\(', line)
for kw in kwlist: todos[kw] = ""
for line in f:
ctr += 1
heading_search = re.search(r"^(\*+)\s(.*?)\s*$", line)
if heading_search: # we are processing a heading line
if heading: # if we have are on second heading, append first heading to headings list
thisNode = Orgnode(level, heading, bodytext, tags)
if closed_date:
thisNode.closed = closed_date
closed_date = ""
if sched_date:
thisNode.scheduled = sched_date
sched_date = ""
if deadline_date:
thisNode.deadline = deadline_date
deadline_date = ""
if logbook:
thisNode.logbook = logbook
logbook = list()
thisNode.properties = property_map
nodelist.append(thisNode)
property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"}
level = heading_search.group(1)
heading = heading_search.group(2)
bodytext = ""
tags = list() # set of all tags in headline
tag_search = re.search(r"(.*?)\s*:([a-zA-Z0-9].*?):$", heading)
if tag_search:
heading = tag_search.group(1)
parsedtags = tag_search.group(2)
if parsedtags:
for parsedtag in parsedtags.split(":"):
if parsedtag != "":
tags.append(parsedtag)
else: # we are processing a non-heading line
if line[:10] == "#+SEQ_TODO":
kwlist = re.findall(r"([A-Z]+)\(", line)
for kw in kwlist:
todos[kw] = ""
# Set file title to TITLE property, if it exists
title_search = re.search(r'^#\+TITLE:\s*(.*)$', line)
if title_search and title_search.group(1).strip() != '':
title_text = title_search.group(1).strip()
if file_title == f'{filename}':
file_title = title_text
else:
file_title += f' {title_text}'
continue
# Set file title to TITLE property, if it exists
title_search = re.search(r"^#\+TITLE:\s*(.*)$", line)
if title_search and title_search.group(1).strip() != "":
title_text = title_search.group(1).strip()
if file_title == f"{filename}":
file_title = title_text
else:
file_title += f" {title_text}"
continue
# Ignore Properties Drawers Completely
if re.search(':PROPERTIES:', line):
in_properties_drawer=True
continue
if in_properties_drawer and re.search(':END:', line):
in_properties_drawer=False
continue
# Ignore Properties Drawers Completely
if re.search(":PROPERTIES:", line):
in_properties_drawer = True
continue
if in_properties_drawer and re.search(":END:", line):
in_properties_drawer = False
continue
# Ignore Logbook Drawer Start, End Lines
if re.search(':LOGBOOK:', line):
in_logbook_drawer=True
continue
if in_logbook_drawer and re.search(':END:', line):
in_logbook_drawer=False
continue
# Ignore Logbook Drawer Start, End Lines
if re.search(":LOGBOOK:", line):
in_logbook_drawer = True
continue
if in_logbook_drawer and re.search(":END:", line):
in_logbook_drawer = False
continue
# Extract Clocking Lines
clocked_re = re.search(r'CLOCK:\s*\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]--\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]', line)
if clocked_re:
# convert clock in, clock out strings to datetime objects
clocked_in = datetime.datetime.strptime(clocked_re.group(1), '%Y-%m-%d %a %H:%M')
clocked_out = datetime.datetime.strptime(clocked_re.group(2), '%Y-%m-%d %a %H:%M')
# add clocked time to the entries logbook list
logbook += [(clocked_in, clocked_out)]
line = ""
# Extract Clocking Lines
clocked_re = re.search(
r"CLOCK:\s*\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]--\[([0-9]{4}-[0-9]{2}-[0-9]{2} [a-zA-Z]{3} [0-9]{2}:[0-9]{2})\]",
line,
)
if clocked_re:
# convert clock in, clock out strings to datetime objects
clocked_in = datetime.datetime.strptime(clocked_re.group(1), "%Y-%m-%d %a %H:%M")
clocked_out = datetime.datetime.strptime(clocked_re.group(2), "%Y-%m-%d %a %H:%M")
# add clocked time to the entries logbook list
logbook += [(clocked_in, clocked_out)]
line = ""
property_search = re.search(r'^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$', line)
if property_search:
# Set ID property to an id based org-mode link to the entry
if property_search.group(1) == 'ID':
property_map['ID'] = f'id:{property_search.group(2)}'
else:
property_map[property_search.group(1)] = property_search.group(2)
continue
property_search = re.search(r"^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$", line)
if property_search:
# Set ID property to an id based org-mode link to the entry
if property_search.group(1) == "ID":
property_map["ID"] = f"id:{property_search.group(2)}"
else:
property_map[property_search.group(1)] = property_search.group(2)
continue
cd_re = re.search(r'CLOSED:\s*\[([0-9]{4})-([0-9]{2})-([0-9]{2})', line)
if cd_re:
closed_date = datetime.date(int(cd_re.group(1)),
int(cd_re.group(2)),
int(cd_re.group(3)) )
sd_re = re.search(r'SCHEDULED:\s*<([0-9]+)\-([0-9]+)\-([0-9]+)', line)
if sd_re:
sched_date = datetime.date(int(sd_re.group(1)),
int(sd_re.group(2)),
int(sd_re.group(3)) )
dd_re = re.search(r'DEADLINE:\s*<(\d+)\-(\d+)\-(\d+)', line)
if dd_re:
deadline_date = datetime.date(int(dd_re.group(1)),
int(dd_re.group(2)),
int(dd_re.group(3)) )
cd_re = re.search(r"CLOSED:\s*\[([0-9]{4})-([0-9]{2})-([0-9]{2})", line)
if cd_re:
closed_date = datetime.date(int(cd_re.group(1)), int(cd_re.group(2)), int(cd_re.group(3)))
sd_re = re.search(r"SCHEDULED:\s*<([0-9]+)\-([0-9]+)\-([0-9]+)", line)
if sd_re:
sched_date = datetime.date(int(sd_re.group(1)), int(sd_re.group(2)), int(sd_re.group(3)))
dd_re = re.search(r"DEADLINE:\s*<(\d+)\-(\d+)\-(\d+)", line)
if dd_re:
deadline_date = datetime.date(int(dd_re.group(1)), int(dd_re.group(2)), int(dd_re.group(3)))
# Ignore property drawer, scheduled, closed, deadline, logbook entries and # lines from body
if not in_properties_drawer and not cd_re and not sd_re and not dd_re and not clocked_re and line[:1] != '#':
bodytext = bodytext + line
# Ignore property drawer, scheduled, closed, deadline, logbook entries and # lines from body
if (
not in_properties_drawer
and not cd_re
and not sd_re
and not dd_re
and not clocked_re
and line[:1] != "#"
):
bodytext = bodytext + line
# write out last node
thisNode = Orgnode(level, heading or file_title, bodytext, tags)
thisNode.properties = property_map
if sched_date:
thisNode.scheduled = sched_date
if deadline_date:
thisNode.deadline = deadline_date
if closed_date:
thisNode.closed = closed_date
if logbook:
thisNode.logbook = logbook
nodelist.append( thisNode )
# write out last node
thisNode = Orgnode(level, heading or file_title, bodytext, tags)
thisNode.properties = property_map
if sched_date:
thisNode.scheduled = sched_date
if deadline_date:
thisNode.deadline = deadline_date
if closed_date:
thisNode.closed = closed_date
if logbook:
thisNode.logbook = logbook
nodelist.append(thisNode)
# using the list of TODO keywords found in the file
# process the headings searching for TODO keywords
for n in nodelist:
todo_search = re.search(r'([A-Z]+)\s(.*?)$', n.heading)
if todo_search:
if todo_search.group(1) in todos:
n.heading = todo_search.group(2)
n.todo = todo_search.group(1)
# using the list of TODO keywords found in the file
# process the headings searching for TODO keywords
for n in nodelist:
todo_search = re.search(r"([A-Z]+)\s(.*?)$", n.heading)
if todo_search:
if todo_search.group(1) in todos:
n.heading = todo_search.group(2)
n.todo = todo_search.group(1)
# extract, set priority from heading, update heading if necessary
priority_search = re.search(r'^\[\#(A|B|C)\] (.*?)$', n.heading)
if priority_search:
n.priority = priority_search.group(1)
n.heading = priority_search.group(2)
# extract, set priority from heading, update heading if necessary
priority_search = re.search(r"^\[\#(A|B|C)\] (.*?)$", n.heading)
if priority_search:
n.priority = priority_search.group(1)
n.heading = priority_search.group(2)
# Set SOURCE property to a file+heading based org-mode link to the entry
if n.level == 0:
n.properties['LINE'] = f'file:{normalize_filename(filename)}::0'
n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}]]'
else:
escaped_heading = n.heading.replace("[","\\[").replace("]","\\]")
n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}::*{escaped_heading}]]'
# Set SOURCE property to a file+heading based org-mode link to the entry
if n.level == 0:
n.properties["LINE"] = f"file:{normalize_filename(filename)}::0"
n.properties["SOURCE"] = f"[[file:{normalize_filename(filename)}]]"
else:
escaped_heading = n.heading.replace("[", "\\[").replace("]", "\\]")
n.properties["SOURCE"] = f"[[file:{normalize_filename(filename)}::*{escaped_heading}]]"
return nodelist
return nodelist
######################
class Orgnode(object):
@@ -222,6 +237,7 @@ class Orgnode(object):
Orgnode class represents a headline, tags and text associated
with the headline.
"""
def __init__(self, level, headline, body, tags):
"""
Create an Orgnode object given the parameters of level (as the
@@ -232,14 +248,14 @@ class Orgnode(object):
self._level = len(level)
self._heading = headline
self._body = body
self._tags = tags # All tags in the headline
self._tags = tags # All tags in the headline
self._todo = ""
self._priority = "" # empty of A, B or C
self._scheduled = "" # Scheduled date
self._deadline = "" # Deadline date
self._closed = "" # Closed date
self._priority = "" # empty of A, B or C
self._scheduled = "" # Scheduled date
self._deadline = "" # Deadline date
self._closed = "" # Closed date
self._properties = dict()
self._logbook = list() # List of clock-in, clock-out tuples representing logbook entries
self._logbook = list() # List of clock-in, clock-out tuples representing logbook entries
# Look for priority in headline and transfer to prty field
@@ -270,7 +286,7 @@ class Orgnode(object):
"""
Returns True if node has non empty body, else False
"""
return self._body and re.sub(r'\n|\t|\r| ', '', self._body) != ''
return self._body and re.sub(r"\n|\t|\r| ", "", self._body) != ""
@property
def level(self):
@@ -417,20 +433,20 @@ class Orgnode(object):
text as used to construct the node.
"""
# Output heading line
n = ''
n = ""
for _ in range(0, self._level):
n = n + '*'
n = n + ' '
n = n + "*"
n = n + " "
if self._todo:
n = n + self._todo + ' '
n = n + self._todo + " "
if self._priority:
n = n + '[#' + self._priority + '] '
n = n + "[#" + self._priority + "] "
n = n + self._heading
n = "%-60s " % n # hack - tags will start in column 62
closecolon = ''
n = "%-60s " % n # hack - tags will start in column 62
closecolon = ""
for t in self._tags:
n = n + ':' + t
closecolon = ':'
n = n + ":" + t
closecolon = ":"
n = n + closecolon
n = n + "\n"
@@ -439,24 +455,24 @@ class Orgnode(object):
# Output Closed Date, Scheduled Date, Deadline Date
if self._closed or self._scheduled or self._deadline:
n = n + indent
n = n + indent
if self._closed:
n = n + f'CLOSED: [{self._closed.strftime("%Y-%m-%d %a")}] '
n = n + f'CLOSED: [{self._closed.strftime("%Y-%m-%d %a")}] '
if self._scheduled:
n = n + f'SCHEDULED: <{self._scheduled.strftime("%Y-%m-%d %a")}> '
n = n + f'SCHEDULED: <{self._scheduled.strftime("%Y-%m-%d %a")}> '
if self._deadline:
n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> '
n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> '
if self._closed or self._scheduled or self._deadline:
n = n + '\n'
n = n + "\n"
# Ouput Property Drawer
n = n + indent + ":PROPERTIES:\n"
for key, value in self._properties.items():
n = n + indent + f":{key}: {value}\n"
n = n + indent + f":{key}: {value}\n"
n = n + indent + ":END:\n"
# Output Body
if self.hasBody:
n = n + self._body
n = n + self._body
return n

View File

@@ -17,14 +17,17 @@ class TextToJsonl(ABC):
self.config = config
@abstractmethod
def process(self, previous_entries: List[Entry]=None) -> List[Tuple[int, Entry]]: ...
def process(self, previous_entries: List[Entry] = None) -> List[Tuple[int, Entry]]:
...
@staticmethod
def hash_func(key: str) -> Callable:
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest()
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest()
@staticmethod
def split_entries_by_max_tokens(entries: List[Entry], max_tokens: int=256, max_word_length: int=500) -> List[Entry]:
def split_entries_by_max_tokens(
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500
) -> List[Entry]:
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
chunked_entries: List[Entry] = []
for entry in entries:
@@ -32,13 +35,15 @@ class TextToJsonl(ABC):
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
compiled_entry_words_chunk = compiled_entry_words[chunk_index:chunk_index + max_tokens]
compiled_entry_chunk = ' '.join(compiled_entry_words_chunk)
compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens]
compiled_entry_chunk = " ".join(compiled_entry_words_chunk)
entry_chunk = Entry(compiled=compiled_entry_chunk, raw=entry.raw, file=entry.file)
chunked_entries.append(entry_chunk)
return chunked_entries
def mark_entries_for_update(self, current_entries: List[Entry], previous_entries: List[Entry], key='compiled', logger=None) -> List[Tuple[int, Entry]]:
def mark_entries_for_update(
self, current_entries: List[Entry], previous_entries: List[Entry], key="compiled", logger=None
) -> List[Tuple[int, Entry]]:
# Hash all current and previous entries to identify new entries
with timer("Hash previous, current entries", logger):
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
@@ -54,10 +59,7 @@ class TextToJsonl(ABC):
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# Mark new entries with -1 id to flag for later embeddings generation
new_entries = [
(-1, hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes
]
new_entries = [(-1, hash_to_current_entries[entry_hash]) for entry_hash in new_entry_hashes]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
@@ -67,4 +69,4 @@ class TextToJsonl(ABC):
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
entries_with_ids = existing_entries_sorted + new_entries
return entries_with_ids
return entries_with_ids

View File

@@ -22,27 +22,30 @@ logger = logging.getLogger(__name__)
# Create Routes
@api.get('/config/data/default')
@api.get("/config/data/default")
def get_default_config_data():
return constants.default_config
@api.get('/config/data', response_model=FullConfig)
@api.get("/config/data", response_model=FullConfig)
def get_config_data():
return state.config
@api.post('/config/data')
@api.post("/config/data")
async def set_config_data(updated_config: FullConfig):
state.config = updated_config
with open(state.config_file, 'w') as outfile:
with open(state.config_file, "w") as outfile:
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
outfile.close()
return state.config
@api.get('/search', response_model=List[SearchResponse])
@api.get("/search", response_model=List[SearchResponse])
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
results: List[SearchResponse] = []
if q is None or q == '':
logger.info(f'No query param (q) passed in API call to initiate search')
if q is None or q == "":
logger.info(f"No query param (q) passed in API call to initiate search")
return results
# initialize variables
@@ -50,9 +53,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
results_count = n
# return cached results, if available
query_cache_key = f'{user_query}-{n}-{t}-{r}'
query_cache_key = f"{user_query}-{n}-{t}-{r}"
if query_cache_key in state.query_cache:
logger.info(f'Return response from query cache')
logger.info(f"Return response from query cache")
return state.query_cache[query_cache_key]
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
@@ -95,7 +98,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
# query images
with timer("Query took", logger):
hits = image_search.query(user_query, results_count, state.model.image_search)
output_directory = constants.web_directory / 'images'
output_directory = constants.web_directory / "images"
# collate and return results
with timer("Collating results took", logger):
@@ -103,8 +106,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
hits,
image_names=state.model.image_search.image_names,
output_directory=output_directory,
image_files_url='/static/images',
count=results_count)
image_files_url="/static/images",
count=results_count,
)
# Cache results
state.query_cache[query_cache_key] = results
@@ -112,7 +116,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
return results
@api.get('/update')
@api.get("/update")
def update(t: Optional[SearchType] = None, force: Optional[bool] = False):
try:
state.search_index_lock.acquire()
@@ -132,4 +136,4 @@ def update(t: Optional[SearchType] = None, force: Optional[bool] = False):
else:
logger.info("Processor reconfigured via API call")
return {'status': 'ok', 'message': 'khoj reloaded'}
return {"status": "ok", "message": "khoj reloaded"}

View File

@@ -9,7 +9,14 @@ from fastapi import APIRouter
# Internal Packages
from khoj.routers.api import search
from khoj.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from khoj.processor.conversation.gpt import (
converse,
extract_search_type,
message_to_log,
message_to_prompt,
understand,
summarize,
)
from khoj.utils.config import SearchType
from khoj.utils.helpers import get_from_dict, resolve_absolute_path
from khoj.utils import state
@@ -21,7 +28,7 @@ logger = logging.getLogger(__name__)
# Create Routes
@api_beta.get('/search')
@api_beta.get("/search")
def search_beta(q: str, n: Optional[int] = 1):
# Initialize Variables
model = state.processor_config.conversation.model
@@ -32,16 +39,16 @@ def search_beta(q: str, n: Optional[int] = 1):
metadata = extract_search_type(q, model=model, api_key=api_key, verbose=state.verbose)
search_type = get_from_dict(metadata, "search-type")
except Exception as e:
return {'status': 'error', 'result': [str(e)], 'type': None}
return {"status": "error", "result": [str(e)], "type": None}
# Search
search_results = search(q, n=n, t=SearchType(search_type))
# Return response
return {'status': 'ok', 'result': search_results, 'type': search_type}
return {"status": "ok", "result": search_results, "type": search_type}
@api_beta.get('/summarize')
@api_beta.get("/summarize")
def summarize_beta(q: str):
# Initialize Variables
model = state.processor_config.conversation.model
@@ -54,23 +61,25 @@ def summarize_beta(q: str):
# Converse with OpenAI GPT
result_list = search(q, n=1, r=True)
collated_result = "\n".join([item.entry for item in result_list])
logger.debug(f'Semantically Similar Notes:\n{collated_result}')
logger.debug(f"Semantically Similar Notes:\n{collated_result}")
try:
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
status = 'ok'
status = "ok"
except Exception as e:
gpt_response = str(e)
status = 'error'
status = "error"
# Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log['chat'] = message_to_log(q, gpt_response, conversation_log=meta_log.get('chat', []))
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q, gpt_response, conversation_log=meta_log.get("chat", [])
)
return {'status': status, 'response': gpt_response}
return {"status": status, "response": gpt_response}
@api_beta.get('/chat')
def chat(q: Optional[str]=None):
@api_beta.get("/chat")
def chat(q: Optional[str] = None):
# Initialize Variables
model = state.processor_config.conversation.model
api_key = state.processor_config.conversation.openai_api_key
@@ -81,10 +90,10 @@ def chat(q: Optional[str]=None):
# If user query is empty, return chat history
if not q:
if meta_log.get('chat'):
return {'status': 'ok', 'response': meta_log["chat"]}
if meta_log.get("chat"):
return {"status": "ok", "response": meta_log["chat"]}
else:
return {'status': 'ok', 'response': []}
return {"status": "ok", "response": []}
# Converse with OpenAI GPT
metadata = understand(q, model=model, api_key=api_key, verbose=state.verbose)
@@ -94,32 +103,39 @@ def chat(q: Optional[str]=None):
query = get_from_dict(metadata, "intent", "query")
result_list = search(query, n=1, t=SearchType.Org, r=True)
collated_result = "\n".join([item.entry for item in result_list])
logger.debug(f'Semantically Similar Notes:\n{collated_result}')
logger.debug(f"Semantically Similar Notes:\n{collated_result}")
try:
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
status = 'ok'
status = "ok"
except Exception as e:
gpt_response = str(e)
status = 'error'
status = "error"
else:
try:
gpt_response = converse(q, model, chat_session, api_key=api_key)
status = 'ok'
status = "ok"
except Exception as e:
gpt_response = str(e)
status = 'error'
status = "error"
# Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log['chat'] = message_to_log(q, gpt_response, metadata, meta_log.get('chat', []))
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q, gpt_response, metadata, meta_log.get("chat", [])
)
return {'status': status, 'response': gpt_response}
return {"status": status, "response": gpt_response}
@schedule.repeat(schedule.every(5).minutes)
def save_chat_session():
# No need to create empty log file
if not (state.processor_config and state.processor_config.conversation and state.processor_config.conversation.meta_log and state.processor_config.conversation.chat_session):
if not (
state.processor_config
and state.processor_config.conversation
and state.processor_config.conversation.meta_log
and state.processor_config.conversation.chat_session
):
return
# Summarize Conversation Logs for this Session
@@ -130,19 +146,19 @@ def save_chat_session():
session = {
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key),
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"])
}
if 'session' in conversation_log:
conversation_log['session'].append(session)
"session-end": len(conversation_log["chat"]),
}
if "session" in conversation_log:
conversation_log["session"].append(session)
else:
conversation_log['session'] = [session]
logger.info('Added new chat session to conversation logs')
conversation_log["session"] = [session]
logger.info("Added new chat session to conversation logs")
# Save Conversation Metadata Logs to Disk
conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile)
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
with open(conversation_logfile, "w+", encoding='utf-8') as logfile:
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
json.dump(conversation_log, logfile)
state.processor_config.conversation.chat_session = None
logger.info('Saved updated conversation logs to disk.')
logger.info("Saved updated conversation logs to disk.")

View File

@@ -18,10 +18,12 @@ templates = Jinja2Templates(directory=constants.web_directory)
def index():
return FileResponse(constants.web_directory / "index.html")
@web_client.get('/config', response_class=HTMLResponse)
@web_client.get("/config", response_class=HTMLResponse)
def config_page(request: Request):
return templates.TemplateResponse("config.html", context={'request': request})
return templates.TemplateResponse("config.html", context={"request": request})
@web_client.get("/chat", response_class=FileResponse)
def chat_page():
return FileResponse(constants.web_directory / "chat.html")
return FileResponse(constants.web_directory / "chat.html")

View File

@@ -8,10 +8,13 @@ from khoj.utils.rawconfig import Entry
class BaseFilter(ABC):
@abstractmethod
def load(self, entries: List[Entry], *args, **kwargs): ...
def load(self, entries: List[Entry], *args, **kwargs):
...
@abstractmethod
def can_filter(self, raw_query:str) -> bool: ...
def can_filter(self, raw_query: str) -> bool:
...
@abstractmethod
def apply(self, query:str, entries: List[Entry]) -> Tuple[str, Set[int]]: ...
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
...

View File

@@ -26,21 +26,19 @@ class DateFilter(BaseFilter):
# - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
def __init__(self, entry_key='raw'):
def __init__(self, entry_key="raw"):
self.entry_key = entry_key
self.date_to_entry_ids = defaultdict(set)
self.cache = LRU()
def load(self, entries, *args, **kwargs):
with timer("Created date filter index", logger):
for id, entry in enumerate(entries):
# Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
for date_in_entry_string in re.findall(r"\d{4}-\d{2}-\d{2}", getattr(entry, self.entry_key)):
# Convert date string in entry to unix timestamp
try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
date_in_entry = datetime.strptime(date_in_entry_string, "%Y-%m-%d").timestamp()
except ValueError:
continue
self.date_to_entry_ids[date_in_entry].add(id)
@@ -49,7 +47,6 @@ class DateFilter(BaseFilter):
"Check if query contains date filters"
return self.extract_date_range(raw_query) is not None
def apply(self, query, entries):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
@@ -61,8 +58,8 @@ class DateFilter(BaseFilter):
return query, set(range(len(entries)))
# remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
query = re.sub(rf"\s+{self.date_regex}", " ", query)
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
# return results from cache if exists
cache_key = tuple(query_daterange)
@@ -87,7 +84,6 @@ class DateFilter(BaseFilter):
return query, entries_to_include
def extract_date_range(self, query):
# find date range filter in query
date_range_matches = re.findall(self.date_regex, query)
@@ -98,7 +94,7 @@ class DateFilter(BaseFilter):
# extract, parse natural dates ranges from date range filter passed in query
# e.g today maps to (start_of_day, start_of_tomorrow)
date_ranges_from_filter = []
for (cmp, date_str) in date_range_matches:
for cmp, date_str in date_range_matches:
if self.parse(date_str):
dt_start, dt_end = self.parse(date_str)
date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]]
@@ -111,15 +107,15 @@ class DateFilter(BaseFilter):
effective_date_range = [0, inf]
date_range_considering_comparator = []
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
if cmp == '>':
if cmp == ">":
date_range_considering_comparator += [[dtrange_end, inf]]
elif cmp == '>=':
elif cmp == ">=":
date_range_considering_comparator += [[dtrange_start, inf]]
elif cmp == '<':
elif cmp == "<":
date_range_considering_comparator += [[0, dtrange_start]]
elif cmp == '<=':
elif cmp == "<=":
date_range_considering_comparator += [[0, dtrange_end]]
elif cmp == '=' or cmp == ':' or cmp == '==':
elif cmp == "=" or cmp == ":" or cmp == "==":
date_range_considering_comparator += [[dtrange_start, dtrange_end]]
# Combine above intervals (via AND/intersect)
@@ -129,48 +125,48 @@ class DateFilter(BaseFilter):
for date_range in date_range_considering_comparator:
effective_date_range = [
max(effective_date_range[0], date_range[0]),
min(effective_date_range[1], date_range[1])]
min(effective_date_range[1], date_range[1]),
]
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
return None
else:
return effective_date_range
def parse(self, date_str, relative_base=None):
"Parse date string passed in date filter of query to datetime object"
# clean date string to handle future date parsing by date parser
future_strings = ['later', 'from now', 'from today']
prefer_dates_from = {True: 'future', False: 'past'}[any([True for fstr in future_strings if fstr in date_str])]
clean_date_str = re.sub('|'.join(future_strings), '', date_str)
future_strings = ["later", "from now", "from today"]
prefer_dates_from = {True: "future", False: "past"}[any([True for fstr in future_strings if fstr in date_str])]
clean_date_str = re.sub("|".join(future_strings), "", date_str)
# parse date passed in query date filter
parsed_date = dtparse.parse(
clean_date_str,
settings= {
'RELATIVE_BASE': relative_base or datetime.now(),
'PREFER_DAY_OF_MONTH': 'first',
'PREFER_DATES_FROM': prefer_dates_from
})
settings={
"RELATIVE_BASE": relative_base or datetime.now(),
"PREFER_DAY_OF_MONTH": "first",
"PREFER_DATES_FROM": prefer_dates_from,
},
)
if parsed_date is None:
return None
return self.date_to_daterange(parsed_date, date_str)
def date_to_daterange(self, parsed_date, date_str):
"Convert parsed date to date ranges at natural granularity (day, week, month or year)"
start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)
if 'year' in date_str:
return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year+1, 1, 1, 0, 0, 0))
if 'month' in date_str:
if "year" in date_str:
return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year + 1, 1, 1, 0, 0, 0))
if "month" in date_str:
start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0)
next_month = start_of_month + relativedelta(months=1)
return (start_of_month, next_month)
if 'week' in date_str:
if "week" in date_str:
# if week in date string, dateparser parses it to next week start
# so today = end of this week
start_of_week = start_of_day - timedelta(days=7)

View File

@@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
class FileFilter(BaseFilter):
file_filter_regex = r'file:"(.+?)" ?'
def __init__(self, entry_key='file'):
def __init__(self, entry_key="file"):
self.entry_key = entry_key
self.file_to_entry_map = defaultdict(set)
self.cache = LRU()
@@ -40,13 +40,13 @@ class FileFilter(BaseFilter):
# e.g. "file:notes.org" -> "file:.*notes.org"
files_to_search = []
for file in sorted(raw_files_to_search):
if '/' not in file and '\\' not in file and '*' not in file:
files_to_search += [f'*{file}']
if "/" not in file and "\\" not in file and "*" not in file:
files_to_search += [f"*{file}"]
else:
files_to_search += [file]
# Return item from cache if exists
query = re.sub(self.file_filter_regex, '', query).strip()
query = re.sub(self.file_filter_regex, "", query).strip()
cache_key = tuple(files_to_search)
if cache_key in self.cache:
logger.info(f"Return file filter results from cache")
@@ -58,10 +58,15 @@ class FileFilter(BaseFilter):
# Mark entries that contain any blocked_words for exclusion
with timer("Mark entries satisfying filter", logger):
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
included_entry_indices = set.union(
*[
self.file_to_entry_map[entry_file]
for entry_file in self.file_to_entry_map.keys()
for search_file in files_to_search
if fnmatch.fnmatch(entry_file, search_file)], set())
if fnmatch.fnmatch(entry_file, search_file)
],
set(),
)
if not included_entry_indices:
return query, {}

View File

@@ -17,26 +17,26 @@ class WordFilter(BaseFilter):
required_regex = r'\+"([a-zA-Z0-9_-]+)" ?'
blocked_regex = r'\-"([a-zA-Z0-9_-]+)" ?'
def __init__(self, entry_key='raw'):
def __init__(self, entry_key="raw"):
self.entry_key = entry_key
self.word_to_entry_index = defaultdict(set)
self.cache = LRU()
def load(self, entries, *args, **kwargs):
with timer("Created word filter index", logger):
self.cache = {} # Clear cache on filter (re-)load
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
entry_splitter = (
r",|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'"
)
# Create map of words to entries they exist in
for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
if word == '':
if word == "":
continue
self.word_to_entry_index[word].add(entry_index)
return self.word_to_entry_index
def can_filter(self, raw_query):
"Check if query contains word filters"
required_words = re.findall(self.required_regex, raw_query)
@@ -44,14 +44,13 @@ class WordFilter(BaseFilter):
return len(required_words) != 0 or len(blocked_words) != 0
def apply(self, query, entries):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from required, blocked words filters
with timer("Extract required, blocked filters from query", logger):
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
query = re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
if len(required_words) == 0 and len(blocked_words) == 0:
return query, set(range(len(entries)))
@@ -70,12 +69,16 @@ class WordFilter(BaseFilter):
with timer("Mark entries satisfying filter", logger):
entries_with_all_required_words = set(range(len(entries)))
if len(required_words) > 0:
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
entries_with_all_required_words = set.intersection(
*[self.word_to_entry_index.get(word, set()) for word in required_words]
)
# mark entries that contain any blocked_words for exclusion
entries_with_any_blocked_words = set()
if len(blocked_words) > 0:
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
entries_with_any_blocked_words = set.union(
*[self.word_to_entry_index.get(word, set()) for word in blocked_words]
)
# get entries satisfying inclusion and exclusion filters
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words

View File

@@ -35,9 +35,10 @@ def initialize_model(search_config: ImageSearchConfig):
# Load the CLIP model
encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = search_config.encoder_type or SentenceTransformer)
model_dir=search_config.model_directory,
model_name=search_config.encoder,
model_type=search_config.encoder_type or SentenceTransformer,
)
return encoder
@@ -46,12 +47,12 @@ def extract_entries(image_directories):
image_names = []
for image_directory in image_directories:
image_directory = resolve_absolute_path(image_directory, strict=True)
image_names.extend(list(image_directory.glob('*.jpg')))
image_names.extend(list(image_directory.glob('*.jpeg')))
image_names.extend(list(image_directory.glob("*.jpg")))
image_names.extend(list(image_directory.glob("*.jpeg")))
if logger.level >= logging.INFO:
image_directory_names = ', '.join([str(image_directory) for image_directory in image_directories])
logger.info(f'Found {len(image_names)} images in {image_directory_names}')
image_directory_names = ", ".join([str(image_directory) for image_directory in image_directories])
logger.info(f"Found {len(image_names)} images in {image_directory_names}")
return sorted(image_names)
@@ -59,7 +60,9 @@ def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate)
image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate)
image_metadata_embeddings = compute_metadata_embeddings(
image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate
)
return image_embeddings, image_metadata_embeddings
@@ -74,15 +77,12 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
image_embeddings = []
for index in trange(0, len(image_names), batch_size):
images = []
for image_name in image_names[index:index+batch_size]:
for image_name in image_names[index : index + batch_size]:
image = Image.open(image_name)
# Resize images to max width of 640px for faster processing
image.thumbnail((640, image.height))
images += [image]
image_embeddings += encoder.encode(
images,
convert_to_tensor=True,
batch_size=min(len(images), batch_size))
image_embeddings += encoder.encode(images, convert_to_tensor=True, batch_size=min(len(images), batch_size))
# Create directory for embeddings file, if it doesn't exist
embeddings_file.parent.mkdir(parents=True, exist_ok=True)
@@ -94,7 +94,9 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
return image_embeddings
def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0):
def compute_metadata_embeddings(
image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0
):
image_metadata_embeddings = None
# Load pre-computed image metadata embedding file if exists
@@ -106,14 +108,17 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
if use_xmp_metadata and image_metadata_embeddings is None:
image_metadata_embeddings = []
for index in trange(0, len(image_names), batch_size):
image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]]
image_metadata = [
extract_metadata(image_name, verbose) for image_name in image_names[index : index + batch_size]
]
try:
image_metadata_embeddings += encoder.encode(
image_metadata,
convert_to_tensor=True,
batch_size=min(len(image_metadata), batch_size))
image_metadata, convert_to_tensor=True, batch_size=min(len(image_metadata), batch_size)
)
except RuntimeError as e:
logger.error(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
logger.error(
f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}"
)
continue
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
logger.info(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
@@ -123,8 +128,10 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
def extract_metadata(image_name):
image_xmp_metadata = Image.open(image_name).getxmp()
image_description = get_from_dict(image_xmp_metadata, 'xmpmeta', 'RDF', 'Description', 'description', 'Alt', 'li', 'text')
image_subjects = get_from_dict(image_xmp_metadata, 'xmpmeta', 'RDF', 'Description', 'subject', 'Bag', 'li')
image_description = get_from_dict(
image_xmp_metadata, "xmpmeta", "RDF", "Description", "description", "Alt", "li", "text"
)
image_subjects = get_from_dict(image_xmp_metadata, "xmpmeta", "RDF", "Description", "subject", "Bag", "li")
image_metadata_subjects = set([subject.split(":")[1] for subject in image_subjects if ":" in subject])
image_processed_metadata = image_description
@@ -141,7 +148,7 @@ def query(raw_query, count, model: ImageSearchModel):
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
query = copy.deepcopy(Image.open(query_imagepath))
query.thumbnail((640, query.height)) # scale down image for faster processing
query.thumbnail((640, query.height)) # scale down image for faster processing
logger.info(f"Find Images by Image: {query_imagepath}")
else:
# Truncate words in query to stay below max_tokens supported by ML model
@@ -155,36 +162,42 @@ def query(raw_query, count, model: ImageSearchModel):
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger):
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
for result
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
image_hits = {
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]
}
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings:
with timer("Metadata Search Time", logger):
metadata_hits = {result['corpus_id']: result['score']
for result
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
metadata_hits = {
result["corpus_id"]: result["score"]
for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]
}
# Sum metadata, image scores of the highest ranked images
for corpus_id, score in metadata_hits.items():
scaling_factor = 0.33
if 'corpus_id' in image_hits:
image_hits[corpus_id].update({
'metadata_score': score,
'score': image_hits[corpus_id].get('score', 0) + scaling_factor*score,
})
if "corpus_id" in image_hits:
image_hits[corpus_id].update(
{
"metadata_score": score,
"score": image_hits[corpus_id].get("score", 0) + scaling_factor * score,
}
)
else:
image_hits[corpus_id] = {'metadata_score': score, 'score': scaling_factor*score}
image_hits[corpus_id] = {"metadata_score": score, "score": scaling_factor * score}
# Reformat results in original form from sentence transformer semantic_search()
hits = [
{
'corpus_id': corpus_id,
'score': scores['score'],
'image_score': scores.get('image_score', 0),
'metadata_score': scores.get('metadata_score', 0),
} for corpus_id, scores in image_hits.items()]
"corpus_id": corpus_id,
"score": scores["score"],
"image_score": scores.get("image_score", 0),
"metadata_score": scores.get("metadata_score", 0),
}
for corpus_id, scores in image_hits.items()
]
# Sort the images based on their combined metadata, image scores
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
@@ -194,7 +207,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
results: List[SearchResponse] = []
for index, hit in enumerate(hits[:count]):
source_path = image_names[hit['corpus_id']]
source_path = image_names[hit["corpus_id"]]
target_image_name = f"{index}{source_path.suffix}"
target_path = resolve_absolute_path(f"{output_directory}/{target_image_name}")
@@ -207,17 +220,18 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
shutil.copy(source_path, target_path)
# Add the image metadata to the results
results += [SearchResponse.parse_obj(
{
"entry": f'{image_files_url}/{target_image_name}',
"score": f"{hit['score']:.9f}",
"additional":
results += [
SearchResponse.parse_obj(
{
"image_score": f"{hit['image_score']:.9f}",
"metadata_score": f"{hit['metadata_score']:.9f}",
"entry": f"{image_files_url}/{target_image_name}",
"score": f"{hit['score']:.9f}",
"additional": {
"image_score": f"{hit['image_score']:.9f}",
"metadata_score": f"{hit['metadata_score']:.9f}",
},
}
}
)]
)
]
return results
@@ -248,9 +262,7 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
embeddings_file,
batch_size=config.batch_size,
regenerate=regenerate,
use_xmp_metadata=config.use_xmp_metadata)
use_xmp_metadata=config.use_xmp_metadata,
)
return ImageSearchModel(all_image_files,
image_embeddings,
image_metadata_embeddings,
encoder)
return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder)

View File

@@ -38,17 +38,19 @@ def initialize_model(search_config: TextSearchConfig):
# The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = search_config.encoder_type or SentenceTransformer,
device=f'{state.device}')
model_dir=search_config.model_directory,
model_name=search_config.encoder,
model_type=search_config.encoder_type or SentenceTransformer,
device=f"{state.device}",
)
# The cross-encoder re-ranks the results to improve quality
cross_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.cross_encoder,
model_type = CrossEncoder,
device=f'{state.device}')
model_dir=search_config.model_directory,
model_name=search_config.cross_encoder,
model_type=CrossEncoder,
device=f"{state.device}",
)
return bi_encoder, cross_encoder, top_k
@@ -58,7 +60,9 @@ def extract_entries(jsonl_file) -> List[Entry]:
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
def compute_embeddings(entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False):
def compute_embeddings(
entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False
):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = []
# Load pre-computed embeddings from file if exists and update them if required
@@ -69,17 +73,23 @@ def compute_embeddings(entries_with_ids: List[Tuple[int, Entry]], bi_encoder: Ba
# Encode any new entries in the corpus and update corpus embeddings
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
if new_entries:
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
new_embeddings = bi_encoder.encode(
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
)
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
if existing_entry_ids:
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device))
existing_embeddings = torch.index_select(
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
)
else:
existing_embeddings = torch.tensor([], device=state.device)
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
# Else compute the corpus embeddings from scratch
else:
new_entries = [entry.compiled for _, entry in entries_with_ids]
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
corpus_embeddings = bi_encoder.encode(
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
)
# Save regenerated or updated embeddings to file
if new_entries:
@@ -112,7 +122,9 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
# Find relevant entries for the query
with timer("Search Time", logger, state.device):
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
hits = util.semantic_search(
question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score
)[0]
# Score all retrieved entries using the cross-encoder
if rank_results:
@@ -128,26 +140,33 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]:
return [SearchResponse.parse_obj(
{
"entry": entries[hit['corpus_id']].raw,
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}",
"additional": {
"file": entries[hit['corpus_id']].file,
"compiled": entries[hit['corpus_id']].compiled
return [
SearchResponse.parse_obj(
{
"entry": entries[hit["corpus_id"]].raw,
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}",
"additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled},
}
})
for hit
in hits[0:count]]
)
for hit in hits[0:count]
]
def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: List[BaseFilter] = []) -> TextSearchModel:
def setup(
text_to_jsonl: Type[TextToJsonl],
config: TextContentConfig,
search_config: TextSearchConfig,
regenerate: bool,
filters: List[BaseFilter] = [],
) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
previous_entries = (
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
)
entries_with_indices = text_to_jsonl(config).process(previous_entries)
# Extract Updated Entries
@@ -158,7 +177,9 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate)
corpus_embeddings = compute_embeddings(
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate
)
for filter in filters:
filter.load(entries, regenerate=regenerate)
@@ -166,8 +187,10 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
def apply_filters(query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]) -> Tuple[str, List[Entry], torch.Tensor]:
'''Filter query, entries and embeddings before semantic search'''
def apply_filters(
query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]
) -> Tuple[str, List[Entry], torch.Tensor]:
"""Filter query, entries and embeddings before semantic search"""
with timer("Total Filter Time", logger, state.device):
included_entry_indices = set(range(len(entries)))
@@ -178,45 +201,50 @@ def apply_filters(query: str, entries: List[Entry], corpus_embeddings: torch.Ten
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return '', [], torch.tensor([], device=state.device)
return "", [], torch.tensor([], device=state.device)
else:
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
corpus_embeddings = torch.index_select(
corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)
)
return query, entries, corpus_embeddings
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]:
'''Score all retrieved entries using the cross-encoder'''
"""Score all retrieved entries using the cross-encoder"""
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
cross_inp = [[query, entries[hit["corpus_id"]].compiled] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
hits[idx]["cross-score"] = cross_scores[idx]
return hits
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
'''Order results by cross-encoder score followed by bi-encoder score'''
"""Order results by cross-encoder score followed by bi-encoder score"""
with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
hits.sort(key=lambda x: x["cross-score"], reverse=True) # sort by cross-encoder score
return hits
def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]:
'''Deduplicate entries by raw entry text before showing to users
"""Deduplicate entries by raw entry text before showing to users
Compiled entries are split by max tokens supported by ML models.
This can result in duplicate hits, entries shown to user.'''
This can result in duplicate hits, entries shown to user."""
with timer("Deduplication Time", logger, state.device):
seen, original_hits_count = set(), len(hits)
hits = [hit for hit in hits
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)] # type: ignore[func-returns-value]
hits = [
hit
for hit in hits
if entries[hit["corpus_id"]].raw not in seen and not seen.add(entries[hit["corpus_id"]].raw) # type: ignore[func-returns-value]
]
duplicate_hits = original_hits_count - len(hits)
logger.debug(f"Removed {duplicate_hits} duplicates")

View File

@@ -10,21 +10,36 @@ from khoj.utils.yaml import parse_config_from_file
def cli(args=None):
# Setup Argument Parser for the Commandline Interface
parser = argparse.ArgumentParser(description="Start Khoj; A Natural Language Search Engine for your personal Notes, Transactions and Photos")
parser.add_argument('--config-file', '-c', default='~/.khoj/khoj.yml', type=pathlib.Path, help="YAML file to configure Khoj")
parser.add_argument('--no-gui', action='store_true', default=False, help="Do not show native desktop GUI. Default: false")
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate model embeddings from source files. Default: false")
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs. Default: 0")
parser.add_argument('--host', type=str, default='127.0.0.1', help="Host address of the server. Default: 127.0.0.1")
parser.add_argument('--port', '-p', type=int, default=8000, help="Port of the server. Default: 8000")
parser.add_argument('--socket', type=pathlib.Path, help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock")
parser.add_argument('--version', '-V', action='store_true', help="Print the installed Khoj version and exit")
parser = argparse.ArgumentParser(
description="Start Khoj; A Natural Language Search Engine for your personal Notes, Transactions and Photos"
)
parser.add_argument(
"--config-file", "-c", default="~/.khoj/khoj.yml", type=pathlib.Path, help="YAML file to configure Khoj"
)
parser.add_argument(
"--no-gui", action="store_true", default=False, help="Do not show native desktop GUI. Default: false"
)
parser.add_argument(
"--regenerate",
action="store_true",
default=False,
help="Regenerate model embeddings from source files. Default: false",
)
parser.add_argument("--verbose", "-v", action="count", default=0, help="Show verbose conversion logs. Default: 0")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host address of the server. Default: 127.0.0.1")
parser.add_argument("--port", "-p", type=int, default=8000, help="Port of the server. Default: 8000")
parser.add_argument(
"--socket",
type=pathlib.Path,
help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock",
)
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
args = parser.parse_args(args)
if args.version:
# Show version of khoj installed and exit
print(version('khoj-assistant'))
print(version("khoj-assistant"))
exit(0)
# Normalize config_file path to absolute path

View File

@@ -28,8 +28,16 @@ class ProcessorType(str, Enum):
Conversation = "conversation"
class TextSearchModel():
def __init__(self, entries: List[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder: CrossEncoder, filters: List[BaseFilter], top_k):
class TextSearchModel:
def __init__(
self,
entries: List[Entry],
corpus_embeddings: torch.Tensor,
bi_encoder: BaseEncoder,
cross_encoder: CrossEncoder,
filters: List[BaseFilter],
top_k,
):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder
@@ -38,7 +46,7 @@ class TextSearchModel():
self.top_k = top_k
class ImageSearchModel():
class ImageSearchModel:
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
self.image_encoder = image_encoder
self.image_names = image_names
@@ -48,7 +56,7 @@ class ImageSearchModel():
@dataclass
class SearchModels():
class SearchModels:
orgmode_search: TextSearchModel = None
ledger_search: TextSearchModel = None
music_search: TextSearchModel = None
@@ -56,15 +64,15 @@ class SearchModels():
image_search: ImageSearchModel = None
class ConversationProcessorConfigModel():
class ConversationProcessorConfigModel:
def __init__(self, processor_config: ConversationProcessorConfig):
self.openai_api_key = processor_config.openai_api_key
self.model = processor_config.model
self.conversation_logfile = Path(processor_config.conversation_logfile)
self.chat_session = ''
self.chat_session = ""
self.meta_log: dict = {}
@dataclass
class ProcessorConfigModel():
class ProcessorConfigModel:
conversation: ConversationProcessorConfigModel = None

View File

@@ -1,65 +1,62 @@
from pathlib import Path
app_root_directory = Path(__file__).parent.parent.parent
web_directory = app_root_directory / 'khoj/interface/web/'
empty_escape_sequences = '\n|\r|\t| '
web_directory = app_root_directory / "khoj/interface/web/"
empty_escape_sequences = "\n|\r|\t| "
# default app config to use
default_config = {
'content-type': {
'org': {
'input-files': None,
'input-filter': None,
'compressed-jsonl': '~/.khoj/content/org/org.jsonl.gz',
'embeddings-file': '~/.khoj/content/org/org_embeddings.pt',
'index_heading_entries': False
"content-type": {
"org": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
"embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
"index_heading_entries": False,
},
'markdown': {
'input-files': None,
'input-filter': None,
'compressed-jsonl': '~/.khoj/content/markdown/markdown.jsonl.gz',
'embeddings-file': '~/.khoj/content/markdown/markdown_embeddings.pt'
"markdown": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
"embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
},
'ledger': {
'input-files': None,
'input-filter': None,
'compressed-jsonl': '~/.khoj/content/ledger/ledger.jsonl.gz',
'embeddings-file': '~/.khoj/content/ledger/ledger_embeddings.pt'
"ledger": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/ledger/ledger.jsonl.gz",
"embeddings-file": "~/.khoj/content/ledger/ledger_embeddings.pt",
},
'image': {
'input-directories': None,
'input-filter': None,
'embeddings-file': '~/.khoj/content/image/image_embeddings.pt',
'batch-size': 50,
'use-xmp-metadata': False
"image": {
"input-directories": None,
"input-filter": None,
"embeddings-file": "~/.khoj/content/image/image_embeddings.pt",
"batch-size": 50,
"use-xmp-metadata": False,
},
'music': {
'input-files': None,
'input-filter': None,
'compressed-jsonl': '~/.khoj/content/music/music.jsonl.gz',
'embeddings-file': '~/.khoj/content/music/music_embeddings.pt'
"music": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/music/music.jsonl.gz",
"embeddings-file": "~/.khoj/content/music/music_embeddings.pt",
},
},
"search-type": {
"symmetric": {
"encoder": "sentence-transformers/all-MiniLM-L6-v2",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/symmetric/",
},
"asymmetric": {
"encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/asymmetric/",
},
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
},
"processor": {
"conversation": {
"openai-api-key": None,
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
}
},
'search-type': {
'symmetric': {
'encoder': 'sentence-transformers/all-MiniLM-L6-v2',
'cross-encoder': 'cross-encoder/ms-marco-MiniLM-L-6-v2',
'model_directory': '~/.khoj/search/symmetric/'
},
'asymmetric': {
'encoder': 'sentence-transformers/multi-qa-MiniLM-L6-cos-v1',
'cross-encoder': 'cross-encoder/ms-marco-MiniLM-L-6-v2',
'model_directory': '~/.khoj/search/asymmetric/'
},
'image': {
'encoder': 'sentence-transformers/clip-ViT-B-32',
'model_directory': '~/.khoj/search/image/'
}
},
'processor': {
'conversation': {
'openai-api-key': None,
'conversation-logfile': '~/.khoj/processor/conversation/conversation_logs.json'
}
}
}
}

View File

@@ -13,16 +13,17 @@ from typing import Optional, Union, TYPE_CHECKING
if TYPE_CHECKING:
# External Packages
from sentence_transformers import CrossEncoder
# Internal Packages
from khoj.utils.models import BaseEncoder
def is_none_or_empty(item):
return item == None or (hasattr(item, '__iter__') and len(item) == 0) or item == ''
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
def to_snake_case_from_dash(item: str):
return item.replace('_', '-')
return item.replace("_", "-")
def get_absolute_path(filepath: Union[str, Path]) -> str:
@@ -34,11 +35,11 @@ def resolve_absolute_path(filepath: Union[str, Optional[Path]], strict=False) ->
def get_from_dict(dictionary, *args):
'''null-aware get from a nested dictionary
Returns: dictionary[args[0]][args[1]]... or None if any keys missing'''
"""null-aware get from a nested dictionary
Returns: dictionary[args[0]][args[1]]... or None if any keys missing"""
current = dictionary
for arg in args:
if not hasattr(current, '__iter__') or not arg in current:
if not hasattr(current, "__iter__") or not arg in current:
return None
current = current[arg]
return current
@@ -54,7 +55,7 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict
def load_model(model_name: str, model_type, model_dir=None, device:str=None) -> Union[BaseEncoder, CrossEncoder]:
def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]:
"Load model from disk or huggingface"
# Construct model path
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
@@ -74,17 +75,18 @@ def load_model(model_name: str, model_type, model_dir=None, device:str=None) ->
def is_pyinstaller_app():
"Returns true if the app is running from Native GUI created by PyInstaller"
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
return getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS")
def get_class_by_name(name: str) -> object:
"Returns the class object from name string"
module_name, class_name = name.rsplit('.', 1)
module_name, class_name = name.rsplit(".", 1)
return getattr(import_module(module_name), class_name)
class timer:
'''Context manager to log time taken for a block of code to run'''
"""Context manager to log time taken for a block of code to run"""
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
self.message = message
self.logger = logger
@@ -116,4 +118,4 @@ class LRU(OrderedDict):
super().__setitem__(key, value)
if len(self) > self.capacity:
oldest = next(iter(self))
del self[oldest]
del self[oldest]

View File

@@ -19,9 +19,9 @@ def load_jsonl(input_path):
# Open JSONL file
if input_path.suffix == ".gz":
jsonl_file = gzip.open(get_absolute_path(input_path), 'rt', encoding='utf-8')
jsonl_file = gzip.open(get_absolute_path(input_path), "rt", encoding="utf-8")
elif input_path.suffix == ".jsonl":
jsonl_file = open(get_absolute_path(input_path), 'r', encoding='utf-8')
jsonl_file = open(get_absolute_path(input_path), "r", encoding="utf-8")
# Read JSONL file
for line in jsonl_file:
@@ -31,7 +31,7 @@ def load_jsonl(input_path):
jsonl_file.close()
# Log JSONL entries loaded
logger.info(f'Loaded {len(data)} records from {input_path}')
logger.info(f"Loaded {len(data)} records from {input_path}")
return data
@@ -41,17 +41,17 @@ def dump_jsonl(jsonl_data, output_path):
# Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
with open(output_path, "w", encoding="utf-8") as f:
f.write(jsonl_data)
logger.info(f'Wrote jsonl data to {output_path}')
logger.info(f"Wrote jsonl data to {output_path}")
def compress_jsonl_data(jsonl_data, output_path):
# Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
with gzip.open(output_path, 'wt', encoding='utf-8') as gzip_file:
with gzip.open(output_path, "wt", encoding="utf-8") as gzip_file:
gzip_file.write(jsonl_data)
logger.info(f'Wrote jsonl data to gzip compressed jsonl at {output_path}')
logger.info(f"Wrote jsonl data to gzip compressed jsonl at {output_path}")

View File

@@ -13,17 +13,25 @@ from khoj.utils.state import processor_config, config_file
class BaseEncoder(ABC):
@abstractmethod
def __init__(self, model_name: str, device: torch.device=None, **kwargs): ...
def __init__(self, model_name: str, device: torch.device = None, **kwargs):
...
@abstractmethod
def encode(self, entries: List[str], device:torch.device=None, **kwargs) -> torch.Tensor: ...
def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor:
...
class OpenAI(BaseEncoder):
def __init__(self, model_name, device=None):
self.model_name = model_name
if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key:
raise Exception(f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {config_file}")
if (
not processor_config
or not processor_config.conversation
or not processor_config.conversation.openai_api_key
):
raise Exception(
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {config_file}"
)
openai.api_key = processor_config.conversation.openai_api_key
self.embedding_dimensions = None
@@ -32,7 +40,7 @@ class OpenAI(BaseEncoder):
for index in trange(0, len(entries)):
# OpenAI models create better embeddings for entries without newlines
processed_entry = entries[index].replace('\n', ' ')
processed_entry = entries[index].replace("\n", " ")
try:
response = openai.Embedding.create(input=processed_entry, model=self.model_name)
@@ -41,10 +49,12 @@ class OpenAI(BaseEncoder):
# Else default to embedding dimensions of the text-embedding-ada-002 model
self.embedding_dimensions = len(response.data[0].embedding) if not self.embedding_dimensions else 1536
except Exception as e:
print(f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}")
print(
f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}"
)
# Use zero embedding vector for entries with failed embeddings
# This ensures entry embeddings match the order of the source entries
# And they have minimal similarity to other entries (as zero vectors are always orthogonal to other vector)
embedding_tensors += [torch.zeros(self.embedding_dimensions, device=device)]
return torch.stack(embedding_tensors)
return torch.stack(embedding_tensors)

View File

@@ -9,11 +9,13 @@ from pydantic import BaseModel, validator
# Internal Packages
from khoj.utils.helpers import to_snake_case_from_dash, is_none_or_empty
class ConfigBase(BaseModel):
class Config:
alias_generator = to_snake_case_from_dash
allow_population_by_field_name = True
class TextContentConfig(ConfigBase):
input_files: Optional[List[Path]]
input_filter: Optional[List[str]]
@@ -21,12 +23,15 @@ class TextContentConfig(ConfigBase):
embeddings_file: Path
index_heading_entries: Optional[bool] = False
@validator('input_filter')
@validator("input_filter")
def input_filter_or_files_required(cls, input_filter, values, **kwargs):
if is_none_or_empty(input_filter) and ('input_files' not in values or values["input_files"] is None):
raise ValueError("Either input_filter or input_files required in all content-type.<text_search> section of Khoj config file")
if is_none_or_empty(input_filter) and ("input_files" not in values or values["input_files"] is None):
raise ValueError(
"Either input_filter or input_files required in all content-type.<text_search> section of Khoj config file"
)
return input_filter
class ImageContentConfig(ConfigBase):
input_directories: Optional[List[Path]]
input_filter: Optional[List[str]]
@@ -34,12 +39,17 @@ class ImageContentConfig(ConfigBase):
use_xmp_metadata: bool
batch_size: int
@validator('input_filter')
@validator("input_filter")
def input_filter_or_directories_required(cls, input_filter, values, **kwargs):
if is_none_or_empty(input_filter) and ('input_directories' not in values or values["input_directories"] is None):
raise ValueError("Either input_filter or input_directories required in all content-type.image section of Khoj config file")
if is_none_or_empty(input_filter) and (
"input_directories" not in values or values["input_directories"] is None
):
raise ValueError(
"Either input_filter or input_directories required in all content-type.image section of Khoj config file"
)
return input_filter
class ContentConfig(ConfigBase):
org: Optional[TextContentConfig]
ledger: Optional[TextContentConfig]
@@ -47,41 +57,49 @@ class ContentConfig(ConfigBase):
music: Optional[TextContentConfig]
markdown: Optional[TextContentConfig]
class TextSearchConfig(ConfigBase):
encoder: str
cross_encoder: str
encoder_type: Optional[str]
model_directory: Optional[Path]
class ImageSearchConfig(ConfigBase):
encoder: str
encoder_type: Optional[str]
model_directory: Optional[Path]
class SearchConfig(ConfigBase):
asymmetric: Optional[TextSearchConfig]
symmetric: Optional[TextSearchConfig]
image: Optional[ImageSearchConfig]
class ConversationProcessorConfig(ConfigBase):
openai_api_key: str
conversation_logfile: Path
model: Optional[str] = "text-davinci-003"
class ProcessorConfig(ConfigBase):
conversation: Optional[ConversationProcessorConfig]
class FullConfig(ConfigBase):
content_type: Optional[ContentConfig]
search_type: Optional[SearchConfig]
processor: Optional[ProcessorConfig]
class SearchResponse(ConfigBase):
entry: str
score: str
additional: Optional[dict]
class Entry():
class Entry:
raw: str
compiled: str
file: Optional[str]
@@ -99,8 +117,4 @@ class Entry():
@classmethod
def from_dict(cls, dictionary: dict):
return cls(
raw=dictionary['raw'],
compiled=dictionary['compiled'],
file=dictionary.get('file', None)
)
return cls(raw=dictionary["raw"], compiled=dictionary["compiled"], file=dictionary.get("file", None))

View File

@@ -17,14 +17,14 @@ def save_config_to_file(yaml_config: dict, yaml_config_file: Path):
# Create output directory, if it doesn't exist
yaml_config_file.parent.mkdir(parents=True, exist_ok=True)
with open(yaml_config_file, 'w', encoding='utf-8') as config_file:
with open(yaml_config_file, "w", encoding="utf-8") as config_file:
yaml.safe_dump(yaml_config, config_file, allow_unicode=True)
def load_config_from_file(yaml_config_file: Path) -> dict:
"Read config from YML file"
config_from_file = None
with open(yaml_config_file, 'r', encoding='utf-8') as config_file:
with open(yaml_config_file, "r", encoding="utf-8") as config_file:
config_from_file = yaml.safe_load(config_file)
return config_from_file