mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Index all text, code files in Github repos. Not just md, org files
This commit is contained in:
@@ -3,16 +3,19 @@ import time
|
|||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from magika import Magika
|
||||||
|
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import GithubConfig, KhojUser
|
from khoj.database.models import GithubConfig, KhojUser
|
||||||
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
|
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
|
||||||
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
||||||
|
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||||
from khoj.processor.content.text_to_entries import TextToEntries
|
from khoj.processor.content.text_to_entries import TextToEntries
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
|
from khoj.utils.rawconfig import GithubContentConfig, GithubRepoConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
magika = Magika()
|
||||||
|
|
||||||
|
|
||||||
class GithubToEntries(TextToEntries):
|
class GithubToEntries(TextToEntries):
|
||||||
@@ -61,7 +64,7 @@ class GithubToEntries(TextToEntries):
|
|||||||
repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}"
|
repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}"
|
||||||
repo_shorthand = f"{repo.owner}/{repo.name}"
|
repo_shorthand = f"{repo.owner}/{repo.name}"
|
||||||
logger.info(f"Processing github repo {repo_shorthand}")
|
logger.info(f"Processing github repo {repo_shorthand}")
|
||||||
with timer("Download markdown files from github repo", logger):
|
with timer("Download files from github repo", logger):
|
||||||
try:
|
try:
|
||||||
markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo)
|
markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo)
|
||||||
except ConnectionAbortedError as e:
|
except ConnectionAbortedError as e:
|
||||||
@@ -70,8 +73,9 @@ class GithubToEntries(TextToEntries):
|
|||||||
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
|
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
logger.info(f"Found {len(markdown_files)} markdown files in github repo {repo_shorthand}")
|
logger.info(
|
||||||
logger.info(f"Found {len(org_files)} org files in github repo {repo_shorthand}")
|
f"Found {len(markdown_files)} md, {len(org_files)} org and {len(plaintext_files)} text files in github repo {repo_shorthand}"
|
||||||
|
)
|
||||||
current_entries = []
|
current_entries = []
|
||||||
|
|
||||||
with timer(f"Extract markdown entries from github repo {repo_shorthand}", logger):
|
with timer(f"Extract markdown entries from github repo {repo_shorthand}", logger):
|
||||||
@@ -84,6 +88,11 @@ class GithubToEntries(TextToEntries):
|
|||||||
*GithubToEntries.extract_org_entries(org_files)
|
*GithubToEntries.extract_org_entries(org_files)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with timer(f"Extract plaintext entries from github repo {repo_shorthand}", logger):
|
||||||
|
current_entries += PlaintextToEntries.convert_text_files_to_entries(
|
||||||
|
*GithubToEntries.extract_plaintext_entries(plaintext_files)
|
||||||
|
)
|
||||||
|
|
||||||
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
|
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
|
||||||
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||||
|
|
||||||
@@ -116,10 +125,11 @@ class GithubToEntries(TextToEntries):
|
|||||||
raise ConnectionAbortedError("Github rate limit reached")
|
raise ConnectionAbortedError("Github rate limit reached")
|
||||||
|
|
||||||
# Extract markdown files from the repository
|
# Extract markdown files from the repository
|
||||||
markdown_files: List[Any] = []
|
markdown_files: List[Dict[str, str]] = []
|
||||||
org_files: List[Any] = []
|
org_files: List[Dict[str, str]] = []
|
||||||
|
plaintext_files: List[Dict[str, str]] = []
|
||||||
if "tree" not in contents:
|
if "tree" not in contents:
|
||||||
return markdown_files, org_files
|
return markdown_files, org_files, plaintext_files
|
||||||
|
|
||||||
for item in contents["tree"]:
|
for item in contents["tree"]:
|
||||||
# Find all markdown files in the repository
|
# Find all markdown files in the repository
|
||||||
@@ -138,9 +148,27 @@ class GithubToEntries(TextToEntries):
|
|||||||
# Add org file contents and URL to list
|
# Add org file contents and URL to list
|
||||||
org_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]
|
org_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]
|
||||||
|
|
||||||
return markdown_files, org_files
|
# Find, index remaining non-binary files in the repository
|
||||||
|
elif item["type"] == "blob":
|
||||||
|
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}'
|
||||||
|
content_bytes = self.get_file_contents(item["url"], decode=False)
|
||||||
|
content_type, content_str = None, None
|
||||||
|
try:
|
||||||
|
content_type = magika.identify_bytes(content_bytes).output.mime_type
|
||||||
|
content_str = content_bytes.decode("utf-8")
|
||||||
|
except:
|
||||||
|
logger.error(
|
||||||
|
f"Unable to identify content type or decode content of file at {url_path}. Skip indexing it"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
def get_file_contents(self, file_url):
|
# Add non-binary file contents and URL to list
|
||||||
|
if content_type.startswith("text/"):
|
||||||
|
plaintext_files += [{"content": content_str, "path": url_path}]
|
||||||
|
|
||||||
|
return markdown_files, org_files, plaintext_files
|
||||||
|
|
||||||
|
def get_file_contents(self, file_url, decode=True):
|
||||||
# Get text from each markdown file
|
# Get text from each markdown file
|
||||||
headers = {"Accept": "application/vnd.github.v3.raw"}
|
headers = {"Accept": "application/vnd.github.v3.raw"}
|
||||||
response = self.session.get(file_url, headers=headers, stream=True)
|
response = self.session.get(file_url, headers=headers, stream=True)
|
||||||
@@ -149,11 +177,11 @@ class GithubToEntries(TextToEntries):
|
|||||||
if response.status_code != 200 and response.headers.get("X-RateLimit-Remaining") == "0":
|
if response.status_code != 200 and response.headers.get("X-RateLimit-Remaining") == "0":
|
||||||
raise ConnectionAbortedError("Github rate limit reached")
|
raise ConnectionAbortedError("Github rate limit reached")
|
||||||
|
|
||||||
content = ""
|
content = "" if decode else b""
|
||||||
for chunk in response.iter_content(chunk_size=2048):
|
for chunk in response.iter_content(chunk_size=2048):
|
||||||
if chunk:
|
if chunk:
|
||||||
try:
|
try:
|
||||||
content += chunk.decode("utf-8")
|
content += chunk.decode("utf-8") if decode else chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unable to decode chunk from {file_url}")
|
logger.error(f"Unable to decode chunk from {file_url}")
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
@@ -180,3 +208,14 @@ class GithubToEntries(TextToEntries):
|
|||||||
doc["content"], doc["path"], entries, entry_to_file_map
|
doc["content"], doc["path"], entries, entry_to_file_map
|
||||||
)
|
)
|
||||||
return entries, dict(entry_to_file_map)
|
return entries, dict(entry_to_file_map)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_plaintext_entries(plaintext_files):
|
||||||
|
entries = []
|
||||||
|
entry_to_file_map = []
|
||||||
|
|
||||||
|
for doc in plaintext_files:
|
||||||
|
entries, entry_to_file_map = PlaintextToEntries.process_single_plaintext_file(
|
||||||
|
doc["content"], doc["path"], entries, entry_to_file_map
|
||||||
|
)
|
||||||
|
return entries, dict(entry_to_file_map)
|
||||||
|
|||||||
Reference in New Issue
Block a user