mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Split Configure API into Content, Model API paths (#857)
## Major: Breaking Changes - Move API endpoints under /configure/<type>/model to /api/model/<type> - Move API endpoints under /api/configure/content/ to /api/content/ - Accept file deletion requests by clients during sync - Split /api/v1/index/update into /api/content PUT, PATCH API endpoints ## Minor: Create New API Endpoint - Create API endpoints to get user content configurations Related: #852
This commit is contained in:
@@ -233,11 +233,15 @@ function pushDataToKhoj (regenerate = false) {
|
|||||||
|
|
||||||
// Request indexing files on server. With upto 1000 files in each request
|
// Request indexing files on server. With upto 1000 files in each request
|
||||||
for (let i = 0; i < filesDataToPush.length; i += 1000) {
|
for (let i = 0; i < filesDataToPush.length; i += 1000) {
|
||||||
|
const syncUrl = `${hostURL}/api/content?client=desktop`;
|
||||||
const filesDataGroup = filesDataToPush.slice(i, i + 1000);
|
const filesDataGroup = filesDataToPush.slice(i, i + 1000);
|
||||||
const formData = new FormData();
|
const formData = new FormData();
|
||||||
filesDataGroup.forEach(fileData => { formData.append('files', fileData.blob, fileData.path) });
|
filesDataGroup.forEach(fileData => { formData.append('files', fileData.blob, fileData.path) });
|
||||||
let request = axios.post(`${hostURL}/api/v1/index/update?force=${regenerate}&client=desktop`, formData, { headers });
|
requests.push(
|
||||||
requests.push(request);
|
regenerate
|
||||||
|
? axios.put(syncUrl, formData, { headers })
|
||||||
|
: axios.patch(syncUrl, formData, { headers })
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for requests batch to finish
|
// Wait for requests batch to finish
|
||||||
|
|||||||
@@ -212,7 +212,7 @@
|
|||||||
const headers = { 'Authorization': `Bearer ${khojToken}` };
|
const headers = { 'Authorization': `Bearer ${khojToken}` };
|
||||||
|
|
||||||
// Populate type dropdown field with enabled content types only
|
// Populate type dropdown field with enabled content types only
|
||||||
fetch(`${hostURL}/api/configure/types`, { headers })
|
fetch(`${hostURL}/api/content/types`, { headers })
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(enabled_types => {
|
.then(enabled_types => {
|
||||||
// Show warning if no content types are enabled
|
// Show warning if no content types are enabled
|
||||||
|
|||||||
@@ -424,12 +424,12 @@ Auto invokes setup steps on calling main entrypoint."
|
|||||||
"Send multi-part form `BODY' of `CONTENT-TYPE' in request to khoj server.
|
"Send multi-part form `BODY' of `CONTENT-TYPE' in request to khoj server.
|
||||||
Append 'TYPE-QUERY' as query parameter in request url.
|
Append 'TYPE-QUERY' as query parameter in request url.
|
||||||
Specify `BOUNDARY' used to separate files in request header."
|
Specify `BOUNDARY' used to separate files in request header."
|
||||||
(let ((url-request-method "POST")
|
(let ((url-request-method ((if force) "PUT" "PATCH"))
|
||||||
(url-request-data body)
|
(url-request-data body)
|
||||||
(url-request-extra-headers `(("content-type" . ,(format "multipart/form-data; boundary=%s" boundary))
|
(url-request-extra-headers `(("content-type" . ,(format "multipart/form-data; boundary=%s" boundary))
|
||||||
("Authorization" . ,(format "Bearer %s" khoj-api-key)))))
|
("Authorization" . ,(format "Bearer %s" khoj-api-key)))))
|
||||||
(with-current-buffer
|
(with-current-buffer
|
||||||
(url-retrieve (format "%s/api/v1/index/update?%s&force=%s&client=emacs" khoj-server-url type-query (or force "false"))
|
(url-retrieve (format "%s/api/content?%s&client=emacs" khoj-server-url type-query)
|
||||||
;; render response from indexing API endpoint on server
|
;; render response from indexing API endpoint on server
|
||||||
(lambda (status)
|
(lambda (status)
|
||||||
(if (not (plist-get status :error))
|
(if (not (plist-get status :error))
|
||||||
@@ -697,7 +697,7 @@ Optionally apply CALLBACK with JSON parsed response and CBARGS."
|
|||||||
|
|
||||||
(defun khoj--get-enabled-content-types ()
|
(defun khoj--get-enabled-content-types ()
|
||||||
"Get content types enabled for search from API."
|
"Get content types enabled for search from API."
|
||||||
(khoj--call-api "/api/configure/types" "GET" nil `(lambda (item) (mapcar #'intern item))))
|
(khoj--call-api "/api/content/types" "GET" nil `(lambda (item) (mapcar #'intern item))))
|
||||||
|
|
||||||
(defun khoj--query-search-api-and-render-results (query content-type buffer-name &optional rerank is-find-similar)
|
(defun khoj--query-search-api-and-render-results (query content-type buffer-name &optional rerank is-find-similar)
|
||||||
"Query Khoj Search API with QUERY, CONTENT-TYPE and RERANK as query params.
|
"Query Khoj Search API with QUERY, CONTENT-TYPE and RERANK as query params.
|
||||||
|
|||||||
@@ -89,10 +89,11 @@ export async function updateContentIndex(vault: Vault, setting: KhojSetting, las
|
|||||||
for (let i = 0; i < fileData.length; i += 1000) {
|
for (let i = 0; i < fileData.length; i += 1000) {
|
||||||
const filesGroup = fileData.slice(i, i + 1000);
|
const filesGroup = fileData.slice(i, i + 1000);
|
||||||
const formData = new FormData();
|
const formData = new FormData();
|
||||||
|
const method = regenerate ? "PUT" : "PATCH";
|
||||||
filesGroup.forEach(fileItem => { formData.append('files', fileItem.blob, fileItem.path) });
|
filesGroup.forEach(fileItem => { formData.append('files', fileItem.blob, fileItem.path) });
|
||||||
// Call Khoj backend to update index with all markdown, pdf files
|
// Call Khoj backend to update index with all markdown, pdf files
|
||||||
const response = await fetch(`${setting.khojUrl}/api/v1/index/update?force=${regenerate}&client=obsidian`, {
|
const response = await fetch(`${setting.khojUrl}/api/content?client=obsidian`, {
|
||||||
method: 'POST',
|
method: method,
|
||||||
headers: {
|
headers: {
|
||||||
'Authorization': `Bearer ${setting.khojApiKey}`,
|
'Authorization': `Bearer ${setting.khojApiKey}`,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -277,8 +277,8 @@ export function uploadDataForIndexing(
|
|||||||
// Wait for all files to be read before making the fetch request
|
// Wait for all files to be read before making the fetch request
|
||||||
Promise.all(fileReadPromises)
|
Promise.all(fileReadPromises)
|
||||||
.then(() => {
|
.then(() => {
|
||||||
return fetch("/api/v1/index/update?force=false&client=web", {
|
return fetch("/api/content?client=web", {
|
||||||
method: "POST",
|
method: "PATCH",
|
||||||
body: formData,
|
body: formData,
|
||||||
});
|
});
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -68,8 +68,8 @@ interface ModelPickerProps {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const ModelPicker: React.FC<any> = (props: ModelPickerProps) => {
|
export const ModelPicker: React.FC<any> = (props: ModelPickerProps) => {
|
||||||
const { data: models } = useOptionsRequest('/api/configure/chat/model/options');
|
const { data: models } = useOptionsRequest('/api/model/chat/options');
|
||||||
const { data: selectedModel } = useSelectedModel('/api/configure/chat/model');
|
const { data: selectedModel } = useSelectedModel('/api/model/chat');
|
||||||
const [openLoginDialog, setOpenLoginDialog] = React.useState(false);
|
const [openLoginDialog, setOpenLoginDialog] = React.useState(false);
|
||||||
|
|
||||||
let userData = useAuthenticatedData();
|
let userData = useAuthenticatedData();
|
||||||
@@ -94,7 +94,7 @@ export const ModelPicker: React.FC<any> = (props: ModelPickerProps) => {
|
|||||||
props.setModelUsed(model);
|
props.setModelUsed(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
fetch('/api/configure/chat/model' + '?id=' + String(model.id), { method: 'POST', body: JSON.stringify(model) })
|
fetch('/api/model/chat' + '?id=' + String(model.id), { method: 'POST', body: JSON.stringify(model) })
|
||||||
.then((response) => {
|
.then((response) => {
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error('Failed to select model');
|
throw new Error('Failed to select model');
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ interface FilesMenuProps {
|
|||||||
|
|
||||||
function FilesMenu(props: FilesMenuProps) {
|
function FilesMenu(props: FilesMenuProps) {
|
||||||
// Use SWR to fetch files
|
// Use SWR to fetch files
|
||||||
const { data: files, error } = useSWR<string[]>(props.conversationId ? '/api/configure/content/computer' : null, fetcher);
|
const { data: files, error } = useSWR<string[]>(props.conversationId ? '/api/content/computer' : null, fetcher);
|
||||||
const { data: selectedFiles, error: selectedFilesError } = useSWR(props.conversationId ? `/api/chat/conversation/file-filters/${props.conversationId}` : null, fetcher);
|
const { data: selectedFiles, error: selectedFilesError } = useSWR(props.conversationId ? `/api/chat/conversation/file-filters/${props.conversationId}` : null, fetcher);
|
||||||
const [isOpen, setIsOpen] = useState(false);
|
const [isOpen, setIsOpen] = useState(false);
|
||||||
const [unfilteredFiles, setUnfilteredFiles] = useState<string[]>([]);
|
const [unfilteredFiles, setUnfilteredFiles] = useState<string[]>([]);
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ from khoj.database.adapters import (
|
|||||||
)
|
)
|
||||||
from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
|
from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
|
||||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
from khoj.routers.indexer import configure_content, configure_search
|
from khoj.routers.api_content import configure_content, configure_search
|
||||||
from khoj.routers.twilio import is_twilio_enabled
|
from khoj.routers.twilio import is_twilio_enabled
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from khoj.utils.config import SearchType
|
from khoj.utils.config import SearchType
|
||||||
@@ -308,16 +308,16 @@ def configure_routes(app):
|
|||||||
from khoj.routers.api import api
|
from khoj.routers.api import api
|
||||||
from khoj.routers.api_agents import api_agents
|
from khoj.routers.api_agents import api_agents
|
||||||
from khoj.routers.api_chat import api_chat
|
from khoj.routers.api_chat import api_chat
|
||||||
from khoj.routers.api_config import api_config
|
from khoj.routers.api_content import api_content
|
||||||
from khoj.routers.indexer import indexer
|
from khoj.routers.api_model import api_model
|
||||||
from khoj.routers.notion import notion_router
|
from khoj.routers.notion import notion_router
|
||||||
from khoj.routers.web_client import web_client
|
from khoj.routers.web_client import web_client
|
||||||
|
|
||||||
app.include_router(api, prefix="/api")
|
app.include_router(api, prefix="/api")
|
||||||
app.include_router(api_chat, prefix="/api/chat")
|
app.include_router(api_chat, prefix="/api/chat")
|
||||||
app.include_router(api_agents, prefix="/api/agents")
|
app.include_router(api_agents, prefix="/api/agents")
|
||||||
app.include_router(api_config, prefix="/api/configure")
|
app.include_router(api_model, prefix="/api/model")
|
||||||
app.include_router(indexer, prefix="/api/v1/index")
|
app.include_router(api_content, prefix="/api/content")
|
||||||
app.include_router(notion_router, prefix="/api/notion")
|
app.include_router(notion_router, prefix="/api/notion")
|
||||||
app.include_router(web_client)
|
app.include_router(web_client)
|
||||||
|
|
||||||
@@ -336,7 +336,7 @@ def configure_routes(app):
|
|||||||
if is_twilio_enabled():
|
if is_twilio_enabled():
|
||||||
from khoj.routers.api_phone import api_phone
|
from khoj.routers.api_phone import api_phone
|
||||||
|
|
||||||
app.include_router(api_phone, prefix="/api/configure/phone")
|
app.include_router(api_phone, prefix="/api/phone")
|
||||||
logger.info("📞 Enabled Twilio")
|
logger.info("📞 Enabled Twilio")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -998,8 +998,8 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
// Wait for all files to be read before making the fetch request
|
// Wait for all files to be read before making the fetch request
|
||||||
Promise.all(fileReadPromises)
|
Promise.all(fileReadPromises)
|
||||||
.then(() => {
|
.then(() => {
|
||||||
return fetch("/api/v1/index/update?force=false&client=web", {
|
return fetch("/api/content?client=web", {
|
||||||
method: "POST",
|
method: "PATCH",
|
||||||
body: formData,
|
body: formData,
|
||||||
});
|
});
|
||||||
})
|
})
|
||||||
@@ -1954,7 +1954,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
}
|
}
|
||||||
var allFiles;
|
var allFiles;
|
||||||
function renderAllFiles() {
|
function renderAllFiles() {
|
||||||
fetch('/api/configure/content/computer')
|
fetch('/api/content/computer')
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
|
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
|
||||||
|
|||||||
@@ -32,7 +32,7 @@
|
|||||||
</style>
|
</style>
|
||||||
<script>
|
<script>
|
||||||
function removeFile(path) {
|
function removeFile(path) {
|
||||||
fetch('/api/configure/content/file?filename=' + path, {
|
fetch('/api/content/file?filename=' + path, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -48,7 +48,7 @@
|
|||||||
|
|
||||||
// Get all currently indexed files
|
// Get all currently indexed files
|
||||||
function getAllComputerFilenames() {
|
function getAllComputerFilenames() {
|
||||||
fetch('/api/configure/content/computer')
|
fetch('/api/content/computer')
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
|
var indexedFiles = document.getElementsByClassName("indexed-files")[0];
|
||||||
@@ -122,7 +122,7 @@
|
|||||||
deleteAllComputerFilesButton.textContent = "🗑️ Deleting...";
|
deleteAllComputerFilesButton.textContent = "🗑️ Deleting...";
|
||||||
deleteAllComputerFilesButton.disabled = true;
|
deleteAllComputerFilesButton.disabled = true;
|
||||||
|
|
||||||
fetch('/api/configure/content/computer', {
|
fetch('/api/content/computer', {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
|||||||
@@ -165,7 +165,7 @@
|
|||||||
|
|
||||||
// Save Github config on server
|
// Save Github config on server
|
||||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||||
fetch('/api/configure/content/github', {
|
fetch('/api/content/github', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
|||||||
@@ -45,7 +45,7 @@
|
|||||||
|
|
||||||
// Save Notion config on server
|
// Save Notion config on server
|
||||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||||
fetch('/api/configure/content/notion', {
|
fetch('/api/content/notion', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
|||||||
@@ -209,7 +209,7 @@
|
|||||||
|
|
||||||
function populate_type_dropdown() {
|
function populate_type_dropdown() {
|
||||||
// Populate type dropdown field with enabled content types only
|
// Populate type dropdown field with enabled content types only
|
||||||
fetch("/api/configure/types")
|
fetch("/api/content/types")
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(enabled_types => {
|
.then(enabled_types => {
|
||||||
// Show warning if no content types are enabled, or just one ("all")
|
// Show warning if no content types are enabled, or just one ("all")
|
||||||
|
|||||||
@@ -394,8 +394,8 @@
|
|||||||
|
|
||||||
function saveProfileGivenName() {
|
function saveProfileGivenName() {
|
||||||
const givenName = document.getElementById("profile_given_name").value;
|
const givenName = document.getElementById("profile_given_name").value;
|
||||||
fetch('/api/configure/user/name?name=' + givenName, {
|
fetch('/api/user/name?name=' + givenName, {
|
||||||
method: 'POST',
|
method: 'PATCH',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
}
|
}
|
||||||
@@ -421,7 +421,7 @@
|
|||||||
saveVoiceModelButton.disabled = true;
|
saveVoiceModelButton.disabled = true;
|
||||||
saveVoiceModelButton.textContent = "Saving...";
|
saveVoiceModelButton.textContent = "Saving...";
|
||||||
|
|
||||||
fetch('/api/configure/voice/model?id=' + voiceModel, {
|
fetch('/api/model/voice?id=' + voiceModel, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -455,7 +455,7 @@
|
|||||||
saveModelButton.innerHTML = "";
|
saveModelButton.innerHTML = "";
|
||||||
saveModelButton.textContent = "Saving...";
|
saveModelButton.textContent = "Saving...";
|
||||||
|
|
||||||
fetch('/api/configure/chat/model?id=' + chatModel, {
|
fetch('/api/model/chat?id=' + chatModel, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -494,7 +494,7 @@
|
|||||||
saveSearchModelButton.disabled = true;
|
saveSearchModelButton.disabled = true;
|
||||||
saveSearchModelButton.textContent = "Saving...";
|
saveSearchModelButton.textContent = "Saving...";
|
||||||
|
|
||||||
fetch('/api/configure/search/model?id=' + searchModel, {
|
fetch('/api/model/search?id=' + searchModel, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -526,7 +526,7 @@
|
|||||||
saveModelButton.disabled = true;
|
saveModelButton.disabled = true;
|
||||||
saveModelButton.innerHTML = "Saving...";
|
saveModelButton.innerHTML = "Saving...";
|
||||||
|
|
||||||
fetch('/api/configure/paint/model?id=' + paintModel, {
|
fetch('/api/model/paint?id=' + paintModel, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -553,7 +553,7 @@
|
|||||||
};
|
};
|
||||||
|
|
||||||
function clearContentType(content_source) {
|
function clearContentType(content_source) {
|
||||||
fetch('/api/configure/content/' + content_source, {
|
fetch('/api/content/' + content_source, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -676,7 +676,7 @@
|
|||||||
|
|
||||||
content_sources = ["computer", "github", "notion"];
|
content_sources = ["computer", "github", "notion"];
|
||||||
content_sources.forEach(content_source => {
|
content_sources.forEach(content_source => {
|
||||||
fetch(`/api/configure/content/${content_source}`, {
|
fetch(`/api/content/${content_source}`, {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -807,7 +807,7 @@
|
|||||||
|
|
||||||
function getIndexedDataSize() {
|
function getIndexedDataSize() {
|
||||||
document.getElementById("indexed-data-size").textContent = "Calculating...";
|
document.getElementById("indexed-data-size").textContent = "Calculating...";
|
||||||
fetch('/api/configure/content/size')
|
fetch('/api/content/size')
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
document.getElementById("indexed-data-size").textContent = data.indexed_data_size_in_mb + " MB used";
|
document.getElementById("indexed-data-size").textContent = data.indexed_data_size_in_mb + " MB used";
|
||||||
@@ -815,7 +815,7 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
function removeFile(path) {
|
function removeFile(path) {
|
||||||
fetch('/api/configure/content/file?filename=' + path, {
|
fetch('/api/content/file?filename=' + path, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -890,7 +890,7 @@
|
|||||||
})
|
})
|
||||||
|
|
||||||
phonenumberRemoveButton.addEventListener("click", () => {
|
phonenumberRemoveButton.addEventListener("click", () => {
|
||||||
fetch('/api/configure/phone', {
|
fetch('/api/phone', {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -917,7 +917,7 @@
|
|||||||
}, 5000);
|
}, 5000);
|
||||||
} else {
|
} else {
|
||||||
const mobileNumber = iti.getNumber();
|
const mobileNumber = iti.getNumber();
|
||||||
fetch('/api/configure/phone?phone_number=' + mobileNumber, {
|
fetch('/api/phone?phone_number=' + mobileNumber, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -970,7 +970,7 @@
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
fetch('/api/configure/phone/verify?code=' + otp, {
|
fetch('/api/phone/verify?code=' + otp, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
|||||||
@@ -19,16 +19,11 @@ class DocxToEntries(TextToEntries):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
if not full_corpus:
|
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
files_to_process = set(files) - deletion_file_names
|
||||||
files_to_process = set(files) - deletion_file_names
|
files = {file: files[file] for file in files_to_process}
|
||||||
files = {file: files[file] for file in files_to_process}
|
|
||||||
else:
|
|
||||||
deletion_file_names = None
|
|
||||||
|
|
||||||
# Extract Entries from specified Docx files
|
# Extract Entries from specified Docx files
|
||||||
with timer("Extract entries from specified DOCX files", logger):
|
with timer("Extract entries from specified DOCX files", logger):
|
||||||
|
|||||||
@@ -48,9 +48,7 @@ class GithubToEntries(TextToEntries):
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
if self.config.pat_token is None or self.config.pat_token == "":
|
if self.config.pat_token is None or self.config.pat_token == "":
|
||||||
logger.error(f"Github PAT token is not set. Skipping github content")
|
logger.error(f"Github PAT token is not set. Skipping github content")
|
||||||
raise ValueError("Github PAT token is not set. Skipping github content")
|
raise ValueError("Github PAT token is not set. Skipping github content")
|
||||||
|
|||||||
@@ -20,16 +20,11 @@ class ImageToEntries(TextToEntries):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
if not full_corpus:
|
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
files_to_process = set(files) - deletion_file_names
|
||||||
files_to_process = set(files) - deletion_file_names
|
files = {file: files[file] for file in files_to_process}
|
||||||
files = {file: files[file] for file in files_to_process}
|
|
||||||
else:
|
|
||||||
deletion_file_names = None
|
|
||||||
|
|
||||||
# Extract Entries from specified image files
|
# Extract Entries from specified image files
|
||||||
with timer("Extract entries from specified Image files", logger):
|
with timer("Extract entries from specified Image files", logger):
|
||||||
|
|||||||
@@ -19,16 +19,11 @@ class MarkdownToEntries(TextToEntries):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
if not full_corpus:
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
files_to_process = set(files) - deletion_file_names
|
||||||
files_to_process = set(files) - deletion_file_names
|
files = {file: files[file] for file in files_to_process}
|
||||||
files = {file: files[file] for file in files_to_process}
|
|
||||||
else:
|
|
||||||
deletion_file_names = None
|
|
||||||
|
|
||||||
max_tokens = 256
|
max_tokens = 256
|
||||||
# Extract Entries from specified Markdown files
|
# Extract Entries from specified Markdown files
|
||||||
|
|||||||
@@ -78,9 +78,7 @@ class NotionToEntries(TextToEntries):
|
|||||||
|
|
||||||
self.body_params = {"page_size": 100}
|
self.body_params = {"page_size": 100}
|
||||||
|
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
current_entries = []
|
current_entries = []
|
||||||
|
|
||||||
# Get all pages
|
# Get all pages
|
||||||
|
|||||||
@@ -20,15 +20,10 @@ class OrgToEntries(TextToEntries):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
) -> Tuple[int, int]:
|
files_to_process = set(files) - deletion_file_names
|
||||||
if not full_corpus:
|
files = {file: files[file] for file in files_to_process}
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
|
||||||
files_to_process = set(files) - deletion_file_names
|
|
||||||
files = {file: files[file] for file in files_to_process}
|
|
||||||
else:
|
|
||||||
deletion_file_names = None
|
|
||||||
|
|
||||||
# Extract Entries from specified Org files
|
# Extract Entries from specified Org files
|
||||||
max_tokens = 256
|
max_tokens = 256
|
||||||
|
|||||||
@@ -22,16 +22,11 @@ class PdfToEntries(TextToEntries):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
if not full_corpus:
|
deletion_file_names = set([file for file in files if files[file] == b""])
|
||||||
deletion_file_names = set([file for file in files if files[file] == b""])
|
files_to_process = set(files) - deletion_file_names
|
||||||
files_to_process = set(files) - deletion_file_names
|
files = {file: files[file] for file in files_to_process}
|
||||||
files = {file: files[file] for file in files_to_process}
|
|
||||||
else:
|
|
||||||
deletion_file_names = None
|
|
||||||
|
|
||||||
# Extract Entries from specified Pdf files
|
# Extract Entries from specified Pdf files
|
||||||
with timer("Extract entries from specified PDF files", logger):
|
with timer("Extract entries from specified PDF files", logger):
|
||||||
|
|||||||
@@ -20,15 +20,10 @@ class PlaintextToEntries(TextToEntries):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
) -> Tuple[int, int]:
|
files_to_process = set(files) - deletion_file_names
|
||||||
if not full_corpus:
|
files = {file: files[file] for file in files_to_process}
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
|
||||||
files_to_process = set(files) - deletion_file_names
|
|
||||||
files = {file: files[file] for file in files_to_process}
|
|
||||||
else:
|
|
||||||
deletion_file_names = None
|
|
||||||
|
|
||||||
# Extract Entries from specified plaintext files
|
# Extract Entries from specified plaintext files
|
||||||
with timer("Extract entries from specified Plaintext files", logger):
|
with timer("Extract entries from specified Plaintext files", logger):
|
||||||
|
|||||||
@@ -31,9 +31,7 @@ class TextToEntries(ABC):
|
|||||||
self.date_filter = DateFilter()
|
self.date_filter = DateFilter()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process(
|
def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
|
||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
...
|
...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from fastapi.responses import Response
|
|||||||
from starlette.authentication import has_required_scope, requires
|
from starlette.authentication import has_required_scope, requires
|
||||||
|
|
||||||
from khoj.configure import initialize_content
|
from khoj.configure import initialize_content
|
||||||
|
from khoj.database import adapters
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
AutomationAdapters,
|
AutomationAdapters,
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
@@ -39,6 +40,7 @@ from khoj.routers.helpers import (
|
|||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
acreate_title_from_query,
|
acreate_title_from_query,
|
||||||
|
get_user_config,
|
||||||
schedule_automation,
|
schedule_automation,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
)
|
)
|
||||||
@@ -276,6 +278,49 @@ async def transcribe(
|
|||||||
return Response(content=content, media_type="application/json", status_code=200)
|
return Response(content=content, media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@api.get("/settings", response_class=Response)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
def get_settings(request: Request, detailed: Optional[bool] = False) -> Response:
|
||||||
|
user = request.user.object
|
||||||
|
user_config = get_user_config(user, request, is_detailed=detailed)
|
||||||
|
del user_config["request"]
|
||||||
|
|
||||||
|
# Return config data as a JSON response
|
||||||
|
return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@api.patch("/user/name", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
def set_user_name(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
split_name = name.split(" ")
|
||||||
|
|
||||||
|
if len(split_name) > 2:
|
||||||
|
raise HTTPException(status_code=400, detail="Name must be in the format: Firstname Lastname")
|
||||||
|
|
||||||
|
if len(split_name) == 1:
|
||||||
|
first_name = split_name[0]
|
||||||
|
last_name = ""
|
||||||
|
else:
|
||||||
|
first_name, last_name = split_name[0], split_name[-1]
|
||||||
|
|
||||||
|
adapters.set_user_name(user, first_name, last_name)
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_user_name",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
async def extract_references_and_questions(
|
async def extract_references_and_questions(
|
||||||
request: Request,
|
request: Request,
|
||||||
meta_log: dict,
|
meta_log: dict,
|
||||||
|
|||||||
@@ -1,434 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
from typing import Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
|
||||||
from fastapi.requests import Request
|
|
||||||
from fastapi.responses import Response
|
|
||||||
from starlette.authentication import has_required_scope, requires
|
|
||||||
|
|
||||||
from khoj.database import adapters
|
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
|
||||||
from khoj.database.models import Entry as DbEntry
|
|
||||||
from khoj.database.models import (
|
|
||||||
GithubConfig,
|
|
||||||
KhojUser,
|
|
||||||
LocalMarkdownConfig,
|
|
||||||
LocalOrgConfig,
|
|
||||||
LocalPdfConfig,
|
|
||||||
LocalPlaintextConfig,
|
|
||||||
NotionConfig,
|
|
||||||
Subscription,
|
|
||||||
)
|
|
||||||
from khoj.routers.helpers import CommonQueryParams, update_telemetry_state
|
|
||||||
from khoj.utils import constants, state
|
|
||||||
from khoj.utils.rawconfig import (
|
|
||||||
FullConfig,
|
|
||||||
GithubContentConfig,
|
|
||||||
NotionContentConfig,
|
|
||||||
SearchConfig,
|
|
||||||
)
|
|
||||||
from khoj.utils.state import SearchType
|
|
||||||
|
|
||||||
api_config = APIRouter()
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def map_config_to_object(content_source: str):
|
|
||||||
if content_source == DbEntry.EntrySource.GITHUB:
|
|
||||||
return GithubConfig
|
|
||||||
if content_source == DbEntry.EntrySource.NOTION:
|
|
||||||
return NotionConfig
|
|
||||||
if content_source == DbEntry.EntrySource.COMPUTER:
|
|
||||||
return "Computer"
|
|
||||||
|
|
||||||
|
|
||||||
async def map_config_to_db(config: FullConfig, user: KhojUser):
|
|
||||||
if config.content_type:
|
|
||||||
if config.content_type.org:
|
|
||||||
await LocalOrgConfig.objects.filter(user=user).adelete()
|
|
||||||
await LocalOrgConfig.objects.acreate(
|
|
||||||
input_files=config.content_type.org.input_files,
|
|
||||||
input_filter=config.content_type.org.input_filter,
|
|
||||||
index_heading_entries=config.content_type.org.index_heading_entries,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
if config.content_type.markdown:
|
|
||||||
await LocalMarkdownConfig.objects.filter(user=user).adelete()
|
|
||||||
await LocalMarkdownConfig.objects.acreate(
|
|
||||||
input_files=config.content_type.markdown.input_files,
|
|
||||||
input_filter=config.content_type.markdown.input_filter,
|
|
||||||
index_heading_entries=config.content_type.markdown.index_heading_entries,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
if config.content_type.pdf:
|
|
||||||
await LocalPdfConfig.objects.filter(user=user).adelete()
|
|
||||||
await LocalPdfConfig.objects.acreate(
|
|
||||||
input_files=config.content_type.pdf.input_files,
|
|
||||||
input_filter=config.content_type.pdf.input_filter,
|
|
||||||
index_heading_entries=config.content_type.pdf.index_heading_entries,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
if config.content_type.plaintext:
|
|
||||||
await LocalPlaintextConfig.objects.filter(user=user).adelete()
|
|
||||||
await LocalPlaintextConfig.objects.acreate(
|
|
||||||
input_files=config.content_type.plaintext.input_files,
|
|
||||||
input_filter=config.content_type.plaintext.input_filter,
|
|
||||||
index_heading_entries=config.content_type.plaintext.index_heading_entries,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
if config.content_type.github:
|
|
||||||
await adapters.set_user_github_config(
|
|
||||||
user=user,
|
|
||||||
pat_token=config.content_type.github.pat_token,
|
|
||||||
repos=config.content_type.github.repos,
|
|
||||||
)
|
|
||||||
if config.content_type.notion:
|
|
||||||
await adapters.set_notion_config(
|
|
||||||
user=user,
|
|
||||||
token=config.content_type.notion.token,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _initialize_config():
|
|
||||||
if state.config is None:
|
|
||||||
state.config = FullConfig()
|
|
||||||
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.post("/content/github", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def set_content_github(
|
|
||||||
request: Request,
|
|
||||||
updated_config: Union[GithubContentConfig, None],
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
_initialize_config()
|
|
||||||
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
try:
|
|
||||||
await adapters.set_user_github_config(
|
|
||||||
user=user,
|
|
||||||
pat_token=updated_config.pat_token,
|
|
||||||
repos=updated_config.repos,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e, exc_info=True)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_content_config",
|
|
||||||
client=client,
|
|
||||||
metadata={"content_type": "github"},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.post("/content/notion", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def set_content_notion(
|
|
||||||
request: Request,
|
|
||||||
updated_config: Union[NotionContentConfig, None],
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
_initialize_config()
|
|
||||||
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
try:
|
|
||||||
await adapters.set_notion_config(
|
|
||||||
user=user,
|
|
||||||
token=updated_config.token,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e, exc_info=True)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_content_config",
|
|
||||||
client=client,
|
|
||||||
metadata={"content_type": "notion"},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.delete("/content/{content_source}", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def delete_content_source(
|
|
||||||
request: Request,
|
|
||||||
content_source: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="delete_content_config",
|
|
||||||
client=client,
|
|
||||||
metadata={"content_source": content_source},
|
|
||||||
)
|
|
||||||
|
|
||||||
content_object = map_config_to_object(content_source)
|
|
||||||
if content_object is None:
|
|
||||||
raise ValueError(f"Invalid content source: {content_source}")
|
|
||||||
elif content_object != "Computer":
|
|
||||||
await content_object.objects.filter(user=user).adelete()
|
|
||||||
await sync_to_async(EntryAdapters.delete_all_entries)(user, file_source=content_source)
|
|
||||||
|
|
||||||
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.delete("/content/file", status_code=201)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def delete_content_file(
|
|
||||||
request: Request,
|
|
||||||
filename: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="delete_file",
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
await EntryAdapters.adelete_entry_by_file(user, filename)
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.get("/content/{content_source}", response_model=List[str])
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def get_content_source(
|
|
||||||
request: Request,
|
|
||||||
content_source: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="get_all_filenames",
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await sync_to_async(list)(EntryAdapters.get_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.get("/chat/model/options", response_model=Dict[str, Union[str, int]])
|
|
||||||
def get_chat_model_options(
|
|
||||||
request: Request,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
conversation_options = ConversationAdapters.get_conversation_processor_options().all()
|
|
||||||
|
|
||||||
all_conversation_options = list()
|
|
||||||
for conversation_option in conversation_options:
|
|
||||||
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
|
|
||||||
|
|
||||||
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.get("/chat/model")
|
|
||||||
@requires(["authenticated"])
|
|
||||||
def get_user_chat_model(
|
|
||||||
request: Request,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
chat_model = ConversationAdapters.get_conversation_config(user)
|
|
||||||
|
|
||||||
if chat_model is None:
|
|
||||||
chat_model = ConversationAdapters.get_default_conversation_config()
|
|
||||||
|
|
||||||
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.post("/chat/model", status_code=200)
|
|
||||||
@requires(["authenticated", "premium"])
|
|
||||||
async def update_chat_model(
|
|
||||||
request: Request,
|
|
||||||
id: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_conversation_chat_model",
|
|
||||||
client=client,
|
|
||||||
metadata={"processor_conversation_type": "conversation"},
|
|
||||||
)
|
|
||||||
|
|
||||||
if new_config is None:
|
|
||||||
return {"status": "error", "message": "Model not found"}
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.post("/voice/model", status_code=200)
|
|
||||||
@requires(["authenticated", "premium"])
|
|
||||||
async def update_voice_model(
|
|
||||||
request: Request,
|
|
||||||
id: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
new_config = await ConversationAdapters.aset_user_voice_model(user, id)
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_voice_model",
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
if new_config is None:
|
|
||||||
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Model not found"}))
|
|
||||||
|
|
||||||
return Response(status_code=202, content=json.dumps({"status": "ok"}))
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.post("/search/model", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def update_search_model(
|
|
||||||
request: Request,
|
|
||||||
id: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
prev_config = await adapters.aget_user_search_model(user)
|
|
||||||
new_config = await adapters.aset_user_search_model(user, int(id))
|
|
||||||
|
|
||||||
if prev_config and int(id) != prev_config.id and new_config:
|
|
||||||
await EntryAdapters.adelete_all_entries(user)
|
|
||||||
|
|
||||||
if not prev_config:
|
|
||||||
# If the use was just using the default config, delete all the entries and set the new config.
|
|
||||||
await EntryAdapters.adelete_all_entries(user)
|
|
||||||
|
|
||||||
if new_config is None:
|
|
||||||
return {"status": "error", "message": "Model not found"}
|
|
||||||
else:
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_search_model",
|
|
||||||
client=client,
|
|
||||||
metadata={"search_model": new_config.setting.name},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.post("/paint/model", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def update_paint_model(
|
|
||||||
request: Request,
|
|
||||||
id: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
subscribed = has_required_scope(request, ["premium"])
|
|
||||||
|
|
||||||
if not subscribed:
|
|
||||||
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
|
|
||||||
|
|
||||||
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_paint_model",
|
|
||||||
client=client,
|
|
||||||
metadata={"paint_model": new_config.setting.model_name},
|
|
||||||
)
|
|
||||||
|
|
||||||
if new_config is None:
|
|
||||||
return {"status": "error", "message": "Model not found"}
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.get("/content/size", response_model=Dict[str, int])
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def get_content_size(request: Request, common: CommonQueryParams):
|
|
||||||
user = request.user.object
|
|
||||||
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
|
|
||||||
return Response(
|
|
||||||
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
|
|
||||||
media_type="application/json",
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.post("/user/name", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
def set_user_name(
|
|
||||||
request: Request,
|
|
||||||
name: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
split_name = name.split(" ")
|
|
||||||
|
|
||||||
if len(split_name) > 2:
|
|
||||||
raise HTTPException(status_code=400, detail="Name must be in the format: Firstname Lastname")
|
|
||||||
|
|
||||||
if len(split_name) == 1:
|
|
||||||
first_name = split_name[0]
|
|
||||||
last_name = ""
|
|
||||||
else:
|
|
||||||
first_name, last_name = split_name[0], split_name[-1]
|
|
||||||
|
|
||||||
adapters.set_user_name(user, first_name, last_name)
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_user_name",
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api_config.get("/types", response_model=List[str])
|
|
||||||
@requires(["authenticated"])
|
|
||||||
def get_config_types(
|
|
||||||
request: Request,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
enabled_file_types = EntryAdapters.get_unique_file_types(user)
|
|
||||||
configured_content_types = list(enabled_file_types)
|
|
||||||
|
|
||||||
if state.config and state.config.content_type:
|
|
||||||
for ctype in state.config.content_type.model_dump(exclude_none=True):
|
|
||||||
configured_content_types.append(ctype)
|
|
||||||
|
|
||||||
return [
|
|
||||||
search_type.value
|
|
||||||
for search_type in SearchType
|
|
||||||
if (search_type.value in configured_content_types) or search_type == SearchType.All
|
|
||||||
]
|
|
||||||
508
src/khoj/routers/api_content.py
Normal file
508
src/khoj/routers/api_content.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Depends,
|
||||||
|
Header,
|
||||||
|
HTTPException,
|
||||||
|
Request,
|
||||||
|
Response,
|
||||||
|
UploadFile,
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from starlette.authentication import requires
|
||||||
|
|
||||||
|
from khoj.database import adapters
|
||||||
|
from khoj.database.adapters import (
|
||||||
|
EntryAdapters,
|
||||||
|
get_user_github_config,
|
||||||
|
get_user_notion_config,
|
||||||
|
)
|
||||||
|
from khoj.database.models import Entry as DbEntry
|
||||||
|
from khoj.database.models import (
|
||||||
|
GithubConfig,
|
||||||
|
GithubRepoConfig,
|
||||||
|
KhojUser,
|
||||||
|
LocalMarkdownConfig,
|
||||||
|
LocalOrgConfig,
|
||||||
|
LocalPdfConfig,
|
||||||
|
LocalPlaintextConfig,
|
||||||
|
NotionConfig,
|
||||||
|
)
|
||||||
|
from khoj.routers.helpers import (
|
||||||
|
ApiIndexedDataLimiter,
|
||||||
|
CommonQueryParams,
|
||||||
|
configure_content,
|
||||||
|
get_user_config,
|
||||||
|
update_telemetry_state,
|
||||||
|
)
|
||||||
|
from khoj.utils import constants, state
|
||||||
|
from khoj.utils.config import SearchModels
|
||||||
|
from khoj.utils.helpers import get_file_type
|
||||||
|
from khoj.utils.rawconfig import (
|
||||||
|
ContentConfig,
|
||||||
|
FullConfig,
|
||||||
|
GithubContentConfig,
|
||||||
|
NotionContentConfig,
|
||||||
|
SearchConfig,
|
||||||
|
)
|
||||||
|
from khoj.utils.state import SearchType
|
||||||
|
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
api_content = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class File(BaseModel):
|
||||||
|
path: str
|
||||||
|
content: Union[str, bytes]
|
||||||
|
|
||||||
|
|
||||||
|
class IndexBatchRequest(BaseModel):
|
||||||
|
files: list[File]
|
||||||
|
|
||||||
|
|
||||||
|
class IndexerInput(BaseModel):
|
||||||
|
org: Optional[dict[str, str]] = None
|
||||||
|
markdown: Optional[dict[str, str]] = None
|
||||||
|
pdf: Optional[dict[str, bytes]] = None
|
||||||
|
plaintext: Optional[dict[str, str]] = None
|
||||||
|
image: Optional[dict[str, bytes]] = None
|
||||||
|
docx: Optional[dict[str, bytes]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@api_content.put("")
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def put_content(
|
||||||
|
request: Request,
|
||||||
|
files: list[UploadFile],
|
||||||
|
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
user_agent: Optional[str] = Header(None),
|
||||||
|
referer: Optional[str] = Header(None),
|
||||||
|
host: Optional[str] = Header(None),
|
||||||
|
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
|
||||||
|
ApiIndexedDataLimiter(
|
||||||
|
incoming_entries_size_limit=10,
|
||||||
|
subscribed_incoming_entries_size_limit=25,
|
||||||
|
total_entries_size_limit=10,
|
||||||
|
subscribed_total_entries_size_limit=100,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
):
|
||||||
|
return await indexer(request, files, t, True, client, user_agent, referer, host)
|
||||||
|
|
||||||
|
|
||||||
|
@api_content.patch("")
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def patch_content(
|
||||||
|
request: Request,
|
||||||
|
files: list[UploadFile],
|
||||||
|
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
user_agent: Optional[str] = Header(None),
|
||||||
|
referer: Optional[str] = Header(None),
|
||||||
|
host: Optional[str] = Header(None),
|
||||||
|
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
|
||||||
|
ApiIndexedDataLimiter(
|
||||||
|
incoming_entries_size_limit=10,
|
||||||
|
subscribed_incoming_entries_size_limit=25,
|
||||||
|
total_entries_size_limit=10,
|
||||||
|
subscribed_total_entries_size_limit=100,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
):
|
||||||
|
return await indexer(request, files, t, False, client, user_agent, referer, host)
|
||||||
|
|
||||||
|
|
||||||
|
@api_content.get("/github", response_class=Response)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
def get_content_github(request: Request) -> Response:
|
||||||
|
user = request.user.object
|
||||||
|
user_config = get_user_config(user, request)
|
||||||
|
del user_config["request"]
|
||||||
|
|
||||||
|
current_github_config = get_user_github_config(user)
|
||||||
|
|
||||||
|
if current_github_config:
|
||||||
|
raw_repos = current_github_config.githubrepoconfig.all()
|
||||||
|
repos = []
|
||||||
|
for repo in raw_repos:
|
||||||
|
repos.append(
|
||||||
|
GithubRepoConfig(
|
||||||
|
name=repo.name,
|
||||||
|
owner=repo.owner,
|
||||||
|
branch=repo.branch,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_config = GithubContentConfig(
|
||||||
|
pat_token=current_github_config.pat_token,
|
||||||
|
repos=repos,
|
||||||
|
)
|
||||||
|
current_config = json.loads(current_config.json())
|
||||||
|
else:
|
||||||
|
current_config = {} # type: ignore
|
||||||
|
|
||||||
|
user_config["current_config"] = current_config
|
||||||
|
|
||||||
|
# Return config data as a JSON response
|
||||||
|
return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@api_content.get("/notion", response_class=Response)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
def get_content_notion(request: Request) -> Response:
|
||||||
|
user = request.user.object
|
||||||
|
user_config = get_user_config(user, request)
|
||||||
|
del user_config["request"]
|
||||||
|
|
||||||
|
current_notion_config = get_user_notion_config(user)
|
||||||
|
token = current_notion_config.token if current_notion_config else ""
|
||||||
|
current_config = NotionContentConfig(token=token)
|
||||||
|
current_config = json.loads(current_config.model_dump_json())
|
||||||
|
|
||||||
|
user_config["current_config"] = current_config
|
||||||
|
|
||||||
|
# Return config data as a JSON response
|
||||||
|
return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@api_content.post("/github", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def set_content_github(
|
||||||
|
request: Request,
|
||||||
|
updated_config: Union[GithubContentConfig, None],
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
_initialize_config()
|
||||||
|
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
try:
|
||||||
|
await adapters.set_user_github_config(
|
||||||
|
user=user,
|
||||||
|
pat_token=updated_config.pat_token,
|
||||||
|
repos=updated_config.repos,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_content_config",
|
||||||
|
client=client,
|
||||||
|
metadata={"content_type": "github"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_content.post("/notion", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def set_content_notion(
|
||||||
|
request: Request,
|
||||||
|
updated_config: Union[NotionContentConfig, None],
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
_initialize_config()
|
||||||
|
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
try:
|
||||||
|
await adapters.set_notion_config(
|
||||||
|
user=user,
|
||||||
|
token=updated_config.token,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to set Notion config")
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_content_config",
|
||||||
|
client=client,
|
||||||
|
metadata={"content_type": "notion"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_content.delete("/{content_source}", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def delete_content_source(
|
||||||
|
request: Request,
|
||||||
|
content_source: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="delete_content_config",
|
||||||
|
client=client,
|
||||||
|
metadata={"content_source": content_source},
|
||||||
|
)
|
||||||
|
|
||||||
|
content_object = map_config_to_object(content_source)
|
||||||
|
if content_object is None:
|
||||||
|
raise ValueError(f"Invalid content source: {content_source}")
|
||||||
|
elif content_object != "Computer":
|
||||||
|
await content_object.objects.filter(user=user).adelete()
|
||||||
|
await sync_to_async(EntryAdapters.delete_all_entries)(user, file_source=content_source)
|
||||||
|
|
||||||
|
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_content.delete("/file", status_code=201)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def delete_content_file(
|
||||||
|
request: Request,
|
||||||
|
filename: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="delete_file",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
await EntryAdapters.adelete_entry_by_file(user, filename)
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_content.get("/size", response_model=Dict[str, int])
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def get_content_size(request: Request, common: CommonQueryParams, client: Optional[str] = None):
|
||||||
|
user = request.user.object
|
||||||
|
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
|
||||||
|
return Response(
|
||||||
|
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
|
||||||
|
media_type="application/json",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_content.get("/types", response_model=List[str])
|
||||||
|
@requires(["authenticated"])
|
||||||
|
def get_content_types(request: Request, client: Optional[str] = None):
|
||||||
|
user = request.user.object
|
||||||
|
all_content_types = {s.value for s in SearchType}
|
||||||
|
configured_content_types = set(EntryAdapters.get_unique_file_types(user))
|
||||||
|
configured_content_types |= {"all"}
|
||||||
|
|
||||||
|
if state.config and state.config.content_type:
|
||||||
|
for ctype in state.config.content_type.model_dump(exclude_none=True):
|
||||||
|
configured_content_types.add(ctype)
|
||||||
|
|
||||||
|
return list(configured_content_types & all_content_types)
|
||||||
|
|
||||||
|
|
||||||
|
@api_content.get("/{content_source}", response_model=List[str])
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def get_content_source(
|
||||||
|
request: Request,
|
||||||
|
content_source: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="get_all_filenames",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await sync_to_async(list)(EntryAdapters.get_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
|
||||||
|
async def indexer(
|
||||||
|
request: Request,
|
||||||
|
files: list[UploadFile],
|
||||||
|
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
|
||||||
|
regenerate: bool = False,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
user_agent: Optional[str] = Header(None),
|
||||||
|
referer: Optional[str] = Header(None),
|
||||||
|
host: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
method = "regenerate" if regenerate else "sync"
|
||||||
|
index_files: Dict[str, Dict[str, str]] = {
|
||||||
|
"org": {},
|
||||||
|
"markdown": {},
|
||||||
|
"pdf": {},
|
||||||
|
"plaintext": {},
|
||||||
|
"image": {},
|
||||||
|
"docx": {},
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
logger.info(f"📬 Updating content index via API call by {client} client")
|
||||||
|
for file in files:
|
||||||
|
file_content = file.file.read()
|
||||||
|
file_type, encoding = get_file_type(file.content_type, file_content)
|
||||||
|
if file_type in index_files:
|
||||||
|
index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content
|
||||||
|
else:
|
||||||
|
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
|
||||||
|
|
||||||
|
indexer_input = IndexerInput(
|
||||||
|
org=index_files["org"],
|
||||||
|
markdown=index_files["markdown"],
|
||||||
|
pdf=index_files["pdf"],
|
||||||
|
plaintext=index_files["plaintext"],
|
||||||
|
image=index_files["image"],
|
||||||
|
docx=index_files["docx"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if state.config == None:
|
||||||
|
logger.info("📬 Initializing content index on first run.")
|
||||||
|
default_full_config = FullConfig(
|
||||||
|
content_type=None,
|
||||||
|
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
|
||||||
|
processor=None,
|
||||||
|
)
|
||||||
|
state.config = default_full_config
|
||||||
|
default_content_config = ContentConfig(
|
||||||
|
org=None,
|
||||||
|
markdown=None,
|
||||||
|
pdf=None,
|
||||||
|
docx=None,
|
||||||
|
image=None,
|
||||||
|
github=None,
|
||||||
|
notion=None,
|
||||||
|
plaintext=None,
|
||||||
|
)
|
||||||
|
state.config.content_type = default_content_config
|
||||||
|
save_config_to_file_updated_state()
|
||||||
|
configure_search(state.search_models, state.config.search_type)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
success = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
configure_content,
|
||||||
|
indexer_input.model_dump(),
|
||||||
|
regenerate,
|
||||||
|
t,
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index")
|
||||||
|
logger.info(f"Finished {method} {t} data sent by {client} client into content index")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to {method} {t} data sent by {client} client into content index: {e}", exc_info=True)
|
||||||
|
logger.error(
|
||||||
|
f"🚨 Failed to {method} {t} data sent by {client} client into content index: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return Response(content="Failed", status_code=500)
|
||||||
|
|
||||||
|
indexing_metadata = {
|
||||||
|
"num_org": len(index_files["org"]),
|
||||||
|
"num_markdown": len(index_files["markdown"]),
|
||||||
|
"num_pdf": len(index_files["pdf"]),
|
||||||
|
"num_plaintext": len(index_files["plaintext"]),
|
||||||
|
"num_image": len(index_files["image"]),
|
||||||
|
"num_docx": len(index_files["docx"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="index/update",
|
||||||
|
client=client,
|
||||||
|
user_agent=user_agent,
|
||||||
|
referer=referer,
|
||||||
|
host=host,
|
||||||
|
metadata=indexing_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"📪 Content index updated via API call by {client} client")
|
||||||
|
|
||||||
|
indexed_filenames = ",".join(file for ctype in index_files for file in index_files[ctype]) or ""
|
||||||
|
return Response(content=indexed_filenames, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
|
||||||
|
# Run Validation Checks
|
||||||
|
if search_models is None:
|
||||||
|
search_models = SearchModels()
|
||||||
|
|
||||||
|
return search_models
|
||||||
|
|
||||||
|
|
||||||
|
def map_config_to_object(content_source: str):
|
||||||
|
if content_source == DbEntry.EntrySource.GITHUB:
|
||||||
|
return GithubConfig
|
||||||
|
if content_source == DbEntry.EntrySource.NOTION:
|
||||||
|
return NotionConfig
|
||||||
|
if content_source == DbEntry.EntrySource.COMPUTER:
|
||||||
|
return "Computer"
|
||||||
|
|
||||||
|
|
||||||
|
async def map_config_to_db(config: FullConfig, user: KhojUser):
|
||||||
|
if config.content_type:
|
||||||
|
if config.content_type.org:
|
||||||
|
await LocalOrgConfig.objects.filter(user=user).adelete()
|
||||||
|
await LocalOrgConfig.objects.acreate(
|
||||||
|
input_files=config.content_type.org.input_files,
|
||||||
|
input_filter=config.content_type.org.input_filter,
|
||||||
|
index_heading_entries=config.content_type.org.index_heading_entries,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if config.content_type.markdown:
|
||||||
|
await LocalMarkdownConfig.objects.filter(user=user).adelete()
|
||||||
|
await LocalMarkdownConfig.objects.acreate(
|
||||||
|
input_files=config.content_type.markdown.input_files,
|
||||||
|
input_filter=config.content_type.markdown.input_filter,
|
||||||
|
index_heading_entries=config.content_type.markdown.index_heading_entries,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if config.content_type.pdf:
|
||||||
|
await LocalPdfConfig.objects.filter(user=user).adelete()
|
||||||
|
await LocalPdfConfig.objects.acreate(
|
||||||
|
input_files=config.content_type.pdf.input_files,
|
||||||
|
input_filter=config.content_type.pdf.input_filter,
|
||||||
|
index_heading_entries=config.content_type.pdf.index_heading_entries,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if config.content_type.plaintext:
|
||||||
|
await LocalPlaintextConfig.objects.filter(user=user).adelete()
|
||||||
|
await LocalPlaintextConfig.objects.acreate(
|
||||||
|
input_files=config.content_type.plaintext.input_files,
|
||||||
|
input_filter=config.content_type.plaintext.input_filter,
|
||||||
|
index_heading_entries=config.content_type.plaintext.index_heading_entries,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if config.content_type.github:
|
||||||
|
await adapters.set_user_github_config(
|
||||||
|
user=user,
|
||||||
|
pat_token=config.content_type.github.pat_token,
|
||||||
|
repos=config.content_type.github.repos,
|
||||||
|
)
|
||||||
|
if config.content_type.notion:
|
||||||
|
await adapters.set_notion_config(
|
||||||
|
user=user,
|
||||||
|
token=config.content_type.notion.token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_config():
|
||||||
|
if state.config is None:
|
||||||
|
state.config = FullConfig()
|
||||||
|
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
|
||||||
156
src/khoj/routers/api_model.py
Normal file
156
src/khoj/routers/api_model.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.requests import Request
|
||||||
|
from fastapi.responses import Response
|
||||||
|
from starlette.authentication import has_required_scope, requires
|
||||||
|
|
||||||
|
from khoj.database import adapters
|
||||||
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||||
|
from khoj.routers.helpers import update_telemetry_state
|
||||||
|
|
||||||
|
api_model = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@api_model.get("/chat/options", response_model=Dict[str, Union[str, int]])
|
||||||
|
def get_chat_model_options(
|
||||||
|
request: Request,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
conversation_options = ConversationAdapters.get_conversation_processor_options().all()
|
||||||
|
|
||||||
|
all_conversation_options = list()
|
||||||
|
for conversation_option in conversation_options:
|
||||||
|
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
|
||||||
|
|
||||||
|
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@api_model.get("/chat")
|
||||||
|
@requires(["authenticated"])
|
||||||
|
def get_user_chat_model(
|
||||||
|
request: Request,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
chat_model = ConversationAdapters.get_conversation_config(user)
|
||||||
|
|
||||||
|
if chat_model is None:
|
||||||
|
chat_model = ConversationAdapters.get_default_conversation_config()
|
||||||
|
|
||||||
|
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
|
||||||
|
|
||||||
|
|
||||||
|
@api_model.post("/chat", status_code=200)
|
||||||
|
@requires(["authenticated", "premium"])
|
||||||
|
async def update_chat_model(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_conversation_chat_model",
|
||||||
|
client=client,
|
||||||
|
metadata={"processor_conversation_type": "conversation"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_config is None:
|
||||||
|
return {"status": "error", "message": "Model not found"}
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_model.post("/voice", status_code=200)
|
||||||
|
@requires(["authenticated", "premium"])
|
||||||
|
async def update_voice_model(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
new_config = await ConversationAdapters.aset_user_voice_model(user, id)
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_voice_model",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_config is None:
|
||||||
|
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Model not found"}))
|
||||||
|
|
||||||
|
return Response(status_code=202, content=json.dumps({"status": "ok"}))
|
||||||
|
|
||||||
|
|
||||||
|
@api_model.post("/search", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def update_search_model(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
prev_config = await adapters.aget_user_search_model(user)
|
||||||
|
new_config = await adapters.aset_user_search_model(user, int(id))
|
||||||
|
|
||||||
|
if prev_config and int(id) != prev_config.id and new_config:
|
||||||
|
await EntryAdapters.adelete_all_entries(user)
|
||||||
|
|
||||||
|
if not prev_config:
|
||||||
|
# If the use was just using the default config, delete all the entries and set the new config.
|
||||||
|
await EntryAdapters.adelete_all_entries(user)
|
||||||
|
|
||||||
|
if new_config is None:
|
||||||
|
return {"status": "error", "message": "Model not found"}
|
||||||
|
else:
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_search_model",
|
||||||
|
client=client,
|
||||||
|
metadata={"search_model": new_config.setting.name},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_model.post("/paint", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def update_paint_model(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
subscribed = has_required_scope(request, ["premium"])
|
||||||
|
|
||||||
|
if not subscribed:
|
||||||
|
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
|
||||||
|
|
||||||
|
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_paint_model",
|
||||||
|
client=client,
|
||||||
|
metadata={"paint_model": new_config.setting.model_name},
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_config is None:
|
||||||
|
return {"status": "error", "message": "Model not found"}
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
@@ -1313,7 +1313,6 @@ def configure_content(
|
|||||||
files: Optional[dict[str, dict[str, str]]],
|
files: Optional[dict[str, dict[str, str]]],
|
||||||
regenerate: bool = False,
|
regenerate: bool = False,
|
||||||
t: Optional[state.SearchType] = state.SearchType.All,
|
t: Optional[state.SearchType] = state.SearchType.All,
|
||||||
full_corpus: bool = True,
|
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
success = True
|
success = True
|
||||||
@@ -1344,7 +1343,6 @@ def configure_content(
|
|||||||
OrgToEntries,
|
OrgToEntries,
|
||||||
files.get("org"),
|
files.get("org"),
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1362,7 +1360,6 @@ def configure_content(
|
|||||||
MarkdownToEntries,
|
MarkdownToEntries,
|
||||||
files.get("markdown"),
|
files.get("markdown"),
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1379,7 +1376,6 @@ def configure_content(
|
|||||||
PdfToEntries,
|
PdfToEntries,
|
||||||
files.get("pdf"),
|
files.get("pdf"),
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1398,7 +1394,6 @@ def configure_content(
|
|||||||
PlaintextToEntries,
|
PlaintextToEntries,
|
||||||
files.get("plaintext"),
|
files.get("plaintext"),
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1418,7 +1413,6 @@ def configure_content(
|
|||||||
GithubToEntries,
|
GithubToEntries,
|
||||||
None,
|
None,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
config=github_config,
|
config=github_config,
|
||||||
)
|
)
|
||||||
@@ -1439,7 +1433,6 @@ def configure_content(
|
|||||||
NotionToEntries,
|
NotionToEntries,
|
||||||
None,
|
None,
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
config=notion_config,
|
config=notion_config,
|
||||||
)
|
)
|
||||||
@@ -1459,7 +1452,6 @@ def configure_content(
|
|||||||
ImageToEntries,
|
ImageToEntries,
|
||||||
files.get("image"),
|
files.get("image"),
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1472,7 +1464,6 @@ def configure_content(
|
|||||||
DocxToEntries,
|
DocxToEntries,
|
||||||
files.get("docx"),
|
files.get("docx"),
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
full_corpus=full_corpus,
|
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,166 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Optional, Union
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, Request, Response, UploadFile
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from starlette.authentication import requires
|
|
||||||
|
|
||||||
from khoj.routers.helpers import (
|
|
||||||
ApiIndexedDataLimiter,
|
|
||||||
configure_content,
|
|
||||||
update_telemetry_state,
|
|
||||||
)
|
|
||||||
from khoj.utils import constants, state
|
|
||||||
from khoj.utils.config import SearchModels
|
|
||||||
from khoj.utils.helpers import get_file_type
|
|
||||||
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
|
|
||||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
indexer = APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
class File(BaseModel):
|
|
||||||
path: str
|
|
||||||
content: Union[str, bytes]
|
|
||||||
|
|
||||||
|
|
||||||
class IndexBatchRequest(BaseModel):
|
|
||||||
files: list[File]
|
|
||||||
|
|
||||||
|
|
||||||
class IndexerInput(BaseModel):
|
|
||||||
org: Optional[dict[str, str]] = None
|
|
||||||
markdown: Optional[dict[str, str]] = None
|
|
||||||
pdf: Optional[dict[str, bytes]] = None
|
|
||||||
plaintext: Optional[dict[str, str]] = None
|
|
||||||
image: Optional[dict[str, bytes]] = None
|
|
||||||
docx: Optional[dict[str, bytes]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@indexer.post("/update")
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def update(
|
|
||||||
request: Request,
|
|
||||||
files: list[UploadFile],
|
|
||||||
force: bool = False,
|
|
||||||
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
user_agent: Optional[str] = Header(None),
|
|
||||||
referer: Optional[str] = Header(None),
|
|
||||||
host: Optional[str] = Header(None),
|
|
||||||
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
|
|
||||||
ApiIndexedDataLimiter(
|
|
||||||
incoming_entries_size_limit=10,
|
|
||||||
subscribed_incoming_entries_size_limit=25,
|
|
||||||
total_entries_size_limit=10,
|
|
||||||
subscribed_total_entries_size_limit=100,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
index_files: Dict[str, Dict[str, str]] = {
|
|
||||||
"org": {},
|
|
||||||
"markdown": {},
|
|
||||||
"pdf": {},
|
|
||||||
"plaintext": {},
|
|
||||||
"image": {},
|
|
||||||
"docx": {},
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
logger.info(f"📬 Updating content index via API call by {client} client")
|
|
||||||
for file in files:
|
|
||||||
file_content = file.file.read()
|
|
||||||
file_type, encoding = get_file_type(file.content_type, file_content)
|
|
||||||
if file_type in index_files:
|
|
||||||
index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content
|
|
||||||
else:
|
|
||||||
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
|
|
||||||
|
|
||||||
indexer_input = IndexerInput(
|
|
||||||
org=index_files["org"],
|
|
||||||
markdown=index_files["markdown"],
|
|
||||||
pdf=index_files["pdf"],
|
|
||||||
plaintext=index_files["plaintext"],
|
|
||||||
image=index_files["image"],
|
|
||||||
docx=index_files["docx"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if state.config == None:
|
|
||||||
logger.info("📬 Initializing content index on first run.")
|
|
||||||
default_full_config = FullConfig(
|
|
||||||
content_type=None,
|
|
||||||
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
|
|
||||||
processor=None,
|
|
||||||
)
|
|
||||||
state.config = default_full_config
|
|
||||||
default_content_config = ContentConfig(
|
|
||||||
org=None,
|
|
||||||
markdown=None,
|
|
||||||
pdf=None,
|
|
||||||
docx=None,
|
|
||||||
image=None,
|
|
||||||
github=None,
|
|
||||||
notion=None,
|
|
||||||
plaintext=None,
|
|
||||||
)
|
|
||||||
state.config.content_type = default_content_config
|
|
||||||
save_config_to_file_updated_state()
|
|
||||||
configure_search(state.search_models, state.config.search_type)
|
|
||||||
|
|
||||||
# Extract required fields from config
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
success = await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
configure_content,
|
|
||||||
indexer_input.model_dump(),
|
|
||||||
force,
|
|
||||||
t,
|
|
||||||
False,
|
|
||||||
user,
|
|
||||||
)
|
|
||||||
if not success:
|
|
||||||
raise RuntimeError("Failed to update content index")
|
|
||||||
logger.info(f"Finished processing batch indexing request")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to process batch indexing request: {e}", exc_info=True)
|
|
||||||
logger.error(
|
|
||||||
f'🚨 Failed to {"force " if force else ""}update {t} content index triggered via API call by {client} client: {e}',
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return Response(content="Failed", status_code=500)
|
|
||||||
|
|
||||||
indexing_metadata = {
|
|
||||||
"num_org": len(index_files["org"]),
|
|
||||||
"num_markdown": len(index_files["markdown"]),
|
|
||||||
"num_pdf": len(index_files["pdf"]),
|
|
||||||
"num_plaintext": len(index_files["plaintext"]),
|
|
||||||
"num_image": len(index_files["image"]),
|
|
||||||
"num_docx": len(index_files["docx"]),
|
|
||||||
}
|
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="index/update",
|
|
||||||
client=client,
|
|
||||||
user_agent=user_agent,
|
|
||||||
referer=referer,
|
|
||||||
host=host,
|
|
||||||
metadata=indexing_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"📪 Content index updated via API call by {client} client")
|
|
||||||
|
|
||||||
indexed_filenames = ",".join(file for ctype in index_files for file in index_files[ctype]) or ""
|
|
||||||
return Response(content=indexed_filenames, status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
|
|
||||||
# Run Validation Checks
|
|
||||||
if search_models is None:
|
|
||||||
search_models = SearchModels()
|
|
||||||
|
|
||||||
return search_models
|
|
||||||
@@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas
|
|||||||
notion_redirect = str(request.app.url_path_for("notion_config_page"))
|
notion_redirect = str(request.app.url_path_for("notion_config_page"))
|
||||||
|
|
||||||
# Trigger an async job to configure_content. Let it run without blocking the response.
|
# Trigger an async job to configure_content. Let it run without blocking the response.
|
||||||
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, True, user)
|
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user)
|
||||||
|
|
||||||
return RedirectResponse(notion_redirect)
|
return RedirectResponse(notion_redirect)
|
||||||
|
|||||||
@@ -199,17 +199,16 @@ def setup(
|
|||||||
text_to_entries: Type[TextToEntries],
|
text_to_entries: Type[TextToEntries],
|
||||||
files: dict[str, str],
|
files: dict[str, str],
|
||||||
regenerate: bool,
|
regenerate: bool,
|
||||||
full_corpus: bool = True,
|
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
config=None,
|
config=None,
|
||||||
) -> None:
|
) -> Tuple[int, int]:
|
||||||
if config:
|
if config:
|
||||||
num_new_embeddings, num_deleted_embeddings = text_to_entries(config).process(
|
num_new_embeddings, num_deleted_embeddings = text_to_entries(config).process(
|
||||||
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
|
files=files, user=user, regenerate=regenerate
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
num_new_embeddings, num_deleted_embeddings = text_to_entries().process(
|
num_new_embeddings, num_deleted_embeddings = text_to_entries().process(
|
||||||
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
|
files=files, user=user, regenerate=regenerate
|
||||||
)
|
)
|
||||||
|
|
||||||
if files:
|
if files:
|
||||||
@@ -219,6 +218,8 @@ def setup(
|
|||||||
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names[:10]} ..."
|
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names[:10]} ..."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
|
|
||||||
|
|
||||||
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
||||||
"""Score all retrieved entries using the cross-encoder"""
|
"""Score all retrieved entries using the cross-encoder"""
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ def get_org_files(config: TextContentConfig):
|
|||||||
logger.debug("At least one of org-files or org-file-filter is required to be specified")
|
logger.debug("At least one of org-files or org-file-filter is required to be specified")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
"Get Org files to process"
|
# Get Org files to process
|
||||||
absolute_org_files, filtered_org_files = set(), set()
|
absolute_org_files, filtered_org_files = set(), set()
|
||||||
if org_files:
|
if org_files:
|
||||||
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
|
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from khoj.database.models import (
|
|||||||
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
||||||
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
|
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
from khoj.routers.indexer import configure_content
|
from khoj.routers.api_content import configure_content
|
||||||
from khoj.search_type import text_search
|
from khoj.search_type import text_search
|
||||||
from khoj.utils import fs_syncer, state
|
from khoj.utils import fs_syncer, state
|
||||||
from khoj.utils.config import SearchModels
|
from khoj.utils.config import SearchModels
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ def test_index_update_with_no_auth_key(client):
|
|||||||
files = get_sample_files_data()
|
files = get_sample_files_data()
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.post("/api/v1/index/update", files=files)
|
response = client.patch("/api/content", files=files)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
@@ -89,7 +89,7 @@ def test_index_update_with_invalid_auth_key(client):
|
|||||||
headers = {"Authorization": "Bearer kk-invalid-token"}
|
headers = {"Authorization": "Bearer kk-invalid-token"}
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
response = client.patch("/api/content", files=files, headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
@@ -130,7 +130,7 @@ def test_index_update_big_files(client):
|
|||||||
headers = {"Authorization": "Bearer kk-secret"}
|
headers = {"Authorization": "Bearer kk-secret"}
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
response = client.patch("/api/content", files=files, headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 429
|
assert response.status_code == 429
|
||||||
@@ -146,7 +146,7 @@ def test_index_update_medium_file_unsubscribed(client, api_user4: KhojApiUser):
|
|||||||
headers = {"Authorization": f"Bearer {api_token}"}
|
headers = {"Authorization": f"Bearer {api_token}"}
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
response = client.patch("/api/content", files=files, headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 429
|
assert response.status_code == 429
|
||||||
@@ -162,7 +162,7 @@ def test_index_update_normal_file_unsubscribed(client, api_user4: KhojApiUser):
|
|||||||
headers = {"Authorization": f"Bearer {api_token}"}
|
headers = {"Authorization": f"Bearer {api_token}"}
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
response = client.patch("/api/content", files=files, headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -177,7 +177,7 @@ def test_index_update_big_files_no_billing(client):
|
|||||||
headers = {"Authorization": "Bearer kk-secret"}
|
headers = {"Authorization": "Bearer kk-secret"}
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
response = client.patch("/api/content", files=files, headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -191,7 +191,7 @@ def test_index_update(client):
|
|||||||
headers = {"Authorization": "Bearer kk-secret"}
|
headers = {"Authorization": "Bearer kk-secret"}
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
response = client.patch("/api/content", files=files, headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -208,8 +208,8 @@ def test_index_update_fails_if_more_than_1000_files(client, api_user4: KhojApiUs
|
|||||||
headers = {"Authorization": f"Bearer {api_token}"}
|
headers = {"Authorization": f"Bearer {api_token}"}
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
ok_response = client.post("/api/v1/index/update", files=files[:1000], headers=headers)
|
ok_response = client.patch("/api/content", files=files[:1000], headers=headers)
|
||||||
bad_response = client.post("/api/v1/index/update", files=files, headers=headers)
|
bad_response = client.patch("/api/content", files=files, headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert ok_response.status_code == 200
|
assert ok_response.status_code == 200
|
||||||
@@ -226,7 +226,7 @@ def test_regenerate_with_valid_content_type(client):
|
|||||||
headers = {"Authorization": "Bearer kk-secret"}
|
headers = {"Authorization": "Bearer kk-secret"}
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.post(f"/api/v1/index/update?t={content_type}", files=files, headers=headers)
|
response = client.patch(f"/api/content?t={content_type}", files=files, headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
|
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
|
||||||
@@ -243,7 +243,7 @@ def test_regenerate_with_github_fails_without_pat(client):
|
|||||||
files = get_sample_files_data()
|
files = get_sample_files_data()
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.post(f"/api/v1/index/update?t=github", files=files, headers=headers)
|
response = client.patch(f"/api/content?t=github", files=files, headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github"
|
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github"
|
||||||
@@ -269,11 +269,11 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
|
|||||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/api/configure/types", headers=headers)
|
response = client.get(f"/api/content/types", headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == ["all", "org", "plaintext"]
|
assert set(response.json()) == {"all", "org", "plaintext"}
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@@ -289,7 +289,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
|||||||
client = TestClient(fastapi_app)
|
client = TestClient(fastapi_app)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/api/configure/types")
|
response = client.get(f"/api/content/types")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def test_index_update_with_user2(client, api_user2: KhojApiUser):
|
|||||||
source_file_symbol = set([f[1][0] for f in files])
|
source_file_symbol = set([f[1][0] for f in files])
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {api_user2.token}"}
|
headers = {"Authorization": f"Bearer {api_user2.token}"}
|
||||||
update_response = client.post("/api/v1/index/update", files=files, headers=headers)
|
update_response = client.patch("/api/content", files=files, headers=headers)
|
||||||
search_response = client.get("/api/search?q=hardware&t=all", headers=headers)
|
search_response = client.get("/api/search?q=hardware&t=all", headers=headers)
|
||||||
results = search_response.json()
|
results = search_response.json()
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ def test_index_update_with_user2_inaccessible_user1(client, api_user2: KhojApiUs
|
|||||||
source_file_symbol = set([f[1][0] for f in files])
|
source_file_symbol = set([f[1][0] for f in files])
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {api_user2.token}"}
|
headers = {"Authorization": f"Bearer {api_user2.token}"}
|
||||||
update_response = client.post("/api/v1/index/update", files=files, headers=headers)
|
update_response = client.patch("/api/content", files=files, headers=headers)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
headers = {"Authorization": f"Bearer {api_user.token}"}
|
headers = {"Authorization": f"Bearer {api_user.token}"}
|
||||||
|
|||||||
@@ -6,9 +6,16 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from khoj.database.adapters import EntryAdapters
|
||||||
from khoj.database.models import Entry, GithubConfig, KhojUser, LocalOrgConfig
|
from khoj.database.models import Entry, GithubConfig, KhojUser, LocalOrgConfig
|
||||||
|
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
|
||||||
from khoj.processor.content.github.github_to_entries import GithubToEntries
|
from khoj.processor.content.github.github_to_entries import GithubToEntries
|
||||||
|
from khoj.processor.content.images.image_to_entries import ImageToEntries
|
||||||
|
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
|
||||||
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
||||||
|
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
||||||
|
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||||
|
from khoj.processor.content.text_to_entries import TextToEntries
|
||||||
from khoj.search_type import text_search
|
from khoj.search_type import text_search
|
||||||
from khoj.utils.fs_syncer import collect_files, get_org_files
|
from khoj.utils.fs_syncer import collect_files, get_org_files
|
||||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||||
@@ -151,7 +158,6 @@ async def test_text_search(search_config: SearchConfig):
|
|||||||
OrgToEntries,
|
OrgToEntries,
|
||||||
data,
|
data,
|
||||||
True,
|
True,
|
||||||
True,
|
|
||||||
default_user,
|
default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -240,7 +246,6 @@ conda activate khoj
|
|||||||
OrgToEntries,
|
OrgToEntries,
|
||||||
data,
|
data,
|
||||||
regenerate=False,
|
regenerate=False,
|
||||||
full_corpus=False,
|
|
||||||
user=default_user,
|
user=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -396,6 +401,49 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
|||||||
verify_embeddings(3, default_user)
|
verify_embeddings(3, default_user)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"text_to_entries",
|
||||||
|
[
|
||||||
|
(OrgToEntries),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_update_index_with_deleted_file(
|
||||||
|
org_config_with_only_new_file: LocalOrgConfig, text_to_entries: TextToEntries, default_user: KhojUser
|
||||||
|
):
|
||||||
|
"Delete entries associated with new file when file path with empty content passed."
|
||||||
|
# Arrange
|
||||||
|
file_to_index = "test"
|
||||||
|
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||||
|
initial_data = {file_to_index: new_entry}
|
||||||
|
final_data = {file_to_index: ""}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# load entries after adding file
|
||||||
|
initial_added_entries, _ = text_search.setup(text_to_entries, initial_data, regenerate=True, user=default_user)
|
||||||
|
initial_total_entries = EntryAdapters.get_existing_entry_hashes_by_file(default_user, file_to_index).count()
|
||||||
|
|
||||||
|
# load entries after deleting file
|
||||||
|
final_added_entries, final_deleted_entries = text_search.setup(
|
||||||
|
text_to_entries, final_data, regenerate=False, user=default_user
|
||||||
|
)
|
||||||
|
final_total_entries = EntryAdapters.get_existing_entry_hashes_by_file(default_user, file_to_index).count()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert initial_total_entries > 0, "File entries not indexed"
|
||||||
|
assert initial_added_entries > 0, "No entries got added"
|
||||||
|
|
||||||
|
assert final_total_entries == 0, "File did not get deleted"
|
||||||
|
assert final_added_entries == 0, "Entries were unexpectedly added in delete entries pass"
|
||||||
|
assert final_deleted_entries == initial_added_entries, "All added entries were not deleted"
|
||||||
|
|
||||||
|
verify_embeddings(0, default_user), "Embeddings still exist for user"
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
EntryAdapters.delete_all_entries(default_user)
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
|
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
|
||||||
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
|
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
|
||||||
|
|||||||
Reference in New Issue
Block a user