diff --git a/src/khoj/processor/content/github/github_to_entries.py b/src/khoj/processor/content/github/github_to_entries.py index d6fd1a91..8fc1e18b 100644 --- a/src/khoj/processor/content/github/github_to_entries.py +++ b/src/khoj/processor/content/github/github_to_entries.py @@ -3,16 +3,19 @@ import time from typing import Any, Dict, List, Tuple import requests +from magika import Magika from khoj.database.models import Entry as DbEntry from khoj.database.models import GithubConfig, KhojUser from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries +from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries from khoj.processor.content.text_to_entries import TextToEntries from khoj.utils.helpers import timer -from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig +from khoj.utils.rawconfig import GithubContentConfig, GithubRepoConfig logger = logging.getLogger(__name__) +magika = Magika() class GithubToEntries(TextToEntries): @@ -61,7 +64,7 @@ class GithubToEntries(TextToEntries): repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}" repo_shorthand = f"{repo.owner}/{repo.name}" logger.info(f"Processing github repo {repo_shorthand}") - with timer("Download markdown files from github repo", logger): + with timer("Download files from github repo", logger): try: markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo) except ConnectionAbortedError as e: @@ -70,8 +73,9 @@ class GithubToEntries(TextToEntries): logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True) raise e - logger.info(f"Found {len(markdown_files)} markdown files in github repo {repo_shorthand}") - logger.info(f"Found {len(org_files)} org files in github repo {repo_shorthand}") + logger.info( + f"Found {len(markdown_files)} md, {len(org_files)} org and {len(plaintext_files)} text files in github repo {repo_shorthand}" + ) current_entries = [] with timer(f"Extract markdown entries from github repo {repo_shorthand}", logger): @@ -84,6 +88,11 @@ class GithubToEntries(TextToEntries): *GithubToEntries.extract_org_entries(org_files) ) + with timer(f"Extract plaintext entries from github repo {repo_shorthand}", logger): + current_entries += PlaintextToEntries.convert_text_files_to_entries( + *GithubToEntries.extract_plaintext_entries(plaintext_files) + ) + with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger): current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256) @@ -116,10 +125,11 @@ class GithubToEntries(TextToEntries): raise ConnectionAbortedError("Github rate limit reached") # Extract markdown files from the repository - markdown_files: List[Any] = [] - org_files: List[Any] = [] + markdown_files: List[Dict[str, str]] = [] + org_files: List[Dict[str, str]] = [] + plaintext_files: List[Dict[str, str]] = [] if "tree" not in contents: - return markdown_files, org_files + return markdown_files, org_files, plaintext_files for item in contents["tree"]: # Find all markdown files in the repository @@ -138,9 +148,27 @@ class GithubToEntries(TextToEntries): # Add org file contents and URL to list org_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}] - return markdown_files, org_files + # Find, index remaining non-binary files in the repository + elif item["type"] == "blob": + url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}' + content_bytes = self.get_file_contents(item["url"], decode=False) + content_type, content_str = None, None + try: + content_type = magika.identify_bytes(content_bytes).output.mime_type + content_str = content_bytes.decode("utf-8") + except: + logger.error( + f"Unable to identify content type or decode content of file at {url_path}. Skip indexing it" + ) + continue - def get_file_contents(self, file_url): + # Add non-binary file contents and URL to list + if content_type.startswith("text/"): + plaintext_files += [{"content": content_str, "path": url_path}] + + return markdown_files, org_files, plaintext_files + + def get_file_contents(self, file_url, decode=True): # Get text from each markdown file headers = {"Accept": "application/vnd.github.v3.raw"} response = self.session.get(file_url, headers=headers, stream=True) @@ -149,11 +177,11 @@ class GithubToEntries(TextToEntries): if response.status_code != 200 and response.headers.get("X-RateLimit-Remaining") == "0": raise ConnectionAbortedError("Github rate limit reached") - content = "" + content = "" if decode else b"" for chunk in response.iter_content(chunk_size=2048): if chunk: try: - content += chunk.decode("utf-8") + content += chunk.decode("utf-8") if decode else chunk except Exception as e: logger.error(f"Unable to decode chunk from {file_url}") logger.error(e) @@ -180,3 +208,14 @@ class GithubToEntries(TextToEntries): doc["content"], doc["path"], entries, entry_to_file_map ) return entries, dict(entry_to_file_map) + + @staticmethod + def extract_plaintext_entries(plaintext_files): + entries = [] + entry_to_file_map = [] + + for doc in plaintext_files: + entries, entry_to_file_map = PlaintextToEntries.process_single_plaintext_file( + doc["content"], doc["path"], entries, entry_to_file_map + ) + return entries, dict(entry_to_file_map)