Split Configure API into Content, Model API paths (#857)

## Major: Breaking Changes
- Move API endpoints under /configure/<type>/model to /api/model/<type>
- Move API endpoints under /api/configure/content/ to /api/content/
- Accept file deletion requests by clients during sync
- Split /api/v1/index/update into /api/content PUT, PATCH API endpoints

## Minor: Create New API Endpoint
- Create API endpoints to get user content configurations

Related: #852
This commit is contained in:
Debanjum
2024-07-26 23:48:41 -07:00
committed by GitHub
36 changed files with 857 additions and 739 deletions

View File

@@ -233,11 +233,15 @@ function pushDataToKhoj (regenerate = false) {
// Request indexing files on server. With upto 1000 files in each request
for (let i = 0; i < filesDataToPush.length; i += 1000) {
const syncUrl = `${hostURL}/api/content?client=desktop`;
const filesDataGroup = filesDataToPush.slice(i, i + 1000);
const formData = new FormData();
filesDataGroup.forEach(fileData => { formData.append('files', fileData.blob, fileData.path) });
let request = axios.post(`${hostURL}/api/v1/index/update?force=${regenerate}&client=desktop`, formData, { headers });
requests.push(request);
requests.push(
regenerate
? axios.put(syncUrl, formData, { headers })
: axios.patch(syncUrl, formData, { headers })
);
}
// Wait for requests batch to finish

View File

@@ -212,7 +212,7 @@
const headers = { 'Authorization': `Bearer ${khojToken}` };
// Populate type dropdown field with enabled content types only
fetch(`${hostURL}/api/configure/types`, { headers })
fetch(`${hostURL}/api/content/types`, { headers })
.then(response => response.json())
.then(enabled_types => {
// Show warning if no content types are enabled

View File

@@ -424,12 +424,12 @@ Auto invokes setup steps on calling main entrypoint."
"Send multi-part form `BODY' of `CONTENT-TYPE' in request to khoj server.
Append 'TYPE-QUERY' as query parameter in request url.
Specify `BOUNDARY' used to separate files in request header."
(let ((url-request-method "POST")
(let ((url-request-method ((if force) "PUT" "PATCH"))
(url-request-data body)
(url-request-extra-headers `(("content-type" . ,(format "multipart/form-data; boundary=%s" boundary))
("Authorization" . ,(format "Bearer %s" khoj-api-key)))))
(with-current-buffer
(url-retrieve (format "%s/api/v1/index/update?%s&force=%s&client=emacs" khoj-server-url type-query (or force "false"))
(url-retrieve (format "%s/api/content?%s&client=emacs" khoj-server-url type-query)
;; render response from indexing API endpoint on server
(lambda (status)
(if (not (plist-get status :error))
@@ -697,7 +697,7 @@ Optionally apply CALLBACK with JSON parsed response and CBARGS."
(defun khoj--get-enabled-content-types ()
"Get content types enabled for search from API."
(khoj--call-api "/api/configure/types" "GET" nil `(lambda (item) (mapcar #'intern item))))
(khoj--call-api "/api/content/types" "GET" nil `(lambda (item) (mapcar #'intern item))))
(defun khoj--query-search-api-and-render-results (query content-type buffer-name &optional rerank is-find-similar)
"Query Khoj Search API with QUERY, CONTENT-TYPE and RERANK as query params.

View File

@@ -89,10 +89,11 @@ export async function updateContentIndex(vault: Vault, setting: KhojSetting, las
for (let i = 0; i < fileData.length; i += 1000) {
const filesGroup = fileData.slice(i, i + 1000);
const formData = new FormData();
const method = regenerate ? "PUT" : "PATCH";
filesGroup.forEach(fileItem => { formData.append('files', fileItem.blob, fileItem.path) });
// Call Khoj backend to update index with all markdown, pdf files
const response = await fetch(`${setting.khojUrl}/api/v1/index/update?force=${regenerate}&client=obsidian`, {
method: 'POST',
const response = await fetch(`${setting.khojUrl}/api/content?client=obsidian`, {
method: method,
headers: {
'Authorization': `Bearer ${setting.khojApiKey}`,
},

View File

@@ -277,8 +277,8 @@ export function uploadDataForIndexing(
// Wait for all files to be read before making the fetch request
Promise.all(fileReadPromises)
.then(() => {
return fetch("/api/v1/index/update?force=false&client=web", {
method: "POST",
return fetch("/api/content?client=web", {
method: "PATCH",
body: formData,
});
})

View File

@@ -68,8 +68,8 @@ interface ModelPickerProps {
}
export const ModelPicker: React.FC<any> = (props: ModelPickerProps) => {
const { data: models } = useOptionsRequest('/api/configure/chat/model/options');
const { data: selectedModel } = useSelectedModel('/api/configure/chat/model');
const { data: models } = useOptionsRequest('/api/model/chat/options');
const { data: selectedModel } = useSelectedModel('/api/model/chat');
const [openLoginDialog, setOpenLoginDialog] = React.useState(false);
let userData = useAuthenticatedData();
@@ -94,7 +94,7 @@ export const ModelPicker: React.FC<any> = (props: ModelPickerProps) => {
props.setModelUsed(model);
}
fetch('/api/configure/chat/model' + '?id=' + String(model.id), { method: 'POST', body: JSON.stringify(model) })
fetch('/api/model/chat' + '?id=' + String(model.id), { method: 'POST', body: JSON.stringify(model) })
.then((response) => {
if (!response.ok) {
throw new Error('Failed to select model');

View File

@@ -148,7 +148,7 @@ interface FilesMenuProps {
function FilesMenu(props: FilesMenuProps) {
// Use SWR to fetch files
const { data: files, error } = useSWR<string[]>(props.conversationId ? '/api/configure/content/computer' : null, fetcher);
const { data: files, error } = useSWR<string[]>(props.conversationId ? '/api/content/computer' : null, fetcher);
const { data: selectedFiles, error: selectedFilesError } = useSWR(props.conversationId ? `/api/chat/conversation/file-filters/${props.conversationId}` : null, fetcher);
const [isOpen, setIsOpen] = useState(false);
const [unfilteredFiles, setUnfilteredFiles] = useState<string[]>([]);

View File

@@ -42,7 +42,7 @@ from khoj.database.adapters import (
)
from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, configure_search
from khoj.routers.api_content import configure_content, configure_search
from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state
from khoj.utils.config import SearchType
@@ -308,16 +308,16 @@ def configure_routes(app):
from khoj.routers.api import api
from khoj.routers.api_agents import api_agents
from khoj.routers.api_chat import api_chat
from khoj.routers.api_config import api_config
from khoj.routers.indexer import indexer
from khoj.routers.api_content import api_content
from khoj.routers.api_model import api_model
from khoj.routers.notion import notion_router
from khoj.routers.web_client import web_client
app.include_router(api, prefix="/api")
app.include_router(api_chat, prefix="/api/chat")
app.include_router(api_agents, prefix="/api/agents")
app.include_router(api_config, prefix="/api/configure")
app.include_router(indexer, prefix="/api/v1/index")
app.include_router(api_model, prefix="/api/model")
app.include_router(api_content, prefix="/api/content")
app.include_router(notion_router, prefix="/api/notion")
app.include_router(web_client)
@@ -336,7 +336,7 @@ def configure_routes(app):
if is_twilio_enabled():
from khoj.routers.api_phone import api_phone
app.include_router(api_phone, prefix="/api/configure/phone")
app.include_router(api_phone, prefix="/api/phone")
logger.info("📞 Enabled Twilio")

View File

@@ -998,8 +998,8 @@ To get started, just start typing below. You can also type / to see a list of co
// Wait for all files to be read before making the fetch request
Promise.all(fileReadPromises)
.then(() => {
return fetch("/api/v1/index/update?force=false&client=web", {
method: "POST",
return fetch("/api/content?client=web", {
method: "PATCH",
body: formData,
});
})
@@ -1954,7 +1954,7 @@ To get started, just start typing below. You can also type / to see a list of co
}
var allFiles;
function renderAllFiles() {
fetch('/api/configure/content/computer')
fetch('/api/content/computer')
.then(response => response.json())
.then(data => {
var indexedFiles = document.getElementsByClassName("indexed-files")[0];

View File

@@ -32,7 +32,7 @@
</style>
<script>
function removeFile(path) {
fetch('/api/configure/content/file?filename=' + path, {
fetch('/api/content/file?filename=' + path, {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
@@ -48,7 +48,7 @@
// Get all currently indexed files
function getAllComputerFilenames() {
fetch('/api/configure/content/computer')
fetch('/api/content/computer')
.then(response => response.json())
.then(data => {
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
@@ -122,7 +122,7 @@
deleteAllComputerFilesButton.textContent = "🗑️ Deleting...";
deleteAllComputerFilesButton.disabled = true;
fetch('/api/configure/content/computer', {
fetch('/api/content/computer', {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',

View File

@@ -165,7 +165,7 @@
// Save Github config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/configure/content/github', {
fetch('/api/content/github', {
method: 'POST',
headers: {
'Content-Type': 'application/json',

View File

@@ -45,7 +45,7 @@
// Save Notion config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/configure/content/notion', {
fetch('/api/content/notion', {
method: 'POST',
headers: {
'Content-Type': 'application/json',

View File

@@ -209,7 +209,7 @@
function populate_type_dropdown() {
// Populate type dropdown field with enabled content types only
fetch("/api/configure/types")
fetch("/api/content/types")
.then(response => response.json())
.then(enabled_types => {
// Show warning if no content types are enabled, or just one ("all")

View File

@@ -394,8 +394,8 @@
function saveProfileGivenName() {
const givenName = document.getElementById("profile_given_name").value;
fetch('/api/configure/user/name?name=' + givenName, {
method: 'POST',
fetch('/api/user/name?name=' + givenName, {
method: 'PATCH',
headers: {
'Content-Type': 'application/json',
}
@@ -421,7 +421,7 @@
saveVoiceModelButton.disabled = true;
saveVoiceModelButton.textContent = "Saving...";
fetch('/api/configure/voice/model?id=' + voiceModel, {
fetch('/api/model/voice?id=' + voiceModel, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@@ -455,7 +455,7 @@
saveModelButton.innerHTML = "";
saveModelButton.textContent = "Saving...";
fetch('/api/configure/chat/model?id=' + chatModel, {
fetch('/api/model/chat?id=' + chatModel, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@@ -494,7 +494,7 @@
saveSearchModelButton.disabled = true;
saveSearchModelButton.textContent = "Saving...";
fetch('/api/configure/search/model?id=' + searchModel, {
fetch('/api/model/search?id=' + searchModel, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@@ -526,7 +526,7 @@
saveModelButton.disabled = true;
saveModelButton.innerHTML = "Saving...";
fetch('/api/configure/paint/model?id=' + paintModel, {
fetch('/api/model/paint?id=' + paintModel, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@@ -553,7 +553,7 @@
};
function clearContentType(content_source) {
fetch('/api/configure/content/' + content_source, {
fetch('/api/content/' + content_source, {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
@@ -676,7 +676,7 @@
content_sources = ["computer", "github", "notion"];
content_sources.forEach(content_source => {
fetch(`/api/configure/content/${content_source}`, {
fetch(`/api/content/${content_source}`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
@@ -807,7 +807,7 @@
function getIndexedDataSize() {
document.getElementById("indexed-data-size").textContent = "Calculating...";
fetch('/api/configure/content/size')
fetch('/api/content/size')
.then(response => response.json())
.then(data => {
document.getElementById("indexed-data-size").textContent = data.indexed_data_size_in_mb + " MB used";
@@ -815,7 +815,7 @@
}
function removeFile(path) {
fetch('/api/configure/content/file?filename=' + path, {
fetch('/api/content/file?filename=' + path, {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
@@ -890,7 +890,7 @@
})
phonenumberRemoveButton.addEventListener("click", () => {
fetch('/api/configure/phone', {
fetch('/api/phone', {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
@@ -917,7 +917,7 @@
}, 5000);
} else {
const mobileNumber = iti.getNumber();
fetch('/api/configure/phone?phone_number=' + mobileNumber, {
fetch('/api/phone?phone_number=' + mobileNumber, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@@ -970,7 +970,7 @@
return;
}
fetch('/api/configure/phone/verify?code=' + otp, {
fetch('/api/phone/verify?code=' + otp, {
method: 'POST',
headers: {
'Content-Type': 'application/json',

View File

@@ -19,16 +19,11 @@ class DocxToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
# Extract Entries from specified Docx files
with timer("Extract entries from specified DOCX files", logger):

View File

@@ -48,9 +48,7 @@ class GithubToEntries(TextToEntries):
else:
return
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
if self.config.pat_token is None or self.config.pat_token == "":
logger.error(f"Github PAT token is not set. Skipping github content")
raise ValueError("Github PAT token is not set. Skipping github content")

View File

@@ -20,16 +20,11 @@ class ImageToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
# Extract Entries from specified image files
with timer("Extract entries from specified Image files", logger):

View File

@@ -19,16 +19,11 @@ class MarkdownToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
max_tokens = 256
# Extract Entries from specified Markdown files

View File

@@ -78,9 +78,7 @@ class NotionToEntries(TextToEntries):
self.body_params = {"page_size": 100}
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
current_entries = []
# Get all pages

View File

@@ -20,15 +20,10 @@ class OrgToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
# Extract Entries from specified Org files
max_tokens = 256

View File

@@ -22,16 +22,11 @@ class PdfToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
# Extract required fields from config
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
# Extract Entries from specified Pdf files
with timer("Extract entries from specified PDF files", logger):

View File

@@ -20,15 +20,10 @@ class PlaintextToEntries(TextToEntries):
super().__init__()
# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
# Extract Entries from specified plaintext files
with timer("Extract entries from specified Plaintext files", logger):

View File

@@ -31,9 +31,7 @@ class TextToEntries(ABC):
self.date_filter = DateFilter()
@abstractmethod
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
...
@staticmethod

View File

@@ -19,6 +19,7 @@ from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires
from khoj.configure import initialize_content
from khoj.database import adapters
from khoj.database.adapters import (
AutomationAdapters,
ConversationAdapters,
@@ -39,6 +40,7 @@ from khoj.routers.helpers import (
CommonQueryParams,
ConversationCommandRateLimiter,
acreate_title_from_query,
get_user_config,
schedule_automation,
update_telemetry_state,
)
@@ -276,6 +278,49 @@ async def transcribe(
return Response(content=content, media_type="application/json", status_code=200)
@api.get("/settings", response_class=Response)
@requires(["authenticated"])
def get_settings(request: Request, detailed: Optional[bool] = False) -> Response:
user = request.user.object
user_config = get_user_config(user, request, is_detailed=detailed)
del user_config["request"]
# Return config data as a JSON response
return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
@api.patch("/user/name", status_code=200)
@requires(["authenticated"])
def set_user_name(
request: Request,
name: str,
client: Optional[str] = None,
):
user = request.user.object
split_name = name.split(" ")
if len(split_name) > 2:
raise HTTPException(status_code=400, detail="Name must be in the format: Firstname Lastname")
if len(split_name) == 1:
first_name = split_name[0]
last_name = ""
else:
first_name, last_name = split_name[0], split_name[-1]
adapters.set_user_name(user, first_name, last_name)
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_user_name",
client=client,
)
return {"status": "ok"}
async def extract_references_and_questions(
request: Request,
meta_log: dict,

View File

@@ -1,434 +0,0 @@
import json
import logging
import math
from typing import Dict, List, Optional, Union
from asgiref.sync import sync_to_async
from fastapi import APIRouter, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import Entry as DbEntry
from khoj.database.models import (
GithubConfig,
KhojUser,
LocalMarkdownConfig,
LocalOrgConfig,
LocalPdfConfig,
LocalPlaintextConfig,
NotionConfig,
Subscription,
)
from khoj.routers.helpers import CommonQueryParams, update_telemetry_state
from khoj.utils import constants, state
from khoj.utils.rawconfig import (
FullConfig,
GithubContentConfig,
NotionContentConfig,
SearchConfig,
)
from khoj.utils.state import SearchType
api_config = APIRouter()
logger = logging.getLogger(__name__)
def map_config_to_object(content_source: str):
if content_source == DbEntry.EntrySource.GITHUB:
return GithubConfig
if content_source == DbEntry.EntrySource.NOTION:
return NotionConfig
if content_source == DbEntry.EntrySource.COMPUTER:
return "Computer"
async def map_config_to_db(config: FullConfig, user: KhojUser):
if config.content_type:
if config.content_type.org:
await LocalOrgConfig.objects.filter(user=user).adelete()
await LocalOrgConfig.objects.acreate(
input_files=config.content_type.org.input_files,
input_filter=config.content_type.org.input_filter,
index_heading_entries=config.content_type.org.index_heading_entries,
user=user,
)
if config.content_type.markdown:
await LocalMarkdownConfig.objects.filter(user=user).adelete()
await LocalMarkdownConfig.objects.acreate(
input_files=config.content_type.markdown.input_files,
input_filter=config.content_type.markdown.input_filter,
index_heading_entries=config.content_type.markdown.index_heading_entries,
user=user,
)
if config.content_type.pdf:
await LocalPdfConfig.objects.filter(user=user).adelete()
await LocalPdfConfig.objects.acreate(
input_files=config.content_type.pdf.input_files,
input_filter=config.content_type.pdf.input_filter,
index_heading_entries=config.content_type.pdf.index_heading_entries,
user=user,
)
if config.content_type.plaintext:
await LocalPlaintextConfig.objects.filter(user=user).adelete()
await LocalPlaintextConfig.objects.acreate(
input_files=config.content_type.plaintext.input_files,
input_filter=config.content_type.plaintext.input_filter,
index_heading_entries=config.content_type.plaintext.index_heading_entries,
user=user,
)
if config.content_type.github:
await adapters.set_user_github_config(
user=user,
pat_token=config.content_type.github.pat_token,
repos=config.content_type.github.repos,
)
if config.content_type.notion:
await adapters.set_notion_config(
user=user,
token=config.content_type.notion.token,
)
def _initialize_config():
if state.config is None:
state.config = FullConfig()
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
@api_config.post("/content/github", status_code=200)
@requires(["authenticated"])
async def set_content_github(
request: Request,
updated_config: Union[GithubContentConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
try:
await adapters.set_user_github_config(
user=user,
pat_token=updated_config.pat_token,
repos=updated_config.repos,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_content_config",
client=client,
metadata={"content_type": "github"},
)
return {"status": "ok"}
@api_config.post("/content/notion", status_code=200)
@requires(["authenticated"])
async def set_content_notion(
request: Request,
updated_config: Union[NotionContentConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
try:
await adapters.set_notion_config(
user=user,
token=updated_config.token,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_content_config",
client=client,
metadata={"content_type": "notion"},
)
return {"status": "ok"}
@api_config.delete("/content/{content_source}", status_code=200)
@requires(["authenticated"])
async def delete_content_source(
request: Request,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_content_config",
client=client,
metadata={"content_source": content_source},
)
content_object = map_config_to_object(content_source)
if content_object is None:
raise ValueError(f"Invalid content source: {content_source}")
elif content_object != "Computer":
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EntryAdapters.delete_all_entries)(user, file_source=content_source)
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
return {"status": "ok"}
@api_config.delete("/content/file", status_code=201)
@requires(["authenticated"])
async def delete_content_file(
request: Request,
filename: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_file",
client=client,
)
await EntryAdapters.adelete_entry_by_file(user, filename)
return {"status": "ok"}
@api_config.get("/content/{content_source}", response_model=List[str])
@requires(["authenticated"])
async def get_content_source(
request: Request,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="get_all_filenames",
client=client,
)
return await sync_to_async(list)(EntryAdapters.get_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
@api_config.get("/chat/model/options", response_model=Dict[str, Union[str, int]])
def get_chat_model_options(
request: Request,
client: Optional[str] = None,
):
conversation_options = ConversationAdapters.get_conversation_processor_options().all()
all_conversation_options = list()
for conversation_option in conversation_options:
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
@api_config.get("/chat/model")
@requires(["authenticated"])
def get_user_chat_model(
request: Request,
client: Optional[str] = None,
):
user = request.user.object
chat_model = ConversationAdapters.get_conversation_config(user)
if chat_model is None:
chat_model = ConversationAdapters.get_default_conversation_config()
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
@api_config.post("/chat/model", status_code=200)
@requires(["authenticated", "premium"])
async def update_chat_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_conversation_chat_model",
client=client,
metadata={"processor_conversation_type": "conversation"},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}
@api_config.post("/voice/model", status_code=200)
@requires(["authenticated", "premium"])
async def update_voice_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
new_config = await ConversationAdapters.aset_user_voice_model(user, id)
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_voice_model",
client=client,
)
if new_config is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Model not found"}))
return Response(status_code=202, content=json.dumps({"status": "ok"}))
@api_config.post("/search/model", status_code=200)
@requires(["authenticated"])
async def update_search_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
prev_config = await adapters.aget_user_search_model(user)
new_config = await adapters.aset_user_search_model(user, int(id))
if prev_config and int(id) != prev_config.id and new_config:
await EntryAdapters.adelete_all_entries(user)
if not prev_config:
# If the use was just using the default config, delete all the entries and set the new config.
await EntryAdapters.adelete_all_entries(user)
if new_config is None:
return {"status": "error", "message": "Model not found"}
else:
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_search_model",
client=client,
metadata={"search_model": new_config.setting.name},
)
return {"status": "ok"}
@api_config.post("/paint/model", status_code=200)
@requires(["authenticated"])
async def update_paint_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_paint_model",
client=client,
metadata={"paint_model": new_config.setting.model_name},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}
@api_config.get("/content/size", response_model=Dict[str, int])
@requires(["authenticated"])
async def get_content_size(request: Request, common: CommonQueryParams):
user = request.user.object
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
return Response(
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
media_type="application/json",
status_code=200,
)
@api_config.post("/user/name", status_code=200)
@requires(["authenticated"])
def set_user_name(
request: Request,
name: str,
client: Optional[str] = None,
):
user = request.user.object
split_name = name.split(" ")
if len(split_name) > 2:
raise HTTPException(status_code=400, detail="Name must be in the format: Firstname Lastname")
if len(split_name) == 1:
first_name = split_name[0]
last_name = ""
else:
first_name, last_name = split_name[0], split_name[-1]
adapters.set_user_name(user, first_name, last_name)
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_user_name",
client=client,
)
return {"status": "ok"}
@api_config.get("/types", response_model=List[str])
@requires(["authenticated"])
def get_config_types(
request: Request,
):
user = request.user.object
enabled_file_types = EntryAdapters.get_unique_file_types(user)
configured_content_types = list(enabled_file_types)
if state.config and state.config.content_type:
for ctype in state.config.content_type.model_dump(exclude_none=True):
configured_content_types.append(ctype)
return [
search_type.value
for search_type in SearchType
if (search_type.value in configured_content_types) or search_type == SearchType.All
]

View File

@@ -0,0 +1,508 @@
import asyncio
import json
import logging
import math
from typing import Dict, List, Optional, Union
from asgiref.sync import sync_to_async
from fastapi import (
APIRouter,
Depends,
Header,
HTTPException,
Request,
Response,
UploadFile,
)
from pydantic import BaseModel
from starlette.authentication import requires
from khoj.database import adapters
from khoj.database.adapters import (
EntryAdapters,
get_user_github_config,
get_user_notion_config,
)
from khoj.database.models import Entry as DbEntry
from khoj.database.models import (
GithubConfig,
GithubRepoConfig,
KhojUser,
LocalMarkdownConfig,
LocalOrgConfig,
LocalPdfConfig,
LocalPlaintextConfig,
NotionConfig,
)
from khoj.routers.helpers import (
ApiIndexedDataLimiter,
CommonQueryParams,
configure_content,
get_user_config,
update_telemetry_state,
)
from khoj.utils import constants, state
from khoj.utils.config import SearchModels
from khoj.utils.helpers import get_file_type
from khoj.utils.rawconfig import (
ContentConfig,
FullConfig,
GithubContentConfig,
NotionContentConfig,
SearchConfig,
)
from khoj.utils.state import SearchType
from khoj.utils.yaml import save_config_to_file_updated_state
logger = logging.getLogger(__name__)
api_content = APIRouter()
class File(BaseModel):
path: str
content: Union[str, bytes]
class IndexBatchRequest(BaseModel):
files: list[File]
class IndexerInput(BaseModel):
org: Optional[dict[str, str]] = None
markdown: Optional[dict[str, str]] = None
pdf: Optional[dict[str, bytes]] = None
plaintext: Optional[dict[str, str]] = None
image: Optional[dict[str, bytes]] = None
docx: Optional[dict[str, bytes]] = None
@api_content.put("")
@requires(["authenticated"])
async def put_content(
request: Request,
files: list[UploadFile],
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
ApiIndexedDataLimiter(
incoming_entries_size_limit=10,
subscribed_incoming_entries_size_limit=25,
total_entries_size_limit=10,
subscribed_total_entries_size_limit=100,
)
),
):
return await indexer(request, files, t, True, client, user_agent, referer, host)
@api_content.patch("")
@requires(["authenticated"])
async def patch_content(
request: Request,
files: list[UploadFile],
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
ApiIndexedDataLimiter(
incoming_entries_size_limit=10,
subscribed_incoming_entries_size_limit=25,
total_entries_size_limit=10,
subscribed_total_entries_size_limit=100,
)
),
):
return await indexer(request, files, t, False, client, user_agent, referer, host)
@api_content.get("/github", response_class=Response)
@requires(["authenticated"])
def get_content_github(request: Request) -> Response:
user = request.user.object
user_config = get_user_config(user, request)
del user_config["request"]
current_github_config = get_user_github_config(user)
if current_github_config:
raw_repos = current_github_config.githubrepoconfig.all()
repos = []
for repo in raw_repos:
repos.append(
GithubRepoConfig(
name=repo.name,
owner=repo.owner,
branch=repo.branch,
)
)
current_config = GithubContentConfig(
pat_token=current_github_config.pat_token,
repos=repos,
)
current_config = json.loads(current_config.json())
else:
current_config = {} # type: ignore
user_config["current_config"] = current_config
# Return config data as a JSON response
return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
@api_content.get("/notion", response_class=Response)
@requires(["authenticated"])
def get_content_notion(request: Request) -> Response:
user = request.user.object
user_config = get_user_config(user, request)
del user_config["request"]
current_notion_config = get_user_notion_config(user)
token = current_notion_config.token if current_notion_config else ""
current_config = NotionContentConfig(token=token)
current_config = json.loads(current_config.model_dump_json())
user_config["current_config"] = current_config
# Return config data as a JSON response
return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
@api_content.post("/github", status_code=200)
@requires(["authenticated"])
async def set_content_github(
request: Request,
updated_config: Union[GithubContentConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
try:
await adapters.set_user_github_config(
user=user,
pat_token=updated_config.pat_token,
repos=updated_config.repos,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_content_config",
client=client,
metadata={"content_type": "github"},
)
return {"status": "ok"}
@api_content.post("/notion", status_code=200)
@requires(["authenticated"])
async def set_content_notion(
request: Request,
updated_config: Union[NotionContentConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
try:
await adapters.set_notion_config(
user=user,
token=updated_config.token,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Notion config")
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_content_config",
client=client,
metadata={"content_type": "notion"},
)
return {"status": "ok"}
@api_content.delete("/{content_source}", status_code=200)
@requires(["authenticated"])
async def delete_content_source(
request: Request,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_content_config",
client=client,
metadata={"content_source": content_source},
)
content_object = map_config_to_object(content_source)
if content_object is None:
raise ValueError(f"Invalid content source: {content_source}")
elif content_object != "Computer":
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EntryAdapters.delete_all_entries)(user, file_source=content_source)
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
return {"status": "ok"}
@api_content.delete("/file", status_code=201)
@requires(["authenticated"])
async def delete_content_file(
request: Request,
filename: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_file",
client=client,
)
await EntryAdapters.adelete_entry_by_file(user, filename)
return {"status": "ok"}
@api_content.get("/size", response_model=Dict[str, int])
@requires(["authenticated"])
async def get_content_size(request: Request, common: CommonQueryParams, client: Optional[str] = None):
user = request.user.object
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
return Response(
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
media_type="application/json",
status_code=200,
)
@api_content.get("/types", response_model=List[str])
@requires(["authenticated"])
def get_content_types(request: Request, client: Optional[str] = None):
user = request.user.object
all_content_types = {s.value for s in SearchType}
configured_content_types = set(EntryAdapters.get_unique_file_types(user))
configured_content_types |= {"all"}
if state.config and state.config.content_type:
for ctype in state.config.content_type.model_dump(exclude_none=True):
configured_content_types.add(ctype)
return list(configured_content_types & all_content_types)
@api_content.get("/{content_source}", response_model=List[str])
@requires(["authenticated"])
async def get_content_source(
request: Request,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="get_all_filenames",
client=client,
)
return await sync_to_async(list)(EntryAdapters.get_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
async def indexer(
request: Request,
files: list[UploadFile],
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
regenerate: bool = False,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
user = request.user.object
method = "regenerate" if regenerate else "sync"
index_files: Dict[str, Dict[str, str]] = {
"org": {},
"markdown": {},
"pdf": {},
"plaintext": {},
"image": {},
"docx": {},
}
try:
logger.info(f"📬 Updating content index via API call by {client} client")
for file in files:
file_content = file.file.read()
file_type, encoding = get_file_type(file.content_type, file_content)
if file_type in index_files:
index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content
else:
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
indexer_input = IndexerInput(
org=index_files["org"],
markdown=index_files["markdown"],
pdf=index_files["pdf"],
plaintext=index_files["plaintext"],
image=index_files["image"],
docx=index_files["docx"],
)
if state.config == None:
logger.info("📬 Initializing content index on first run.")
default_full_config = FullConfig(
content_type=None,
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
processor=None,
)
state.config = default_full_config
default_content_config = ContentConfig(
org=None,
markdown=None,
pdf=None,
docx=None,
image=None,
github=None,
notion=None,
plaintext=None,
)
state.config.content_type = default_content_config
save_config_to_file_updated_state()
configure_search(state.search_models, state.config.search_type)
loop = asyncio.get_event_loop()
success = await loop.run_in_executor(
None,
configure_content,
indexer_input.model_dump(),
regenerate,
t,
user,
)
if not success:
raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index")
logger.info(f"Finished {method} {t} data sent by {client} client into content index")
except Exception as e:
logger.error(f"Failed to {method} {t} data sent by {client} client into content index: {e}", exc_info=True)
logger.error(
f"🚨 Failed to {method} {t} data sent by {client} client into content index: {e}",
exc_info=True,
)
return Response(content="Failed", status_code=500)
indexing_metadata = {
"num_org": len(index_files["org"]),
"num_markdown": len(index_files["markdown"]),
"num_pdf": len(index_files["pdf"]),
"num_plaintext": len(index_files["plaintext"]),
"num_image": len(index_files["image"]),
"num_docx": len(index_files["docx"]),
}
update_telemetry_state(
request=request,
telemetry_type="api",
api="index/update",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
metadata=indexing_metadata,
)
logger.info(f"📪 Content index updated via API call by {client} client")
indexed_filenames = ",".join(file for ctype in index_files for file in index_files[ctype]) or ""
return Response(content=indexed_filenames, status_code=200)
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
# Run Validation Checks
if search_models is None:
search_models = SearchModels()
return search_models
def map_config_to_object(content_source: str):
if content_source == DbEntry.EntrySource.GITHUB:
return GithubConfig
if content_source == DbEntry.EntrySource.NOTION:
return NotionConfig
if content_source == DbEntry.EntrySource.COMPUTER:
return "Computer"
async def map_config_to_db(config: FullConfig, user: KhojUser):
if config.content_type:
if config.content_type.org:
await LocalOrgConfig.objects.filter(user=user).adelete()
await LocalOrgConfig.objects.acreate(
input_files=config.content_type.org.input_files,
input_filter=config.content_type.org.input_filter,
index_heading_entries=config.content_type.org.index_heading_entries,
user=user,
)
if config.content_type.markdown:
await LocalMarkdownConfig.objects.filter(user=user).adelete()
await LocalMarkdownConfig.objects.acreate(
input_files=config.content_type.markdown.input_files,
input_filter=config.content_type.markdown.input_filter,
index_heading_entries=config.content_type.markdown.index_heading_entries,
user=user,
)
if config.content_type.pdf:
await LocalPdfConfig.objects.filter(user=user).adelete()
await LocalPdfConfig.objects.acreate(
input_files=config.content_type.pdf.input_files,
input_filter=config.content_type.pdf.input_filter,
index_heading_entries=config.content_type.pdf.index_heading_entries,
user=user,
)
if config.content_type.plaintext:
await LocalPlaintextConfig.objects.filter(user=user).adelete()
await LocalPlaintextConfig.objects.acreate(
input_files=config.content_type.plaintext.input_files,
input_filter=config.content_type.plaintext.input_filter,
index_heading_entries=config.content_type.plaintext.index_heading_entries,
user=user,
)
if config.content_type.github:
await adapters.set_user_github_config(
user=user,
pat_token=config.content_type.github.pat_token,
repos=config.content_type.github.repos,
)
if config.content_type.notion:
await adapters.set_notion_config(
user=user,
token=config.content_type.notion.token,
)
def _initialize_config():
if state.config is None:
state.config = FullConfig()
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])

View File

@@ -0,0 +1,156 @@
import json
import logging
from typing import Dict, Optional, Union
from fastapi import APIRouter, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.routers.helpers import update_telemetry_state
api_model = APIRouter()
logger = logging.getLogger(__name__)
@api_model.get("/chat/options", response_model=Dict[str, Union[str, int]])
def get_chat_model_options(
request: Request,
client: Optional[str] = None,
):
conversation_options = ConversationAdapters.get_conversation_processor_options().all()
all_conversation_options = list()
for conversation_option in conversation_options:
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
@api_model.get("/chat")
@requires(["authenticated"])
def get_user_chat_model(
request: Request,
client: Optional[str] = None,
):
user = request.user.object
chat_model = ConversationAdapters.get_conversation_config(user)
if chat_model is None:
chat_model = ConversationAdapters.get_default_conversation_config()
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
@api_model.post("/chat", status_code=200)
@requires(["authenticated", "premium"])
async def update_chat_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_conversation_chat_model",
client=client,
metadata={"processor_conversation_type": "conversation"},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}
@api_model.post("/voice", status_code=200)
@requires(["authenticated", "premium"])
async def update_voice_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
new_config = await ConversationAdapters.aset_user_voice_model(user, id)
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_voice_model",
client=client,
)
if new_config is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Model not found"}))
return Response(status_code=202, content=json.dumps({"status": "ok"}))
@api_model.post("/search", status_code=200)
@requires(["authenticated"])
async def update_search_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
prev_config = await adapters.aget_user_search_model(user)
new_config = await adapters.aset_user_search_model(user, int(id))
if prev_config and int(id) != prev_config.id and new_config:
await EntryAdapters.adelete_all_entries(user)
if not prev_config:
# If the use was just using the default config, delete all the entries and set the new config.
await EntryAdapters.adelete_all_entries(user)
if new_config is None:
return {"status": "error", "message": "Model not found"}
else:
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_search_model",
client=client,
metadata={"search_model": new_config.setting.name},
)
return {"status": "ok"}
@api_model.post("/paint", status_code=200)
@requires(["authenticated"])
async def update_paint_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_paint_model",
client=client,
metadata={"paint_model": new_config.setting.model_name},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}

View File

@@ -1313,7 +1313,6 @@ def configure_content(
files: Optional[dict[str, dict[str, str]]],
regenerate: bool = False,
t: Optional[state.SearchType] = state.SearchType.All,
full_corpus: bool = True,
user: KhojUser = None,
) -> bool:
success = True
@@ -1344,7 +1343,6 @@ def configure_content(
OrgToEntries,
files.get("org"),
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
)
except Exception as e:
@@ -1362,7 +1360,6 @@ def configure_content(
MarkdownToEntries,
files.get("markdown"),
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
)
@@ -1379,7 +1376,6 @@ def configure_content(
PdfToEntries,
files.get("pdf"),
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
)
@@ -1398,7 +1394,6 @@ def configure_content(
PlaintextToEntries,
files.get("plaintext"),
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
)
@@ -1418,7 +1413,6 @@ def configure_content(
GithubToEntries,
None,
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
config=github_config,
)
@@ -1439,7 +1433,6 @@ def configure_content(
NotionToEntries,
None,
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
config=notion_config,
)
@@ -1459,7 +1452,6 @@ def configure_content(
ImageToEntries,
files.get("image"),
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
)
except Exception as e:
@@ -1472,7 +1464,6 @@ def configure_content(
DocxToEntries,
files.get("docx"),
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
)
except Exception as e:

View File

@@ -1,166 +0,0 @@
import asyncio
import logging
from typing import Dict, Optional, Union
from fastapi import APIRouter, Depends, Header, Request, Response, UploadFile
from pydantic import BaseModel
from starlette.authentication import requires
from khoj.routers.helpers import (
ApiIndexedDataLimiter,
configure_content,
update_telemetry_state,
)
from khoj.utils import constants, state
from khoj.utils.config import SearchModels
from khoj.utils.helpers import get_file_type
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
from khoj.utils.yaml import save_config_to_file_updated_state
logger = logging.getLogger(__name__)
indexer = APIRouter()
class File(BaseModel):
path: str
content: Union[str, bytes]
class IndexBatchRequest(BaseModel):
files: list[File]
class IndexerInput(BaseModel):
org: Optional[dict[str, str]] = None
markdown: Optional[dict[str, str]] = None
pdf: Optional[dict[str, bytes]] = None
plaintext: Optional[dict[str, str]] = None
image: Optional[dict[str, bytes]] = None
docx: Optional[dict[str, bytes]] = None
@indexer.post("/update")
@requires(["authenticated"])
async def update(
request: Request,
files: list[UploadFile],
force: bool = False,
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
ApiIndexedDataLimiter(
incoming_entries_size_limit=10,
subscribed_incoming_entries_size_limit=25,
total_entries_size_limit=10,
subscribed_total_entries_size_limit=100,
)
),
):
user = request.user.object
index_files: Dict[str, Dict[str, str]] = {
"org": {},
"markdown": {},
"pdf": {},
"plaintext": {},
"image": {},
"docx": {},
}
try:
logger.info(f"📬 Updating content index via API call by {client} client")
for file in files:
file_content = file.file.read()
file_type, encoding = get_file_type(file.content_type, file_content)
if file_type in index_files:
index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content
else:
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
indexer_input = IndexerInput(
org=index_files["org"],
markdown=index_files["markdown"],
pdf=index_files["pdf"],
plaintext=index_files["plaintext"],
image=index_files["image"],
docx=index_files["docx"],
)
if state.config == None:
logger.info("📬 Initializing content index on first run.")
default_full_config = FullConfig(
content_type=None,
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
processor=None,
)
state.config = default_full_config
default_content_config = ContentConfig(
org=None,
markdown=None,
pdf=None,
docx=None,
image=None,
github=None,
notion=None,
plaintext=None,
)
state.config.content_type = default_content_config
save_config_to_file_updated_state()
configure_search(state.search_models, state.config.search_type)
# Extract required fields from config
loop = asyncio.get_event_loop()
success = await loop.run_in_executor(
None,
configure_content,
indexer_input.model_dump(),
force,
t,
False,
user,
)
if not success:
raise RuntimeError("Failed to update content index")
logger.info(f"Finished processing batch indexing request")
except Exception as e:
logger.error(f"Failed to process batch indexing request: {e}", exc_info=True)
logger.error(
f'🚨 Failed to {"force " if force else ""}update {t} content index triggered via API call by {client} client: {e}',
exc_info=True,
)
return Response(content="Failed", status_code=500)
indexing_metadata = {
"num_org": len(index_files["org"]),
"num_markdown": len(index_files["markdown"]),
"num_pdf": len(index_files["pdf"]),
"num_plaintext": len(index_files["plaintext"]),
"num_image": len(index_files["image"]),
"num_docx": len(index_files["docx"]),
}
update_telemetry_state(
request=request,
telemetry_type="api",
api="index/update",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
metadata=indexing_metadata,
)
logger.info(f"📪 Content index updated via API call by {client} client")
indexed_filenames = ",".join(file for ctype in index_files for file in index_files[ctype]) or ""
return Response(content=indexed_filenames, status_code=200)
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
# Run Validation Checks
if search_models is None:
search_models = SearchModels()
return search_models

View File

@@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas
notion_redirect = str(request.app.url_path_for("notion_config_page"))
# Trigger an async job to configure_content. Let it run without blocking the response.
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, True, user)
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user)
return RedirectResponse(notion_redirect)

View File

@@ -199,17 +199,16 @@ def setup(
text_to_entries: Type[TextToEntries],
files: dict[str, str],
regenerate: bool,
full_corpus: bool = True,
user: KhojUser = None,
config=None,
) -> None:
) -> Tuple[int, int]:
if config:
num_new_embeddings, num_deleted_embeddings = text_to_entries(config).process(
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
files=files, user=user, regenerate=regenerate
)
else:
num_new_embeddings, num_deleted_embeddings = text_to_entries().process(
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
files=files, user=user, regenerate=regenerate
)
if files:
@@ -219,6 +218,8 @@ def setup(
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names[:10]} ..."
)
return num_new_embeddings, num_deleted_embeddings
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
"""Score all retrieved entries using the cross-encoder"""

View File

@@ -124,7 +124,7 @@ def get_org_files(config: TextContentConfig):
logger.debug("At least one of org-files or org-file-filter is required to be specified")
return {}
"Get Org files to process"
# Get Org files to process
absolute_org_files, filtered_org_files = set(), set()
if org_files:
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}