Split Configure API into Content, Model API paths (#857)

## Major: Breaking Changes
- Move API endpoints under /configure/<type>/model to /api/model/<type>
- Move API endpoints under /api/configure/content/ to /api/content/
- Accept file deletion requests by clients during sync
- Split /api/v1/index/update into /api/content PUT, PATCH API endpoints

## Minor: Create New API Endpoint
- Create API endpoints to get user content configurations

Related: #852
This commit is contained in:
Debanjum
2024-07-26 23:48:41 -07:00
committed by GitHub
36 changed files with 857 additions and 739 deletions

View File

@@ -233,11 +233,15 @@ function pushDataToKhoj (regenerate = false) {
// Request indexing files on server. With upto 1000 files in each request // Request indexing files on server. With upto 1000 files in each request
for (let i = 0; i < filesDataToPush.length; i += 1000) { for (let i = 0; i < filesDataToPush.length; i += 1000) {
const syncUrl = `${hostURL}/api/content?client=desktop`;
const filesDataGroup = filesDataToPush.slice(i, i + 1000); const filesDataGroup = filesDataToPush.slice(i, i + 1000);
const formData = new FormData(); const formData = new FormData();
filesDataGroup.forEach(fileData => { formData.append('files', fileData.blob, fileData.path) }); filesDataGroup.forEach(fileData => { formData.append('files', fileData.blob, fileData.path) });
let request = axios.post(`${hostURL}/api/v1/index/update?force=${regenerate}&client=desktop`, formData, { headers }); requests.push(
requests.push(request); regenerate
? axios.put(syncUrl, formData, { headers })
: axios.patch(syncUrl, formData, { headers })
);
} }
// Wait for requests batch to finish // Wait for requests batch to finish

View File

@@ -212,7 +212,7 @@
const headers = { 'Authorization': `Bearer ${khojToken}` }; const headers = { 'Authorization': `Bearer ${khojToken}` };
// Populate type dropdown field with enabled content types only // Populate type dropdown field with enabled content types only
fetch(`${hostURL}/api/configure/types`, { headers }) fetch(`${hostURL}/api/content/types`, { headers })
.then(response => response.json()) .then(response => response.json())
.then(enabled_types => { .then(enabled_types => {
// Show warning if no content types are enabled // Show warning if no content types are enabled

View File

@@ -424,12 +424,12 @@ Auto invokes setup steps on calling main entrypoint."
"Send multi-part form `BODY' of `CONTENT-TYPE' in request to khoj server. "Send multi-part form `BODY' of `CONTENT-TYPE' in request to khoj server.
Append 'TYPE-QUERY' as query parameter in request url. Append 'TYPE-QUERY' as query parameter in request url.
Specify `BOUNDARY' used to separate files in request header." Specify `BOUNDARY' used to separate files in request header."
(let ((url-request-method "POST") (let ((url-request-method ((if force) "PUT" "PATCH"))
(url-request-data body) (url-request-data body)
(url-request-extra-headers `(("content-type" . ,(format "multipart/form-data; boundary=%s" boundary)) (url-request-extra-headers `(("content-type" . ,(format "multipart/form-data; boundary=%s" boundary))
("Authorization" . ,(format "Bearer %s" khoj-api-key))))) ("Authorization" . ,(format "Bearer %s" khoj-api-key)))))
(with-current-buffer (with-current-buffer
(url-retrieve (format "%s/api/v1/index/update?%s&force=%s&client=emacs" khoj-server-url type-query (or force "false")) (url-retrieve (format "%s/api/content?%s&client=emacs" khoj-server-url type-query)
;; render response from indexing API endpoint on server ;; render response from indexing API endpoint on server
(lambda (status) (lambda (status)
(if (not (plist-get status :error)) (if (not (plist-get status :error))
@@ -697,7 +697,7 @@ Optionally apply CALLBACK with JSON parsed response and CBARGS."
(defun khoj--get-enabled-content-types () (defun khoj--get-enabled-content-types ()
"Get content types enabled for search from API." "Get content types enabled for search from API."
(khoj--call-api "/api/configure/types" "GET" nil `(lambda (item) (mapcar #'intern item)))) (khoj--call-api "/api/content/types" "GET" nil `(lambda (item) (mapcar #'intern item))))
(defun khoj--query-search-api-and-render-results (query content-type buffer-name &optional rerank is-find-similar) (defun khoj--query-search-api-and-render-results (query content-type buffer-name &optional rerank is-find-similar)
"Query Khoj Search API with QUERY, CONTENT-TYPE and RERANK as query params. "Query Khoj Search API with QUERY, CONTENT-TYPE and RERANK as query params.

View File

@@ -89,10 +89,11 @@ export async function updateContentIndex(vault: Vault, setting: KhojSetting, las
for (let i = 0; i < fileData.length; i += 1000) { for (let i = 0; i < fileData.length; i += 1000) {
const filesGroup = fileData.slice(i, i + 1000); const filesGroup = fileData.slice(i, i + 1000);
const formData = new FormData(); const formData = new FormData();
const method = regenerate ? "PUT" : "PATCH";
filesGroup.forEach(fileItem => { formData.append('files', fileItem.blob, fileItem.path) }); filesGroup.forEach(fileItem => { formData.append('files', fileItem.blob, fileItem.path) });
// Call Khoj backend to update index with all markdown, pdf files // Call Khoj backend to update index with all markdown, pdf files
const response = await fetch(`${setting.khojUrl}/api/v1/index/update?force=${regenerate}&client=obsidian`, { const response = await fetch(`${setting.khojUrl}/api/content?client=obsidian`, {
method: 'POST', method: method,
headers: { headers: {
'Authorization': `Bearer ${setting.khojApiKey}`, 'Authorization': `Bearer ${setting.khojApiKey}`,
}, },

View File

@@ -277,8 +277,8 @@ export function uploadDataForIndexing(
// Wait for all files to be read before making the fetch request // Wait for all files to be read before making the fetch request
Promise.all(fileReadPromises) Promise.all(fileReadPromises)
.then(() => { .then(() => {
return fetch("/api/v1/index/update?force=false&client=web", { return fetch("/api/content?client=web", {
method: "POST", method: "PATCH",
body: formData, body: formData,
}); });
}) })

View File

@@ -68,8 +68,8 @@ interface ModelPickerProps {
} }
export const ModelPicker: React.FC<any> = (props: ModelPickerProps) => { export const ModelPicker: React.FC<any> = (props: ModelPickerProps) => {
const { data: models } = useOptionsRequest('/api/configure/chat/model/options'); const { data: models } = useOptionsRequest('/api/model/chat/options');
const { data: selectedModel } = useSelectedModel('/api/configure/chat/model'); const { data: selectedModel } = useSelectedModel('/api/model/chat');
const [openLoginDialog, setOpenLoginDialog] = React.useState(false); const [openLoginDialog, setOpenLoginDialog] = React.useState(false);
let userData = useAuthenticatedData(); let userData = useAuthenticatedData();
@@ -94,7 +94,7 @@ export const ModelPicker: React.FC<any> = (props: ModelPickerProps) => {
props.setModelUsed(model); props.setModelUsed(model);
} }
fetch('/api/configure/chat/model' + '?id=' + String(model.id), { method: 'POST', body: JSON.stringify(model) }) fetch('/api/model/chat' + '?id=' + String(model.id), { method: 'POST', body: JSON.stringify(model) })
.then((response) => { .then((response) => {
if (!response.ok) { if (!response.ok) {
throw new Error('Failed to select model'); throw new Error('Failed to select model');

View File

@@ -148,7 +148,7 @@ interface FilesMenuProps {
function FilesMenu(props: FilesMenuProps) { function FilesMenu(props: FilesMenuProps) {
// Use SWR to fetch files // Use SWR to fetch files
const { data: files, error } = useSWR<string[]>(props.conversationId ? '/api/configure/content/computer' : null, fetcher); const { data: files, error } = useSWR<string[]>(props.conversationId ? '/api/content/computer' : null, fetcher);
const { data: selectedFiles, error: selectedFilesError } = useSWR(props.conversationId ? `/api/chat/conversation/file-filters/${props.conversationId}` : null, fetcher); const { data: selectedFiles, error: selectedFilesError } = useSWR(props.conversationId ? `/api/chat/conversation/file-filters/${props.conversationId}` : null, fetcher);
const [isOpen, setIsOpen] = useState(false); const [isOpen, setIsOpen] = useState(false);
const [unfilteredFiles, setUnfilteredFiles] = useState<string[]>([]); const [unfilteredFiles, setUnfilteredFiles] = useState<string[]>([]);

View File

@@ -42,7 +42,7 @@ from khoj.database.adapters import (
) )
from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription from khoj.database.models import ClientApplication, KhojUser, ProcessLock, Subscription
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, configure_search from khoj.routers.api_content import configure_content, configure_search
from khoj.routers.twilio import is_twilio_enabled from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import SearchType from khoj.utils.config import SearchType
@@ -308,16 +308,16 @@ def configure_routes(app):
from khoj.routers.api import api from khoj.routers.api import api
from khoj.routers.api_agents import api_agents from khoj.routers.api_agents import api_agents
from khoj.routers.api_chat import api_chat from khoj.routers.api_chat import api_chat
from khoj.routers.api_config import api_config from khoj.routers.api_content import api_content
from khoj.routers.indexer import indexer from khoj.routers.api_model import api_model
from khoj.routers.notion import notion_router from khoj.routers.notion import notion_router
from khoj.routers.web_client import web_client from khoj.routers.web_client import web_client
app.include_router(api, prefix="/api") app.include_router(api, prefix="/api")
app.include_router(api_chat, prefix="/api/chat") app.include_router(api_chat, prefix="/api/chat")
app.include_router(api_agents, prefix="/api/agents") app.include_router(api_agents, prefix="/api/agents")
app.include_router(api_config, prefix="/api/configure") app.include_router(api_model, prefix="/api/model")
app.include_router(indexer, prefix="/api/v1/index") app.include_router(api_content, prefix="/api/content")
app.include_router(notion_router, prefix="/api/notion") app.include_router(notion_router, prefix="/api/notion")
app.include_router(web_client) app.include_router(web_client)
@@ -336,7 +336,7 @@ def configure_routes(app):
if is_twilio_enabled(): if is_twilio_enabled():
from khoj.routers.api_phone import api_phone from khoj.routers.api_phone import api_phone
app.include_router(api_phone, prefix="/api/configure/phone") app.include_router(api_phone, prefix="/api/phone")
logger.info("📞 Enabled Twilio") logger.info("📞 Enabled Twilio")

View File

@@ -998,8 +998,8 @@ To get started, just start typing below. You can also type / to see a list of co
// Wait for all files to be read before making the fetch request // Wait for all files to be read before making the fetch request
Promise.all(fileReadPromises) Promise.all(fileReadPromises)
.then(() => { .then(() => {
return fetch("/api/v1/index/update?force=false&client=web", { return fetch("/api/content?client=web", {
method: "POST", method: "PATCH",
body: formData, body: formData,
}); });
}) })
@@ -1954,7 +1954,7 @@ To get started, just start typing below. You can also type / to see a list of co
} }
var allFiles; var allFiles;
function renderAllFiles() { function renderAllFiles() {
fetch('/api/configure/content/computer') fetch('/api/content/computer')
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
var indexedFiles = document.getElementsByClassName("indexed-files")[0]; var indexedFiles = document.getElementsByClassName("indexed-files")[0];

View File

@@ -32,7 +32,7 @@
</style> </style>
<script> <script>
function removeFile(path) { function removeFile(path) {
fetch('/api/configure/content/file?filename=' + path, { fetch('/api/content/file?filename=' + path, {
method: 'DELETE', method: 'DELETE',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -48,7 +48,7 @@
// Get all currently indexed files // Get all currently indexed files
function getAllComputerFilenames() { function getAllComputerFilenames() {
fetch('/api/configure/content/computer') fetch('/api/content/computer')
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
var indexedFiles = document.getElementsByClassName("indexed-files")[0]; var indexedFiles = document.getElementsByClassName("indexed-files")[0];
@@ -122,7 +122,7 @@
deleteAllComputerFilesButton.textContent = "🗑️ Deleting..."; deleteAllComputerFilesButton.textContent = "🗑️ Deleting...";
deleteAllComputerFilesButton.disabled = true; deleteAllComputerFilesButton.disabled = true;
fetch('/api/configure/content/computer', { fetch('/api/content/computer', {
method: 'DELETE', method: 'DELETE',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',

View File

@@ -165,7 +165,7 @@
// Save Github config on server // Save Github config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/configure/content/github', { fetch('/api/content/github', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',

View File

@@ -45,7 +45,7 @@
// Save Notion config on server // Save Notion config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/configure/content/notion', { fetch('/api/content/notion', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',

View File

@@ -209,7 +209,7 @@
function populate_type_dropdown() { function populate_type_dropdown() {
// Populate type dropdown field with enabled content types only // Populate type dropdown field with enabled content types only
fetch("/api/configure/types") fetch("/api/content/types")
.then(response => response.json()) .then(response => response.json())
.then(enabled_types => { .then(enabled_types => {
// Show warning if no content types are enabled, or just one ("all") // Show warning if no content types are enabled, or just one ("all")

View File

@@ -394,8 +394,8 @@
function saveProfileGivenName() { function saveProfileGivenName() {
const givenName = document.getElementById("profile_given_name").value; const givenName = document.getElementById("profile_given_name").value;
fetch('/api/configure/user/name?name=' + givenName, { fetch('/api/user/name?name=' + givenName, {
method: 'POST', method: 'PATCH',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
} }
@@ -421,7 +421,7 @@
saveVoiceModelButton.disabled = true; saveVoiceModelButton.disabled = true;
saveVoiceModelButton.textContent = "Saving..."; saveVoiceModelButton.textContent = "Saving...";
fetch('/api/configure/voice/model?id=' + voiceModel, { fetch('/api/model/voice?id=' + voiceModel, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -455,7 +455,7 @@
saveModelButton.innerHTML = ""; saveModelButton.innerHTML = "";
saveModelButton.textContent = "Saving..."; saveModelButton.textContent = "Saving...";
fetch('/api/configure/chat/model?id=' + chatModel, { fetch('/api/model/chat?id=' + chatModel, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -494,7 +494,7 @@
saveSearchModelButton.disabled = true; saveSearchModelButton.disabled = true;
saveSearchModelButton.textContent = "Saving..."; saveSearchModelButton.textContent = "Saving...";
fetch('/api/configure/search/model?id=' + searchModel, { fetch('/api/model/search?id=' + searchModel, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -526,7 +526,7 @@
saveModelButton.disabled = true; saveModelButton.disabled = true;
saveModelButton.innerHTML = "Saving..."; saveModelButton.innerHTML = "Saving...";
fetch('/api/configure/paint/model?id=' + paintModel, { fetch('/api/model/paint?id=' + paintModel, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -553,7 +553,7 @@
}; };
function clearContentType(content_source) { function clearContentType(content_source) {
fetch('/api/configure/content/' + content_source, { fetch('/api/content/' + content_source, {
method: 'DELETE', method: 'DELETE',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -676,7 +676,7 @@
content_sources = ["computer", "github", "notion"]; content_sources = ["computer", "github", "notion"];
content_sources.forEach(content_source => { content_sources.forEach(content_source => {
fetch(`/api/configure/content/${content_source}`, { fetch(`/api/content/${content_source}`, {
method: 'GET', method: 'GET',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -807,7 +807,7 @@
function getIndexedDataSize() { function getIndexedDataSize() {
document.getElementById("indexed-data-size").textContent = "Calculating..."; document.getElementById("indexed-data-size").textContent = "Calculating...";
fetch('/api/configure/content/size') fetch('/api/content/size')
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
document.getElementById("indexed-data-size").textContent = data.indexed_data_size_in_mb + " MB used"; document.getElementById("indexed-data-size").textContent = data.indexed_data_size_in_mb + " MB used";
@@ -815,7 +815,7 @@
} }
function removeFile(path) { function removeFile(path) {
fetch('/api/configure/content/file?filename=' + path, { fetch('/api/content/file?filename=' + path, {
method: 'DELETE', method: 'DELETE',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -890,7 +890,7 @@
}) })
phonenumberRemoveButton.addEventListener("click", () => { phonenumberRemoveButton.addEventListener("click", () => {
fetch('/api/configure/phone', { fetch('/api/phone', {
method: 'DELETE', method: 'DELETE',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -917,7 +917,7 @@
}, 5000); }, 5000);
} else { } else {
const mobileNumber = iti.getNumber(); const mobileNumber = iti.getNumber();
fetch('/api/configure/phone?phone_number=' + mobileNumber, { fetch('/api/phone?phone_number=' + mobileNumber, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -970,7 +970,7 @@
return; return;
} }
fetch('/api/configure/phone/verify?code=' + otp, { fetch('/api/phone/verify?code=' + otp, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',

View File

@@ -19,16 +19,11 @@ class DocxToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
if not full_corpus: deletion_file_names = set([file for file in files if files[file] == b""])
deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names
files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process}
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified Docx files # Extract Entries from specified Docx files
with timer("Extract entries from specified DOCX files", logger): with timer("Extract entries from specified DOCX files", logger):

View File

@@ -48,9 +48,7 @@ class GithubToEntries(TextToEntries):
else: else:
return return
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
if self.config.pat_token is None or self.config.pat_token == "": if self.config.pat_token is None or self.config.pat_token == "":
logger.error(f"Github PAT token is not set. Skipping github content") logger.error(f"Github PAT token is not set. Skipping github content")
raise ValueError("Github PAT token is not set. Skipping github content") raise ValueError("Github PAT token is not set. Skipping github content")

View File

@@ -20,16 +20,11 @@ class ImageToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
if not full_corpus: deletion_file_names = set([file for file in files if files[file] == b""])
deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names
files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process}
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified image files # Extract Entries from specified image files
with timer("Extract entries from specified Image files", logger): with timer("Extract entries from specified Image files", logger):

View File

@@ -19,16 +19,11 @@ class MarkdownToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
if not full_corpus: deletion_file_names = set([file for file in files if files[file] == ""])
deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names
files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process}
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
max_tokens = 256 max_tokens = 256
# Extract Entries from specified Markdown files # Extract Entries from specified Markdown files

View File

@@ -78,9 +78,7 @@ class NotionToEntries(TextToEntries):
self.body_params = {"page_size": 100} self.body_params = {"page_size": 100}
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
current_entries = [] current_entries = []
# Get all pages # Get all pages

View File

@@ -20,15 +20,10 @@ class OrgToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False deletion_file_names = set([file for file in files if files[file] == ""])
) -> Tuple[int, int]: files_to_process = set(files) - deletion_file_names
if not full_corpus: files = {file: files[file] for file in files_to_process}
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified Org files # Extract Entries from specified Org files
max_tokens = 256 max_tokens = 256

View File

@@ -22,16 +22,11 @@ class PdfToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
if not full_corpus: deletion_file_names = set([file for file in files if files[file] == b""])
deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names
files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process}
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified Pdf files # Extract Entries from specified Pdf files
with timer("Extract entries from specified PDF files", logger): with timer("Extract entries from specified PDF files", logger):

View File

@@ -20,15 +20,10 @@ class PlaintextToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False deletion_file_names = set([file for file in files if files[file] == ""])
) -> Tuple[int, int]: files_to_process = set(files) - deletion_file_names
if not full_corpus: files = {file: files[file] for file in files_to_process}
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified plaintext files # Extract Entries from specified plaintext files
with timer("Extract entries from specified Plaintext files", logger): with timer("Extract entries from specified Plaintext files", logger):

View File

@@ -31,9 +31,7 @@ class TextToEntries(ABC):
self.date_filter = DateFilter() self.date_filter = DateFilter()
@abstractmethod @abstractmethod
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
... ...
@staticmethod @staticmethod

View File

@@ -19,6 +19,7 @@ from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires from starlette.authentication import has_required_scope, requires
from khoj.configure import initialize_content from khoj.configure import initialize_content
from khoj.database import adapters
from khoj.database.adapters import ( from khoj.database.adapters import (
AutomationAdapters, AutomationAdapters,
ConversationAdapters, ConversationAdapters,
@@ -39,6 +40,7 @@ from khoj.routers.helpers import (
CommonQueryParams, CommonQueryParams,
ConversationCommandRateLimiter, ConversationCommandRateLimiter,
acreate_title_from_query, acreate_title_from_query,
get_user_config,
schedule_automation, schedule_automation,
update_telemetry_state, update_telemetry_state,
) )
@@ -276,6 +278,49 @@ async def transcribe(
return Response(content=content, media_type="application/json", status_code=200) return Response(content=content, media_type="application/json", status_code=200)
@api.get("/settings", response_class=Response)
@requires(["authenticated"])
def get_settings(request: Request, detailed: Optional[bool] = False) -> Response:
user = request.user.object
user_config = get_user_config(user, request, is_detailed=detailed)
del user_config["request"]
# Return config data as a JSON response
return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
@api.patch("/user/name", status_code=200)
@requires(["authenticated"])
def set_user_name(
request: Request,
name: str,
client: Optional[str] = None,
):
user = request.user.object
split_name = name.split(" ")
if len(split_name) > 2:
raise HTTPException(status_code=400, detail="Name must be in the format: Firstname Lastname")
if len(split_name) == 1:
first_name = split_name[0]
last_name = ""
else:
first_name, last_name = split_name[0], split_name[-1]
adapters.set_user_name(user, first_name, last_name)
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_user_name",
client=client,
)
return {"status": "ok"}
async def extract_references_and_questions( async def extract_references_and_questions(
request: Request, request: Request,
meta_log: dict, meta_log: dict,

View File

@@ -1,434 +0,0 @@
import json
import logging
import math
from typing import Dict, List, Optional, Union
from asgiref.sync import sync_to_async
from fastapi import APIRouter, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import Entry as DbEntry
from khoj.database.models import (
GithubConfig,
KhojUser,
LocalMarkdownConfig,
LocalOrgConfig,
LocalPdfConfig,
LocalPlaintextConfig,
NotionConfig,
Subscription,
)
from khoj.routers.helpers import CommonQueryParams, update_telemetry_state
from khoj.utils import constants, state
from khoj.utils.rawconfig import (
FullConfig,
GithubContentConfig,
NotionContentConfig,
SearchConfig,
)
from khoj.utils.state import SearchType
api_config = APIRouter()
logger = logging.getLogger(__name__)
def map_config_to_object(content_source: str):
if content_source == DbEntry.EntrySource.GITHUB:
return GithubConfig
if content_source == DbEntry.EntrySource.NOTION:
return NotionConfig
if content_source == DbEntry.EntrySource.COMPUTER:
return "Computer"
async def map_config_to_db(config: FullConfig, user: KhojUser):
if config.content_type:
if config.content_type.org:
await LocalOrgConfig.objects.filter(user=user).adelete()
await LocalOrgConfig.objects.acreate(
input_files=config.content_type.org.input_files,
input_filter=config.content_type.org.input_filter,
index_heading_entries=config.content_type.org.index_heading_entries,
user=user,
)
if config.content_type.markdown:
await LocalMarkdownConfig.objects.filter(user=user).adelete()
await LocalMarkdownConfig.objects.acreate(
input_files=config.content_type.markdown.input_files,
input_filter=config.content_type.markdown.input_filter,
index_heading_entries=config.content_type.markdown.index_heading_entries,
user=user,
)
if config.content_type.pdf:
await LocalPdfConfig.objects.filter(user=user).adelete()
await LocalPdfConfig.objects.acreate(
input_files=config.content_type.pdf.input_files,
input_filter=config.content_type.pdf.input_filter,
index_heading_entries=config.content_type.pdf.index_heading_entries,
user=user,
)
if config.content_type.plaintext:
await LocalPlaintextConfig.objects.filter(user=user).adelete()
await LocalPlaintextConfig.objects.acreate(
input_files=config.content_type.plaintext.input_files,
input_filter=config.content_type.plaintext.input_filter,
index_heading_entries=config.content_type.plaintext.index_heading_entries,
user=user,
)
if config.content_type.github:
await adapters.set_user_github_config(
user=user,
pat_token=config.content_type.github.pat_token,
repos=config.content_type.github.repos,
)
if config.content_type.notion:
await adapters.set_notion_config(
user=user,
token=config.content_type.notion.token,
)
def _initialize_config():
if state.config is None:
state.config = FullConfig()
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
@api_config.post("/content/github", status_code=200)
@requires(["authenticated"])
async def set_content_github(
request: Request,
updated_config: Union[GithubContentConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
try:
await adapters.set_user_github_config(
user=user,
pat_token=updated_config.pat_token,
repos=updated_config.repos,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_content_config",
client=client,
metadata={"content_type": "github"},
)
return {"status": "ok"}
@api_config.post("/content/notion", status_code=200)
@requires(["authenticated"])
async def set_content_notion(
request: Request,
updated_config: Union[NotionContentConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
try:
await adapters.set_notion_config(
user=user,
token=updated_config.token,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_content_config",
client=client,
metadata={"content_type": "notion"},
)
return {"status": "ok"}
@api_config.delete("/content/{content_source}", status_code=200)
@requires(["authenticated"])
async def delete_content_source(
request: Request,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_content_config",
client=client,
metadata={"content_source": content_source},
)
content_object = map_config_to_object(content_source)
if content_object is None:
raise ValueError(f"Invalid content source: {content_source}")
elif content_object != "Computer":
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EntryAdapters.delete_all_entries)(user, file_source=content_source)
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
return {"status": "ok"}
@api_config.delete("/content/file", status_code=201)
@requires(["authenticated"])
async def delete_content_file(
request: Request,
filename: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_file",
client=client,
)
await EntryAdapters.adelete_entry_by_file(user, filename)
return {"status": "ok"}
@api_config.get("/content/{content_source}", response_model=List[str])
@requires(["authenticated"])
async def get_content_source(
request: Request,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="get_all_filenames",
client=client,
)
return await sync_to_async(list)(EntryAdapters.get_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
@api_config.get("/chat/model/options", response_model=Dict[str, Union[str, int]])
def get_chat_model_options(
request: Request,
client: Optional[str] = None,
):
conversation_options = ConversationAdapters.get_conversation_processor_options().all()
all_conversation_options = list()
for conversation_option in conversation_options:
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
@api_config.get("/chat/model")
@requires(["authenticated"])
def get_user_chat_model(
request: Request,
client: Optional[str] = None,
):
user = request.user.object
chat_model = ConversationAdapters.get_conversation_config(user)
if chat_model is None:
chat_model = ConversationAdapters.get_default_conversation_config()
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
@api_config.post("/chat/model", status_code=200)
@requires(["authenticated", "premium"])
async def update_chat_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_conversation_chat_model",
client=client,
metadata={"processor_conversation_type": "conversation"},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}
@api_config.post("/voice/model", status_code=200)
@requires(["authenticated", "premium"])
async def update_voice_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
new_config = await ConversationAdapters.aset_user_voice_model(user, id)
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_voice_model",
client=client,
)
if new_config is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Model not found"}))
return Response(status_code=202, content=json.dumps({"status": "ok"}))
@api_config.post("/search/model", status_code=200)
@requires(["authenticated"])
async def update_search_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
prev_config = await adapters.aget_user_search_model(user)
new_config = await adapters.aset_user_search_model(user, int(id))
if prev_config and int(id) != prev_config.id and new_config:
await EntryAdapters.adelete_all_entries(user)
if not prev_config:
# If the use was just using the default config, delete all the entries and set the new config.
await EntryAdapters.adelete_all_entries(user)
if new_config is None:
return {"status": "error", "message": "Model not found"}
else:
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_search_model",
client=client,
metadata={"search_model": new_config.setting.name},
)
return {"status": "ok"}
@api_config.post("/paint/model", status_code=200)
@requires(["authenticated"])
async def update_paint_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_paint_model",
client=client,
metadata={"paint_model": new_config.setting.model_name},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}
@api_config.get("/content/size", response_model=Dict[str, int])
@requires(["authenticated"])
async def get_content_size(request: Request, common: CommonQueryParams):
user = request.user.object
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
return Response(
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
media_type="application/json",
status_code=200,
)
@api_config.post("/user/name", status_code=200)
@requires(["authenticated"])
def set_user_name(
request: Request,
name: str,
client: Optional[str] = None,
):
user = request.user.object
split_name = name.split(" ")
if len(split_name) > 2:
raise HTTPException(status_code=400, detail="Name must be in the format: Firstname Lastname")
if len(split_name) == 1:
first_name = split_name[0]
last_name = ""
else:
first_name, last_name = split_name[0], split_name[-1]
adapters.set_user_name(user, first_name, last_name)
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_user_name",
client=client,
)
return {"status": "ok"}
@api_config.get("/types", response_model=List[str])
@requires(["authenticated"])
def get_config_types(
request: Request,
):
user = request.user.object
enabled_file_types = EntryAdapters.get_unique_file_types(user)
configured_content_types = list(enabled_file_types)
if state.config and state.config.content_type:
for ctype in state.config.content_type.model_dump(exclude_none=True):
configured_content_types.append(ctype)
return [
search_type.value
for search_type in SearchType
if (search_type.value in configured_content_types) or search_type == SearchType.All
]

View File

@@ -0,0 +1,508 @@
import asyncio
import json
import logging
import math
from typing import Dict, List, Optional, Union
from asgiref.sync import sync_to_async
from fastapi import (
APIRouter,
Depends,
Header,
HTTPException,
Request,
Response,
UploadFile,
)
from pydantic import BaseModel
from starlette.authentication import requires
from khoj.database import adapters
from khoj.database.adapters import (
EntryAdapters,
get_user_github_config,
get_user_notion_config,
)
from khoj.database.models import Entry as DbEntry
from khoj.database.models import (
GithubConfig,
GithubRepoConfig,
KhojUser,
LocalMarkdownConfig,
LocalOrgConfig,
LocalPdfConfig,
LocalPlaintextConfig,
NotionConfig,
)
from khoj.routers.helpers import (
ApiIndexedDataLimiter,
CommonQueryParams,
configure_content,
get_user_config,
update_telemetry_state,
)
from khoj.utils import constants, state
from khoj.utils.config import SearchModels
from khoj.utils.helpers import get_file_type
from khoj.utils.rawconfig import (
ContentConfig,
FullConfig,
GithubContentConfig,
NotionContentConfig,
SearchConfig,
)
from khoj.utils.state import SearchType
from khoj.utils.yaml import save_config_to_file_updated_state
logger = logging.getLogger(__name__)
api_content = APIRouter()
class File(BaseModel):
path: str
content: Union[str, bytes]
class IndexBatchRequest(BaseModel):
files: list[File]
class IndexerInput(BaseModel):
org: Optional[dict[str, str]] = None
markdown: Optional[dict[str, str]] = None
pdf: Optional[dict[str, bytes]] = None
plaintext: Optional[dict[str, str]] = None
image: Optional[dict[str, bytes]] = None
docx: Optional[dict[str, bytes]] = None
@api_content.put("")
@requires(["authenticated"])
async def put_content(
request: Request,
files: list[UploadFile],
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
ApiIndexedDataLimiter(
incoming_entries_size_limit=10,
subscribed_incoming_entries_size_limit=25,
total_entries_size_limit=10,
subscribed_total_entries_size_limit=100,
)
),
):
return await indexer(request, files, t, True, client, user_agent, referer, host)
@api_content.patch("")
@requires(["authenticated"])
async def patch_content(
request: Request,
files: list[UploadFile],
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
ApiIndexedDataLimiter(
incoming_entries_size_limit=10,
subscribed_incoming_entries_size_limit=25,
total_entries_size_limit=10,
subscribed_total_entries_size_limit=100,
)
),
):
return await indexer(request, files, t, False, client, user_agent, referer, host)
@api_content.get("/github", response_class=Response)
@requires(["authenticated"])
def get_content_github(request: Request) -> Response:
user = request.user.object
user_config = get_user_config(user, request)
del user_config["request"]
current_github_config = get_user_github_config(user)
if current_github_config:
raw_repos = current_github_config.githubrepoconfig.all()
repos = []
for repo in raw_repos:
repos.append(
GithubRepoConfig(
name=repo.name,
owner=repo.owner,
branch=repo.branch,
)
)
current_config = GithubContentConfig(
pat_token=current_github_config.pat_token,
repos=repos,
)
current_config = json.loads(current_config.json())
else:
current_config = {} # type: ignore
user_config["current_config"] = current_config
# Return config data as a JSON response
return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
@api_content.get("/notion", response_class=Response)
@requires(["authenticated"])
def get_content_notion(request: Request) -> Response:
user = request.user.object
user_config = get_user_config(user, request)
del user_config["request"]
current_notion_config = get_user_notion_config(user)
token = current_notion_config.token if current_notion_config else ""
current_config = NotionContentConfig(token=token)
current_config = json.loads(current_config.model_dump_json())
user_config["current_config"] = current_config
# Return config data as a JSON response
return Response(content=json.dumps(user_config), media_type="application/json", status_code=200)
@api_content.post("/github", status_code=200)
@requires(["authenticated"])
async def set_content_github(
request: Request,
updated_config: Union[GithubContentConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
try:
await adapters.set_user_github_config(
user=user,
pat_token=updated_config.pat_token,
repos=updated_config.repos,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_content_config",
client=client,
metadata={"content_type": "github"},
)
return {"status": "ok"}
@api_content.post("/notion", status_code=200)
@requires(["authenticated"])
async def set_content_notion(
request: Request,
updated_config: Union[NotionContentConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
try:
await adapters.set_notion_config(
user=user,
token=updated_config.token,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Notion config")
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_content_config",
client=client,
metadata={"content_type": "notion"},
)
return {"status": "ok"}
@api_content.delete("/{content_source}", status_code=200)
@requires(["authenticated"])
async def delete_content_source(
request: Request,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_content_config",
client=client,
metadata={"content_source": content_source},
)
content_object = map_config_to_object(content_source)
if content_object is None:
raise ValueError(f"Invalid content source: {content_source}")
elif content_object != "Computer":
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EntryAdapters.delete_all_entries)(user, file_source=content_source)
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
return {"status": "ok"}
@api_content.delete("/file", status_code=201)
@requires(["authenticated"])
async def delete_content_file(
request: Request,
filename: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_file",
client=client,
)
await EntryAdapters.adelete_entry_by_file(user, filename)
return {"status": "ok"}
@api_content.get("/size", response_model=Dict[str, int])
@requires(["authenticated"])
async def get_content_size(request: Request, common: CommonQueryParams, client: Optional[str] = None):
user = request.user.object
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
return Response(
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
media_type="application/json",
status_code=200,
)
@api_content.get("/types", response_model=List[str])
@requires(["authenticated"])
def get_content_types(request: Request, client: Optional[str] = None):
user = request.user.object
all_content_types = {s.value for s in SearchType}
configured_content_types = set(EntryAdapters.get_unique_file_types(user))
configured_content_types |= {"all"}
if state.config and state.config.content_type:
for ctype in state.config.content_type.model_dump(exclude_none=True):
configured_content_types.add(ctype)
return list(configured_content_types & all_content_types)
@api_content.get("/{content_source}", response_model=List[str])
@requires(["authenticated"])
async def get_content_source(
request: Request,
content_source: str,
client: Optional[str] = None,
):
user = request.user.object
update_telemetry_state(
request=request,
telemetry_type="api",
api="get_all_filenames",
client=client,
)
return await sync_to_async(list)(EntryAdapters.get_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
async def indexer(
request: Request,
files: list[UploadFile],
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
regenerate: bool = False,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
user = request.user.object
method = "regenerate" if regenerate else "sync"
index_files: Dict[str, Dict[str, str]] = {
"org": {},
"markdown": {},
"pdf": {},
"plaintext": {},
"image": {},
"docx": {},
}
try:
logger.info(f"📬 Updating content index via API call by {client} client")
for file in files:
file_content = file.file.read()
file_type, encoding = get_file_type(file.content_type, file_content)
if file_type in index_files:
index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content
else:
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
indexer_input = IndexerInput(
org=index_files["org"],
markdown=index_files["markdown"],
pdf=index_files["pdf"],
plaintext=index_files["plaintext"],
image=index_files["image"],
docx=index_files["docx"],
)
if state.config == None:
logger.info("📬 Initializing content index on first run.")
default_full_config = FullConfig(
content_type=None,
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
processor=None,
)
state.config = default_full_config
default_content_config = ContentConfig(
org=None,
markdown=None,
pdf=None,
docx=None,
image=None,
github=None,
notion=None,
plaintext=None,
)
state.config.content_type = default_content_config
save_config_to_file_updated_state()
configure_search(state.search_models, state.config.search_type)
loop = asyncio.get_event_loop()
success = await loop.run_in_executor(
None,
configure_content,
indexer_input.model_dump(),
regenerate,
t,
user,
)
if not success:
raise RuntimeError(f"Failed to {method} {t} data sent by {client} client into content index")
logger.info(f"Finished {method} {t} data sent by {client} client into content index")
except Exception as e:
logger.error(f"Failed to {method} {t} data sent by {client} client into content index: {e}", exc_info=True)
logger.error(
f"🚨 Failed to {method} {t} data sent by {client} client into content index: {e}",
exc_info=True,
)
return Response(content="Failed", status_code=500)
indexing_metadata = {
"num_org": len(index_files["org"]),
"num_markdown": len(index_files["markdown"]),
"num_pdf": len(index_files["pdf"]),
"num_plaintext": len(index_files["plaintext"]),
"num_image": len(index_files["image"]),
"num_docx": len(index_files["docx"]),
}
update_telemetry_state(
request=request,
telemetry_type="api",
api="index/update",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
metadata=indexing_metadata,
)
logger.info(f"📪 Content index updated via API call by {client} client")
indexed_filenames = ",".join(file for ctype in index_files for file in index_files[ctype]) or ""
return Response(content=indexed_filenames, status_code=200)
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
# Run Validation Checks
if search_models is None:
search_models = SearchModels()
return search_models
def map_config_to_object(content_source: str):
if content_source == DbEntry.EntrySource.GITHUB:
return GithubConfig
if content_source == DbEntry.EntrySource.NOTION:
return NotionConfig
if content_source == DbEntry.EntrySource.COMPUTER:
return "Computer"
async def map_config_to_db(config: FullConfig, user: KhojUser):
if config.content_type:
if config.content_type.org:
await LocalOrgConfig.objects.filter(user=user).adelete()
await LocalOrgConfig.objects.acreate(
input_files=config.content_type.org.input_files,
input_filter=config.content_type.org.input_filter,
index_heading_entries=config.content_type.org.index_heading_entries,
user=user,
)
if config.content_type.markdown:
await LocalMarkdownConfig.objects.filter(user=user).adelete()
await LocalMarkdownConfig.objects.acreate(
input_files=config.content_type.markdown.input_files,
input_filter=config.content_type.markdown.input_filter,
index_heading_entries=config.content_type.markdown.index_heading_entries,
user=user,
)
if config.content_type.pdf:
await LocalPdfConfig.objects.filter(user=user).adelete()
await LocalPdfConfig.objects.acreate(
input_files=config.content_type.pdf.input_files,
input_filter=config.content_type.pdf.input_filter,
index_heading_entries=config.content_type.pdf.index_heading_entries,
user=user,
)
if config.content_type.plaintext:
await LocalPlaintextConfig.objects.filter(user=user).adelete()
await LocalPlaintextConfig.objects.acreate(
input_files=config.content_type.plaintext.input_files,
input_filter=config.content_type.plaintext.input_filter,
index_heading_entries=config.content_type.plaintext.index_heading_entries,
user=user,
)
if config.content_type.github:
await adapters.set_user_github_config(
user=user,
pat_token=config.content_type.github.pat_token,
repos=config.content_type.github.repos,
)
if config.content_type.notion:
await adapters.set_notion_config(
user=user,
token=config.content_type.notion.token,
)
def _initialize_config():
if state.config is None:
state.config = FullConfig()
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])

View File

@@ -0,0 +1,156 @@
import json
import logging
from typing import Dict, Optional, Union
from fastapi import APIRouter, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.routers.helpers import update_telemetry_state
api_model = APIRouter()
logger = logging.getLogger(__name__)
@api_model.get("/chat/options", response_model=Dict[str, Union[str, int]])
def get_chat_model_options(
request: Request,
client: Optional[str] = None,
):
conversation_options = ConversationAdapters.get_conversation_processor_options().all()
all_conversation_options = list()
for conversation_option in conversation_options:
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
@api_model.get("/chat")
@requires(["authenticated"])
def get_user_chat_model(
request: Request,
client: Optional[str] = None,
):
user = request.user.object
chat_model = ConversationAdapters.get_conversation_config(user)
if chat_model is None:
chat_model = ConversationAdapters.get_default_conversation_config()
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
@api_model.post("/chat", status_code=200)
@requires(["authenticated", "premium"])
async def update_chat_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_conversation_chat_model",
client=client,
metadata={"processor_conversation_type": "conversation"},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}
@api_model.post("/voice", status_code=200)
@requires(["authenticated", "premium"])
async def update_voice_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
new_config = await ConversationAdapters.aset_user_voice_model(user, id)
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_voice_model",
client=client,
)
if new_config is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Model not found"}))
return Response(status_code=202, content=json.dumps({"status": "ok"}))
@api_model.post("/search", status_code=200)
@requires(["authenticated"])
async def update_search_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
prev_config = await adapters.aget_user_search_model(user)
new_config = await adapters.aset_user_search_model(user, int(id))
if prev_config and int(id) != prev_config.id and new_config:
await EntryAdapters.adelete_all_entries(user)
if not prev_config:
# If the use was just using the default config, delete all the entries and set the new config.
await EntryAdapters.adelete_all_entries(user)
if new_config is None:
return {"status": "error", "message": "Model not found"}
else:
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_search_model",
client=client,
metadata={"search_model": new_config.setting.name},
)
return {"status": "ok"}
@api_model.post("/paint", status_code=200)
@requires(["authenticated"])
async def update_paint_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_paint_model",
client=client,
metadata={"paint_model": new_config.setting.model_name},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}

View File

@@ -1313,7 +1313,6 @@ def configure_content(
files: Optional[dict[str, dict[str, str]]], files: Optional[dict[str, dict[str, str]]],
regenerate: bool = False, regenerate: bool = False,
t: Optional[state.SearchType] = state.SearchType.All, t: Optional[state.SearchType] = state.SearchType.All,
full_corpus: bool = True,
user: KhojUser = None, user: KhojUser = None,
) -> bool: ) -> bool:
success = True success = True
@@ -1344,7 +1343,6 @@ def configure_content(
OrgToEntries, OrgToEntries,
files.get("org"), files.get("org"),
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
) )
except Exception as e: except Exception as e:
@@ -1362,7 +1360,6 @@ def configure_content(
MarkdownToEntries, MarkdownToEntries,
files.get("markdown"), files.get("markdown"),
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
) )
@@ -1379,7 +1376,6 @@ def configure_content(
PdfToEntries, PdfToEntries,
files.get("pdf"), files.get("pdf"),
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
) )
@@ -1398,7 +1394,6 @@ def configure_content(
PlaintextToEntries, PlaintextToEntries,
files.get("plaintext"), files.get("plaintext"),
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
) )
@@ -1418,7 +1413,6 @@ def configure_content(
GithubToEntries, GithubToEntries,
None, None,
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
config=github_config, config=github_config,
) )
@@ -1439,7 +1433,6 @@ def configure_content(
NotionToEntries, NotionToEntries,
None, None,
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
config=notion_config, config=notion_config,
) )
@@ -1459,7 +1452,6 @@ def configure_content(
ImageToEntries, ImageToEntries,
files.get("image"), files.get("image"),
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
) )
except Exception as e: except Exception as e:
@@ -1472,7 +1464,6 @@ def configure_content(
DocxToEntries, DocxToEntries,
files.get("docx"), files.get("docx"),
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
) )
except Exception as e: except Exception as e:

View File

@@ -1,166 +0,0 @@
import asyncio
import logging
from typing import Dict, Optional, Union
from fastapi import APIRouter, Depends, Header, Request, Response, UploadFile
from pydantic import BaseModel
from starlette.authentication import requires
from khoj.routers.helpers import (
ApiIndexedDataLimiter,
configure_content,
update_telemetry_state,
)
from khoj.utils import constants, state
from khoj.utils.config import SearchModels
from khoj.utils.helpers import get_file_type
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
from khoj.utils.yaml import save_config_to_file_updated_state
logger = logging.getLogger(__name__)
indexer = APIRouter()
class File(BaseModel):
path: str
content: Union[str, bytes]
class IndexBatchRequest(BaseModel):
files: list[File]
class IndexerInput(BaseModel):
org: Optional[dict[str, str]] = None
markdown: Optional[dict[str, str]] = None
pdf: Optional[dict[str, bytes]] = None
plaintext: Optional[dict[str, str]] = None
image: Optional[dict[str, bytes]] = None
docx: Optional[dict[str, bytes]] = None
@indexer.post("/update")
@requires(["authenticated"])
async def update(
request: Request,
files: list[UploadFile],
force: bool = False,
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
ApiIndexedDataLimiter(
incoming_entries_size_limit=10,
subscribed_incoming_entries_size_limit=25,
total_entries_size_limit=10,
subscribed_total_entries_size_limit=100,
)
),
):
user = request.user.object
index_files: Dict[str, Dict[str, str]] = {
"org": {},
"markdown": {},
"pdf": {},
"plaintext": {},
"image": {},
"docx": {},
}
try:
logger.info(f"📬 Updating content index via API call by {client} client")
for file in files:
file_content = file.file.read()
file_type, encoding = get_file_type(file.content_type, file_content)
if file_type in index_files:
index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content
else:
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
indexer_input = IndexerInput(
org=index_files["org"],
markdown=index_files["markdown"],
pdf=index_files["pdf"],
plaintext=index_files["plaintext"],
image=index_files["image"],
docx=index_files["docx"],
)
if state.config == None:
logger.info("📬 Initializing content index on first run.")
default_full_config = FullConfig(
content_type=None,
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
processor=None,
)
state.config = default_full_config
default_content_config = ContentConfig(
org=None,
markdown=None,
pdf=None,
docx=None,
image=None,
github=None,
notion=None,
plaintext=None,
)
state.config.content_type = default_content_config
save_config_to_file_updated_state()
configure_search(state.search_models, state.config.search_type)
# Extract required fields from config
loop = asyncio.get_event_loop()
success = await loop.run_in_executor(
None,
configure_content,
indexer_input.model_dump(),
force,
t,
False,
user,
)
if not success:
raise RuntimeError("Failed to update content index")
logger.info(f"Finished processing batch indexing request")
except Exception as e:
logger.error(f"Failed to process batch indexing request: {e}", exc_info=True)
logger.error(
f'🚨 Failed to {"force " if force else ""}update {t} content index triggered via API call by {client} client: {e}',
exc_info=True,
)
return Response(content="Failed", status_code=500)
indexing_metadata = {
"num_org": len(index_files["org"]),
"num_markdown": len(index_files["markdown"]),
"num_pdf": len(index_files["pdf"]),
"num_plaintext": len(index_files["plaintext"]),
"num_image": len(index_files["image"]),
"num_docx": len(index_files["docx"]),
}
update_telemetry_state(
request=request,
telemetry_type="api",
api="index/update",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
metadata=indexing_metadata,
)
logger.info(f"📪 Content index updated via API call by {client} client")
indexed_filenames = ",".join(file for ctype in index_files for file in index_files[ctype]) or ""
return Response(content=indexed_filenames, status_code=200)
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
# Run Validation Checks
if search_models is None:
search_models = SearchModels()
return search_models

View File

@@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas
notion_redirect = str(request.app.url_path_for("notion_config_page")) notion_redirect = str(request.app.url_path_for("notion_config_page"))
# Trigger an async job to configure_content. Let it run without blocking the response. # Trigger an async job to configure_content. Let it run without blocking the response.
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, True, user) background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user)
return RedirectResponse(notion_redirect) return RedirectResponse(notion_redirect)

View File

@@ -199,17 +199,16 @@ def setup(
text_to_entries: Type[TextToEntries], text_to_entries: Type[TextToEntries],
files: dict[str, str], files: dict[str, str],
regenerate: bool, regenerate: bool,
full_corpus: bool = True,
user: KhojUser = None, user: KhojUser = None,
config=None, config=None,
) -> None: ) -> Tuple[int, int]:
if config: if config:
num_new_embeddings, num_deleted_embeddings = text_to_entries(config).process( num_new_embeddings, num_deleted_embeddings = text_to_entries(config).process(
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate files=files, user=user, regenerate=regenerate
) )
else: else:
num_new_embeddings, num_deleted_embeddings = text_to_entries().process( num_new_embeddings, num_deleted_embeddings = text_to_entries().process(
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate files=files, user=user, regenerate=regenerate
) )
if files: if files:
@@ -219,6 +218,8 @@ def setup(
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names[:10]} ..." f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names[:10]} ..."
) )
return num_new_embeddings, num_deleted_embeddings
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]: def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
"""Score all retrieved entries using the cross-encoder""" """Score all retrieved entries using the cross-encoder"""

View File

@@ -124,7 +124,7 @@ def get_org_files(config: TextContentConfig):
logger.debug("At least one of org-files or org-file-filter is required to be specified") logger.debug("At least one of org-files or org-file-filter is required to be specified")
return {} return {}
"Get Org files to process" # Get Org files to process
absolute_org_files, filtered_org_files = set(), set() absolute_org_files, filtered_org_files = set(), set()
if org_files: if org_files:
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files} absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}

View File

@@ -25,7 +25,7 @@ from khoj.database.models import (
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content from khoj.routers.api_content import configure_content
from khoj.search_type import text_search from khoj.search_type import text_search
from khoj.utils import fs_syncer, state from khoj.utils import fs_syncer, state
from khoj.utils.config import SearchModels from khoj.utils.config import SearchModels

View File

@@ -75,7 +75,7 @@ def test_index_update_with_no_auth_key(client):
files = get_sample_files_data() files = get_sample_files_data()
# Act # Act
response = client.post("/api/v1/index/update", files=files) response = client.patch("/api/content", files=files)
# Assert # Assert
assert response.status_code == 403 assert response.status_code == 403
@@ -89,7 +89,7 @@ def test_index_update_with_invalid_auth_key(client):
headers = {"Authorization": "Bearer kk-invalid-token"} headers = {"Authorization": "Bearer kk-invalid-token"}
# Act # Act
response = client.post("/api/v1/index/update", files=files, headers=headers) response = client.patch("/api/content", files=files, headers=headers)
# Assert # Assert
assert response.status_code == 403 assert response.status_code == 403
@@ -130,7 +130,7 @@ def test_index_update_big_files(client):
headers = {"Authorization": "Bearer kk-secret"} headers = {"Authorization": "Bearer kk-secret"}
# Act # Act
response = client.post("/api/v1/index/update", files=files, headers=headers) response = client.patch("/api/content", files=files, headers=headers)
# Assert # Assert
assert response.status_code == 429 assert response.status_code == 429
@@ -146,7 +146,7 @@ def test_index_update_medium_file_unsubscribed(client, api_user4: KhojApiUser):
headers = {"Authorization": f"Bearer {api_token}"} headers = {"Authorization": f"Bearer {api_token}"}
# Act # Act
response = client.post("/api/v1/index/update", files=files, headers=headers) response = client.patch("/api/content", files=files, headers=headers)
# Assert # Assert
assert response.status_code == 429 assert response.status_code == 429
@@ -162,7 +162,7 @@ def test_index_update_normal_file_unsubscribed(client, api_user4: KhojApiUser):
headers = {"Authorization": f"Bearer {api_token}"} headers = {"Authorization": f"Bearer {api_token}"}
# Act # Act
response = client.post("/api/v1/index/update", files=files, headers=headers) response = client.patch("/api/content", files=files, headers=headers)
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@@ -177,7 +177,7 @@ def test_index_update_big_files_no_billing(client):
headers = {"Authorization": "Bearer kk-secret"} headers = {"Authorization": "Bearer kk-secret"}
# Act # Act
response = client.post("/api/v1/index/update", files=files, headers=headers) response = client.patch("/api/content", files=files, headers=headers)
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@@ -191,7 +191,7 @@ def test_index_update(client):
headers = {"Authorization": "Bearer kk-secret"} headers = {"Authorization": "Bearer kk-secret"}
# Act # Act
response = client.post("/api/v1/index/update", files=files, headers=headers) response = client.patch("/api/content", files=files, headers=headers)
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@@ -208,8 +208,8 @@ def test_index_update_fails_if_more_than_1000_files(client, api_user4: KhojApiUs
headers = {"Authorization": f"Bearer {api_token}"} headers = {"Authorization": f"Bearer {api_token}"}
# Act # Act
ok_response = client.post("/api/v1/index/update", files=files[:1000], headers=headers) ok_response = client.patch("/api/content", files=files[:1000], headers=headers)
bad_response = client.post("/api/v1/index/update", files=files, headers=headers) bad_response = client.patch("/api/content", files=files, headers=headers)
# Assert # Assert
assert ok_response.status_code == 200 assert ok_response.status_code == 200
@@ -226,7 +226,7 @@ def test_regenerate_with_valid_content_type(client):
headers = {"Authorization": "Bearer kk-secret"} headers = {"Authorization": "Bearer kk-secret"}
# Act # Act
response = client.post(f"/api/v1/index/update?t={content_type}", files=files, headers=headers) response = client.patch(f"/api/content?t={content_type}", files=files, headers=headers)
# Assert # Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}" assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
@@ -243,7 +243,7 @@ def test_regenerate_with_github_fails_without_pat(client):
files = get_sample_files_data() files = get_sample_files_data()
# Act # Act
response = client.post(f"/api/v1/index/update?t=github", files=files, headers=headers) response = client.patch(f"/api/content?t=github", files=files, headers=headers)
# Assert # Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github" assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github"
@@ -269,11 +269,11 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user) text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
# Act # Act
response = client.get(f"/api/configure/types", headers=headers) response = client.get(f"/api/content/types", headers=headers)
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == ["all", "org", "plaintext"] assert set(response.json()) == {"all", "org", "plaintext"}
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@@ -289,7 +289,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
client = TestClient(fastapi_app) client = TestClient(fastapi_app)
# Act # Act
response = client.get(f"/api/configure/types") response = client.get(f"/api/content/types")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -29,7 +29,7 @@ def test_index_update_with_user2(client, api_user2: KhojApiUser):
source_file_symbol = set([f[1][0] for f in files]) source_file_symbol = set([f[1][0] for f in files])
headers = {"Authorization": f"Bearer {api_user2.token}"} headers = {"Authorization": f"Bearer {api_user2.token}"}
update_response = client.post("/api/v1/index/update", files=files, headers=headers) update_response = client.patch("/api/content", files=files, headers=headers)
search_response = client.get("/api/search?q=hardware&t=all", headers=headers) search_response = client.get("/api/search?q=hardware&t=all", headers=headers)
results = search_response.json() results = search_response.json()
@@ -47,7 +47,7 @@ def test_index_update_with_user2_inaccessible_user1(client, api_user2: KhojApiUs
source_file_symbol = set([f[1][0] for f in files]) source_file_symbol = set([f[1][0] for f in files])
headers = {"Authorization": f"Bearer {api_user2.token}"} headers = {"Authorization": f"Bearer {api_user2.token}"}
update_response = client.post("/api/v1/index/update", files=files, headers=headers) update_response = client.patch("/api/content", files=files, headers=headers)
# Act # Act
headers = {"Authorization": f"Bearer {api_user.token}"} headers = {"Authorization": f"Bearer {api_user.token}"}

View File

@@ -6,9 +6,16 @@ from pathlib import Path
import pytest import pytest
from khoj.database.adapters import EntryAdapters
from khoj.database.models import Entry, GithubConfig, KhojUser, LocalOrgConfig from khoj.database.models import Entry, GithubConfig, KhojUser, LocalOrgConfig
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
from khoj.processor.content.github.github_to_entries import GithubToEntries from khoj.processor.content.github.github_to_entries import GithubToEntries
from khoj.processor.content.images.image_to_entries import ImageToEntries
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.search_type import text_search from khoj.search_type import text_search
from khoj.utils.fs_syncer import collect_files, get_org_files from khoj.utils.fs_syncer import collect_files, get_org_files
from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.utils.rawconfig import ContentConfig, SearchConfig
@@ -151,7 +158,6 @@ async def test_text_search(search_config: SearchConfig):
OrgToEntries, OrgToEntries,
data, data,
True, True,
True,
default_user, default_user,
) )
@@ -240,7 +246,6 @@ conda activate khoj
OrgToEntries, OrgToEntries,
data, data,
regenerate=False, regenerate=False,
full_corpus=False,
user=default_user, user=default_user,
) )
@@ -396,6 +401,49 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
verify_embeddings(3, default_user) verify_embeddings(3, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
@pytest.mark.parametrize(
"text_to_entries",
[
(OrgToEntries),
],
)
def test_update_index_with_deleted_file(
org_config_with_only_new_file: LocalOrgConfig, text_to_entries: TextToEntries, default_user: KhojUser
):
"Delete entries associated with new file when file path with empty content passed."
# Arrange
file_to_index = "test"
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
initial_data = {file_to_index: new_entry}
final_data = {file_to_index: ""}
# Act
# load entries after adding file
initial_added_entries, _ = text_search.setup(text_to_entries, initial_data, regenerate=True, user=default_user)
initial_total_entries = EntryAdapters.get_existing_entry_hashes_by_file(default_user, file_to_index).count()
# load entries after deleting file
final_added_entries, final_deleted_entries = text_search.setup(
text_to_entries, final_data, regenerate=False, user=default_user
)
final_total_entries = EntryAdapters.get_existing_entry_hashes_by_file(default_user, file_to_index).count()
# Assert
assert initial_total_entries > 0, "File entries not indexed"
assert initial_added_entries > 0, "No entries got added"
assert final_total_entries == 0, "File did not get deleted"
assert final_added_entries == 0, "Entries were unexpectedly added in delete entries pass"
assert final_deleted_entries == initial_added_entries, "All added entries were not deleted"
verify_embeddings(0, default_user), "Embeddings still exist for user"
# Clean up
EntryAdapters.delete_all_entries(default_user)
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set") @pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser): def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):