diff --git a/README.md b/README.md index 84da8a5f..96aebc6c 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ - **General** - **Natural**: Advanced natural language understanding using Transformer based ML Models - **Pluggable**: Modular architecture makes it easy to plug in new data sources, frontends and ML models - - **Multiple Sources**: Index your Org-mode and Markdown notes, Beancount transactions, PDF files and Photos + - **Multiple Sources**: Index your Org-mode and Markdown notes, Beancount transactions, PDF files, Github repositories, and Photos - **Multiple Interfaces**: Interact from your [Web Browser](./src/khoj/interface/web/index.html), [Emacs](./src/interface/emacs/khoj.el) or [Obsidian](./src/interface/obsidian/) ## Demos @@ -75,7 +75,7 @@ https://github.com/debanjum/khoj/assets/6413477/3e33d8ea-25bb-46c8-a3bf-c92f78d0 - Install Khoj via `pip` and start Khoj backend in non-gui mode - Install Khoj plugin via Community Plugins settings pane on Obsidian app - Check the new Khoj plugin settings -- Let Khoj backend index the markdown, pdf files in the current Vault +- Let Khoj backend index the markdown, pdf, Github markdown files in the current Vault - Open Khoj plugin on Obsidian via Search button on Left Pane - Search \"*Announce plugin to folks*\" in the [Obsidian Plugin docs](https://marcus.se.net/obsidian-plugin-docs/) - Jump to the [search result](https://marcus.se.net/obsidian-plugin-docs/publishing/submit-your-plugin) @@ -328,6 +328,11 @@ Add your OpenAI API to Khoj by using either of the two options below: 1. [Setup your OpenAI API key in Khoj](#set-your-openai-api-key-in-khoj) 2. Interact with them from the [Khoj Swagger docs](http://locahost:8000/docs)[^2] +### Use a Github Repository as a source +Note that this plugin is currently *only* indexing Markdown files. It will ignore all other files in the repository. This is because Khoj, as it stands, is a semantic search engine. Eventually, we hope to get to a state where you can search for any file in your repository and even explain code. + +1. Get a [pat token](https://docs.github.com/en/github/authenticating-to-github/keeping-your-account-and-data-secure/creating-a-personal-access-token) with `repo` and `read:org` scopes in the classic flow. +2. Configure your settings to include the `owner` and `repo_name`. The `owner` will be the organization name if the repo is in an organization. The `repo_name` will be the name of the repository. Optionally, you can also supply a branch name. If no branch name is supplied, the `master` branch will be used. ## Performance @@ -396,7 +401,7 @@ git clone https://github.com/debanjum/khoj && cd khoj ##### 2. Configure -- **Required**: Update [docker-compose.yml](./docker-compose.yml) to mount your images, (org-mode or markdown) notes, pdf and beancount directories +- **Required**: Update [docker-compose.yml](./docker-compose.yml) to mount your images, (org-mode or markdown) notes, pdf, Github repositories, and beancount directories - **Optional**: Edit application configuration in [khoj_docker.yml](./config/khoj_docker.yml) ##### 3. Run @@ -458,7 +463,7 @@ conda activate khoj #### Before Creating PR -1. Run Tests +1. Run Tests. If you get an error complaining about a missing `fast_tokenizer_file`, follow the solution [in this Github issue](https://github.com/UKPLab/sentence-transformers/issues/1659). ```shell pytest ``` diff --git a/pyproject.toml b/pyproject.toml index cf77ea79..f44849ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ dependencies = [ "aiohttp == 3.8.4", "langchain >= 0.0.187", "pypdf >= 3.9.0", + "requests >= 2.26.0", ] dynamic = ["version"] diff --git a/src/interface/obsidian/src/search_modal.ts b/src/interface/obsidian/src/search_modal.ts index 4bad70f6..84ebeaa4 100644 --- a/src/interface/obsidian/src/search_modal.ts +++ b/src/interface/obsidian/src/search_modal.ts @@ -127,6 +127,15 @@ export class KhojSearchModal extends SuggestModal { let entry_snipped_indicator = result.entry.split('\n').length > lines_to_render ? ' **...**' : ''; let snipped_entry = result.entry.split('\n').slice(0, lines_to_render).join('\n'); + // Show reindex hint on first search result + if (this.resultContainerEl.children.length == 1) { + let infoHintEl = createEl("div",{ cls: 'khoj-info-hint' }); + el.insertAdjacentElement("beforebegin", infoHintEl); + setTimeout(() => { + infoHintEl.setText('Unexpected results? Try re-index your vault from the Khoj plugin settings to fix it.'); + }, 3000); + } + // Show filename of each search result for context el.createEl("div",{ cls: 'khoj-result-file' }).setText(filename ?? ""); let result_el = el.createEl("div", { cls: 'khoj-result-entry' }) diff --git a/src/interface/obsidian/styles.css b/src/interface/obsidian/styles.css index e3597abe..be8065b8 100644 --- a/src/interface/obsidian/styles.css +++ b/src/interface/obsidian/styles.css @@ -148,9 +148,9 @@ If your plugin does not need CSS, delete this file. .khoj-result-file { font-weight: 600; - } +} - .khoj-result-entry { +.khoj-result-entry { color: var(--text-muted); margin-left: 2em; padding-left: 0.5em; @@ -160,17 +160,25 @@ If your plugin does not need CSS, delete this file. border-left-style: solid; border-left-color: var(--color-accent-2); white-space: normal; - } +} - .khoj-result-entry > * { +.khoj-result-entry > * { font-size: var(--font-ui-medium); - } +} - .khoj-result-entry > p { +.khoj-result-entry > p { margin-top: 0.2em; margin-bottom: 0.2em; - } +} - .khoj-result-entry p br { +.khoj-result-entry p br { display: none; - } +} + +.khoj-info-hint { + color: var(--text-muted); + font-size: var(--font-ui-small); + font-style: italic; + text-align: center; + margin-bottom: 0.5em; +} diff --git a/src/khoj/configure.py b/src/khoj/configure.py index ae49678b..3aa39f10 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -16,6 +16,7 @@ from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl +from khoj.processor.github.github_to_jsonl import GithubToJsonl from khoj.search_type import image_search, text_search from khoj.utils import constants, state from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel @@ -89,7 +90,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, if (t == state.SearchType.Org or t == None) and config.content_type.org: logger.info("🦄 Setting up search for orgmode notes") # Extract Entries, Generate Notes Embeddings - model.orgmode_search = text_search.setup( + model.org_search = text_search.setup( OrgToJsonl, config.content_type.org, search_config=config.search_type.asymmetric, @@ -135,7 +136,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, # Initialize PDF Search if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf: - logger.info("💸 Setting up search for pdf") + logger.info("🖨️ Setting up search for pdf") # Extract Entries, Generate PDF Embeddings model.pdf_search = text_search.setup( PdfToJsonl, @@ -153,6 +154,17 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, config.content_type.image, search_config=config.search_type.image, regenerate=regenerate ) + if (t == state.SearchType.Github or t == None) and config.content_type.github: + logger.info("🐙 Setting up search for github") + # Extract Entries, Generate Github Embeddings + model.github_search = text_search.setup( + GithubToJsonl, + config.content_type.github, + search_config=config.search_type.asymmetric, + regenerate=regenerate, + filters=[DateFilter(), WordFilter(), FileFilter()], + ) + # Initialize External Plugin Search if (t == None or t in state.SearchType) and config.content_type.plugins: logger.info("🔌 Setting up search for plugins") diff --git a/src/khoj/interface/desktop/labelled_text_field.py b/src/khoj/interface/desktop/labelled_text_field.py index 4032c2c0..a897ee48 100644 --- a/src/khoj/interface/desktop/labelled_text_field.py +++ b/src/khoj/interface/desktop/labelled_text_field.py @@ -3,14 +3,18 @@ from PyQt6 import QtWidgets # Internal Packages from khoj.utils.config import ProcessorType +from khoj.utils.config import SearchType class LabelledTextField(QtWidgets.QWidget): - def __init__(self, title, processor_type: ProcessorType = None, default_value: str = None): + def __init__( + self, title, search_type: SearchType = None, processor_type: ProcessorType = None, default_value: str = None + ): QtWidgets.QWidget.__init__(self) layout = QtWidgets.QHBoxLayout() self.setLayout(layout) self.processor_type = processor_type + self.search_type = search_type self.label = QtWidgets.QLabel() self.label.setText(title) diff --git a/src/khoj/interface/desktop/main_window.py b/src/khoj/interface/desktop/main_window.py index 8d54f209..f4ca7f8c 100644 --- a/src/khoj/interface/desktop/main_window.py +++ b/src/khoj/interface/desktop/main_window.py @@ -62,7 +62,6 @@ class MainWindow(QtWidgets.QMainWindow): search_type, None ) or self.get_default_config(search_type=search_type) self.search_settings_panels += [self.add_settings_panel(current_content_config, search_type)] - # Add Conversation Processor Panel to Configure Screen self.processor_settings_panels = [] conversation_type = ProcessorType.Conversation @@ -88,6 +87,8 @@ class MainWindow(QtWidgets.QMainWindow): if search_type == SearchType.Image: current_content_files = current_content_config.get("input-directories", []) file_input_text = f"{search_type.name} Folders" + elif search_type == SearchType.Github: + return self.add_github_settings_panel(current_content_config, SearchType.Github) else: current_content_files = current_content_config.get("input-files", []) file_input_text = f"{search_type.name} Files" @@ -111,6 +112,37 @@ class MainWindow(QtWidgets.QMainWindow): return search_type_settings + def add_github_settings_panel(self, current_content_config: dict, search_type: SearchType): + search_type_settings = QtWidgets.QWidget() + search_type_layout = QtWidgets.QVBoxLayout(search_type_settings) + enable_search_type = SearchCheckBox(f"Search {search_type.name}", search_type) + # Add labelled text input field + input_fields = [] + + fields = ["pat-token", "repo-name", "repo-owner", "repo-branch"] + active = False + for field in fields: + field_value = current_content_config.get(field, None) + input_field = LabelledTextField(field, search_type=search_type, default_value=field_value) + search_type_layout.addWidget(input_field) + input_fields += [input_field] + if field_value: + active = True + + # Set enabled/disabled based on checkbox state + enable_search_type.setChecked(active) + for input_field in input_fields: + input_field.setEnabled(enable_search_type.isChecked()) + enable_search_type.stateChanged.connect(lambda _: [input_field.setEnabled(enable_search_type.isChecked()) for input_field in input_fields]) # type: ignore[attr-defined] + + # Add setting widgets for given search type to panel + search_type_layout.addWidget(enable_search_type) + for input_field in input_fields: + search_type_layout.addWidget(input_field) + self.wlayout.addWidget(search_type_settings) + + return search_type_settings + def add_processor_panel(self, current_conversation_config: dict, processor_type: ProcessorType): "Add Conversation Processor Panel" # Get current settings from config for given processor type @@ -121,7 +153,9 @@ class MainWindow(QtWidgets.QMainWindow): processor_type_layout = QtWidgets.QVBoxLayout(processor_type_settings) enable_conversation = ProcessorCheckBox(f"Conversation", processor_type) # Add file browser to set input files for given processor type - input_field = LabelledTextField("OpenAI API Key", processor_type, current_openai_api_key) + input_field = LabelledTextField( + "OpenAI API Key", processor_type=processor_type, default_value=current_openai_api_key + ) # Set enabled/disabled based on checkbox state enable_conversation.setChecked(current_openai_api_key is not None) @@ -185,7 +219,7 @@ class MainWindow(QtWidgets.QMainWindow): "Update config with search settings from UI" for settings_panel in self.search_settings_panels: for child in settings_panel.children(): - if not isinstance(child, (SearchCheckBox, FileBrowser)): + if not isinstance(child, (SearchCheckBox, FileBrowser, LabelledTextField)): continue if isinstance(child, SearchCheckBox): # Search Type Disabled @@ -209,6 +243,10 @@ class MainWindow(QtWidgets.QMainWindow): self.new_config["content-type"][child.search_type.value]["input-files"] = ( child.getPaths() if child.getPaths() != [] else None ) + elif isinstance(child, LabelledTextField): + self.new_config["content-type"][child.search_type.value][ + child.label.text() + ] = child.input_field.toPlainText() def update_processor_settings(self): "Update config with conversation settings from UI" diff --git a/src/khoj/interface/web/index.html b/src/khoj/interface/web/index.html index c77e5a3b..51412d75 100644 --- a/src/khoj/interface/web/index.html +++ b/src/khoj/interface/web/index.html @@ -33,7 +33,11 @@ function render_markdown(query, data) { var md = window.markdownit(); return md.render(data.map(function (item) { - return `${item.entry}` + if (item.additional.file.startsWith("http")) { + lines = item.entry.split("\n"); + return `${lines[0]}\t[*](${item.additional.file})\n${lines.slice(1).join("\n")}`; + } + return `${item.entry}`; }).join("\n")); } @@ -65,6 +69,8 @@ return render_ledger(query, data); } else if (type === "pdf") { return render_pdf(query, data); + } else if (type == "github") { + return render_markdown(query, data); } else { return `
` + data.map((item) => `

