Index all text, code files in Github repos. Not just md, org files

This commit is contained in:
Debanjum Singh Solanky
2024-04-09 00:22:36 +05:30
parent 8291b898ca
commit a8dec1c9d5

View File

@@ -3,16 +3,19 @@ import time
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
import requests import requests
from magika import Magika
from khoj.database.models import Entry as DbEntry from khoj.database.models import Entry as DbEntry
from khoj.database.models import GithubConfig, KhojUser from khoj.database.models import GithubConfig, KhojUser
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.processor.content.text_to_entries import TextToEntries from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig from khoj.utils.rawconfig import GithubContentConfig, GithubRepoConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
magika = Magika()
class GithubToEntries(TextToEntries): class GithubToEntries(TextToEntries):
@@ -61,7 +64,7 @@ class GithubToEntries(TextToEntries):
repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}" repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}"
repo_shorthand = f"{repo.owner}/{repo.name}" repo_shorthand = f"{repo.owner}/{repo.name}"
logger.info(f"Processing github repo {repo_shorthand}") logger.info(f"Processing github repo {repo_shorthand}")
with timer("Download markdown files from github repo", logger): with timer("Download files from github repo", logger):
try: try:
markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo) markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo)
except ConnectionAbortedError as e: except ConnectionAbortedError as e:
@@ -70,8 +73,9 @@ class GithubToEntries(TextToEntries):
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True) logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
raise e raise e
logger.info(f"Found {len(markdown_files)} markdown files in github repo {repo_shorthand}") logger.info(
logger.info(f"Found {len(org_files)} org files in github repo {repo_shorthand}") f"Found {len(markdown_files)} md, {len(org_files)} org and {len(plaintext_files)} text files in github repo {repo_shorthand}"
)
current_entries = [] current_entries = []
with timer(f"Extract markdown entries from github repo {repo_shorthand}", logger): with timer(f"Extract markdown entries from github repo {repo_shorthand}", logger):
@@ -84,6 +88,11 @@ class GithubToEntries(TextToEntries):
*GithubToEntries.extract_org_entries(org_files) *GithubToEntries.extract_org_entries(org_files)
) )
with timer(f"Extract plaintext entries from github repo {repo_shorthand}", logger):
current_entries += PlaintextToEntries.convert_text_files_to_entries(
*GithubToEntries.extract_plaintext_entries(plaintext_files)
)
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger): with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256) current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256)
@@ -116,10 +125,11 @@ class GithubToEntries(TextToEntries):
raise ConnectionAbortedError("Github rate limit reached") raise ConnectionAbortedError("Github rate limit reached")
# Extract markdown files from the repository # Extract markdown files from the repository
markdown_files: List[Any] = [] markdown_files: List[Dict[str, str]] = []
org_files: List[Any] = [] org_files: List[Dict[str, str]] = []
plaintext_files: List[Dict[str, str]] = []
if "tree" not in contents: if "tree" not in contents:
return markdown_files, org_files return markdown_files, org_files, plaintext_files
for item in contents["tree"]: for item in contents["tree"]:
# Find all markdown files in the repository # Find all markdown files in the repository
@@ -138,9 +148,27 @@ class GithubToEntries(TextToEntries):
# Add org file contents and URL to list # Add org file contents and URL to list
org_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}] org_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]
return markdown_files, org_files # Find, index remaining non-binary files in the repository
elif item["type"] == "blob":
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}'
content_bytes = self.get_file_contents(item["url"], decode=False)
content_type, content_str = None, None
try:
content_type = magika.identify_bytes(content_bytes).output.mime_type
content_str = content_bytes.decode("utf-8")
except:
logger.error(
f"Unable to identify content type or decode content of file at {url_path}. Skip indexing it"
)
continue
def get_file_contents(self, file_url): # Add non-binary file contents and URL to list
if content_type.startswith("text/"):
plaintext_files += [{"content": content_str, "path": url_path}]
return markdown_files, org_files, plaintext_files
def get_file_contents(self, file_url, decode=True):
# Get text from each markdown file # Get text from each markdown file
headers = {"Accept": "application/vnd.github.v3.raw"} headers = {"Accept": "application/vnd.github.v3.raw"}
response = self.session.get(file_url, headers=headers, stream=True) response = self.session.get(file_url, headers=headers, stream=True)
@@ -149,11 +177,11 @@ class GithubToEntries(TextToEntries):
if response.status_code != 200 and response.headers.get("X-RateLimit-Remaining") == "0": if response.status_code != 200 and response.headers.get("X-RateLimit-Remaining") == "0":
raise ConnectionAbortedError("Github rate limit reached") raise ConnectionAbortedError("Github rate limit reached")
content = "" content = "" if decode else b""
for chunk in response.iter_content(chunk_size=2048): for chunk in response.iter_content(chunk_size=2048):
if chunk: if chunk:
try: try:
content += chunk.decode("utf-8") content += chunk.decode("utf-8") if decode else chunk
except Exception as e: except Exception as e:
logger.error(f"Unable to decode chunk from {file_url}") logger.error(f"Unable to decode chunk from {file_url}")
logger.error(e) logger.error(e)
@@ -180,3 +208,14 @@ class GithubToEntries(TextToEntries):
doc["content"], doc["path"], entries, entry_to_file_map doc["content"], doc["path"], entries, entry_to_file_map
) )
return entries, dict(entry_to_file_map) return entries, dict(entry_to_file_map)
@staticmethod
def extract_plaintext_entries(plaintext_files):
entries = []
entry_to_file_map = []
for doc in plaintext_files:
entries, entry_to_file_map = PlaintextToEntries.process_single_plaintext_file(
doc["content"], doc["path"], entries, entry_to_file_map
)
return entries, dict(entry_to_file_map)