Index all text, code files in Github repos. Not just md, org files

This commit is contained in:
Debanjum Singh Solanky
2024-04-09 00:22:36 +05:30
parent 8291b898ca
commit a8dec1c9d5

View File

@@ -3,16 +3,19 @@ import time
from typing import Any, Dict, List, Tuple
import requests
from magika import Magika
from khoj.database.models import Entry as DbEntry
from khoj.database.models import GithubConfig, KhojUser
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
from khoj.utils.rawconfig import GithubContentConfig, GithubRepoConfig
logger = logging.getLogger(__name__)
magika = Magika()
class GithubToEntries(TextToEntries):
@@ -61,7 +64,7 @@ class GithubToEntries(TextToEntries):
repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}"
repo_shorthand = f"{repo.owner}/{repo.name}"
logger.info(f"Processing github repo {repo_shorthand}")
with timer("Download markdown files from github repo", logger):
with timer("Download files from github repo", logger):
try:
markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo)
except ConnectionAbortedError as e:
@@ -70,8 +73,9 @@ class GithubToEntries(TextToEntries):
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
raise e
logger.info(f"Found {len(markdown_files)} markdown files in github repo {repo_shorthand}")
logger.info(f"Found {len(org_files)} org files in github repo {repo_shorthand}")
logger.info(
f"Found {len(markdown_files)} md, {len(org_files)} org and {len(plaintext_files)} text files in github repo {repo_shorthand}"
)
current_entries = []
with timer(f"Extract markdown entries from github repo {repo_shorthand}", logger):
@@ -84,6 +88,11 @@ class GithubToEntries(TextToEntries):
*GithubToEntries.extract_org_entries(org_files)
)
with timer(f"Extract plaintext entries from github repo {repo_shorthand}", logger):
current_entries += PlaintextToEntries.convert_text_files_to_entries(
*GithubToEntries.extract_plaintext_entries(plaintext_files)
)
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256)
@@ -116,10 +125,11 @@ class GithubToEntries(TextToEntries):
raise ConnectionAbortedError("Github rate limit reached")
# Extract markdown files from the repository
markdown_files: List[Any] = []
org_files: List[Any] = []
markdown_files: List[Dict[str, str]] = []
org_files: List[Dict[str, str]] = []
plaintext_files: List[Dict[str, str]] = []
if "tree" not in contents:
return markdown_files, org_files
return markdown_files, org_files, plaintext_files
for item in contents["tree"]:
# Find all markdown files in the repository
@@ -138,9 +148,27 @@ class GithubToEntries(TextToEntries):
# Add org file contents and URL to list
org_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]
return markdown_files, org_files
# Find, index remaining non-binary files in the repository
elif item["type"] == "blob":
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}'
content_bytes = self.get_file_contents(item["url"], decode=False)
content_type, content_str = None, None
try:
content_type = magika.identify_bytes(content_bytes).output.mime_type
content_str = content_bytes.decode("utf-8")
except:
logger.error(
f"Unable to identify content type or decode content of file at {url_path}. Skip indexing it"
)
continue
def get_file_contents(self, file_url):
# Add non-binary file contents and URL to list
if content_type.startswith("text/"):
plaintext_files += [{"content": content_str, "path": url_path}]
return markdown_files, org_files, plaintext_files
def get_file_contents(self, file_url, decode=True):
# Get text from each markdown file
headers = {"Accept": "application/vnd.github.v3.raw"}
response = self.session.get(file_url, headers=headers, stream=True)
@@ -149,11 +177,11 @@ class GithubToEntries(TextToEntries):
if response.status_code != 200 and response.headers.get("X-RateLimit-Remaining") == "0":
raise ConnectionAbortedError("Github rate limit reached")
content = ""
content = "" if decode else b""
for chunk in response.iter_content(chunk_size=2048):
if chunk:
try:
content += chunk.decode("utf-8")
content += chunk.decode("utf-8") if decode else chunk
except Exception as e:
logger.error(f"Unable to decode chunk from {file_url}")
logger.error(e)
@@ -180,3 +208,14 @@ class GithubToEntries(TextToEntries):
doc["content"], doc["path"], entries, entry_to_file_map
)
return entries, dict(entry_to_file_map)
@staticmethod
def extract_plaintext_entries(plaintext_files):
entries = []
entry_to_file_map = []
for doc in plaintext_files:
entries, entry_to_file_map = PlaintextToEntries.process_single_plaintext_file(
doc["content"], doc["path"], entries, entry_to_file_map
)
return entries, dict(entry_to_file_map)