mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 05:39:12 +00:00
Split Configure API into Content, Model API paths (#857)
## Major: Breaking Changes - Move API endpoints under /configure/<type>/model to /api/model/<type> - Move API endpoints under /api/configure/content/ to /api/content/ - Accept file deletion requests by clients during sync - Split /api/v1/index/update into /api/content PUT, PATCH API endpoints ## Minor: Create New API Endpoint - Create API endpoints to get user content configurations Related: #852
This commit is contained in:
@@ -233,11 +233,15 @@ function pushDataToKhoj (regenerate = false) {
|
||||
|
||||
// Request indexing files on server. With upto 1000 files in each request
|
||||
for (let i = 0; i < filesDataToPush.length; i += 1000) {
|
||||
const syncUrl = `${hostURL}/api/content?client=desktop`;
|
||||
const filesDataGroup = filesDataToPush.slice(i, i + 1000);
|
||||
const formData = new FormData();
|
||||
filesDataGroup.forEach(fileData => { formData.append('files', fileData.blob, fileData.path) });
|
||||
let request = axios.post(`${hostURL}/api/v1/index/update?force=${regenerate}&client=desktop`, formData, { headers });
|
||||
requests.push(request);
|
||||
requests.push(
|
||||
regenerate
|
||||
? axios.put(syncUrl, formData, { headers })
|
||||
: axios.patch(syncUrl, formData, { headers })
|
||||
);
|
||||
}
|
||||
|
||||
// Wait for requests batch to finish
|
||||
|
||||
@@ -212,7 +212,7 @@
|
||||
const headers = { 'Authorization': `Bearer ${khojToken}` };
|
||||
|
||||
// Populate type dropdown field with enabled content types only
|
||||
fetch(`${hostURL}/api/configure/types`, { headers })
|
||||
fetch(`${hostURL}/api/content/types`, { headers })
|
||||
.then(response => response.json())
|
||||
.then(enabled_types => {
|
||||
// Show warning if no content types are enabled
|
||||
|
||||
@@ -424,12 +424,12 @@ Auto invokes setup steps on calling main entrypoint."
|
||||
"Send multi-part form `BODY' of `CONTENT-TYPE' in request to khoj server.
|
||||
Append 'TYPE-QUERY' as query parameter in request url.
|
||||
Specify `BOUNDARY' used to separate files in request header."
|
||||
(let ((url-request-method "POST")
|
||||
(let ((url-request-method ((if force) "PUT" "PATCH"))
|
||||
(url-request-data body)
|
||||
(url-request-extra-headers `(("content-type" . ,(format "multipart/form-data; boundary=%s" boundary))
|
||||
("Authorization" . ,(format "Bearer %s" khoj-api-key)))))
|
||||
(with-current-buffer
|
||||
(url-retrieve (format "%s/api/v1/index/update?%s&force=%s&client=emacs" khoj-server-url type-query (or force "false"))
|
||||
(url-retrieve (format "%s/api/content?%s&client=emacs" khoj-server-url type-query)
|
||||
;; render response from indexing API endpoint on server
|
||||
(lambda (status)
|
||||
(if (not (plist-get status :error))
|
||||
@@ -697,7 +697,7 @@ Optionally apply CALLBACK with JSON parsed response and CBARGS."
|
||||
|
||||
(defun khoj--get-enabled-content-types ()
|
||||
"Get content types enabled for search from API."
|
||||
(khoj--call-api "/api/configure/types" "GET" nil `(lambda (item) (mapcar #'intern item))))
|
||||
(khoj--call-api "/api/content/types" "GET" nil `(lambda (item) (mapcar #'intern item))))
|
||||
|
||||
(defun khoj--query-search-api-and-render-results (query content-type buffer-name &optional rerank is-find-similar)
|
||||
"Query Khoj Search API with QUERY, CONTENT-TYPE and RERANK as query params.
|
||||
|
||||
@@ -89,10 +89,11 @@ export async function updateContentIndex(vault: Vault, setting: KhojSetting, las
|
||||
for (let i = 0; i < fileData.length; i += 1000) {
|
||||
const filesGroup = fileData.slice(i, i + 1000);
|
||||
const formData = new FormData();
|
||||
const method = regenerate ? "PUT" : "PATCH";
|
||||
filesGroup.forEach(fileItem => { formData.append('files', fileItem.blob, fileItem.path) });
|
||||
// Call Khoj backend to update index with all markdown, pdf files
|
||||
const response = await fetch(`${setting.khojUrl}/api/v1/index/update?force=${regenerate}&client=obsidian`, {
|
||||
method: 'POST',
|
||||
const response = await fetch(`${setting.khojUrl}/api/content?client=obsidian`, {
|
||||
method: method,
|
||||
headers: {
|
||||
'Authorization': `Bearer ${setting.khojApiKey}`,
|
||||
},
|
||||
|
||||
@@ -277,8 +277,8 @@ export function uploadDataForIndexing(
|
||||
// Wait for all files to be read before making the fetch request
|
||||
Promise.all(fileReadPromises)
|
||||
.then(() => {
|
||||
return fetch("/api/v1/index/update?force=false&client=web", {
|
||||
method: "POST",
|
||||
return fetch("/api/content?client=web", {
|
||||
method: "PATCH",
|
||||
body: formData,
|
||||
});
|
||||
})
|
||||
|
||||
@@ -68,8 +68,8 @@ interface ModelPickerProps {
|
||||
}
|
||||
|
||||
export const ModelPicker: React.FC<any> = (props: ModelPickerProps) => {
|
||||
const { data: models } = useOptionsRequest('/api/configure/chat/model/options');
|
||||
const { data: selectedModel } = useSelectedModel('/api/configure/chat/model');
|
||||
const { data: models } = useOptionsRequest('/api/model/chat/options');
|
||||
const { data: selectedModel } = useSelectedModel('/api/model/chat');
|
||||
const [openLoginDialog, setOpenLoginDialog] = React.useState(false);
|
||||
|
||||
let userData = useAuthenticatedData();
|
||||
@@ -94,7 +94,7 @@ export const ModelPicker: React.FC<any> = (props: ModelPickerProps) => {
|
||||
props.setModelUsed(model);
|
||||
}
|
||||
|
||||
fetch('/api/configure/chat/model' + '?id=' + String(model.id), { method: 'POST', body: JSON.stringify(model) })
|
||||
fetch('/api/model/chat' + '?id=' + String(model.id), { method: 'POST', body: JSON.stringify(model) })
|
||||
.then((response) => {
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to select model');
|
||||
|
||||
@@ -148,7 +148,7 @@ interface FilesMenuProps {
|
||||
|
||||
function FilesMenu(props: FilesMenuProps) {
|
||||
// Use SWR to fetch files
|
||||
const { data: files, error } = useSWR<string[]>(props.conversationId ? '/api/configure/content/computer' : null, fetcher);
|
||||
const { data: files, error } = useSWR<string[]>(props.conversationId ? '/api/content/computer' : null, fetcher);
|
||||
const { data: selectedFiles, error: selectedFilesError } = useSWR(props.conversationId ? `/api/chat/conversation/file-filters/${props.conversationId}` : null, fetcher);
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [unfilteredFiles, setUnfilteredFiles] = useState<string[]>([]);
|
||||
|
||||
@@ -42,7 +42,7 @@ from khoj.database.adapters import (
|
||||
)
|
||||
from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.routers.indexer import configure_content, configure_search
|
||||
from khoj.routers.api_content import configure_content, configure_search
|
||||
from khoj.routers.twilio import is_twilio_enabled
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import SearchType
|
||||
@@ -308,16 +308,16 @@ def configure_routes(app):
|
||||
from khoj.routers.api import api
|
||||
from khoj.routers.api_agents import api_agents
|
||||
from khoj.routers.api_chat import api_chat
|
||||
from khoj.routers.api_config import api_config
|
||||
from khoj.routers.indexer import indexer
|
||||
from khoj.routers.api_content import api_content
|
||||
from khoj.routers.api_model import api_model
|
||||
from khoj.routers.notion import notion_router
|
||||
from khoj.routers.web_client import web_client
|
||||
|
||||
app.include_router(api, prefix="/api")
|
||||
app.include_router(api_chat, prefix="/api/chat")
|
||||
app.include_router(api_agents, prefix="/api/agents")
|
||||
app.include_router(api_config, prefix="/api/configure")
|
||||
app.include_router(indexer, prefix="/api/v1/index")
|
||||
app.include_router(api_model, prefix="/api/model")
|
||||
app.include_router(api_content, prefix="/api/content")
|
||||
app.include_router(notion_router, prefix="/api/notion")
|
||||
app.include_router(web_client)
|
||||
|
||||
@@ -336,7 +336,7 @@ def configure_routes(app):
|
||||
if is_twilio_enabled():
|
||||
from khoj.routers.api_phone import api_phone
|
||||
|
||||
app.include_router(api_phone, prefix="/api/configure/phone")
|
||||
app.include_router(api_phone, prefix="/api/phone")
|
||||
logger.info("📞 Enabled Twilio")
|
||||
|
||||
|
||||
|
||||
@@ -998,8 +998,8 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
// Wait for all files to be read before making the fetch request
|
||||
Promise.all(fileReadPromises)
|
||||
.then(() => {
|
||||
return fetch("/api/v1/index/update?force=false&client=web", {
|
||||
method: "POST",
|
||||
return fetch("/api/content?client=web", {
|
||||
method: "PATCH",
|
||||
body: formData,
|
||||
});
|
||||
})
|
||||
@@ -1954,7 +1954,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
}
|
||||
var allFiles;
|
||||
function renderAllFiles() {
|
||||
fetch('/api/configure/content/computer')
|
||||
fetch('/api/content/computer')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
|
||||
|
||||
@@ -32,7 +32,7 @@
|
||||
</style>
|
||||
<script>
|
||||
function removeFile(path) {
|
||||
fetch('/api/configure/content/file?filename=' + path, {
|
||||
fetch('/api/content/file?filename=' + path, {
|
||||
method: 'DELETE',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -48,7 +48,7 @@
|
||||
|
||||
// Get all currently indexed files
|
||||
function getAllComputerFilenames() {
|
||||
fetch('/api/configure/content/computer')
|
||||
fetch('/api/content/computer')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
|
||||
@@ -122,7 +122,7 @@
|
||||
deleteAllComputerFilesButton.textContent = "🗑️ Deleting...";
|
||||
deleteAllComputerFilesButton.disabled = true;
|
||||
|
||||
fetch('/api/configure/content/computer', {
|
||||
fetch('/api/content/computer', {
|
||||
method: 'DELETE',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
||||
@@ -165,7 +165,7 @@
|
||||
|
||||
// Save Github config on server
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
fetch('/api/configure/content/github', {
|
||||
fetch('/api/content/github', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
||||
@@ -45,7 +45,7 @@
|
||||
|
||||
// Save Notion config on server
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
fetch('/api/configure/content/notion', {
|
||||
fetch('/api/content/notion', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
||||
@@ -209,7 +209,7 @@
|
||||
|
||||
function populate_type_dropdown() {
|
||||
// Populate type dropdown field with enabled content types only
|
||||
fetch("/api/configure/types")
|
||||
fetch("/api/content/types")
|
||||
.then(response => response.json())
|
||||
.then(enabled_types => {
|
||||
// Show warning if no content types are enabled, or just one ("all")
|
||||
|
||||
@@ -394,8 +394,8 @@
|
||||
|
||||
function saveProfileGivenName() {
|
||||
const givenName = document.getElementById("profile_given_name").value;
|
||||
fetch('/api/configure/user/name?name=' + givenName, {
|
||||
method: 'POST',
|
||||
fetch('/api/user/name?name=' + givenName, {
|
||||
method: 'PATCH',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
@@ -421,7 +421,7 @@
|
||||
saveVoiceModelButton.disabled = true;
|
||||
saveVoiceModelButton.textContent = "Saving...";
|
||||
|
||||
fetch('/api/configure/voice/model?id=' + voiceModel, {
|
||||
fetch('/api/model/voice?id=' + voiceModel, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -455,7 +455,7 @@
|
||||
saveModelButton.innerHTML = "";
|
||||
saveModelButton.textContent = "Saving...";
|
||||
|
||||
fetch('/api/configure/chat/model?id=' + chatModel, {
|
||||
fetch('/api/model/chat?id=' + chatModel, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -494,7 +494,7 @@
|
||||
saveSearchModelButton.disabled = true;
|
||||
saveSearchModelButton.textContent = "Saving...";
|
||||
|
||||
fetch('/api/configure/search/model?id=' + searchModel, {
|
||||
fetch('/api/model/search?id=' + searchModel, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -526,7 +526,7 @@
|
||||
saveModelButton.disabled = true;
|
||||
saveModelButton.innerHTML = "Saving...";
|
||||
|
||||
fetch('/api/configure/paint/model?id=' + paintModel, {
|
||||
fetch('/api/model/paint?id=' + paintModel, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -553,7 +553,7 @@
|
||||
};
|
||||
|
||||
function clearContentType(content_source) {
|
||||
fetch('/api/configure/content/' + content_source, {
|
||||
fetch('/api/content/' + content_source, {
|
||||
method: 'DELETE',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -676,7 +676,7 @@
|
||||
|
||||
content_sources = ["computer", "github", "notion"];
|
||||
content_sources.forEach(content_source => {
|
||||
fetch(`/api/configure/content/${content_source}`, {
|
||||
fetch(`/api/content/${content_source}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -807,7 +807,7 @@
|
||||
|
||||
function getIndexedDataSize() {
|
||||
document.getElementById("indexed-data-size").textContent = "Calculating...";
|
||||
fetch('/api/configure/content/size')
|
||||
fetch('/api/content/size')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
document.getElementById("indexed-data-size").textContent = data.indexed_data_size_in_mb + " MB used";
|
||||
@@ -815,7 +815,7 @@
|
||||
}
|
||||
|
||||
function removeFile(path) {
|
||||
fetch('/api/configure/content/file?filename=' + path, {
|
||||
fetch('/api/content/file?filename=' + path, {
|
||||
method: 'DELETE',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -890,7 +890,7 @@
|
||||
})
|
||||
|
||||
phonenumberRemoveButton.addEventListener("click", () => {
|
||||
fetch('/api/configure/phone', {
|
||||
fetch('/api/phone', {
|
||||
method: 'DELETE',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -917,7 +917,7 @@
|
||||
}, 5000);
|
||||
} else {
|
||||
const mobileNumber = iti.getNumber();
|
||||
fetch('/api/configure/phone?phone_number=' + mobileNumber, {
|
||||
fetch('/api/phone?phone_number=' + mobileNumber, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -970,7 +970,7 @@
|
||||
return;
|
||||
}
|
||||
|
||||
fetch('/api/configure/phone/verify?code=' + otp, {
|
||||
fetch('/api/phone/verify?code=' + otp, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
||||
@@ -19,16 +19,11 @@ class DocxToEntries(TextToEntries):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
else:
|
||||
deletion_file_names = None
|
||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
|
||||
# Extract Entries from specified Docx files
|
||||
with timer("Extract entries from specified DOCX files", logger):
|
||||
|
||||
@@ -48,9 +48,7 @@ class GithubToEntries(TextToEntries):
|
||||
else:
|
||||
return
|
||||
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
if self.config.pat_token is None or self.config.pat_token == "":
|
||||
logger.error(f"Github PAT token is not set. Skipping github content")
|
||||
raise ValueError("Github PAT token is not set. Skipping github content")
|
||||
|
||||
@@ -20,16 +20,11 @@ class ImageToEntries(TextToEntries):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
else:
|
||||
deletion_file_names = None
|
||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
|
||||
# Extract Entries from specified image files
|
||||
with timer("Extract entries from specified Image files", logger):
|
||||
|
||||
@@ -19,16 +19,11 @@ class MarkdownToEntries(TextToEntries):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
else:
|
||||
deletion_file_names = None
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
|
||||
max_tokens = 256
|
||||
# Extract Entries from specified Markdown files
|
||||
|
||||
@@ -78,9 +78,7 @@ class NotionToEntries(TextToEntries):
|
||||
|
||||
self.body_params = {"page_size": 100}
|
||||
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
current_entries = []
|
||||
|
||||
# Get all pages
|
||||
|
||||
@@ -20,15 +20,10 @@ class OrgToEntries(TextToEntries):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
else:
|
||||
deletion_file_names = None
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
|
||||
# Extract Entries from specified Org files
|
||||
max_tokens = 256
|
||||
|
||||
@@ -22,16 +22,11 @@ class PdfToEntries(TextToEntries):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
else:
|
||||
deletion_file_names = None
|
||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
|
||||
# Extract Entries from specified Pdf files
|
||||
with timer("Extract entries from specified PDF files", logger):
|
||||
|
||||
@@ -20,15 +20,10 @@ class PlaintextToEntries(TextToEntries):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
else:
|
||||
deletion_file_names = None
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
files = {file: files[file] for file in files_to_process}
|
||||
|
||||
# Extract Entries from specified plaintext files
|
||||
with timer("Extract entries from specified Plaintext files", logger):
|
||||
|
||||
@@ -31,9 +31,7 @@ class TextToEntries(ABC):
|
||||
self.date_filter = DateFilter()
|
||||
|
||||
@abstractmethod
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -19,6 +19,7 @@ from fastapi.responses import Response
|
||||
from starlette.authentication import has_required_scope, requires
|
||||
|
||||
from khoj.configure import initialize_content
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import (
|
||||
AutomationAdapters,
|
||||
ConversationAdapters,
|
||||
@@ -39,6 +40,7 @@ from khoj.routers.helpers import (
|
||||
CommonQueryParams,
|
||||
ConversationCommandRateLimiter,
|
||||
acreate_title_from_query,
|
||||
get_user_config,
|
||||
schedule_automation,
|
||||
update_telemetry_state,
|
||||
)
|
||||
@@ -276,6 +278,49 @@ async def transcribe(
|
||||
return Response(content=content, media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
@api.get("/settings", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
def get_settings(request: Request, detailed: Optional[bool] = False) -> Response:
|
||||
user = request.user.object
|
||||
user_config = get_user_config(user, request, is_detailed=detailed)
|
||||
del user_config["request"]
|
||||
|
||||
# Return config data as a JSON response
|
||||
return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
@api.patch("/user/name", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
def set_user_name(
|
||||
request: Request,
|
||||
name: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
split_name = name.split(" ")
|
||||
|
||||
if len(split_name) > 2:
|
||||
raise HTTPException(status_code=400, detail="Name must be in the format: Firstname Lastname")
|
||||
|
||||
if len(split_name) == 1:
|
||||
first_name = split_name[0]
|
||||
last_name = ""
|
||||
else:
|
||||
first_name, last_name = split_name[0], split_name[-1]
|
||||
|
||||
adapters.set_user_name(user, first_name, last_name)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_user_name",
|
||||
client=client,
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
async def extract_references_and_questions(
|
||||
request: Request,
|
||||
meta_log: dict,
|
||||
|
||||
@@ -1,434 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response
|
||||
from starlette.authentication import has_required_scope, requires
|
||||
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import (
|
||||
GithubConfig,
|
||||
KhojUser,
|
||||
LocalMarkdownConfig,
|
||||
LocalOrgConfig,
|
||||
LocalPdfConfig,
|
||||
LocalPlaintextConfig,
|
||||
NotionConfig,
|
||||
Subscription,
|
||||
)
|
||||
from khoj.routers.helpers import CommonQueryParams, update_telemetry_state
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.rawconfig import (
|
||||
FullConfig,
|
||||
GithubContentConfig,
|
||||
NotionContentConfig,
|
||||
SearchConfig,
|
||||
)
|
||||
from khoj.utils.state import SearchType
|
||||
|
||||
api_config = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def map_config_to_object(content_source: str):
|
||||
if content_source == DbEntry.EntrySource.GITHUB:
|
||||
return GithubConfig
|
||||
if content_source == DbEntry.EntrySource.NOTION:
|
||||
return NotionConfig
|
||||
if content_source == DbEntry.EntrySource.COMPUTER:
|
||||
return "Computer"
|
||||
|
||||
|
||||
async def map_config_to_db(config: FullConfig, user: KhojUser):
|
||||
if config.content_type:
|
||||
if config.content_type.org:
|
||||
await LocalOrgConfig.objects.filter(user=user).adelete()
|
||||
await LocalOrgConfig.objects.acreate(
|
||||
input_files=config.content_type.org.input_files,
|
||||
input_filter=config.content_type.org.input_filter,
|
||||
index_heading_entries=config.content_type.org.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.markdown:
|
||||
await LocalMarkdownConfig.objects.filter(user=user).adelete()
|
||||
await LocalMarkdownConfig.objects.acreate(
|
||||
input_files=config.content_type.markdown.input_files,
|
||||
input_filter=config.content_type.markdown.input_filter,
|
||||
index_heading_entries=config.content_type.markdown.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.pdf:
|
||||
await LocalPdfConfig.objects.filter(user=user).adelete()
|
||||
await LocalPdfConfig.objects.acreate(
|
||||
input_files=config.content_type.pdf.input_files,
|
||||
input_filter=config.content_type.pdf.input_filter,
|
||||
index_heading_entries=config.content_type.pdf.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.plaintext:
|
||||
await LocalPlaintextConfig.objects.filter(user=user).adelete()
|
||||
await LocalPlaintextConfig.objects.acreate(
|
||||
input_files=config.content_type.plaintext.input_files,
|
||||
input_filter=config.content_type.plaintext.input_filter,
|
||||
index_heading_entries=config.content_type.plaintext.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.github:
|
||||
await adapters.set_user_github_config(
|
||||
user=user,
|
||||
pat_token=config.content_type.github.pat_token,
|
||||
repos=config.content_type.github.repos,
|
||||
)
|
||||
if config.content_type.notion:
|
||||
await adapters.set_notion_config(
|
||||
user=user,
|
||||
token=config.content_type.notion.token,
|
||||
)
|
||||
|
||||
|
||||
def _initialize_config():
|
||||
if state.config is None:
|
||||
state.config = FullConfig()
|
||||
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
|
||||
|
||||
|
||||
@api_config.post("/content/github", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def set_content_github(
|
||||
request: Request,
|
||||
updated_config: Union[GithubContentConfig, None],
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
_initialize_config()
|
||||
|
||||
user = request.user.object
|
||||
|
||||
try:
|
||||
await adapters.set_user_github_config(
|
||||
user=user,
|
||||
pat_token=updated_config.pat_token,
|
||||
repos=updated_config.repos,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_content_config",
|
||||
client=client,
|
||||
metadata={"content_type": "github"},
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_config.post("/content/notion", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def set_content_notion(
|
||||
request: Request,
|
||||
updated_config: Union[NotionContentConfig, None],
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
_initialize_config()
|
||||
|
||||
user = request.user.object
|
||||
|
||||
try:
|
||||
await adapters.set_notion_config(
|
||||
user=user,
|
||||
token=updated_config.token,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_content_config",
|
||||
client=client,
|
||||
metadata={"content_type": "notion"},
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_config.delete("/content/{content_source}", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def delete_content_source(
|
||||
request: Request,
|
||||
content_source: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="delete_content_config",
|
||||
client=client,
|
||||
metadata={"content_source": content_source},
|
||||
)
|
||||
|
||||
content_object = map_config_to_object(content_source)
|
||||
if content_object is None:
|
||||
raise ValueError(f"Invalid content source: {content_source}")
|
||||
elif content_object != "Computer":
|
||||
await content_object.objects.filter(user=user).adelete()
|
||||
await sync_to_async(EntryAdapters.delete_all_entries)(user, file_source=content_source)
|
||||
|
||||
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_config.delete("/content/file", status_code=201)
|
||||
@requires(["authenticated"])
|
||||
async def delete_content_file(
|
||||
request: Request,
|
||||
filename: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="delete_file",
|
||||
client=client,
|
||||
)
|
||||
|
||||
await EntryAdapters.adelete_entry_by_file(user, filename)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_config.get("/content/{content_source}", response_model=List[str])
|
||||
@requires(["authenticated"])
|
||||
async def get_content_source(
|
||||
request: Request,
|
||||
content_source: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="get_all_filenames",
|
||||
client=client,
|
||||
)
|
||||
|
||||
return await sync_to_async(list)(EntryAdapters.get_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
|
||||
|
||||
|
||||
@api_config.get("/chat/model/options", response_model=Dict[str, Union[str, int]])
|
||||
def get_chat_model_options(
|
||||
request: Request,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
conversation_options = ConversationAdapters.get_conversation_processor_options().all()
|
||||
|
||||
all_conversation_options = list()
|
||||
for conversation_option in conversation_options:
|
||||
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
|
||||
|
||||
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
@api_config.get("/chat/model")
|
||||
@requires(["authenticated"])
|
||||
def get_user_chat_model(
|
||||
request: Request,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
chat_model = ConversationAdapters.get_conversation_config(user)
|
||||
|
||||
if chat_model is None:
|
||||
chat_model = ConversationAdapters.get_default_conversation_config()
|
||||
|
||||
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
|
||||
|
||||
|
||||
@api_config.post("/chat/model", status_code=200)
|
||||
@requires(["authenticated", "premium"])
|
||||
async def update_chat_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_conversation_chat_model",
|
||||
client=client,
|
||||
metadata={"processor_conversation_type": "conversation"},
|
||||
)
|
||||
|
||||
if new_config is None:
|
||||
return {"status": "error", "message": "Model not found"}
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_config.post("/voice/model", status_code=200)
|
||||
@requires(["authenticated", "premium"])
|
||||
async def update_voice_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
new_config = await ConversationAdapters.aset_user_voice_model(user, id)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_voice_model",
|
||||
client=client,
|
||||
)
|
||||
|
||||
if new_config is None:
|
||||
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Model not found"}))
|
||||
|
||||
return Response(status_code=202, content=json.dumps({"status": "ok"}))
|
||||
|
||||
|
||||
@api_config.post("/search/model", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def update_search_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
prev_config = await adapters.aget_user_search_model(user)
|
||||
new_config = await adapters.aset_user_search_model(user, int(id))
|
||||
|
||||
if prev_config and int(id) != prev_config.id and new_config:
|
||||
await EntryAdapters.adelete_all_entries(user)
|
||||
|
||||
if not prev_config:
|
||||
# If the use was just using the default config, delete all the entries and set the new config.
|
||||
await EntryAdapters.adelete_all_entries(user)
|
||||
|
||||
if new_config is None:
|
||||
return {"status": "error", "message": "Model not found"}
|
||||
else:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_search_model",
|
||||
client=client,
|
||||
metadata={"search_model": new_config.setting.name},
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_config.post("/paint/model", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def update_paint_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
|
||||
if not subscribed:
|
||||
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
|
||||
|
||||
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_paint_model",
|
||||
client=client,
|
||||
metadata={"paint_model": new_config.setting.model_name},
|
||||
)
|
||||
|
||||
if new_config is None:
|
||||
return {"status": "error", "message": "Model not found"}
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_config.get("/content/size", response_model=Dict[str, int])
|
||||
@requires(["authenticated"])
|
||||
async def get_content_size(request: Request, common: CommonQueryParams):
|
||||
user = request.user.object
|
||||
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
|
||||
return Response(
|
||||
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
|
||||
media_type="application/json",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@api_config.post("/user/name", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
def set_user_name(
|
||||
request: Request,
|
||||
name: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
split_name = name.split(" ")
|
||||
|
||||
if len(split_name) > 2:
|
||||
raise HTTPException(status_code=400, detail="Name must be in the format: Firstname Lastname")
|
||||
|
||||
if len(split_name) == 1:
|
||||
first_name = split_name[0]
|
||||
last_name = ""
|
||||
else:
|
||||
first_name, last_name = split_name[0], split_name[-1]
|
||||
|
||||
adapters.set_user_name(user, first_name, last_name)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_user_name",
|
||||
client=client,
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_config.get("/types", response_model=List[str])
|
||||
@requires(["authenticated"])
|
||||
def get_config_types(
|
||||
request: Request,
|
||||
):
|
||||
user = request.user.object
|
||||
enabled_file_types = EntryAdapters.get_unique_file_types(user)
|
||||
configured_content_types = list(enabled_file_types)
|
||||
|
||||
if state.config and state.config.content_type:
|
||||
for ctype in state.config.content_type.model_dump(exclude_none=True):
|
||||
configured_content_types.append(ctype)
|
||||
|
||||
return [
|
||||
search_type.value
|
||||
for search_type in SearchType
|
||||
if (search_type.value in configured_content_types) or search_type == SearchType.All
|
||||
]
|
||||
508
src/khoj/routers/api_content.py
Normal file
508
src/khoj/routers/api_content.py
Normal file
@@ -0,0 +1,508 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
Header,
|
||||
HTTPException,
|
||||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from starlette.authentication import requires
|
||||
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import (
|
||||
EntryAdapters,
|
||||
get_user_github_config,
|
||||
get_user_notion_config,
|
||||
)
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import (
|
||||
GithubConfig,
|
||||
GithubRepoConfig,
|
||||
KhojUser,
|
||||
LocalMarkdownConfig,
|
||||
LocalOrgConfig,
|
||||
LocalPdfConfig,
|
||||
LocalPlaintextConfig,
|
||||
NotionConfig,
|
||||
)
|
||||
from khoj.routers.helpers import (
|
||||
ApiIndexedDataLimiter,
|
||||
CommonQueryParams,
|
||||
configure_content,
|
||||
get_user_config,
|
||||
update_telemetry_state,
|
||||
)
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.helpers import get_file_type
|
||||
from khoj.utils.rawconfig import (
|
||||
ContentConfig,
|
||||
FullConfig,
|
||||
GithubContentConfig,
|
||||
NotionContentConfig,
|
||||
SearchConfig,
|
||||
)
|
||||
from khoj.utils.state import SearchType
|
||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
api_content = APIRouter()
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
path: str
|
||||
content: Union[str, bytes]
|
||||
|
||||
|
||||
class IndexBatchRequest(BaseModel):
|
||||
files: list[File]
|
||||
|
||||
|
||||
class IndexerInput(BaseModel):
|
||||
org: Optional[dict[str, str]] = None
|
||||
markdown: Optional[dict[str, str]] = None
|
||||
pdf: Optional[dict[str, bytes]] = None
|
||||
plaintext: Optional[dict[str, str]] = None
|
||||
image: Optional[dict[str, bytes]] = None
|
||||
docx: Optional[dict[str, bytes]] = None
|
||||
|
||||
|
||||
@api_content.put("")
|
||||
@requires(["authenticated"])
|
||||
async def put_content(
|
||||
request: Request,
|
||||
files: list[UploadFile],
|
||||
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
|
||||
ApiIndexedDataLimiter(
|
||||
incoming_entries_size_limit=10,
|
||||
subscribed_incoming_entries_size_limit=25,
|
||||
total_entries_size_limit=10,
|
||||
subscribed_total_entries_size_limit=100,
|
||||
)
|
||||
),
|
||||
):
|
||||
return await indexer(request, files, t, True, client, user_agent, referer, host)
|
||||
|
||||
|
||||
@api_content.patch("")
|
||||
@requires(["authenticated"])
|
||||
async def patch_content(
|
||||
request: Request,
|
||||
files: list[UploadFile],
|
||||
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
|
||||
ApiIndexedDataLimiter(
|
||||
incoming_entries_size_limit=10,
|
||||
subscribed_incoming_entries_size_limit=25,
|
||||
total_entries_size_limit=10,
|
||||
subscribed_total_entries_size_limit=100,
|
||||
)
|
||||
),
|
||||
):
|
||||
return await indexer(request, files, t, False, client, user_agent, referer, host)
|
||||
|
||||
|
||||
@api_content.get("/github", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
def get_content_github(request: Request) -> Response:
|
||||
user = request.user.object
|
||||
user_config = get_user_config(user, request)
|
||||
del user_config["request"]
|
||||
|
||||
current_github_config = get_user_github_config(user)
|
||||
|
||||
if current_github_config:
|
||||
raw_repos = current_github_config.githubrepoconfig.all()
|
||||
repos = []
|
||||
for repo in raw_repos:
|
||||
repos.append(
|
||||
GithubRepoConfig(
|
||||
name=repo.name,
|
||||
owner=repo.owner,
|
||||
branch=repo.branch,
|
||||
)
|
||||
)
|
||||
current_config = GithubContentConfig(
|
||||
pat_token=current_github_config.pat_token,
|
||||
repos=repos,
|
||||
)
|
||||
current_config = json.loads(current_config.json())
|
||||
else:
|
||||
current_config = {} # type: ignore
|
||||
|
||||
user_config["current_config"] = current_config
|
||||
|
||||
# Return config data as a JSON response
|
||||
return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
@api_content.get("/notion", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
def get_content_notion(request: Request) -> Response:
|
||||
user = request.user.object
|
||||
user_config = get_user_config(user, request)
|
||||
del user_config["request"]
|
||||
|
||||
current_notion_config = get_user_notion_config(user)
|
||||
token = current_notion_config.token if current_notion_config else ""
|
||||
current_config = NotionContentConfig(token=token)
|
||||
current_config = json.loads(current_config.model_dump_json())
|
||||
|
||||
user_config["current_config"] = current_config
|
||||
|
||||
# Return config data as a JSON response
|
||||
return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
@api_content.post("/github", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def set_content_github(
|
||||
request: Request,
|
||||
updated_config: Union[GithubContentConfig, None],
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
_initialize_config()
|
||||
|
||||
user = request.user.object
|
||||
|
||||
try:
|
||||
await adapters.set_user_github_config(
|
||||
user=user,
|
||||
pat_token=updated_config.pat_token,
|
||||
repos=updated_config.repos,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_content_config",
|
||||
client=client,
|
||||
metadata={"content_type": "github"},
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_content.post("/notion", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def set_content_notion(
|
||||
request: Request,
|
||||
updated_config: Union[NotionContentConfig, None],
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
_initialize_config()
|
||||
|
||||
user = request.user.object
|
||||
|
||||
try:
|
||||
await adapters.set_notion_config(
|
||||
user=user,
|
||||
token=updated_config.token,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to set Notion config")
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_content_config",
|
||||
client=client,
|
||||
metadata={"content_type": "notion"},
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_content.delete("/{content_source}", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def delete_content_source(
|
||||
request: Request,
|
||||
content_source: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="delete_content_config",
|
||||
client=client,
|
||||
metadata={"content_source": content_source},
|
||||
)
|
||||
|
||||
content_object = map_config_to_object(content_source)
|
||||
if content_object is None:
|
||||
raise ValueError(f"Invalid content source: {content_source}")
|
||||
elif content_object != "Computer":
|
||||
await content_object.objects.filter(user=user).adelete()
|
||||
await sync_to_async(EntryAdapters.delete_all_entries)(user, file_source=content_source)
|
||||
|
||||
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_content.delete("/file", status_code=201)
|
||||
@requires(["authenticated"])
|
||||
async def delete_content_file(
|
||||
request: Request,
|
||||
filename: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="delete_file",
|
||||
client=client,
|
||||
)
|
||||
|
||||
await EntryAdapters.adelete_entry_by_file(user, filename)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_content.get("/size", response_model=Dict[str, int])
|
||||
@requires(["authenticated"])
|
||||
async def get_content_size(request: Request, common: CommonQueryParams, client: Optional[str] = None):
|
||||
user = request.user.object
|
||||
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
|
||||
return Response(
|
||||
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
|
||||
media_type="application/json",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@api_content.get("/types", response_model=List[str])
|
||||
@requires(["authenticated"])
|
||||
def get_content_types(request: Request, client: Optional[str] = None):
|
||||
user = request.user.object
|
||||
all_content_types = {s.value for s in SearchType}
|
||||
configured_content_types = set(EntryAdapters.get_unique_file_types(user))
|
||||
configured_content_types |= {"all"}
|
||||
|
||||
if state.config and state.config.content_type:
|
||||
for ctype in state.config.content_type.model_dump(exclude_none=True):
|
||||
configured_content_types.add(ctype)
|
||||
|
||||
return list(configured_content_types & all_content_types)
|
||||
|
||||
|
||||
@api_content.get("/{content_source}", response_model=List[str])
|
||||
@requires(["authenticated"])
|
||||
async def get_content_source(
|
||||
request: Request,
|
||||
content_source: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="get_all_filenames",
|
||||
client=client,
|
||||
)
|
||||
|
||||
return await sync_to_async(list)(EntryAdapters.get_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
|
||||
|
||||
|
||||
async def indexer(
|
||||
request: Request,
|
||||
files: list[UploadFile],
|
||||
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
|
||||
regenerate: bool = False,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
):
|
||||
user = request.user.object
|
||||
method = "regenerate" if regenerate else "sync"
|
||||
index_files: Dict[str, Dict[str, str]] = {
|
||||
"org": {},
|
||||
"markdown": {},
|
||||
"pdf": {},
|
||||
"plaintext": {},
|
||||
"image": {},
|
||||
"docx": {},
|
||||
}
|
||||
try:
|
||||
logger.info(f"📬 Updating content index via API call by {client} client")
|
||||
for file in files:
|
||||
file_content = file.file.read()
|
||||
file_type, encoding = get_file_type(file.content_type, file_content)
|
||||
if file_type in index_files:
|
||||
index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content
|
||||
else:
|
||||
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
|
||||
|
||||
indexer_input = IndexerInput(
|
||||
org=index_files["org"],
|
||||
markdown=index_files["markdown"],
|
||||
pdf=index_files["pdf"],
|
||||
plaintext=index_files["plaintext"],
|
||||
image=index_files["image"],
|
||||
docx=index_files["docx"],
|
||||
)
|
||||
|
||||
if state.config == None:
|
||||
logger.info("📬 Initializing content index on first run.")
|
||||
default_full_config = FullConfig(
|
||||
content_type=None,
|
||||
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
|
||||
processor=None,
|
||||
)
|
||||
state.config = default_full_config
|
||||
default_content_config = ContentConfig(
|
||||
org=None,
|
||||
markdown=None,
|
||||
pdf=None,
|
||||
docx=None,
|
||||
image=None,
|
||||
github=None,
|
||||
notion=None,
|
||||
plaintext=None,
|
||||
)
|
||||
state.config.content_type = default_content_config
|
||||
save_config_to_file_updated_state()
|
||||
configure_search(state.search_models, state.config.search_type)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
success = await loop.run_in_executor(
|
||||
None,
|
||||
configure_content,
|
||||
indexer_input.model_dump(),
|
||||
regenerate,
|
||||
t,
|
||||
user,
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index")
|
||||
logger.info(f"Finished {method} {t} data sent by {client} client into content index")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to {method} {t} data sent by {client} client into content index: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"🚨 Failed to {method} {t} data sent by {client} client into content index: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return Response(content="Failed", status_code=500)
|
||||
|
||||
indexing_metadata = {
|
||||
"num_org": len(index_files["org"]),
|
||||
"num_markdown": len(index_files["markdown"]),
|
||||
"num_pdf": len(index_files["pdf"]),
|
||||
"num_plaintext": len(index_files["plaintext"]),
|
||||
"num_image": len(index_files["image"]),
|
||||
"num_docx": len(index_files["docx"]),
|
||||
}
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="index/update",
|
||||
client=client,
|
||||
user_agent=user_agent,
|
||||
referer=referer,
|
||||
host=host,
|
||||
metadata=indexing_metadata,
|
||||
)
|
||||
|
||||
logger.info(f"📪 Content index updated via API call by {client} client")
|
||||
|
||||
indexed_filenames = ",".join(file for ctype in index_files for file in index_files[ctype]) or ""
|
||||
return Response(content=indexed_filenames, status_code=200)
|
||||
|
||||
|
||||
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
|
||||
# Run Validation Checks
|
||||
if search_models is None:
|
||||
search_models = SearchModels()
|
||||
|
||||
return search_models
|
||||
|
||||
|
||||
def map_config_to_object(content_source: str):
|
||||
if content_source == DbEntry.EntrySource.GITHUB:
|
||||
return GithubConfig
|
||||
if content_source == DbEntry.EntrySource.NOTION:
|
||||
return NotionConfig
|
||||
if content_source == DbEntry.EntrySource.COMPUTER:
|
||||
return "Computer"
|
||||
|
||||
|
||||
async def map_config_to_db(config: FullConfig, user: KhojUser):
|
||||
if config.content_type:
|
||||
if config.content_type.org:
|
||||
await LocalOrgConfig.objects.filter(user=user).adelete()
|
||||
await LocalOrgConfig.objects.acreate(
|
||||
input_files=config.content_type.org.input_files,
|
||||
input_filter=config.content_type.org.input_filter,
|
||||
index_heading_entries=config.content_type.org.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.markdown:
|
||||
await LocalMarkdownConfig.objects.filter(user=user).adelete()
|
||||
await LocalMarkdownConfig.objects.acreate(
|
||||
input_files=config.content_type.markdown.input_files,
|
||||
input_filter=config.content_type.markdown.input_filter,
|
||||
index_heading_entries=config.content_type.markdown.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.pdf:
|
||||
await LocalPdfConfig.objects.filter(user=user).adelete()
|
||||
await LocalPdfConfig.objects.acreate(
|
||||
input_files=config.content_type.pdf.input_files,
|
||||
input_filter=config.content_type.pdf.input_filter,
|
||||
index_heading_entries=config.content_type.pdf.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.plaintext:
|
||||
await LocalPlaintextConfig.objects.filter(user=user).adelete()
|
||||
await LocalPlaintextConfig.objects.acreate(
|
||||
input_files=config.content_type.plaintext.input_files,
|
||||
input_filter=config.content_type.plaintext.input_filter,
|
||||
index_heading_entries=config.content_type.plaintext.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.github:
|
||||
await adapters.set_user_github_config(
|
||||
user=user,
|
||||
pat_token=config.content_type.github.pat_token,
|
||||
repos=config.content_type.github.repos,
|
||||
)
|
||||
if config.content_type.notion:
|
||||
await adapters.set_notion_config(
|
||||
user=user,
|
||||
token=config.content_type.notion.token,
|
||||
)
|
||||
|
||||
|
||||
def _initialize_config():
|
||||
if state.config is None:
|
||||
state.config = FullConfig()
|
||||
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
|
||||
156
src/khoj/routers/api_model.py
Normal file
156
src/khoj/routers/api_model.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response
|
||||
from starlette.authentication import has_required_scope, requires
|
||||
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.routers.helpers import update_telemetry_state
|
||||
|
||||
api_model = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@api_model.get("/chat/options", response_model=Dict[str, Union[str, int]])
|
||||
def get_chat_model_options(
|
||||
request: Request,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
conversation_options = ConversationAdapters.get_conversation_processor_options().all()
|
||||
|
||||
all_conversation_options = list()
|
||||
for conversation_option in conversation_options:
|
||||
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
|
||||
|
||||
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
@api_model.get("/chat")
|
||||
@requires(["authenticated"])
|
||||
def get_user_chat_model(
|
||||
request: Request,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
chat_model = ConversationAdapters.get_conversation_config(user)
|
||||
|
||||
if chat_model is None:
|
||||
chat_model = ConversationAdapters.get_default_conversation_config()
|
||||
|
||||
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
|
||||
|
||||
|
||||
@api_model.post("/chat", status_code=200)
|
||||
@requires(["authenticated", "premium"])
|
||||
async def update_chat_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_conversation_chat_model",
|
||||
client=client,
|
||||
metadata={"processor_conversation_type": "conversation"},
|
||||
)
|
||||
|
||||
if new_config is None:
|
||||
return {"status": "error", "message": "Model not found"}
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_model.post("/voice", status_code=200)
|
||||
@requires(["authenticated", "premium"])
|
||||
async def update_voice_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
new_config = await ConversationAdapters.aset_user_voice_model(user, id)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_voice_model",
|
||||
client=client,
|
||||
)
|
||||
|
||||
if new_config is None:
|
||||
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Model not found"}))
|
||||
|
||||
return Response(status_code=202, content=json.dumps({"status": "ok"}))
|
||||
|
||||
|
||||
@api_model.post("/search", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def update_search_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
prev_config = await adapters.aget_user_search_model(user)
|
||||
new_config = await adapters.aset_user_search_model(user, int(id))
|
||||
|
||||
if prev_config and int(id) != prev_config.id and new_config:
|
||||
await EntryAdapters.adelete_all_entries(user)
|
||||
|
||||
if not prev_config:
|
||||
# If the use was just using the default config, delete all the entries and set the new config.
|
||||
await EntryAdapters.adelete_all_entries(user)
|
||||
|
||||
if new_config is None:
|
||||
return {"status": "error", "message": "Model not found"}
|
||||
else:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_search_model",
|
||||
client=client,
|
||||
metadata={"search_model": new_config.setting.name},
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_model.post("/paint", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def update_paint_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
|
||||
if not subscribed:
|
||||
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
|
||||
|
||||
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_paint_model",
|
||||
client=client,
|
||||
metadata={"paint_model": new_config.setting.model_name},
|
||||
)
|
||||
|
||||
if new_config is None:
|
||||
return {"status": "error", "message": "Model not found"}
|
||||
|
||||
return {"status": "ok"}
|
||||
@@ -1313,7 +1313,6 @@ def configure_content(
|
||||
files: Optional[dict[str, dict[str, str]]],
|
||||
regenerate: bool = False,
|
||||
t: Optional[state.SearchType] = state.SearchType.All,
|
||||
full_corpus: bool = True,
|
||||
user: KhojUser = None,
|
||||
) -> bool:
|
||||
success = True
|
||||
@@ -1344,7 +1343,6 @@ def configure_content(
|
||||
OrgToEntries,
|
||||
files.get("org"),
|
||||
regenerate=regenerate,
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1362,7 +1360,6 @@ def configure_content(
|
||||
MarkdownToEntries,
|
||||
files.get("markdown"),
|
||||
regenerate=regenerate,
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
)
|
||||
|
||||
@@ -1379,7 +1376,6 @@ def configure_content(
|
||||
PdfToEntries,
|
||||
files.get("pdf"),
|
||||
regenerate=regenerate,
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
)
|
||||
|
||||
@@ -1398,7 +1394,6 @@ def configure_content(
|
||||
PlaintextToEntries,
|
||||
files.get("plaintext"),
|
||||
regenerate=regenerate,
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
)
|
||||
|
||||
@@ -1418,7 +1413,6 @@ def configure_content(
|
||||
GithubToEntries,
|
||||
None,
|
||||
regenerate=regenerate,
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
config=github_config,
|
||||
)
|
||||
@@ -1439,7 +1433,6 @@ def configure_content(
|
||||
NotionToEntries,
|
||||
None,
|
||||
regenerate=regenerate,
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
config=notion_config,
|
||||
)
|
||||
@@ -1459,7 +1452,6 @@ def configure_content(
|
||||
ImageToEntries,
|
||||
files.get("image"),
|
||||
regenerate=regenerate,
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1472,7 +1464,6 @@ def configure_content(
|
||||
DocxToEntries,
|
||||
files.get("docx"),
|
||||
regenerate=regenerate,
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Request, Response, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from starlette.authentication import requires
|
||||
|
||||
from khoj.routers.helpers import (
|
||||
ApiIndexedDataLimiter,
|
||||
configure_content,
|
||||
update_telemetry_state,
|
||||
)
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.helpers import get_file_type
|
||||
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
|
||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
indexer = APIRouter()
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
path: str
|
||||
content: Union[str, bytes]
|
||||
|
||||
|
||||
class IndexBatchRequest(BaseModel):
|
||||
files: list[File]
|
||||
|
||||
|
||||
class IndexerInput(BaseModel):
|
||||
org: Optional[dict[str, str]] = None
|
||||
markdown: Optional[dict[str, str]] = None
|
||||
pdf: Optional[dict[str, bytes]] = None
|
||||
plaintext: Optional[dict[str, str]] = None
|
||||
image: Optional[dict[str, bytes]] = None
|
||||
docx: Optional[dict[str, bytes]] = None
|
||||
|
||||
|
||||
@indexer.post("/update")
|
||||
@requires(["authenticated"])
|
||||
async def update(
|
||||
request: Request,
|
||||
files: list[UploadFile],
|
||||
force: bool = False,
|
||||
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
|
||||
ApiIndexedDataLimiter(
|
||||
incoming_entries_size_limit=10,
|
||||
subscribed_incoming_entries_size_limit=25,
|
||||
total_entries_size_limit=10,
|
||||
subscribed_total_entries_size_limit=100,
|
||||
)
|
||||
),
|
||||
):
|
||||
user = request.user.object
|
||||
index_files: Dict[str, Dict[str, str]] = {
|
||||
"org": {},
|
||||
"markdown": {},
|
||||
"pdf": {},
|
||||
"plaintext": {},
|
||||
"image": {},
|
||||
"docx": {},
|
||||
}
|
||||
try:
|
||||
logger.info(f"📬 Updating content index via API call by {client} client")
|
||||
for file in files:
|
||||
file_content = file.file.read()
|
||||
file_type, encoding = get_file_type(file.content_type, file_content)
|
||||
if file_type in index_files:
|
||||
index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content
|
||||
else:
|
||||
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
|
||||
|
||||
indexer_input = IndexerInput(
|
||||
org=index_files["org"],
|
||||
markdown=index_files["markdown"],
|
||||
pdf=index_files["pdf"],
|
||||
plaintext=index_files["plaintext"],
|
||||
image=index_files["image"],
|
||||
docx=index_files["docx"],
|
||||
)
|
||||
|
||||
if state.config == None:
|
||||
logger.info("📬 Initializing content index on first run.")
|
||||
default_full_config = FullConfig(
|
||||
content_type=None,
|
||||
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
|
||||
processor=None,
|
||||
)
|
||||
state.config = default_full_config
|
||||
default_content_config = ContentConfig(
|
||||
org=None,
|
||||
markdown=None,
|
||||
pdf=None,
|
||||
docx=None,
|
||||
image=None,
|
||||
github=None,
|
||||
notion=None,
|
||||
plaintext=None,
|
||||
)
|
||||
state.config.content_type = default_content_config
|
||||
save_config_to_file_updated_state()
|
||||
configure_search(state.search_models, state.config.search_type)
|
||||
|
||||
# Extract required fields from config
|
||||
loop = asyncio.get_event_loop()
|
||||
success = await loop.run_in_executor(
|
||||
None,
|
||||
configure_content,
|
||||
indexer_input.model_dump(),
|
||||
force,
|
||||
t,
|
||||
False,
|
||||
user,
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to update content index")
|
||||
logger.info(f"Finished processing batch indexing request")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process batch indexing request: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f'🚨 Failed to {"force " if force else ""}update {t} content index triggered via API call by {client} client: {e}',
|
||||
exc_info=True,
|
||||
)
|
||||
return Response(content="Failed", status_code=500)
|
||||
|
||||
indexing_metadata = {
|
||||
"num_org": len(index_files["org"]),
|
||||
"num_markdown": len(index_files["markdown"]),
|
||||
"num_pdf": len(index_files["pdf"]),
|
||||
"num_plaintext": len(index_files["plaintext"]),
|
||||
"num_image": len(index_files["image"]),
|
||||
"num_docx": len(index_files["docx"]),
|
||||
}
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="index/update",
|
||||
client=client,
|
||||
user_agent=user_agent,
|
||||
referer=referer,
|
||||
host=host,
|
||||
metadata=indexing_metadata,
|
||||
)
|
||||
|
||||
logger.info(f"📪 Content index updated via API call by {client} client")
|
||||
|
||||
indexed_filenames = ",".join(file for ctype in index_files for file in index_files[ctype]) or ""
|
||||
return Response(content=indexed_filenames, status_code=200)
|
||||
|
||||
|
||||
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
|
||||
# Run Validation Checks
|
||||
if search_models is None:
|
||||
search_models = SearchModels()
|
||||
|
||||
return search_models
|
||||
@@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas
|
||||
notion_redirect = str(request.app.url_path_for("notion_config_page"))
|
||||
|
||||
# Trigger an async job to configure_content. Let it run without blocking the response.
|
||||
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, True, user)
|
||||
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user)
|
||||
|
||||
return RedirectResponse(notion_redirect)
|
||||
|
||||
@@ -199,17 +199,16 @@ def setup(
|
||||
text_to_entries: Type[TextToEntries],
|
||||
files: dict[str, str],
|
||||
regenerate: bool,
|
||||
full_corpus: bool = True,
|
||||
user: KhojUser = None,
|
||||
config=None,
|
||||
) -> None:
|
||||
) -> Tuple[int, int]:
|
||||
if config:
|
||||
num_new_embeddings, num_deleted_embeddings = text_to_entries(config).process(
|
||||
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
|
||||
files=files, user=user, regenerate=regenerate
|
||||
)
|
||||
else:
|
||||
num_new_embeddings, num_deleted_embeddings = text_to_entries().process(
|
||||
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
|
||||
files=files, user=user, regenerate=regenerate
|
||||
)
|
||||
|
||||
if files:
|
||||
@@ -219,6 +218,8 @@ def setup(
|
||||
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names[:10]} ..."
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
|
||||
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
||||
"""Score all retrieved entries using the cross-encoder"""
|
||||
|
||||
@@ -124,7 +124,7 @@ def get_org_files(config: TextContentConfig):
|
||||
logger.debug("At least one of org-files or org-file-filter is required to be specified")
|
||||
return {}
|
||||
|
||||
"Get Org files to process"
|
||||
# Get Org files to process
|
||||
absolute_org_files, filtered_org_files = set(), set()
|
||||
if org_files:
|
||||
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
|
||||
|
||||
Reference in New Issue
Block a user