${item.entry}

`).join("\n") @@ -295,7 +301,7 @@ text-align: left; white-space: pre-line; } - #results-markdown { + #results-markdown, #results-github { text-align: left; } #results-music, diff --git a/src/khoj/processor/github/__init__.py b/src/khoj/processor/github/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/processor/github/github_to_jsonl.py b/src/khoj/processor/github/github_to_jsonl.py new file mode 100644 index 00000000..80d55f38 --- /dev/null +++ b/src/khoj/processor/github/github_to_jsonl.py @@ -0,0 +1,166 @@ +# Standard Packages +import logging +import time +from typing import Dict, List + +# External Packages +import requests + +# Internal Packages +from khoj.utils.helpers import timer +from khoj.utils.rawconfig import Entry, GithubContentConfig +from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl +from khoj.processor.text_to_jsonl import TextToJsonl +from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data + + +logger = logging.getLogger(__name__) + + +class GithubToJsonl(TextToJsonl): + def __init__(self, config: GithubContentConfig): + super().__init__(config) + self.config = config + self.repo_url = f"https://api.github.com/repos/{self.config.repo_owner}/{self.config.repo_name}" + + @staticmethod + def wait_for_rate_limit_reset(response, func, *args, **kwargs): + if response.status_code != 200 and response.headers.get("X-RateLimit-Remaining") == "0": + wait_time = int(response.headers.get("X-RateLimit-Reset")) - int(time.time()) + logger.info(f"Github Rate limit reached. Waiting for {wait_time} seconds") + time.sleep(wait_time) + return func(*args, **kwargs) + else: + return + + def process(self, previous_entries=None): + with timer("Download markdown files from github repo", logger): + try: + docs = self.get_markdown_files() + except Exception as e: + logger.error(f"Unable to download github repo {self.config.repo_owner}/{self.config.repo_name}") + raise e + + logger.info(f"Found {len(docs)} documents in github repo {self.config.repo_owner}/{self.config.repo_name}") + + with timer("Extract markdown entries from github repo", logger): + current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps( + *GithubToJsonl.extract_markdown_entries(docs) + ) + + with timer("Extract commit messages from github repo", logger): + current_entries += self.convert_commits_to_entries(self.get_commits()) + + with timer("Split entries by max token size supported by model", logger): + current_entries = TextToJsonl.split_entries_by_max_tokens(current_entries, max_tokens=256) + + # Identify, mark and merge any new entries with previous entries + with timer("Identify new or updated entries", logger): + if not previous_entries: + entries_with_ids = list(enumerate(current_entries)) + else: + entries_with_ids = TextToJsonl.mark_entries_for_update( + current_entries, previous_entries, key="compiled", logger=logger + ) + + with timer("Write github entries to JSONL file", logger): + # Process Each Entry from All Notes Files + entries = list(map(lambda entry: entry[1], entries_with_ids)) + jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries) + + # Compress JSONL formatted Data + if self.config.compressed_jsonl.suffix == ".gz": + compress_jsonl_data(jsonl_data, self.config.compressed_jsonl) + elif self.config.compressed_jsonl.suffix == ".jsonl": + dump_jsonl(jsonl_data, self.config.compressed_jsonl) + + return entries_with_ids + + def get_markdown_files(self): + # Get the contents of the repository + repo_content_url = f"{self.repo_url}/git/trees/{self.config.repo_branch}" + headers = {"Authorization": f"token {self.config.pat_token}"} + params = {"recursive": "true"} + response = requests.get(repo_content_url, headers=headers, params=params) + contents = response.json() + + # Wait for rate limit reset if needed + result = self.wait_for_rate_limit_reset(response, self.get_markdown_files) + if result is not None: + return result + + # Extract markdown files from the repository + markdown_files = [] + for item in contents["tree"]: + # Find all markdown files in the repository + if item["type"] == "blob" and item["path"].endswith(".md"): + # Create URL for each markdown file on Github + url_path = f'https://github.com/{self.config.repo_owner}/{self.config.repo_name}/blob/{self.config.repo_branch}/{item["path"]}' + + # Add markdown file contents and URL to list + markdown_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}] + + return markdown_files + + def get_file_contents(self, file_url): + # Get text from each markdown file + headers = {"Authorization": f"token {self.config.pat_token}", "Accept": "application/vnd.github.v3.raw"} + response = requests.get(file_url, headers=headers) + + # Wait for rate limit reset if needed + result = self.wait_for_rate_limit_reset(response, self.get_file_contents, file_url) + if result is not None: + return result + + return response.content.decode("utf-8") + + def get_commits(self) -> List[Dict]: + # Get commit messages from the repository using the Github API + commits_url = f"{self.repo_url}/commits" + headers = {"Authorization": f"token {self.config.pat_token}"} + params = {"per_page": 100} + commits = [] + + while commits_url is not None: + # Get the next page of commits + response = requests.get(commits_url, headers=headers, params=params) + raw_commits = response.json() + + # Wait for rate limit reset if needed + result = self.wait_for_rate_limit_reset(response, self.get_commits) + if result is not None: + return result + + # Extract commit messages from the response + for commit in raw_commits: + commits += [{"content": commit["commit"]["message"], "path": commit["html_url"]}] + + # Get the URL for the next page of commits, if any + commits_url = response.links.get("next", {}).get("url") + + return commits + + def convert_commits_to_entries(self, commits) -> List[Entry]: + entries: List[Entry] = [] + for commit in commits: + compiled = f'Commit message from {self.config.repo_owner}/{self.config.repo_name}:\n{commit["content"]}' + entries.append( + Entry( + compiled=compiled, + raw=f'### {commit["content"]}', + heading=commit["content"].split("\n")[0], + file=commit["path"], + ) + ) + + return entries + + @staticmethod + def extract_markdown_entries(markdown_files): + entries = [] + entry_to_file_map = [] + for doc in markdown_files: + entries, entry_to_file_map = MarkdownToJsonl.process_single_markdown_file( + doc["content"], doc["path"], entries, entry_to_file_map + ) + return entries, dict(entry_to_file_map) diff --git a/src/khoj/processor/jsonl/jsonl_to_jsonl.py b/src/khoj/processor/jsonl/jsonl_to_jsonl.py index 83c82374..f743d5d5 100644 --- a/src/khoj/processor/jsonl/jsonl_to_jsonl.py +++ b/src/khoj/processor/jsonl/jsonl_to_jsonl.py @@ -41,7 +41,7 @@ class JsonlToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = self.mark_entries_for_update( + entries_with_ids = TextToJsonl.mark_entries_for_update( current_entries, previous_entries, key="compiled", diff --git a/src/khoj/processor/ledger/beancount_to_jsonl.py b/src/khoj/processor/ledger/beancount_to_jsonl.py index 49c43301..347012a3 100644 --- a/src/khoj/processor/ledger/beancount_to_jsonl.py +++ b/src/khoj/processor/ledger/beancount_to_jsonl.py @@ -48,7 +48,7 @@ class BeancountToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = self.mark_entries_for_update( + entries_with_ids = TextToJsonl.mark_entries_for_update( current_entries, previous_entries, key="compiled", logger=logger ) diff --git a/src/khoj/processor/markdown/markdown_to_jsonl.py b/src/khoj/processor/markdown/markdown_to_jsonl.py index 0179e05e..efb508ad 100644 --- a/src/khoj/processor/markdown/markdown_to_jsonl.py +++ b/src/khoj/processor/markdown/markdown_to_jsonl.py @@ -49,7 +49,7 @@ class MarkdownToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = self.mark_entries_for_update( + entries_with_ids = TextToJsonl.mark_entries_for_update( current_entries, previous_entries, key="compiled", logger=logger ) @@ -101,27 +101,37 @@ class MarkdownToJsonl(TextToJsonl): "Extract entries by heading from specified Markdown files" # Regex to extract Markdown Entries by Heading - markdown_heading_regex = r"^#" entries = [] entry_to_file_map = [] for markdown_file in markdown_files: with open(markdown_file, "r", encoding="utf8") as f: markdown_content = f.read() - markdown_entries_per_file = [] - any_headings = re.search(markdown_heading_regex, markdown_content, flags=re.MULTILINE) - for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE): - # Add heading level as the regex split removed it from entries with headings - prefix = "#" if entry.startswith("#") else "# " if any_headings else "" - stripped_entry = entry.strip(empty_escape_sequences) - if stripped_entry != "": - markdown_entries_per_file.append(f"{prefix}{stripped_entry}") - - entry_to_file_map += zip(markdown_entries_per_file, [markdown_file] * len(markdown_entries_per_file)) - entries.extend(markdown_entries_per_file) + entries, entry_to_file_map = MarkdownToJsonl.process_single_markdown_file( + markdown_content, markdown_file, entries, entry_to_file_map + ) return entries, dict(entry_to_file_map) + @staticmethod + def process_single_markdown_file( + markdown_content: str, markdown_file: Path, entries: List, entry_to_file_map: List + ): + markdown_heading_regex = r"^#" + + markdown_entries_per_file = [] + any_headings = re.search(markdown_heading_regex, markdown_content, flags=re.MULTILINE) + for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE): + # Add heading level as the regex split removed it from entries with headings + prefix = "#" if entry.startswith("#") else "# " if any_headings else "" + stripped_entry = entry.strip(empty_escape_sequences) + if stripped_entry != "": + markdown_entries_per_file.append(f"{prefix}{stripped_entry}") + + entry_to_file_map += zip(markdown_entries_per_file, [markdown_file] * len(markdown_entries_per_file)) + entries.extend(markdown_entries_per_file) + return entries, entry_to_file_map + @staticmethod def convert_markdown_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]: "Convert each Markdown entries into a dictionary" diff --git a/src/khoj/processor/org_mode/org_to_jsonl.py b/src/khoj/processor/org_mode/org_to_jsonl.py index e5ec7cc6..96f2238e 100644 --- a/src/khoj/processor/org_mode/org_to_jsonl.py +++ b/src/khoj/processor/org_mode/org_to_jsonl.py @@ -50,7 +50,7 @@ class OrgToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = self.mark_entries_for_update( + entries_with_ids = TextToJsonl.mark_entries_for_update( current_entries, previous_entries, key="compiled", logger=logger ) diff --git a/src/khoj/processor/pdf/pdf_to_jsonl.py b/src/khoj/processor/pdf/pdf_to_jsonl.py index 27c03d55..d8092cc8 100644 --- a/src/khoj/processor/pdf/pdf_to_jsonl.py +++ b/src/khoj/processor/pdf/pdf_to_jsonl.py @@ -48,7 +48,7 @@ class PdfToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = self.mark_entries_for_update( + entries_with_ids = TextToJsonl.mark_entries_for_update( current_entries, previous_entries, key="compiled", logger=logger ) diff --git a/src/khoj/processor/text_to_jsonl.py b/src/khoj/processor/text_to_jsonl.py index 3dd0d1b5..f7bca376 100644 --- a/src/khoj/processor/text_to_jsonl.py +++ b/src/khoj/processor/text_to_jsonl.py @@ -6,14 +6,14 @@ from typing import Callable, List, Tuple from khoj.utils.helpers import timer # Internal Packages -from khoj.utils.rawconfig import Entry, TextContentConfig +from khoj.utils.rawconfig import Entry, TextConfigBase logger = logging.getLogger(__name__) class TextToJsonl(ABC): - def __init__(self, config: TextContentConfig): + def __init__(self, config: TextConfigBase): self.config = config @abstractmethod @@ -60,8 +60,9 @@ class TextToJsonl(ABC): return chunked_entries + @staticmethod def mark_entries_for_update( - self, current_entries: List[Entry], previous_entries: List[Entry], key="compiled", logger=None + current_entries: List[Entry], previous_entries: List[Entry], key="compiled", logger=None ) -> List[Tuple[int, Entry]]: # Hash all current and previous entries to identify new entries with timer("Hash previous, current entries", logger): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index cc25ee39..d0962d62 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -44,7 +44,10 @@ def get_config_types(): return [ search_type.value for search_type in SearchType - if search_type.value in configured_content_types + if ( + search_type.value in configured_content_types + and getattr(state.model, f"{search_type.value}_search") is not None + ) or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"]) ] @@ -101,11 +104,11 @@ def search( logger.debug(f"Return response from query cache") return state.query_cache[query_cache_key] - if (t == SearchType.Org or t == None) and state.model.orgmode_search: + if (t == SearchType.Org or t == None) and state.model.org_search: # query org-mode notes with timer("Query took", logger): hits, entries = text_search.query( - user_query, state.model.orgmode_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe + user_query, state.model.org_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe ) # collate and return results @@ -134,6 +137,17 @@ def search( with timer("Collating results took", logger): results = text_search.collate_results(hits, entries, results_count) + elif (t == SearchType.Github or t == None) and state.model.github_search: + # query github embeddings + with timer("Query took", logger): + hits, entries = text_search.query( + user_query, state.model.github_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe + ) + + # collate and return results + with timer("Collating results took", logger): + results = text_search.collate_results(hits, entries, results_count) + elif (t == SearchType.Ledger or t == None) and state.model.ledger_search: # query transactions with timer("Query took", logger): diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 7b590d13..a83f7814 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -23,6 +23,7 @@ class SearchType(str, Enum): Markdown = "markdown" Image = "image" Pdf = "pdf" + Github = "github" class ProcessorType(str, Enum): @@ -58,12 +59,13 @@ class ImageSearchModel: @dataclass class SearchModels: - orgmode_search: TextSearchModel = None + org_search: TextSearchModel = None ledger_search: TextSearchModel = None music_search: TextSearchModel = None markdown_search: TextSearchModel = None pdf_search: TextSearchModel = None image_search: ImageSearchModel = None + github_search: TextSearchModel = None plugin_search: Dict[str, TextSearchModel] = None diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index 8958a57a..9bc0c418 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -47,6 +47,14 @@ default_config = { "compressed-jsonl": "~/.khoj/content/music/music.jsonl.gz", "embeddings-file": "~/.khoj/content/music/music_embeddings.pt", }, + "github": { + "pat-token": None, + "repo-name": None, + "repo-owner": None, + "repo-branch": "master", + "compressed-jsonl": "~/.khoj/content/github/github.jsonl.gz", + "embeddings-file": "~/.khoj/content/github/github_embeddings.pt", + }, }, "search-type": { "symmetric": { diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index af254168..e4f5074a 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -22,11 +22,14 @@ class ConfigBase(BaseModel): return setattr(self, key, value) -class TextContentConfig(ConfigBase): - input_files: Optional[List[Path]] - input_filter: Optional[List[str]] +class TextConfigBase(ConfigBase): compressed_jsonl: Path embeddings_file: Path + + +class TextContentConfig(TextConfigBase): + input_files: Optional[List[Path]] + input_filter: Optional[List[str]] index_heading_entries: Optional[bool] = False @validator("input_filter") @@ -38,6 +41,13 @@ class TextContentConfig(ConfigBase): return input_filter +class GithubContentConfig(TextConfigBase): + pat_token: str + repo_name: str + repo_owner: str + repo_branch: Optional[str] = "master" + + class ImageContentConfig(ConfigBase): input_directories: Optional[List[Path]] input_filter: Optional[List[str]] @@ -63,6 +73,7 @@ class ContentConfig(ConfigBase): music: Optional[TextContentConfig] markdown: Optional[TextContentConfig] pdf: Optional[TextContentConfig] + github: Optional[GithubContentConfig] plugins: Optional[Dict[str, TextContentConfig]] diff --git a/tests/conftest.py b/tests/conftest.py index 84ec658d..d4638adb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,7 @@ from khoj.utils.rawconfig import ( ConversationProcessorConfig, ProcessorConfig, TextContentConfig, + GithubContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, @@ -89,6 +90,15 @@ def content_config(tmp_path_factory, search_config: SearchConfig): ) } + content_config.github = GithubContentConfig( + pat_token=os.getenv("GITHUB_PAT_TOKEN", ""), + repo_name="lantern", + repo_owner="khoj-ai", + repo_branch="master", + compressed_jsonl=content_dir.joinpath("github.jsonl.gz"), + embeddings_file=content_dir.joinpath("github_embeddings.pt"), + ) + filters = [DateFilter(), WordFilter(), FileFilter()] text_search.setup( JsonlToJsonl, content_config.plugins["plugin1"], search_config.asymmetric, regenerate=False, filters=filters @@ -159,6 +169,10 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor state.config.search_type = search_config state.SearchType = configure_search_types(state.config) + # These lines help us Mock the Search models for these search types + state.model.org_search = {} + state.model.image_search = {} + configure_routes(app) return TestClient(app) diff --git a/tests/test_client.py b/tests/test_client.py index cee0ee67..d74b4f2d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -166,7 +166,7 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear # ---------------------------------------------------------------------------------------------------- def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) + model.org_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) user_query = quote("How to git install application?") # Act @@ -183,7 +183,7 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter(), FileFilter()] - model.orgmode_search = text_search.setup( + model.org_search = text_search.setup( OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters ) user_query = quote('+"Emacs" file:"*.org"') @@ -202,7 +202,7 @@ def test_notes_search_with_only_filters(client, content_config: ContentConfig, s def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter()] - model.orgmode_search = text_search.setup( + model.org_search = text_search.setup( OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters ) user_query = quote('How to git install application? +"Emacs"') @@ -221,7 +221,7 @@ def test_notes_search_with_include_filter(client, content_config: ContentConfig, def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter()] - model.orgmode_search = text_search.setup( + model.org_search = text_search.setup( OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters ) user_query = quote('How to git install application? -"clone"') diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index 06a507c5..ac8a7665 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -32,7 +32,7 @@ class TestTruncateMessage: def test_truncate_message_first_large(self): chat_messages = ChatMessageFactory.build_batch(25) - big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000)) + big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000)) big_chat_message.content = big_chat_message.content + "\n" + "Question?" copy_big_chat_message = big_chat_message.copy() chat_messages.insert(0, big_chat_message) diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 830feb9b..6634a671 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -1,6 +1,7 @@ # System Packages import logging from pathlib import Path +import os # External Packages import pytest @@ -10,6 +11,7 @@ from khoj.utils.state import model from khoj.search_type import text_search from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl +from khoj.processor.github.github_to_jsonl import GithubToJsonl # Test @@ -170,3 +172,14 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search # Cleanup # reset input_files in config to empty list content_config.org.input_files = [] + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set") +def test_asymmetric_setup_github(content_config: ContentConfig, search_config: SearchConfig): + # Act + # Regenerate github embeddings to test asymmetric setup without caching + github_model = text_search.setup(GithubToJsonl, content_config.github, search_config.asymmetric, regenerate=True) + + # Assert + assert len(github_model.entries) > 1