Merge branch 'master' of github.com:khoj-ai/khoj into features/chat-ui-updates-big

This commit is contained in:
sabaimran
2024-07-08 17:00:42 +05:30
60 changed files with 1419 additions and 492 deletions

View File

@@ -15,6 +15,7 @@ Khoj will keep these files in sync to provide contextual responses when you sear
- **Faster answers**: Find answers quickly, from your private notes or the public internet - **Faster answers**: Find answers quickly, from your private notes or the public internet
- **Assisted creativity**: Smoothly weave across retrieving answers and generating content - **Assisted creativity**: Smoothly weave across retrieving answers and generating content
- **Iterative discovery**: Iteratively explore and re-discover your notes - **Iterative discovery**: Iteratively explore and re-discover your notes
- **Quick access**: Use [Khoj Mini](/features/khoj_mini) on the desktop to quickly pull up a mini chat module for quicker answers
- **Search** - **Search**
- **Natural**: Advanced natural language understanding using Transformer based ML Models - **Natural**: Advanced natural language understanding using Transformer based ML Models
- **Incremental**: Incremental search for a fast, search-as-you-type experience - **Incremental**: Incremental search for a fast, search-as-you-type experience

View File

@@ -0,0 +1,9 @@
# Desktop Quick Chat (Khoj Mini)
Once you have the Khoj [desktop application](https://khoj.dev/downloads) installed, you can use the desktop shortcut to quickly pull up a mini chat module for quicker answers. See the desktop setup instructions [in the docs](/clients/desktop.md) for more information.
To use it, you just have to copy the text you want to inject into your query, and then run `Ctrl + Shift + K` (or `Cmd + Shift + K` on Mac) to open the mini chat module. The text you copied will be automatically pasted into the chat module, and you can then hit enter to get the answer. You can edit the text before hitting enter if you want to refine your query.
The desktop shortcut is a great way to quickly get answers to your questions without having to switch between windows or tabs. It's especially useful when you're working on a project and need to quickly look up something without losing your focus.
![Desktop Shortcut](https://assets.khoj.dev/courseload_decision_dekstop.gif)

View File

@@ -1,17 +1,21 @@
# Online Search # Online Search
By default, Khoj will try to infer which information-sourcing tools are required to answer your question. Sometimes, you'll have a need for outside questions that the LLM's knowledge doesn't cover. In that case, it will use the `online` search feature. Khoj will research on the internet to ground its responses, when it determines that it would need fresh information outside its existing knowledge to answer the query. It will always show any online references it used to respond to your requests.
For example, these queries would trigger an online search: By default, Khoj will try to infer which information sources, it needs to read to answer your question. This can include reading your documents or researching information online. You can also explicitly trigger an online search by adding the `/online` prefix to your chat query.
Example queries that should trigger an online search:
- What's the latest news about the Israel-Palestine war? - What's the latest news about the Israel-Palestine war?
- Where can I find the best pizza in New York City? - Where can I find the best pizza in New York City?
- Deadline for filing taxes 2024. - /online Deadline for filing taxes 2024.
- Give me a summary of this article: https://en.wikipedia.org/wiki/Haitian_Revolution - Give me a summary of this article: https://en.wikipedia.org/wiki/Haitian_Revolution
Try it out yourself! https://app.khoj.dev Try it out yourself! https://app.khoj.dev
## Self-Hosting ## Self-Hosting
The general online search function currently requires an API key from Serper.dev. You can grab one here: https://serper.dev/, and then add it as an environment variable with the name `SERPER_DEV_API_KEY`. Online search works out of the box even when self-hosting. Khoj uses [JinaAI's reader API](https://jina.ai/reader/) to search online and read webpages by default. No API key setup is necessary.
Without any API keys, Khoj will use the `requests` library to directly read any webpages you give it a link to. This means that you can use Khoj to read any webpage that you have access in your local network. To improve online search, set the `SERPER_DEV_API_KEY` environment variable to your [Serper.dev](https://serper.dev/) API key. These search results include additional context like answer box, knowledge graph etc.
For advanced webpage reading, set the `OLOSTEP_API_KEY` environment variable to your [Olostep](https://www.olostep.com/) API key. This has a higher success rate at reading webpages than the default webpage reader.

View File

@@ -1,51 +0,0 @@
---
sidebar_position: 2
---
# Demos
Check out a couple of demos and screenshots of Khoj in action.
### Screenshots
| Web | Obsidian | Emacs |
|:---:|:--------:|:-----:|
| ![](/img/khoj_search_on_web.png ':size=300px') | ![](/img/khoj_search_on_obsidian.png ':size=300px') | ![](/img/khoj_search_on_emacs.png ':size=300px') |
| ![](/img/khoj_chat_on_web.png ':size=300px') | ![](/img/khoj_chat_on_obsidian.png ':size=300px') | ![](/img/khoj_chat_on_emacs.png ':size=400px') |
### Videos
#### Khoj in Obsidian
[Link to Video](https://github-production-user-asset-6210df.s3.amazonaws.com/6413477/240061700-3e33d8ea-25bb-46c8-a3bf-c92f78d0f56b.mp4)
##### Installation
1. Install Khoj via `pip` and start Khoj backend in a terminal (Run `khoj`)
```bash
python -m pip install khoj-assistant
khoj
```
2. Install Khoj plugin via Community Plugins settings pane on Obsidian app
- Check the new Khoj plugin settings
- Let Khoj backend index the markdown, pdf, Github markdown files in the current Vault
- Open Khoj plugin on Obsidian via Search button on Left Pane
- Search \"*Announce plugin to folks*\" in the [Obsidian Plugin docs](https://marcus.se.net/obsidian-plugin-docs/)
- Jump to the [search result](https://marcus.se.net/obsidian-plugin-docs/publishing/submit-your-plugin)
#### Khoj in Emacs, Browser
[Link to Video](https://user-images.githubusercontent.com/6413477/184735169-92c78bf1-d827-4663-9087-a1ea194b8f4b.mp4)
##### Installation
- Install Khoj via pip
- Start Khoj app
- Add this readme and [khoj.el readme](https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs) as org-mode for Khoj to index
- Search \"*Setup editor*\" on the Web and Emacs. Re-rank the results for better accuracy
- Top result is what we are looking for, the [section to Install Khoj.el on Emacs](https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs#2-Install-Khojel)
##### Analysis
- The results do not have any words used in the query
- *Based on the top result it seems the re-ranking model understands that Emacs is an editor?*
- The results incrementally update as the query is entered
- The results are re-ranked, for better accuracy, once user hits enter

View File

@@ -27,10 +27,10 @@ keywords: ["khoj", "khoj ai", "khoj docs", "khoj documentation", "khoj features"
Welcome to the Khoj Docs! This is the best place to get setup and explore Khoj's features. Welcome to the Khoj Docs! This is the best place to get setup and explore Khoj's features.
- Khoj is an open source, personal AI - Khoj is an open source, personal AI
- You can [chat](/features/chat) with it about anything. It'll use files you shared with it to respond, when relevant - You can [chat](/features/chat) with it about anything. It'll use files you shared with it to respond, when relevant. It can also access information from the public internet.
- Quickly [find](/features/search) relevant notes and documents using natural language - Quickly [find](/features/search) relevant notes and documents using natural language
- It understands pdf, plaintext, markdown, org-mode files, [notion pages](/data-sources/notion_integration) and [github repositories](/data-sources/github_integration) - It understands pdf, plaintext, markdown, org-mode files, [notion pages](/data-sources/notion_integration) and [github repositories](/data-sources/github_integration)
- Access it from your [Emacs](/clients/emacs), [Obsidian](/clients/obsidian), [Web browser](/clients/web) or the [Khoj Desktop app](/clients/desktop) - Access it from your [Emacs](/clients/emacs), [Obsidian](/clients/obsidian), the [Khoj desktop app](/clients/desktop), or [any web browser](/clients/web)
- Use [cloud](https://app.khoj.dev/login) to access your Khoj anytime from anywhere, [self-host](/get-started/setup) on consumer hardware for privacy - Use [cloud](https://app.khoj.dev/login) to access your Khoj anytime from anywhere, [self-host](/get-started/setup) on consumer hardware for privacy
## Quickstart ## Quickstart
@@ -39,13 +39,3 @@ Welcome to the Khoj Docs! This is the best place to get setup and explore Khoj's
## At a Glance ## At a Glance
![demo_chat](https://assets.khoj.dev/using_khoj_for_studying.gif) ![demo_chat](https://assets.khoj.dev/using_khoj_for_studying.gif)
#### [Search](/features/search)
- **Natural**: Use natural language queries to quickly find relevant notes and documents.
- **Incremental**: Incremental search for a fast, search-as-you-type experience
#### [Chat](/features/chat)
- **Faster answers**: Find answers faster, smoother than search. No need to manually scan through your notes to find answers.
- **Iterative discovery**: Iteratively explore and (re-)discover your notes
- **Assisted creativity**: Smoothly weave across answers retrieval and content generation
- **Online or Offline**: Choose online or offline chat depending on your requirements

View File

@@ -22,6 +22,8 @@ Self-hosting isn't for everyone, so we've still taken steps to make Khoj privacy
1. Your embeddings and the associated raw text are stored in a secure Postgres DB in our private AWS cloud. Your data is sharded on a unique user ID. We store the raw text in your files to improve file syncing and provide context when you chat with Khoj. 1. Your embeddings and the associated raw text are stored in a secure Postgres DB in our private AWS cloud. Your data is sharded on a unique user ID. We store the raw text in your files to improve file syncing and provide context when you chat with Khoj.
1. When you use the single-sign-on option with Google, we only receive your name, a link to your profile photo, and your email address. 1. When you use the single-sign-on option with Google, we only receive your name, a link to your profile photo, and your email address.
You can see our full privacy policy [here](https://khoj.dev/privacy-policy).
:::tip[Info] :::tip[Info]
Your data is yours. We do not sell your data or use it for training models. Khoj is a sustainable, open-source alternative to closed-source, commercial personal AI. We have no interest in selling your data to make a quick buck. Your data is yours. We do not sell your data or use it for training models. Khoj is a sustainable, open-source alternative to closed-source, commercial personal AI. We have no interest in selling your data to make a quick buck.

View File

@@ -210,7 +210,7 @@ Add a `ServerChatSettings` with `Default` and `Summarizer` fields set to your pr
##### Configure OpenAI Chat ##### Configure OpenAI Chat
:::info[Ollama Integration] :::info[Ollama Integration]
Using Ollama? See the [Ollama Integration](/advanced/use-openai-proxy#ollama) section for more custom setup instructions. Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more custom setup instructions.
::: :::
1. Go to the [OpenAI settings](http://localhost:42110/server/admin/database/openaiprocessorconversationconfig/) in the server admin settings to add an OpenAI processor conversation config. This is where you set your API key and server API base URL. The API base URL is optional - it's only relevant if you're using another OpenAI-compatible proxy server. 1. Go to the [OpenAI settings](http://localhost:42110/server/admin/database/openaiprocessorconversationconfig/) in the server admin settings to add an OpenAI processor conversation config. This is where you set your API key and server API base URL. The API base URL is optional - it's only relevant if you're using another OpenAI-compatible proxy server.
@@ -227,11 +227,9 @@ Any chat model on Huggingface in GGUF format can be used for local chat. Here's
- The `tokenizer` and `max-prompt-size` fields are optional. You can set these for non-standard models (i.e not Mistral or Llama based models) or when you know the token limit of the model to improve context stuffing. - The `tokenizer` and `max-prompt-size` fields are optional. You can set these for non-standard models (i.e not Mistral or Llama based models) or when you know the token limit of the model to improve context stuffing.
#### Share your data #### Share your data
You can sync your files and folders with Khoj using the [Desktop](/get-started/setup#2-download-the-desktop-client), Obsidian, or Emacs clients or just drag and drop specific files on the Web client Here's how you can do it: You can sync your files and folders with Khoj using the [Desktop](/clients/desktop#setup), [Obsidian](/clients/obsidian#setup), or [Emacs](/clients/emacs#setup) clients or just drag and drop specific files on the [website](/clients/web#upload-documents). You can also directly sync your [Notion workspace](/data-sources/notion_integration).
1. Select files and folders to index [using the desktop client]. When you click 'Save', the files will be sent to your server for indexing.
- Select Notion workspaces and Github repositories to index using the web interface.
[^1]: Khoj, by default, can use [OpenAI GPT3.5+ chat models](https://platform.openai.com/docs/models/overview) or [GGUF chat models](https://huggingface.co/models?library=gguf). See [this section](/miscellaneous/advanced#use-openai-compatible-llm-api-server-self-hosting) on how to locally use OpenAI-format compatible proxy servers. [^1]: Khoj, by default, can use [OpenAI GPT3.5+ chat models](https://platform.openai.com/docs/models/overview) or [GGUF chat models](https://huggingface.co/models?library=gguf). See [this section](/advanced/use-openai-proxy) on how to locally use OpenAI-format compatible proxy servers.
### 3. Use Khoj 🚀 ### 3. Use Khoj 🚀

View File

@@ -1,7 +1,7 @@
{ {
"id": "khoj", "id": "khoj",
"name": "Khoj", "name": "Khoj",
"version": "1.15.0", "version": "1.16.0",
"minAppVersion": "0.15.0", "minAppVersion": "0.15.0",
"description": "An AI copilot for your Second Brain", "description": "An AI copilot for your Second Brain",
"author": "Khoj Inc.", "author": "Khoj Inc.",

View File

@@ -52,7 +52,8 @@ dependencies = [
"pyyaml ~= 6.0", "pyyaml ~= 6.0",
"rich >= 13.3.1", "rich >= 13.3.1",
"schedule == 1.1.0", "schedule == 1.1.0",
"sentence-transformers == 2.5.1", "sentence-transformers == 3.0.1",
"einops == 0.8.0",
"transformers >= 4.28.0", "transformers >= 4.28.0",
"torch == 2.2.2", "torch == 2.2.2",
"uvicorn == 0.17.6", "uvicorn == 0.17.6",

View File

@@ -18,6 +18,7 @@ do
# Bump Obsidian plugin to current version # Bump Obsidian plugin to current version
cd $project_root/src/interface/obsidian cd $project_root/src/interface/obsidian
yarn build # verify build before bumping version
yarn version --$version_type --no-git-tag-version yarn version --$version_type --no-git-tag-version
# append current version, min Obsidian app version from manifest to versions json # append current version, min Obsidian app version from manifest to versions json
cp $project_root/versions.json . cp $project_root/versions.json .

View File

@@ -19,7 +19,7 @@ const textFileTypes = [
'org', 'md', 'markdown', 'txt', 'html', 'xml', 'org', 'md', 'markdown', 'txt', 'html', 'xml',
// Other valid text file extensions from https://google.github.io/magika/model/config.json // Other valid text file extensions from https://google.github.io/magika/model/config.json
'appleplist', 'asm', 'asp', 'batch', 'c', 'cs', 'css', 'csv', 'eml', 'go', 'html', 'ini', 'internetshortcut', 'java', 'javascript', 'json', 'latex', 'lisp', 'makefile', 'markdown', 'mht', 'mum', 'pem', 'perl', 'php', 'powershell', 'python', 'rdf', 'rst', 'rtf', 'ruby', 'rust', 'scala', 'shell', 'smali', 'sql', 'svg', 'symlinktext', 'txt', 'vba', 'winregistry', 'xml', 'yaml'] 'appleplist', 'asm', 'asp', 'batch', 'c', 'cs', 'css', 'csv', 'eml', 'go', 'html', 'ini', 'internetshortcut', 'java', 'javascript', 'json', 'latex', 'lisp', 'makefile', 'markdown', 'mht', 'mum', 'pem', 'perl', 'php', 'powershell', 'python', 'rdf', 'rst', 'rtf', 'ruby', 'rust', 'scala', 'shell', 'smali', 'sql', 'svg', 'symlinktext', 'txt', 'vba', 'winregistry', 'xml', 'yaml']
const binaryFileTypes = ['pdf'] const binaryFileTypes = ['pdf', 'jpg', 'jpeg', 'png']
const validFileTypes = textFileTypes.concat(binaryFileTypes); const validFileTypes = textFileTypes.concat(binaryFileTypes);
const schema = { const schema = {

View File

@@ -1,6 +1,6 @@
{ {
"name": "Khoj", "name": "Khoj",
"version": "1.15.0", "version": "1.16.0",
"description": "An AI copilot for your Second Brain", "description": "An AI copilot for your Second Brain",
"author": "Saba Imran, Debanjum Singh Solanky <team@khoj.dev>", "author": "Saba Imran, Debanjum Singh Solanky <team@khoj.dev>",
"license": "GPL-3.0-or-later", "license": "GPL-3.0-or-later",

View File

@@ -6,7 +6,7 @@
;; Saba Imran <saba@khoj.dev> ;; Saba Imran <saba@khoj.dev>
;; Description: An AI copilot for your Second Brain ;; Description: An AI copilot for your Second Brain
;; Keywords: search, chat, org-mode, outlines, markdown, pdf, image ;; Keywords: search, chat, org-mode, outlines, markdown, pdf, image
;; Version: 1.15.0 ;; Version: 1.16.0
;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1")) ;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1"))
;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs ;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs

View File

@@ -1,7 +1,7 @@
{ {
"id": "khoj", "id": "khoj",
"name": "Khoj", "name": "Khoj",
"version": "1.15.0", "version": "1.16.0",
"minAppVersion": "0.15.0", "minAppVersion": "0.15.0",
"description": "An AI copilot for your Second Brain", "description": "An AI copilot for your Second Brain",
"author": "Khoj Inc.", "author": "Khoj Inc.",

View File

@@ -1,6 +1,6 @@
{ {
"name": "Khoj", "name": "Khoj",
"version": "1.15.0", "version": "1.16.0",
"description": "An AI copilot for your Second Brain", "description": "An AI copilot for your Second Brain",
"author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>", "author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>",
"license": "GPL-3.0-or-later", "license": "GPL-3.0-or-later",

View File

@@ -1,8 +1,9 @@
import { ItemView, MarkdownRenderer, WorkspaceLeaf, request, requestUrl, setIcon } from 'obsidian'; import { ItemView, MarkdownRenderer, Scope, WorkspaceLeaf, request, requestUrl, setIcon } from 'obsidian';
import * as DOMPurify from 'dompurify'; import * as DOMPurify from 'dompurify';
import { KhojSetting } from 'src/settings'; import { KhojSetting } from 'src/settings';
import { KhojPaneView } from 'src/pane_view'; import { KhojPaneView } from 'src/pane_view';
import { KhojView, createCopyParentText, getLinkToEntry, pasteTextAtCursor } from 'src/utils'; import { KhojView, createCopyParentText, getLinkToEntry, pasteTextAtCursor } from 'src/utils';
import { KhojSearchModal } from './search_modal';
export interface ChatJsonResult { export interface ChatJsonResult {
image?: string; image?: string;
@@ -24,10 +25,18 @@ export class KhojChatView extends KhojPaneView {
setting: KhojSetting; setting: KhojSetting;
waitingForLocation: boolean; waitingForLocation: boolean;
location: Location; location: Location;
keyPressTimeout: NodeJS.Timeout | null = null;
constructor(leaf: WorkspaceLeaf, setting: KhojSetting) { constructor(leaf: WorkspaceLeaf, setting: KhojSetting) {
super(leaf, setting); super(leaf, setting);
// Register chat view keybindings
this.scope = new Scope(this.app.scope);
this.scope.register(["Ctrl"], 'n', (_) => this.createNewConversation());
this.scope.register(["Ctrl"], 'o', async (_) => await this.toggleChatSessions());
this.scope.register(["Ctrl"], 'f', (_) => new KhojSearchModal(this.app, this.setting).open());
this.scope.register(["Ctrl"], 'r', (_) => new KhojSearchModal(this.app, this.setting, true).open());
this.waitingForLocation = true; this.waitingForLocation = true;
fetch("https://ipapi.co/json") fetch("https://ipapi.co/json")
@@ -61,8 +70,7 @@ export class KhojChatView extends KhojPaneView {
return "message-circle"; return "message-circle";
} }
async chat() { async chat(isVoice: boolean = false) {
// Get text in chat input element // Get text in chat input element
let input_el = <HTMLTextAreaElement>this.contentEl.getElementsByClassName("khoj-chat-input")[0]; let input_el = <HTMLTextAreaElement>this.contentEl.getElementsByClassName("khoj-chat-input")[0];
@@ -72,7 +80,7 @@ export class KhojChatView extends KhojPaneView {
this.autoResize(); this.autoResize();
// Get and render chat response to user message // Get and render chat response to user message
await this.getChatResponse(user_message); await this.getChatResponse(user_message, isVoice);
} }
async onOpen() { async onOpen() {
@@ -92,8 +100,9 @@ export class KhojChatView extends KhojPaneView {
const objectSrc = `object-src 'none';`; const objectSrc = `object-src 'none';`;
const csp = `${defaultSrc} ${scriptSrc} ${connectSrc} ${styleSrc} ${imgSrc} ${childSrc} ${objectSrc}`; const csp = `${defaultSrc} ${scriptSrc} ${connectSrc} ${styleSrc} ${imgSrc} ${childSrc} ${objectSrc}`;
// Add CSP meta tag to the Khoj Chat modal // WARNING: CSP DISABLED for now as it breaks other Obsidian plugins. Enable when can scope CSP to only Khoj plugin.
document.head.createEl("meta", { attr: { "http-equiv": "Content-Security-Policy", "content": `${csp}` } }); // CSP meta tag for the Khoj Chat modal
// document.head.createEl("meta", { attr: { "http-equiv": "Content-Security-Policy", "content": `${csp}` } });
// Create area for chat logs // Create area for chat logs
let chatBodyEl = contentEl.createDiv({ attr: { id: "khoj-chat-body", class: "khoj-chat-body" } }); let chatBodyEl = contentEl.createDiv({ attr: { id: "khoj-chat-body", class: "khoj-chat-body" } });
@@ -104,9 +113,10 @@ export class KhojChatView extends KhojPaneView {
text: "Chat Sessions", text: "Chat Sessions",
attr: { attr: {
class: "khoj-input-row-button clickable-icon", class: "khoj-input-row-button clickable-icon",
title: "Show Conversations (^O)",
}, },
}) })
chatSessions.addEventListener('click', async (_) => { await this.toggleChatSessions(chatBodyEl) }); chatSessions.addEventListener('click', async (_) => { await this.toggleChatSessions() });
setIcon(chatSessions, "history"); setIcon(chatSessions, "history");
let chatInput = inputRow.createEl("textarea", { let chatInput = inputRow.createEl("textarea", {
@@ -119,14 +129,20 @@ export class KhojChatView extends KhojPaneView {
chatInput.addEventListener('input', (_) => { this.onChatInput() }); chatInput.addEventListener('input', (_) => { this.onChatInput() });
chatInput.addEventListener('keydown', (event) => { this.incrementalChat(event) }); chatInput.addEventListener('keydown', (event) => { this.incrementalChat(event) });
// Add event listeners for long press keybinding
this.contentEl.addEventListener('keydown', this.handleKeyDown.bind(this));
this.contentEl.addEventListener('keyup', this.handleKeyUp.bind(this));
let transcribe = inputRow.createEl("button", { let transcribe = inputRow.createEl("button", {
text: "Transcribe", text: "Transcribe",
attr: { attr: {
id: "khoj-transcribe", id: "khoj-transcribe",
class: "khoj-transcribe khoj-input-row-button clickable-icon ", class: "khoj-transcribe khoj-input-row-button clickable-icon ",
title: "Start Voice Chat (^S)",
}, },
}) })
transcribe.addEventListener('mousedown', async (event) => { await this.speechToText(event) }); transcribe.addEventListener('mousedown', (event) => { this.startSpeechToText(event) });
transcribe.addEventListener('mouseup', async (event) => { await this.stopSpeechToText(event) });
transcribe.addEventListener('touchstart', async (event) => { await this.speechToText(event) }); transcribe.addEventListener('touchstart', async (event) => { await this.speechToText(event) });
transcribe.addEventListener('touchend', async (event) => { await this.speechToText(event) }); transcribe.addEventListener('touchend', async (event) => { await this.speechToText(event) });
transcribe.addEventListener('touchcancel', async (event) => { await this.speechToText(event) }); transcribe.addEventListener('touchcancel', async (event) => { await this.speechToText(event) });
@@ -160,6 +176,46 @@ export class KhojChatView extends KhojPaneView {
}); });
} }
startSpeechToText(event: KeyboardEvent | MouseEvent | TouchEvent, timeout=200) {
if (!this.keyPressTimeout) {
this.keyPressTimeout = setTimeout(async () => {
// Reset auto send voice message timer, UI if running
if (this.sendMessageTimeout) {
// Stop the auto send voice message countdown timer UI
clearTimeout(this.sendMessageTimeout);
const sendButton = <HTMLButtonElement>this.contentEl.getElementsByClassName("khoj-chat-send")[0]
setIcon(sendButton, "arrow-up-circle")
let sendImg = <SVGElement>sendButton.getElementsByClassName("lucide-arrow-up-circle")[0]
sendImg.addEventListener('click', async (_) => { await this.chat() });
// Reset chat input value
const chatInput = <HTMLTextAreaElement>this.contentEl.getElementsByClassName("khoj-chat-input")[0];
chatInput.value = "";
}
// Start new voice message
await this.speechToText(event);
}, timeout);
}
}
async stopSpeechToText(event: KeyboardEvent | MouseEvent | TouchEvent) {
if (this.mediaRecorder) {
await this.speechToText(event);
}
if (this.keyPressTimeout) {
clearTimeout(this.keyPressTimeout);
this.keyPressTimeout = null;
}
}
handleKeyDown(event: KeyboardEvent) {
// Start speech to text if keyboard shortcut is pressed
if (event.key === 's' && event.getModifierState('Control')) this.startSpeechToText(event);
}
async handleKeyUp(event: KeyboardEvent) {
// Stop speech to text if keyboard shortcut is released
if (event.key === 's' && event.getModifierState('Control')) await this.stopSpeechToText(event);
}
processOnlineReferences(referenceSection: HTMLElement, onlineContext: any) { processOnlineReferences(referenceSection: HTMLElement, onlineContext: any) {
let numOnlineReferences = 0; let numOnlineReferences = 0;
for (let subquery in onlineContext) { for (let subquery in onlineContext) {
@@ -294,6 +350,57 @@ export class KhojChatView extends KhojPaneView {
return referenceButton; return referenceButton;
} }
textToSpeech(message: string, event: MouseEvent | null = null): void {
// Replace the speaker with a loading icon.
let loader = document.createElement("span");
loader.classList.add("loader");
let speechButton: HTMLButtonElement;
let speechIcon: Element;
if (event === null) {
// Pick the last speech button if none is provided
let speechButtons = document.getElementsByClassName("speech-button");
speechButton = speechButtons[speechButtons.length - 1] as HTMLButtonElement;
let speechIcons = document.getElementsByClassName("speech-icon");
speechIcon = speechIcons[speechIcons.length - 1];
} else {
speechButton = event.currentTarget as HTMLButtonElement;
speechIcon = event.target as Element;
}
speechButton.appendChild(loader);
speechButton.disabled = true;
const context = new AudioContext();
let textToSpeechApi = `${this.setting.khojUrl}/api/chat/speech?text=${encodeURIComponent(message)}`;
fetch(textToSpeechApi, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
"Authorization": `Bearer ${this.setting.khojApiKey}`,
},
})
.then(response => response.arrayBuffer())
.then(arrayBuffer => context.decodeAudioData(arrayBuffer))
.then(audioBuffer => {
const source = context.createBufferSource();
source.buffer = audioBuffer;
source.connect(context.destination);
source.start(0);
source.onended = function() {
speechButton.removeChild(loader);
speechButton.disabled = false;
};
})
.catch(err => {
console.error("Error playing speech:", err);
speechButton.removeChild(loader);
speechButton.disabled = false; // Consider enabling the button again to allow retrying
});
}
formatHTMLMessage(message: string, raw = false, willReplace = true) { formatHTMLMessage(message: string, raw = false, willReplace = true) {
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for some AI chat model. // Remove any text between <s>[INST] and </s> tags. These are spurious instructions for some AI chat model.
message = message.replace(/<s>\[INST\].+(<\/s>)?/g, ''); message = message.replace(/<s>\[INST\].+(<\/s>)?/g, '');
@@ -461,19 +568,36 @@ export class KhojChatView extends KhojPaneView {
renderActionButtons(message: string, chat_message_body_text_el: HTMLElement) { renderActionButtons(message: string, chat_message_body_text_el: HTMLElement) {
let copyButton = this.contentEl.createEl('button'); let copyButton = this.contentEl.createEl('button');
copyButton.classList.add("copy-button"); copyButton.classList.add("chat-action-button");
copyButton.title = "Copy Message to Clipboard"; copyButton.title = "Copy Message to Clipboard";
setIcon(copyButton, "copy-plus"); setIcon(copyButton, "copy-plus");
copyButton.addEventListener('click', createCopyParentText(message)); copyButton.addEventListener('click', createCopyParentText(message));
chat_message_body_text_el.append(copyButton);
// Add button to paste into current buffer // Add button to paste into current buffer
let pasteToFile = this.contentEl.createEl('button'); let pasteToFile = this.contentEl.createEl('button');
pasteToFile.classList.add("copy-button"); pasteToFile.classList.add("chat-action-button");
pasteToFile.title = "Paste Message to File"; pasteToFile.title = "Paste Message to File";
setIcon(pasteToFile, "clipboard-paste"); setIcon(pasteToFile, "clipboard-paste");
pasteToFile.addEventListener('click', (event) => { pasteTextAtCursor(createCopyParentText(message, 'clipboard-paste')(event)); }); pasteToFile.addEventListener('click', (event) => { pasteTextAtCursor(createCopyParentText(message, 'clipboard-paste')(event)); });
chat_message_body_text_el.append(pasteToFile);
// Only enable the speech feature if the user is subscribed
let speechButton = null;
if (this.setting.userInfo?.is_active) {
// Create a speech button icon to play the message out loud
speechButton = this.contentEl.createEl('button');
speechButton.classList.add("chat-action-button", "speech-button");
speechButton.title = "Listen to Message";
setIcon(speechButton, "speech")
speechButton.addEventListener('click', (event) => this.textToSpeech(message, event));
}
// Append buttons to parent element
chat_message_body_text_el.append(copyButton, pasteToFile);
if (speechButton) {
chat_message_body_text_el.append(speechButton);
}
} }
formatDate(date: Date): string { formatDate(date: Date): string {
@@ -483,14 +607,16 @@ export class KhojChatView extends KhojPaneView {
return `${time_string}, ${date_string}`; return `${time_string}, ${date_string}`;
} }
createNewConversation(chatBodyEl: HTMLElement) { createNewConversation() {
let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement;
chatBodyEl.innerHTML = ""; chatBodyEl.innerHTML = "";
chatBodyEl.dataset.conversationId = ""; chatBodyEl.dataset.conversationId = "";
chatBodyEl.dataset.conversationTitle = ""; chatBodyEl.dataset.conversationTitle = "";
this.renderMessage(chatBodyEl, "Hey 👋🏾, what's up?", "khoj"); this.renderMessage(chatBodyEl, "Hey 👋🏾, what's up?", "khoj");
} }
async toggleChatSessions(chatBodyEl: HTMLElement, forceShow: boolean = false): Promise<boolean> { async toggleChatSessions(forceShow: boolean = false): Promise<boolean> {
let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement;
if (!forceShow && this.contentEl.getElementsByClassName("side-panel")?.length > 0) { if (!forceShow && this.contentEl.getElementsByClassName("side-panel")?.length > 0) {
chatBodyEl.innerHTML = ""; chatBodyEl.innerHTML = "";
return this.getChatHistory(chatBodyEl); return this.getChatHistory(chatBodyEl);
@@ -504,9 +630,10 @@ export class KhojChatView extends KhojPaneView {
const newConversationButtonEl = newConversationEl.createEl("button"); const newConversationButtonEl = newConversationEl.createEl("button");
newConversationButtonEl.classList.add("new-conversation-button"); newConversationButtonEl.classList.add("new-conversation-button");
newConversationButtonEl.classList.add("side-panel-button"); newConversationButtonEl.classList.add("side-panel-button");
newConversationButtonEl.addEventListener('click', (_) => this.createNewConversation(chatBodyEl)); newConversationButtonEl.addEventListener('click', (_) => this.createNewConversation());
setIcon(newConversationButtonEl, "plus"); setIcon(newConversationButtonEl, "plus");
newConversationButtonEl.innerHTML += "New"; newConversationButtonEl.innerHTML += "New";
newConversationButtonEl.title = "New Conversation (^N)";
const existingConversationsEl = sidePanelEl.createDiv("existing-conversations"); const existingConversationsEl = sidePanelEl.createDiv("existing-conversations");
const conversationListEl = existingConversationsEl.createDiv("conversation-list"); const conversationListEl = existingConversationsEl.createDiv("conversation-list");
@@ -666,7 +793,7 @@ export class KhojChatView extends KhojPaneView {
chatBodyEl.innerHTML = ""; chatBodyEl.innerHTML = "";
chatBodyEl.dataset.conversationId = ""; chatBodyEl.dataset.conversationId = "";
chatBodyEl.dataset.conversationTitle = ""; chatBodyEl.dataset.conversationTitle = "";
this.toggleChatSessions(chatBodyEl, true); this.toggleChatSessions(true);
}) })
.catch(err => { .catch(err => {
return; return;
@@ -727,7 +854,7 @@ export class KhojChatView extends KhojPaneView {
return true; return true;
} }
async readChatStream(response: Response, responseElement: HTMLDivElement): Promise<void> { async readChatStream(response: Response, responseElement: HTMLDivElement, isVoice: boolean = false): Promise<void> {
// Exit if response body is empty // Exit if response body is empty
if (response.body == null) return; if (response.body == null) return;
@@ -737,8 +864,12 @@ export class KhojChatView extends KhojPaneView {
while (true) { while (true) {
const { value, done } = await reader.read(); const { value, done } = await reader.read();
// Break if the stream is done if (done) {
if (done) break; // Automatically respond with voice if the subscribed user has sent voice message
if (isVoice && this.setting.userInfo?.is_active) this.textToSpeech(this.result);
// Break if the stream is done
break;
}
let responseText = decoder.decode(value); let responseText = decoder.decode(value);
if (responseText.includes("### compiled references:")) { if (responseText.includes("### compiled references:")) {
@@ -756,7 +887,7 @@ export class KhojChatView extends KhojPaneView {
} }
} }
async getChatResponse(query: string | undefined | null): Promise<void> { async getChatResponse(query: string | undefined | null, isVoice: boolean = false): Promise<void> {
// Exit if query is empty // Exit if query is empty
if (!query || query === "") return; if (!query || query === "") return;
@@ -835,7 +966,7 @@ export class KhojChatView extends KhojPaneView {
} }
} else { } else {
// Stream and render chat response // Stream and render chat response
await this.readChatStream(response, responseElement); await this.readChatStream(response, responseElement, isVoice);
} }
} catch (err) { } catch (err) {
console.log(`Khoj chat response failed with\n${err}`); console.log(`Khoj chat response failed with\n${err}`);
@@ -883,7 +1014,7 @@ export class KhojChatView extends KhojPaneView {
sendMessageTimeout: NodeJS.Timeout | undefined; sendMessageTimeout: NodeJS.Timeout | undefined;
mediaRecorder: MediaRecorder | undefined; mediaRecorder: MediaRecorder | undefined;
async speechToText(event: MouseEvent | TouchEvent) { async speechToText(event: MouseEvent | TouchEvent | KeyboardEvent) {
event.preventDefault(); event.preventDefault();
const transcribeButton = <HTMLButtonElement>this.contentEl.getElementsByClassName("khoj-transcribe")[0]; const transcribeButton = <HTMLButtonElement>this.contentEl.getElementsByClassName("khoj-transcribe")[0];
const chatInput = <HTMLTextAreaElement>this.contentEl.getElementsByClassName("khoj-chat-input")[0]; const chatInput = <HTMLTextAreaElement>this.contentEl.getElementsByClassName("khoj-chat-input")[0];
@@ -916,9 +1047,19 @@ export class KhojChatView extends KhojPaneView {
}); });
// Parse response from Khoj backend // Parse response from Khoj backend
let noSpeechText: string[] = [
"Thanks for watching!",
"Thanks for watching.",
"Thank you for watching!",
"Thank you for watching.",
"You",
"Bye."
];
let noSpeech: boolean = false;
if (response.status === 200) { if (response.status === 200) {
console.log(response); console.log(response);
chatInput.value += response.json.text.trimStart(); noSpeech = noSpeechText.includes(response.json.text.trimStart());
if (!noSpeech) chatInput.value += response.json.text.trimStart();
this.autoResize(); this.autoResize();
} else if (response.status === 501) { } else if (response.status === 501) {
throw new Error("⛔️ Configure speech-to-text model on server."); throw new Error("⛔️ Configure speech-to-text model on server.");
@@ -928,8 +1069,8 @@ export class KhojChatView extends KhojPaneView {
throw new Error("⛔️ Failed to transcribe audio."); throw new Error("⛔️ Failed to transcribe audio.");
} }
// Don't auto-send empty messages // Don't auto-send empty messages or when no speech is detected
if (chatInput.value.length === 0) return; if (chatInput.value.length === 0 || noSpeech) return;
// Show stop auto-send button. It stops auto-send when clicked // Show stop auto-send button. It stops auto-send when clicked
setIcon(sendButton, "stop-circle"); setIcon(sendButton, "stop-circle");
@@ -938,6 +1079,7 @@ export class KhojChatView extends KhojPaneView {
// Start the countdown timer UI // Start the countdown timer UI
stopSendButtonImg.getElementsByTagName("circle")[0].style.animation = "countdown 3s linear 1 forwards"; stopSendButtonImg.getElementsByTagName("circle")[0].style.animation = "countdown 3s linear 1 forwards";
stopSendButtonImg.getElementsByTagName("circle")[0].style.color = "var(--icon-color-active)";
// Auto send message after 3 seconds // Auto send message after 3 seconds
this.sendMessageTimeout = setTimeout(() => { this.sendMessageTimeout = setTimeout(() => {
@@ -947,7 +1089,7 @@ export class KhojChatView extends KhojPaneView {
sendImg.addEventListener('click', async (_) => { await this.chat() }); sendImg.addEventListener('click', async (_) => { await this.chat() });
// Send message // Send message
this.chat(); this.chat(true);
}, 3000); }, 3000);
}; };
@@ -966,21 +1108,23 @@ export class KhojChatView extends KhojPaneView {
}); });
this.mediaRecorder.start(); this.mediaRecorder.start();
setIcon(transcribeButton, "mic-off"); // setIcon(transcribeButton, "mic-off");
transcribeButton.classList.add("loading-encircle")
}; };
// Toggle recording // Toggle recording
if (!this.mediaRecorder || this.mediaRecorder.state === 'inactive' || event.type === 'touchstart') { if (!this.mediaRecorder || this.mediaRecorder.state === 'inactive' || event.type === 'touchstart' || event.type === 'mousedown' || event.type === 'keydown') {
navigator.mediaDevices navigator.mediaDevices
.getUserMedia({ audio: true }) .getUserMedia({ audio: true })
?.then(handleRecording) ?.then(handleRecording)
.catch((e) => { .catch((e) => {
this.flashStatusInChatInput("⛔️ Failed to access microphone"); this.flashStatusInChatInput("⛔️ Failed to access microphone");
}); });
} else if (this.mediaRecorder.state === 'recording' || event.type === 'touchend' || event.type === 'touchcancel') { } else if (this.mediaRecorder?.state === 'recording' || event.type === 'touchend' || event.type === 'touchcancel' || event.type === 'mouseup' || event.type === 'keyup') {
this.mediaRecorder.stop(); this.mediaRecorder.stop();
this.mediaRecorder.stream.getTracks().forEach(track => track.stop()); this.mediaRecorder.stream.getTracks().forEach(track => track.stop());
this.mediaRecorder = undefined; this.mediaRecorder = undefined;
transcribeButton.classList.remove("loading-encircle");
setIcon(transcribeButton, "mic"); setIcon(transcribeButton, "mic");
} }
} }

View File

@@ -2,7 +2,8 @@ import { Plugin, WorkspaceLeaf } from 'obsidian';
import { KhojSetting, KhojSettingTab, DEFAULT_SETTINGS } from 'src/settings' import { KhojSetting, KhojSettingTab, DEFAULT_SETTINGS } from 'src/settings'
import { KhojSearchModal } from 'src/search_modal' import { KhojSearchModal } from 'src/search_modal'
import { KhojChatView } from 'src/chat_view' import { KhojChatView } from 'src/chat_view'
import { updateContentIndex, canConnectToBackend, KhojView } from './utils'; import { updateContentIndex, canConnectToBackend, KhojView, jumpToPreviousView } from './utils';
import { KhojPaneView } from './pane_view';
export default class Khoj extends Plugin { export default class Khoj extends Plugin {
@@ -79,16 +80,30 @@ export default class Khoj extends Plugin {
const leaves = workspace.getLeavesOfType(viewType); const leaves = workspace.getLeavesOfType(viewType);
if (leaves.length > 0) { if (leaves.length > 0) {
// A leaf with our view already exists, use that // A leaf with our view already exists, use that
leaf = leaves[0]; leaf = leaves[0];
} else { } else {
// Our view could not be found in the workspace, create a new leaf // Our view could not be found in the workspace, create a new leaf
// in the right sidebar for it // in the right sidebar for it
leaf = workspace.getRightLeaf(false); leaf = workspace.getRightLeaf(false);
await leaf?.setViewState({ type: viewType, active: true }); await leaf?.setViewState({ type: viewType, active: true });
} }
// "Reveal" the leaf in case it is in a collapsed sidebar if (leaf) {
if (leaf) workspace.revealLeaf(leaf); const activeKhojLeaf = workspace.getActiveViewOfType(KhojPaneView)?.leaf;
} // Jump to the previous view if the current view is Khoj Side Pane
if (activeKhojLeaf === leaf) jumpToPreviousView();
// Else Reveal the leaf in case it is in a collapsed sidebar
else {
workspace.revealLeaf(leaf);
if (viewType === KhojView.CHAT) {
// focus on the chat input when the chat view is opened
let chatView = leaf.view as KhojChatView;
let chatInput = <HTMLTextAreaElement>chatView.contentEl.getElementsByClassName("khoj-chat-input")[0];
if (chatInput) chatInput.focus();
}
}
}
}
} }

View File

@@ -38,16 +38,24 @@ export abstract class KhojPaneView extends ItemView {
const leaves = workspace.getLeavesOfType(viewType); const leaves = workspace.getLeavesOfType(viewType);
if (leaves.length > 0) { if (leaves.length > 0) {
// A leaf with our view already exists, use that // A leaf with our view already exists, use that
leaf = leaves[0]; leaf = leaves[0];
} else { } else {
// Our view could not be found in the workspace, create a new leaf // Our view could not be found in the workspace, create a new leaf
// in the right sidebar for it // in the right sidebar for it
leaf = workspace.getRightLeaf(false); leaf = workspace.getRightLeaf(false);
await leaf?.setViewState({ type: viewType, active: true }); await leaf?.setViewState({ type: viewType, active: true });
} }
// "Reveal" the leaf in case it is in a collapsed sidebar if (leaf) {
if (leaf) workspace.revealLeaf(leaf); if (viewType === KhojView.CHAT) {
} // focus on the chat input when the chat view is opened
let chatInput = <HTMLTextAreaElement>this.contentEl.getElementsByClassName("khoj-chat-input")[0];
if (chatInput) chatInput.focus();
}
// "Reveal" the leaf in case it is in a collapsed sidebar
workspace.revealLeaf(leaf);
}
}
} }

View File

@@ -333,6 +333,12 @@ export function createCopyParentText(message: string, originalButton: string = '
} }
} }
export function jumpToPreviousView() {
const editor: Editor = this.app.workspace.getActiveFileView()?.editor
if (!editor) return;
editor.focus();
}
export function pasteTextAtCursor(text: string | undefined) { export function pasteTextAtCursor(text: string | undefined) {
// Get the current active file's editor // Get the current active file's editor
const editor: Editor = this.app.workspace.getActiveFileView()?.editor const editor: Editor = this.app.workspace.getActiveFileView()?.editor

View File

@@ -477,7 +477,7 @@ span.khoj-nav-item-text {
} }
/* Copy button */ /* Copy button */
button.copy-button { button.chat-action-button {
display: block; display: block;
border-radius: 4px; border-radius: 4px;
color: var(--text-muted); color: var(--text-muted);
@@ -491,20 +491,54 @@ button.copy-button {
margin-top: 8px; margin-top: 8px;
float: right; float: right;
} }
button.copy-button span { button.chat-action-button span {
cursor: pointer; cursor: pointer;
display: inline-block; display: inline-block;
position: relative; position: relative;
transition: 0.5s; transition: 0.5s;
} }
button.chat-action-button:hover {
background-color: var(--background-modifier-active-hover);
color: var(--text-normal);
}
img.copy-icon { img.copy-icon {
width: 16px; width: 16px;
height: 16px; height: 16px;
} }
button.copy-button:hover { /* Circular Loading Spinner */
background-color: var(--background-modifier-active-hover); .loader {
color: var(--text-normal); width: 18px;
height: 18px;
border: 3px solid #FFF;
border-radius: 50%;
display: inline-block;
position: relative;
box-sizing: border-box;
animation: rotation 1s linear infinite;
}
.loader::after {
content: '';
box-sizing: border-box;
position: absolute;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
width: 18px;
height: 18px;
border-radius: 50%;
border: 3px solid transparent;
border-bottom-color: var(--flower);
}
@keyframes rotation {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
} }
/* Loading Spinner */ /* Loading Spinner */
@@ -564,6 +598,44 @@ button.copy-button:hover {
} }
} }
/* Loading Encircle */
.loading-encircle {
position: relative;
}
.loading-encircle::before {
content: '';
position: absolute;
top: 50%;
left: 50%;
width: 24px;
height: 24px;
margin-top: -16px;
margin-left: -16px;
border: 4px solid transparent;
border-color: var(--icon-color-active);
border-radius: 50%;
animation: pulse 3s ease-in-out infinite;
}
@keyframes pulse {
0% {
transform: scale(1);
opacity: 1;
}
50% {
transform: scale(1.2);
opacity: 0.2;
}
100% {
transform: scale(1);
opacity: 1;
}
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
@media only screen and (max-width: 600px) { @media only screen and (max-width: 600px) {
div.khoj-header { div.khoj-header {
display: grid; display: grid;

View File

@@ -52,5 +52,6 @@
"1.12.1": "0.15.0", "1.12.1": "0.15.0",
"1.13.0": "0.15.0", "1.13.0": "0.15.0",
"1.14.0": "0.15.0", "1.14.0": "0.15.0",
"1.15.0": "0.15.0" "1.15.0": "0.15.0",
"1.16.0": "0.15.0"
} }

View File

@@ -112,7 +112,7 @@ ASGI_APPLICATION = "app.asgi.application"
# Database # Database
# https://docs.djangoproject.com/en/4.2/ref/settings/#databases # https://docs.djangoproject.com/en/4.2/ref/settings/#databases
DATA_UPLOAD_MAX_NUMBER_FIELDS = 20000
DATABASES = { DATABASES = {
"default": { "default": {
"ENGINE": "django.db.backends.postgresql", "ENGINE": "django.db.backends.postgresql",
@@ -122,6 +122,7 @@ DATABASES = {
"NAME": os.getenv("POSTGRES_DB", "khoj"), "NAME": os.getenv("POSTGRES_DB", "khoj"),
"PASSWORD": os.getenv("POSTGRES_PASSWORD", "postgres"), "PASSWORD": os.getenv("POSTGRES_PASSWORD", "postgres"),
"CONN_MAX_AGE": 0, "CONN_MAX_AGE": 0,
"CONN_HEALTH_CHECKS": True,
} }
} }

View File

@@ -48,6 +48,7 @@ from khoj.database.models import (
UserConversationConfig, UserConversationConfig,
UserRequests, UserRequests,
UserSearchModelConfig, UserSearchModelConfig,
UserTextToImageModelConfig,
UserVoiceModelConfig, UserVoiceModelConfig,
VoiceModelOption, VoiceModelOption,
) )
@@ -907,7 +908,45 @@ class ConversationAdapters:
@staticmethod @staticmethod
async def aget_text_to_image_model_config(): async def aget_text_to_image_model_config():
return await TextToImageModelConfig.objects.filter().afirst() return await TextToImageModelConfig.objects.filter().prefetch_related("openai_config").afirst()
@staticmethod
def get_text_to_image_model_config():
return TextToImageModelConfig.objects.filter().first()
@staticmethod
def get_text_to_image_model_options():
return TextToImageModelConfig.objects.all()
@staticmethod
def get_user_text_to_image_model_config(user: KhojUser):
config = UserTextToImageModelConfig.objects.filter(user=user).first()
if not config:
default_config = ConversationAdapters.get_text_to_image_model_config()
if not default_config:
return None
return default_config
return config.setting
@staticmethod
async def aget_user_text_to_image_model(user: KhojUser) -> Optional[TextToImageModelConfig]:
config = await UserTextToImageModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
default_config = await ConversationAdapters.aget_text_to_image_model_config()
if not default_config:
return None
return default_config
return config.setting
@staticmethod
async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_config_id: int):
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
if not config:
return None
new_config, _ = await UserTextToImageModelConfig.objects.aupdate_or_create(
user=user, defaults={"setting": config}
)
return new_config
@staticmethod @staticmethod
def add_files_to_filter(user: KhojUser, conversation_id: int, files: List[str]): def add_files_to_filter(user: KhojUser, conversation_id: int, files: List[str]):
@@ -949,7 +988,7 @@ class FileObjectAdapters:
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text) return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod @staticmethod
def get_file_objects_by_name(user: KhojUser, file_name: str): def get_file_object_by_name(user: KhojUser, file_name: str):
return FileObject.objects.filter(user=user, file_name=file_name).first() return FileObject.objects.filter(user=user, file_name=file_name).first()
@staticmethod @staticmethod
@@ -1005,27 +1044,39 @@ class EntryAdapters:
return deleted_count return deleted_count
@staticmethod @staticmethod
def delete_all_entries_by_type(user: KhojUser, file_type: str = None): def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
if file_type is None: queryset = Entry.objects.filter(user=user)
deleted_count, _ = Entry.objects.filter(user=user).delete()
else: if file_type is not None:
deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete() queryset = queryset.filter(file_type=file_type)
if file_source is not None:
queryset = queryset.filter(file_source=file_source)
return queryset
@staticmethod
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
while queryset.exists():
batch_ids = list(queryset.values_list("id", flat=True)[:batch_size])
batch = Entry.objects.filter(id__in=batch_ids, user=user)
count, _ = batch.delete()
deleted_count += count
return deleted_count return deleted_count
@staticmethod @staticmethod
def delete_all_entries(user: KhojUser, file_source: str = None): async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
if file_source is None: deleted_count = 0
deleted_count, _ = Entry.objects.filter(user=user).delete() queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
else: while await queryset.aexists():
deleted_count, _ = Entry.objects.filter(user=user, file_source=file_source).delete() batch_ids = await sync_to_async(list)(queryset.values_list("id", flat=True)[:batch_size])
batch = Entry.objects.filter(id__in=batch_ids, user=user)
count, _ = await batch.adelete()
deleted_count += count
return deleted_count return deleted_count
@staticmethod
async def adelete_all_entries(user: KhojUser, file_source: str = None):
if file_source is None:
return await Entry.objects.filter(user=user).adelete()
return await Entry.objects.filter(user=user, file_source=file_source).adelete()
@staticmethod @staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)

View File

@@ -96,7 +96,6 @@ admin.site.register(SpeechToTextModelOptions)
admin.site.register(SearchModelConfig) admin.site.register(SearchModelConfig)
admin.site.register(ReflectiveQuestion) admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig) admin.site.register(UserSearchModelConfig)
admin.site.register(TextToImageModelConfig)
admin.site.register(ClientApplication) admin.site.register(ClientApplication)
admin.site.register(GithubConfig) admin.site.register(GithubConfig)
admin.site.register(NotionConfig) admin.site.register(NotionConfig)
@@ -126,7 +125,10 @@ class EntryAdmin(admin.ModelAdmin):
"file_path", "file_path",
) )
search_fields = ("id", "user__email", "user__username", "file_path") search_fields = ("id", "user__email", "user__username", "file_path")
list_filter = ("file_type",) list_filter = (
"file_type",
"user__email",
)
ordering = ("-created_at",) ordering = ("-created_at",)
@@ -153,6 +155,16 @@ class ChatModelOptionsAdmin(admin.ModelAdmin):
search_fields = ("id", "chat_model", "model_type") search_fields = ("id", "chat_model", "model_type")
@admin.register(TextToImageModelConfig)
class TextToImageModelOptionsAdmin(admin.ModelAdmin):
list_display = (
"id",
"model_name",
"model_type",
)
search_fields = ("id", "model_name", "model_type")
@admin.register(OpenAIProcessorConversationConfig) @admin.register(OpenAIProcessorConversationConfig)
class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin): class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin):
list_display = ( list_display = (

View File

@@ -0,0 +1,58 @@
# Generated by Django 4.2.11 on 2024-06-26 03:27
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0048_voicemodeloption_uservoicemodelconfig"),
]
operations = [
migrations.AddField(
model_name="texttoimagemodelconfig",
name="api_key",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.AddField(
model_name="texttoimagemodelconfig",
name="openai_config",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.openaiprocessorconversationconfig",
),
),
migrations.AlterField(
model_name="texttoimagemodelconfig",
name="model_type",
field=models.CharField(
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200
),
),
migrations.CreateModel(
name="UserTextToImageModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"setting",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="database.texttoimagemodelconfig"
),
),
(
"user",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
],
options={
"abstract": False,
},
),
]

View File

@@ -0,0 +1,14 @@
# Generated by Django 4.2.11 on 2024-07-02 12:20
from typing import List
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0049_texttoimagemodelconfig_api_key_and_more"),
("database", "0050_alter_processlock_name"),
]
operations: List[str] = []

View File

@@ -215,11 +215,11 @@ class SearchModelConfig(BaseModel):
# Bi-encoder model of sentence-transformer type to load from HuggingFace # Bi-encoder model of sentence-transformer type to load from HuggingFace
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small") bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
# Config passed to the sentence-transformer model constructor. E.g. device="cuda:0", trust_remote_server=True etc. # Config passed to the sentence-transformer model constructor. E.g. device="cuda:0", trust_remote_server=True etc.
bi_encoder_model_config = models.JSONField(default=dict) bi_encoder_model_config = models.JSONField(default=dict, blank=True)
# Query encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models # Query encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models
bi_encoder_query_encode_config = models.JSONField(default=dict) bi_encoder_query_encode_config = models.JSONField(default=dict, blank=True)
# Docs encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models # Docs encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models
bi_encoder_docs_encode_config = models.JSONField(default=dict) bi_encoder_docs_encode_config = models.JSONField(default=dict, blank=True)
# Cross-encoder model of sentence-transformer type to load from HuggingFace # Cross-encoder model of sentence-transformer type to load from HuggingFace
cross_encoder = models.CharField(max_length=200, default="mixedbread-ai/mxbai-rerank-xsmall-v1") cross_encoder = models.CharField(max_length=200, default="mixedbread-ai/mxbai-rerank-xsmall-v1")
# Inference server API endpoint to use for embeddings inference. Bi-encoder model should be hosted on this server # Inference server API endpoint to use for embeddings inference. Bi-encoder model should be hosted on this server
@@ -235,9 +235,37 @@ class SearchModelConfig(BaseModel):
class TextToImageModelConfig(BaseModel): class TextToImageModelConfig(BaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
OPENAI = "openai" OPENAI = "openai"
STABILITYAI = "stability-ai"
model_name = models.CharField(max_length=200, default="dall-e-3") model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
openai_config = models.ForeignKey(
OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True
)
def clean(self):
# Custom validation logic
error = {}
if self.model_type == self.ModelType.OPENAI:
if self.api_key and self.openai_config:
error[
"api_key"
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
error[
"openai_config"
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
if self.model_type != self.ModelType.OPENAI:
if not self.api_key:
error["api_key"] = "The API key field must be set for non OpenAI models."
if self.openai_config:
error["openai_config"] = "OpenAI config cannot be set for non OpenAI models."
if error:
raise ValidationError(error)
def save(self, *args, **kwargs):
self.clean()
super().save(*args, **kwargs)
class SpeechToTextModelOptions(BaseModel): class SpeechToTextModelOptions(BaseModel):
@@ -264,6 +292,11 @@ class UserSearchModelConfig(BaseModel):
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE) setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
class UserTextToImageModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
class Conversation(BaseModel): class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict) conversation_log = models.JSONField(default=dict)

View File

@@ -242,18 +242,25 @@
<script> <script>
async function openChat(agentSlug) { async function openChat(agentSlug) {
// Create a loading animation // Create a loading animation
let loading = document.createElement("div"); let loadingTextEl = document.createElement("div");
loading.innerHTML = '<div>Booting your agent...</div><span class="loader"></span>'; loadingTextEl.textContent = 'Booting your agent...';
loading.style.position = "fixed";
loading.style.top = "0"; let loadingAnimationEl = document.createElement("span");
loading.style.right = "0"; loadingAnimationEl.className = "loader";
loading.style.bottom = "0";
loading.style.left = "0"; let loadingEl = document.createElement("div");
loading.style.display = "flex"; loadingEl.style.position = "fixed";
loading.style.justifyContent = "center"; loadingEl.style.top = "0";
loading.style.alignItems = "center"; loadingEl.style.right = "0";
loading.style.backgroundColor = "rgba(0, 0, 0, 0.5)"; // Semi-transparent black loadingEl.style.bottom = "0";
document.body.appendChild(loading); loadingEl.style.left = "0";
loadingEl.style.display = "flex";
loadingEl.style.justifyContent = "center";
loadingEl.style.alignItems = "center";
loadingEl.style.backgroundColor = "rgba(0, 0, 0, 0.5)"; // Semi-transparent black
loadingEl.append(loadingTextEl, loadingAnimationEl);
document.body.appendChild(loadingEl);
let response = await fetch(`/api/chat/sessions?agent_slug=${agentSlug}`, { method: "POST" }); let response = await fetch(`/api/chat/sessions?agent_slug=${agentSlug}`, { method: "POST" });
let data = await response.json(); let data = await response.json();

View File

@@ -5,13 +5,22 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}"> <link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}">
<title>Khoj</title> <title>Khoj</title>
<meta http-equiv="Content-Security-Policy"
content="default-src 'self' https://assets.khoj.dev;
script-src 'self' https://assets.khoj.dev 'unsafe-inline';
connect-src 'self' https://ipapi.co/json;
style-src 'self' https://assets.khoj.dev 'unsafe-inline' https://fonts.googleapis.com;
img-src 'self' data: https://*.khoj.dev https://*.googleusercontent.com;
font-src https://assets.khoj.dev https://fonts.gstatic.com;
child-src 'none';
object-src 'none';">
<link rel="stylesheet" href="/static/assets/pico.min.css?v={{ khoj_version }}"> <link rel="stylesheet" href="/static/assets/pico.min.css?v={{ khoj_version }}">
<link rel="stylesheet" href="/static/assets/khoj.css?v={{ khoj_version }}"> <link rel="stylesheet" href="/static/assets/khoj.css?v={{ khoj_version }}">
<script <script
integrity="sha384-05IkdNHoAlkhrFVUCCN805WC/h4mcI98GUBssmShF2VJAXKyZTrO/TmJ+4eBo0Cy" integrity="sha384-05IkdNHoAlkhrFVUCCN805WC/h4mcI98GUBssmShF2VJAXKyZTrO/TmJ+4eBo0Cy"
crossorigin="anonymous" crossorigin="anonymous"
src="https://cdnjs.cloudflare.com/ajax/libs/intl-tel-input/17.0.13/js/intlTelInput.min.js"></script> src="https://assets.khoj.dev/intl-tel-input/intlTelInput.min.js"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/intl-tel-input/17.0.13/css/intlTelInput.css"> <link rel="stylesheet" href="https://assets.khoj.dev/intl-tel-input/intlTelInput.css">
</head> </head>
<script type="text/javascript" src="/static/assets/utils.js?v={{ khoj_version }}"></script> <script type="text/javascript" src="/static/assets/utils.js?v={{ khoj_version }}"></script>
<script type="text/javascript" src="/static/assets/purify.min.js?v={{ khoj_version }}"></script> <script type="text/javascript" src="/static/assets/purify.min.js?v={{ khoj_version }}"></script>
@@ -332,6 +341,7 @@
margin: 20px; margin: 20px;
} }
select#paint-models,
select#search-models, select#search-models,
select#voice-models, select#voice-models,
select#chat-models { select#chat-models {

View File

@@ -48,8 +48,8 @@ Get the Khoj [Desktop](https://khoj.dev/downloads), [Obsidian](https://docs.khoj
To get started, just start typing below. You can also type / to see a list of commands. To get started, just start typing below. You can also type / to see a list of commands.
`.trim() `.trim()
const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document']; const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf', 'image/jpeg', 'image/png', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'];
const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf', 'docx']; const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf', 'jpg', 'jpeg', 'png', 'docx'];
let chatOptions = []; let chatOptions = [];
function createCopyParentText(message) { function createCopyParentText(message) {
return function(event) { return function(event) {
@@ -149,7 +149,6 @@ To get started, just start typing below. You can also type / to see a list of co
} }
function generateOnlineReference(reference, index) { function generateOnlineReference(reference, index) {
// Generate HTML for Chat Reference // Generate HTML for Chat Reference
let title = reference.title || reference.link; let title = reference.title || reference.link;
let link = reference.link; let link = reference.link;
@@ -170,7 +169,7 @@ To get started, just start typing below. You can also type / to see a list of co
linkElement.textContent = title; linkElement.textContent = title;
let referenceButton = document.createElement('button'); let referenceButton = document.createElement('button');
referenceButton.innerHTML = linkElement.outerHTML; referenceButton.appendChild(linkElement);
referenceButton.id = `ref-${index}`; referenceButton.id = `ref-${index}`;
referenceButton.classList.add("reference-button"); referenceButton.classList.add("reference-button");
referenceButton.classList.add("collapsed"); referenceButton.classList.add("collapsed");
@@ -181,11 +180,12 @@ To get started, just start typing below. You can also type / to see a list of co
if (this.classList.contains("collapsed")) { if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed"); this.classList.remove("collapsed");
this.classList.add("expanded"); this.classList.add("expanded");
this.innerHTML = linkElement.outerHTML + `<br><br>${question + snippet}`; this.innerHTML = `${linkElement.outerHTML}<br><br>${question}${snippet}`;
} else { } else {
this.classList.add("collapsed"); this.classList.add("collapsed");
this.classList.remove("expanded"); this.classList.remove("expanded");
this.innerHTML = linkElement.outerHTML; this.innerHTML = "";
this.appendChild(linkElement);
} }
}); });
@@ -578,7 +578,7 @@ To get started, just start typing below. You can also type / to see a list of co
let referenceExpandButton = document.createElement('button'); let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button"); referenceExpandButton.classList.add("reference-expand-button");
referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`; referenceExpandButton.textContent = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.addEventListener('click', function() { referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) { if (referenceSection.classList.contains("collapsed")) {
@@ -888,7 +888,7 @@ To get started, just start typing below. You can also type / to see a list of co
if (overlayText == null) { if (overlayText == null) {
dropzone.classList.add('dragover'); dropzone.classList.add('dragover');
var overlayText = document.createElement("div"); var overlayText = document.createElement("div");
overlayText.innerHTML = "Select file(s) or drag + drop it here to share it with Khoj"; overlayText.textContent = "Select file(s) or drag + drop it here to share it with Khoj";
overlayText.className = "dropzone-overlay"; overlayText.className = "dropzone-overlay";
overlayText.id = "dropzone-overlay"; overlayText.id = "dropzone-overlay";
dropzone.appendChild(overlayText); dropzone.appendChild(overlayText);
@@ -949,7 +949,7 @@ To get started, just start typing below. You can also type / to see a list of co
if (overlayText != null) { if (overlayText != null) {
// Display loading spinner // Display loading spinner
var loadingSpinner = document.createElement("div"); var loadingSpinner = document.createElement("div");
overlayText.innerHTML = "Uploading file(s) for indexing"; overlayText.textContent = "Uploading file(s) for indexing";
loadingSpinner.className = "spinner"; loadingSpinner.className = "spinner";
overlayText.appendChild(loadingSpinner); overlayText.appendChild(loadingSpinner);
} }
@@ -974,7 +974,12 @@ To get started, just start typing below. You can also type / to see a list of co
fileType = "text/html"; fileType = "text/html";
} else if (fileExtension === "pdf") { } else if (fileExtension === "pdf") {
fileType = "application/pdf"; fileType = "application/pdf";
} else { } else if (fileExtension === "jpg" || fileExtension === "jpeg"){
fileType = "image/jpeg";
} else if (fileExtension === "png") {
fileType = "image/png";
}
else {
// Skip this file if its type is not supported // Skip this file if its type is not supported
resolve(); resolve();
return; return;
@@ -1037,7 +1042,7 @@ To get started, just start typing below. You can also type / to see a list of co
if (overlayText == null) { if (overlayText == null) {
var overlayText = document.createElement("div"); var overlayText = document.createElement("div");
overlayText.innerHTML = "Drop file to share it with Khoj"; overlayText.textContent = "Drop file to share it with Khoj";
overlayText.className = "dropzone-overlay"; overlayText.className = "dropzone-overlay";
overlayText.id = "dropzone-overlay"; overlayText.id = "dropzone-overlay";
this.appendChild(overlayText); this.appendChild(overlayText);
@@ -1174,11 +1179,15 @@ To get started, just start typing below. You can also type / to see a list of co
websocket.onclose = function(event) { websocket.onclose = function(event) {
websocket = null; websocket = null;
console.log("WebSocket is closed now."); console.log("WebSocket is closed now.");
let setupWebSocketButton = document.createElement("button");
setupWebSocketButton.textContent = "Reconnect to Server";
setupWebSocketButton.onclick = setupWebSocket;
let statusDotIcon = document.getElementById("connection-status-icon"); let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "red"; statusDotIcon.style.backgroundColor = "red";
let statusDotText = document.getElementById("connection-status-text"); let statusDotText = document.getElementById("connection-status-text");
statusDotText.innerHTML = "";
statusDotText.style.marginTop = "5px"; statusDotText.style.marginTop = "5px";
statusDotText.innerHTML = '<button onclick="setupWebSocket()">Reconnect to Server</button>'; statusDotText.appendChild(setupWebSocketButton);
} }
websocket.onerror = function(event) { websocket.onerror = function(event) {
console.log("WebSocket error observed:", event); console.log("WebSocket error observed:", event);
@@ -1429,7 +1438,7 @@ To get started, just start typing below. You can also type / to see a list of co
questionStarterSuggestions.innerHTML = ""; questionStarterSuggestions.innerHTML = "";
data.forEach((questionStarter) => { data.forEach((questionStarter) => {
let questionStarterButton = document.createElement('button'); let questionStarterButton = document.createElement('button');
questionStarterButton.innerHTML = questionStarter; questionStarterButton.textContent = questionStarter;
questionStarterButton.classList.add("question-starter"); questionStarterButton.classList.add("question-starter");
questionStarterButton.addEventListener('click', function() { questionStarterButton.addEventListener('click', function() {
questionStarterSuggestions.style.display = "none"; questionStarterSuggestions.style.display = "none";
@@ -1601,7 +1610,7 @@ To get started, just start typing below. You can also type / to see a list of co
let closeButton = document.createElement('button'); let closeButton = document.createElement('button');
closeButton.id = "close-button"; closeButton.id = "close-button";
closeButton.innerHTML = "Close"; closeButton.textContent = "Close";
closeButton.classList.add("close-button"); closeButton.classList.add("close-button");
closeButton.addEventListener('click', function() { closeButton.addEventListener('click', function() {
modal.remove(); modal.remove();
@@ -1655,7 +1664,7 @@ To get started, just start typing below. You can also type / to see a list of co
let threeDotMenu = document.createElement('div'); let threeDotMenu = document.createElement('div');
threeDotMenu.classList.add("three-dot-menu"); threeDotMenu.classList.add("three-dot-menu");
let threeDotMenuButton = document.createElement('button'); let threeDotMenuButton = document.createElement('button');
threeDotMenuButton.innerHTML = "⋮"; threeDotMenuButton.textContent = "⋮";
threeDotMenuButton.classList.add("three-dot-menu-button"); threeDotMenuButton.classList.add("three-dot-menu-button");
threeDotMenuButton.addEventListener('click', function(event) { threeDotMenuButton.addEventListener('click', function(event) {
event.stopPropagation(); event.stopPropagation();
@@ -1674,7 +1683,7 @@ To get started, just start typing below. You can also type / to see a list of co
conversationMenu.classList.add("conversation-menu"); conversationMenu.classList.add("conversation-menu");
let editTitleButton = document.createElement('button'); let editTitleButton = document.createElement('button');
editTitleButton.innerHTML = "Rename"; editTitleButton.textContent = "Rename";
editTitleButton.classList.add("edit-title-button"); editTitleButton.classList.add("edit-title-button");
editTitleButton.classList.add("three-dot-menu-button-item"); editTitleButton.classList.add("three-dot-menu-button-item");
editTitleButton.addEventListener('click', function(event) { editTitleButton.addEventListener('click', function(event) {
@@ -1708,7 +1717,7 @@ To get started, just start typing below. You can also type / to see a list of co
conversationTitleInputBox.appendChild(conversationTitleInput); conversationTitleInputBox.appendChild(conversationTitleInput);
let conversationTitleInputButton = document.createElement('button'); let conversationTitleInputButton = document.createElement('button');
conversationTitleInputButton.innerHTML = "Save"; conversationTitleInputButton.textContent = "Save";
conversationTitleInputButton.classList.add("three-dot-menu-button-item"); conversationTitleInputButton.classList.add("three-dot-menu-button-item");
conversationTitleInputButton.addEventListener('click', function(event) { conversationTitleInputButton.addEventListener('click', function(event) {
event.stopPropagation(); event.stopPropagation();
@@ -1732,7 +1741,7 @@ To get started, just start typing below. You can also type / to see a list of co
threeDotMenu.appendChild(conversationMenu); threeDotMenu.appendChild(conversationMenu);
let shareButton = document.createElement('button'); let shareButton = document.createElement('button');
shareButton.innerHTML = "Share"; shareButton.textContent = "Share";
shareButton.type = "button"; shareButton.type = "button";
shareButton.classList.add("share-conversation-button"); shareButton.classList.add("share-conversation-button");
shareButton.classList.add("three-dot-menu-button-item"); shareButton.classList.add("three-dot-menu-button-item");
@@ -1799,7 +1808,7 @@ To get started, just start typing below. You can also type / to see a list of co
let deleteButton = document.createElement('button'); let deleteButton = document.createElement('button');
deleteButton.type = "button"; deleteButton.type = "button";
deleteButton.innerHTML = "Delete"; deleteButton.textContent = "Delete";
deleteButton.classList.add("delete-conversation-button"); deleteButton.classList.add("delete-conversation-button");
deleteButton.classList.add("three-dot-menu-button-item"); deleteButton.classList.add("three-dot-menu-button-item");
deleteButton.addEventListener('click', function(event) { deleteButton.addEventListener('click', function(event) {
@@ -1963,12 +1972,16 @@ To get started, just start typing below. You can also type / to see a list of co
} }
allFiles = data; allFiles = data;
var nofilesmessage = document.getElementsByClassName("no-files-message")[0]; var nofilesmessage = document.getElementsByClassName("no-files-message")[0];
nofilesmessage.innerHTML = "";
if(allFiles.length === 0){ if(allFiles.length === 0){
nofilesmessage.innerHTML = `<a class="inline-chat-link" href="https://docs.khoj.dev/category/clients/">How to upload files</a>`; let inlineChatLinkEl = document.createElement('a');
inlineChatLinkEl.className = "inline-chat-link";
inlineChatLinkEl.href = "https://docs.khoj.dev/category/clients/";
inlineChatLinkEl.textContent = "How to upload files";
nofilesmessage.appendChild(inlineChatLinkEl);
document.getElementsByClassName("file-toggle-button")[0].style.display = "none"; document.getElementsByClassName("file-toggle-button")[0].style.display = "none";
} }
else{ else{
nofilesmessage.innerHTML = "";
document.getElementsByClassName("file-toggle-button")[0].style.display = "block"; document.getElementsByClassName("file-toggle-button")[0].style.display = "block";
} }
}) })

View File

@@ -163,10 +163,6 @@
<div class="section-cards"> <div class="section-cards">
<div class="finalize-buttons"> <div class="finalize-buttons">
<button id="sync" type="submit" title="Regenerate index from scratch for Notion, GitHub configuration" style="display: flex; justify-content: center;"> <button id="sync" type="submit" title="Regenerate index from scratch for Notion, GitHub configuration" style="display: flex; justify-content: center;">
<img class="card-icon" src="/static/assets/icons/sync.svg" alt="Sync">
<h3 class="card-title">
Sync
</h3>
</button> </button>
</div> </div>
</div> </div>
@@ -192,11 +188,37 @@
</div> </div>
<div class="card-action-row"> <div class="card-action-row">
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %} {% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
<button id="save-model" class="card-button happy" onclick="updateChatModel()"> <button id="save-chat-model" class="card-button happy" onclick="updateChatModel()">
Save Save
</button> </button>
{% else %} {% else %}
<button id="save-model" class="card-button" disabled> <button id="save-chat-model" class="card-button" disabled>
Subscribe to use different models
</button>
{% endif %}
</div>
</div>
<div class="card">
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
<h3 class="card-title">
<span>Paint</span>
</h3>
</div>
<div class="card-description-row">
<select id="paint-models">
{% for option in paint_model_options %}
<option value="{{ option.id }}" {% if option.id == selected_paint_model_config %}selected{% endif %}>{{ option.model_name }}</option>
{% endfor %}
</select>
</div>
<div class="card-action-row">
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
<button id="save-paint-model" class="card-button happy" onclick="updatePaintModel()">
Save
</button>
{% else %}
<button id="save-paint-model" class="card-button" disabled>
Subscribe to use different models Subscribe to use different models
</button> </button>
{% endif %} {% endif %}
@@ -382,7 +404,8 @@
.then(data => { .then(data => {
if (data.status == "ok") { if (data.status == "ok") {
let notificationBanner = document.getElementById("notification-banner"); let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Profile name has been updated!"; notificationBanner.innerHTML = "";
notificationBanner.textContent = "Profile name has been updated!";
notificationBanner.style.display = "block"; notificationBanner.style.display = "block";
setTimeout(function() { setTimeout(function() {
notificationBanner.style.display = "none"; notificationBanner.style.display = "none";
@@ -394,8 +417,9 @@
function updateVoiceModel() { function updateVoiceModel() {
const voiceModel = document.getElementById("voice-models").value; const voiceModel = document.getElementById("voice-models").value;
const saveVoiceModelButton = document.getElementById("save-voice-model"); const saveVoiceModelButton = document.getElementById("save-voice-model");
saveVoiceModelButton.innerHTML = "";
saveVoiceModelButton.disabled = true; saveVoiceModelButton.disabled = true;
saveVoiceModelButton.innerHTML = "Saving..."; saveVoiceModelButton.textContent = "Saving...";
fetch('/api/config/data/voice/model?id=' + voiceModel, { fetch('/api/config/data/voice/model?id=' + voiceModel, {
method: 'POST', method: 'POST',
@@ -406,18 +430,19 @@
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
if (data.status == "ok") { if (data.status == "ok") {
saveVoiceModelButton.innerHTML = "Save"; saveVoiceModelButton.textContent = "Save";
saveVoiceModelButton.disabled = false; saveVoiceModelButton.disabled = false;
let notificationBanner = document.getElementById("notification-banner"); let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Voice model has been updated!"; notificationBanner.innerHTML = "";
notificationBanner.textContent = "Voice model has been updated!";
notificationBanner.style.display = "block"; notificationBanner.style.display = "block";
setTimeout(function() { setTimeout(function() {
notificationBanner.style.display = "none"; notificationBanner.style.display = "none";
}, 5000); }, 5000);
} else { } else {
saveVoiceModelButton.innerHTML = "Error"; saveVoiceModelButton.textContent = "Error";
saveVoiceModelButton.disabled = false; saveVoiceModelButton.disabled = false;
} }
}) })
@@ -425,9 +450,10 @@
function updateChatModel() { function updateChatModel() {
const chatModel = document.getElementById("chat-models").value; const chatModel = document.getElementById("chat-models").value;
const saveModelButton = document.getElementById("save-model"); const saveModelButton = document.getElementById("save-chat-model");
saveModelButton.disabled = true; saveModelButton.disabled = true;
saveModelButton.innerHTML = "Saving..."; saveModelButton.innerHTML = "";
saveModelButton.textContent = "Saving...";
fetch('/api/config/data/conversation/model?id=' + chatModel, { fetch('/api/config/data/conversation/model?id=' + chatModel, {
method: 'POST', method: 'POST',
@@ -438,18 +464,19 @@
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
if (data.status == "ok") { if (data.status == "ok") {
saveModelButton.innerHTML = "Save"; saveModelButton.textContent = "Save";
saveModelButton.disabled = false; saveModelButton.disabled = false;
let notificationBanner = document.getElementById("notification-banner"); let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Conversation model has been updated!"; notificationBanner.innerHTML = "";
notificationBanner.textContent = "Conversation model has been updated!";
notificationBanner.style.display = "block"; notificationBanner.style.display = "block";
setTimeout(function() { setTimeout(function() {
notificationBanner.style.display = "none"; notificationBanner.style.display = "none";
}, 5000); }, 5000);
} else { } else {
saveModelButton.innerHTML = "Error"; saveModelButton.textContent = "Error";
saveModelButton.disabled = false; saveModelButton.disabled = false;
} }
}) })
@@ -463,8 +490,9 @@
const searchModel = document.getElementById("search-models").value; const searchModel = document.getElementById("search-models").value;
const saveSearchModelButton = document.getElementById("save-search-model"); const saveSearchModelButton = document.getElementById("save-search-model");
saveSearchModelButton.innerHTML = "";
saveSearchModelButton.disabled = true; saveSearchModelButton.disabled = true;
saveSearchModelButton.innerHTML = "Saving..."; saveSearchModelButton.textContent = "Saving...";
fetch('/api/config/data/search/model?id=' + searchModel, { fetch('/api/config/data/search/model?id=' + searchModel, {
method: 'POST', method: 'POST',
@@ -475,15 +503,16 @@
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
if (data.status == "ok") { if (data.status == "ok") {
saveSearchModelButton.innerHTML = "Save"; saveSearchModelButton.textContent = "Save";
saveSearchModelButton.disabled = false; saveSearchModelButton.disabled = false;
} else { } else {
saveSearchModelButton.innerHTML = "Error"; saveSearchModelButton.textContent = "Error";
saveSearchModelButton.disabled = false; saveSearchModelButton.disabled = false;
} }
let notificationBanner = document.getElementById("notification-banner"); let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Khoj can now better understand the language of your content! Manually sync your data from one of the Khoj clients to update your knowledge base."; notificationBanner.innerHTML = "";
notificationBanner.textContent = "Khoj can now better understand the language of your content! Manually sync your data from one of the Khoj clients to update your knowledge base.";
notificationBanner.style.display = "block"; notificationBanner.style.display = "block";
setTimeout(function() { setTimeout(function() {
notificationBanner.style.display = "none"; notificationBanner.style.display = "none";
@@ -491,6 +520,38 @@
}) })
}; };
function updatePaintModel() {
const paintModel = document.getElementById("paint-models").value;
const saveModelButton = document.getElementById("save-paint-model");
saveModelButton.disabled = true;
saveModelButton.innerHTML = "Saving...";
fetch('/api/config/data/paint/model?id=' + paintModel, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
}
})
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
saveModelButton.innerHTML = "Save";
saveModelButton.disabled = false;
let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Paint model has been updated!";
notificationBanner.style.display = "block";
setTimeout(function() {
notificationBanner.style.display = "none";
}, 5000);
} else {
saveModelButton.innerHTML = "Error";
saveModelButton.disabled = false;
}
})
};
function clearContentType(content_source) { function clearContentType(content_source) {
fetch('/api/config/data/content-source/' + content_source, { fetch('/api/config/data/content-source/' + content_source, {
method: 'DELETE', method: 'DELETE',
@@ -549,23 +610,38 @@
}) })
} }
var sync = document.getElementById("sync"); function populateSyncButton() {
sync.addEventListener("click", function(event) { let syncIconEl = document.createElement("img");
syncIconEl.className = "card-icon";
syncIconEl.src = "/static/assets/icons/sync.svg";
syncIconEl.alt = "Sync";
let syncButtonTitleEl = document.createElement("h3");
syncButtonTitleEl.className = "card-title";
syncButtonTitleEl.textContent = "Sync";
return [syncButtonTitleEl, syncIconEl];
}
var syncButtonEl = document.getElementById("sync");
syncButtonEl.innerHTML = "";
syncButtonEl.append(...populateSyncButton());
syncButtonEl.addEventListener("click", function(event) {
event.preventDefault(); event.preventDefault();
updateIndex( updateIndex(
force=true, force=true,
successText="Synced!", successText="Synced!",
errorText="Unable to sync. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.", errorText="Unable to sync. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
button=sync, button=syncButtonEl,
loadingText="Syncing...", loadingText="Syncing...",
emoji=""); emoji="");
}); });
function updateIndex(force, successText, errorText, button, loadingText, emoji) { function updateIndex(force, successText, errorText, button, loadingText, emoji) {
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
const original_html = button.innerHTML;
button.disabled = true; button.disabled = true;
button.innerHTML = emoji + " " + loadingText; button.innerHTML = ""
button.textContent = emoji + " " + loadingText;
fetch('/api/update?&client=web&force=' + force, { fetch('/api/update?&client=web&force=' + force, {
method: 'GET', method: 'GET',
headers: { headers: {
@@ -582,19 +658,19 @@
document.getElementById("status").style.display = "none"; document.getElementById("status").style.display = "none";
button.disabled = false; button.disabled = false;
button.innerHTML = `${successText}`; button.textContent = `${successText}`;
setTimeout(function() { setTimeout(function() {
button.innerHTML = original_html; button.append(...populateSyncButton());
}, 2000); }, 2000);
}) })
.catch((error) => { .catch((error) => {
console.error('Error:', error); console.error('Error:', error);
document.getElementById("status").innerHTML = emoji + " " + errorText document.getElementById("status").textContent = emoji + " " + errorText
document.getElementById("status").style.display = "block"; document.getElementById("status").style.display = "block";
button.disabled = false; button.disabled = false;
button.innerHTML = '⚠️ Unsuccessful'; button.textContent = '⚠️ Unsuccessful';
setTimeout(function() { setTimeout(function() {
button.innerHTML = original_html; button.append(...populateSyncButton());
}, 2000); }, 2000);
}); });
@@ -629,7 +705,7 @@
}) })
.then(response => response.json()) .then(response => response.json())
.then(tokenObj => { .then(tokenObj => {
apiKeyList.innerHTML += generateTokenRow(tokenObj); apiKeyList.appendChild(generateTokenRow(tokenObj));
}); });
} }
@@ -638,16 +714,16 @@
navigator.clipboard.writeText(token); navigator.clipboard.writeText(token);
// Flash the API key copied icon // Flash the API key copied icon
const apiKeyColumn = document.getElementById(`api-key-${token}`); const apiKeyColumn = document.getElementById(`api-key-${token}`);
const original_html = apiKeyColumn.innerHTML; const original_text = apiKeyColumn.textContent;
const copyApiKeyButton = document.getElementById(`api-key-copy-${token}`); const copyApiKeyButton = document.getElementById(`api-key-copy-${token}`);
setTimeout(function() { setTimeout(function() {
copyApiKeyButton.src = "/static/assets/icons/copy-button-success.svg"; copyApiKeyButton.src = "/static/assets/icons/copy-button-success.svg";
setTimeout(() => { setTimeout(() => {
copyApiKeyButton.src = "/static/assets/icons/copy-button.svg"; copyApiKeyButton.src = "/static/assets/icons/copy-button.svg";
}, 1000); }, 1000);
apiKeyColumn.innerHTML = "✅ Copied!"; apiKeyColumn.textContent = "✅ Copied!";
setTimeout(function() { setTimeout(function() {
apiKeyColumn.innerHTML = original_html; apiKeyColumn.textContent = original_text;
}, 1000); }, 1000);
}, 100); }, 100);
} }
@@ -670,16 +746,50 @@
let tokenName = tokenObj.name; let tokenName = tokenObj.name;
let truncatedToken = token.slice(0, 4) + "..." + token.slice(-4); let truncatedToken = token.slice(0, 4) + "..." + token.slice(-4);
let tokenId = `${tokenName}-${truncatedToken}`; let tokenId = `${tokenName}-${truncatedToken}`;
return `
<tr id="api-key-item-${token}"> // Create API Key Row
<td><b>${tokenName}</b></td> let apiKeyItemEl = document.createElement("tr");
<td id="api-key-${token}">${truncatedToken}</td> apiKeyItemEl.id = `api-key-item-${token}`;
<td>
<img id="api-key-copy-${token}" onclick="copyAPIKey('${token}')" class="configured-icon api-key-action enabled" src="/static/assets/icons/copy-button.svg" alt="Copy API Key" title="Copy API Key"> // API Key Name Row
<img id="api-key-delete-${token}" onclick="deleteAPIKey('${token}')" class="configured-icon api-key-action enabled" src="/static/assets/icons/delete.svg" alt="Delete API Key" title="Delete API Key"> let apiKeyNameEl = document.createElement("td");
</td> let apiKeyNameTextEl = document.createElement("b");
</tr> apiKeyNameTextEl.textContent = tokenName;
`;
// API Key Token Row
let apiKeyTokenEl = document.createElement("td");
apiKeyTokenEl.id = `api-key-${token}`;
apiKeyTokenEl.textContent = truncatedToken;
// API Key Actions Row
let apiKeyActionsEl = document.createElement("td");
// Copy API Key Button
let copyApiKeyButtonEl = document.createElement("img");
copyApiKeyButtonEl.id = `api-key-copy-${token}`;
copyApiKeyButtonEl.className = "configured-icon api-key-action enabled";
copyApiKeyButtonEl.src = "/static/assets/icons/copy-button.svg";
copyApiKeyButtonEl.alt = "Copy API Key";
copyApiKeyButtonEl.title = "Copy API Key";
copyApiKeyButtonEl.onclick = function() {
copyAPIKey(token);
};
// Delete API Key Button
let deleteApiKeyButtonEl = document.createElement("img");
deleteApiKeyButtonEl.id = `api-key-delete-${token}`;
deleteApiKeyButtonEl.className = "configured-icon api-key-action enabled";
deleteApiKeyButtonEl.src = "/static/assets/icons/delete.svg";
deleteApiKeyButtonEl.alt = "Delete API Key";
deleteApiKeyButtonEl.title = "Delete API Key";
deleteApiKeyButtonEl.onclick = function() {
deleteAPIKey(token);
};
// Construct the API Key Row
apiKeyNameEl.append(apiKeyNameTextEl);
apiKeyActionsEl.append(copyApiKeyButtonEl, deleteApiKeyButtonEl);
apiKeyItemEl.append(apiKeyNameEl, apiKeyTokenEl, apiKeyActionsEl);
return apiKeyItemEl;
} }
function listApiKeys() { function listApiKeys() {
@@ -688,7 +798,7 @@
.then(response => response.json()) .then(response => response.json())
.then(tokens => { .then(tokens => {
if (!tokens?.length > 0) return; if (!tokens?.length > 0) return;
apiKeyList.innerHTML = tokens?.map(generateTokenRow).join(""); apiKeyList.append(...tokens?.map(generateTokenRow));
}); });
} }
@@ -696,11 +806,11 @@
listApiKeys(); listApiKeys();
function getIndexedDataSize() { function getIndexedDataSize() {
document.getElementById("indexed-data-size").innerHTML = "Calculating..."; document.getElementById("indexed-data-size").textContent = "Calculating...";
fetch('/api/config/index/size') fetch('/api/config/index/size')
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
document.getElementById("indexed-data-size").innerHTML = data.indexed_data_size_in_mb + " MB used"; document.getElementById("indexed-data-size").textContent = data.indexed_data_size_in_mb + " MB used";
}); });
} }
@@ -729,7 +839,7 @@
.catch(() => callback("us")) .catch(() => callback("us"))
}, },
separateDialCode: true, separateDialCode: true,
utilsScript: "https://cdn.jsdelivr.net/npm/intl-tel-input@18.2.1/build/js/utils.js", utilsScript: "https://assets.khoj.dev/intl-tel-input/utils.js",
}); });
const errorMap = ["Invalid number", "Invalid country code", "Too short", "Too long", "Invalid number"]; const errorMap = ["Invalid number", "Invalid country code", "Too short", "Too long", "Invalid number"];
@@ -800,7 +910,7 @@
phonenumberVerifyButton.addEventListener("click", () => { phonenumberVerifyButton.addEventListener("click", () => {
console.log(iti.getValidationError()); console.log(iti.getValidationError());
if (iti.isValidNumber() == false) { if (iti.isValidNumber() == false) {
phoneNumberUpdateCallback.innerHTML = "Invalid phone number: " + errorMap[iti.getValidationError()]; phoneNumberUpdateCallback.textContent = "Invalid phone number: " + errorMap[iti.getValidationError()];
phoneNumberUpdateCallback.style.display = "block"; phoneNumberUpdateCallback.style.display = "block";
setTimeout(function() { setTimeout(function() {
phoneNumberUpdateCallback.style.display = "none"; phoneNumberUpdateCallback.style.display = "none";
@@ -817,12 +927,12 @@
.then(data => { .then(data => {
if (data.status == "ok") { if (data.status == "ok") {
if (isTwilioEnabled == "True" || isTwilioEnabled == "true") { if (isTwilioEnabled == "True" || isTwilioEnabled == "true") {
phoneNumberUpdateCallback.innerHTML = "OTP sent to your phone number"; phoneNumberUpdateCallback.textContent = "OTP sent to your phone number";
phonenumberVerifyOTPButton.style.display = "block"; phonenumberVerifyOTPButton.style.display = "block";
phonenumberOTPInput.style.display = "block"; phonenumberOTPInput.style.display = "block";
} else { } else {
phonenumberVerifiedText.style.display = "block"; phonenumberVerifiedText.style.display = "block";
phoneNumberUpdateCallback.innerHTML = "Phone number updated"; phoneNumberUpdateCallback.textContent = "Phone number updated";
phonenumberUnverifiedText.style.display = "none"; phonenumberUnverifiedText.style.display = "none";
} }
phonenumberVerifyButton.style.display = "none"; phonenumberVerifyButton.style.display = "none";
@@ -831,7 +941,7 @@
phoneNumberUpdateCallback.style.display = "none"; phoneNumberUpdateCallback.style.display = "none";
}, 5000); }, 5000);
} else { } else {
phoneNumberUpdateCallback.innerHTML = "Error updating phone number"; phoneNumberUpdateCallback.textContent = "Error updating phone number";
phoneNumberUpdateCallback.style.display = "block"; phoneNumberUpdateCallback.style.display = "block";
setTimeout(function() { setTimeout(function() {
phoneNumberUpdateCallback.style.display = "none"; phoneNumberUpdateCallback.style.display = "none";
@@ -840,7 +950,7 @@
}) })
.catch((error) => { .catch((error) => {
console.error('Error:', error); console.error('Error:', error);
phoneNumberUpdateCallback.innerHTML = "Error updating phone number"; phoneNumberUpdateCallback.textContent = "Error updating phone number";
phoneNumberUpdateCallback.style.display = "block"; phoneNumberUpdateCallback.style.display = "block";
setTimeout(function() { setTimeout(function() {
phoneNumberUpdateCallback.style.display = "none"; phoneNumberUpdateCallback.style.display = "none";
@@ -852,7 +962,7 @@
phonenumberVerifyOTPButton.addEventListener("click", () => { phonenumberVerifyOTPButton.addEventListener("click", () => {
const otp = phonenumberOTPInput.value; const otp = phonenumberOTPInput.value;
if (otp.length != 6) { if (otp.length != 6) {
phoneNumberUpdateCallback.innerHTML = "Your OTP should be exactly 6 digits"; phoneNumberUpdateCallback.textContent = "Your OTP should be exactly 6 digits";
phoneNumberUpdateCallback.style.display = "block"; phoneNumberUpdateCallback.style.display = "block";
setTimeout(function() { setTimeout(function() {
phoneNumberUpdateCallback.style.display = "none"; phoneNumberUpdateCallback.style.display = "none";
@@ -869,7 +979,7 @@
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
if (data.status == "ok") { if (data.status == "ok") {
phoneNumberUpdateCallback.innerHTML = "Phone number updated"; phoneNumberUpdateCallback.textContent = "Phone number updated";
phonenumberVerifiedText.style.display = "block"; phonenumberVerifiedText.style.display = "block";
phonenumberUnverifiedText.style.display = "none"; phonenumberUnverifiedText.style.display = "none";
phoneNumberUpdateCallback.style.display = "block"; phoneNumberUpdateCallback.style.display = "block";
@@ -881,7 +991,7 @@
phoneNumberUpdateCallback.style.display = "none"; phoneNumberUpdateCallback.style.display = "none";
}, 5000); }, 5000);
} else { } else {
phoneNumberUpdateCallback.innerHTML = "Error updating phone number"; phoneNumberUpdateCallback.textContent = "Error updating phone number";
phoneNumberUpdateCallback.style.display = "block"; phoneNumberUpdateCallback.style.display = "block";
setTimeout(function() { setTimeout(function() {
phoneNumberUpdateCallback.style.display = "none"; phoneNumberUpdateCallback.style.display = "none";
@@ -890,7 +1000,7 @@
}) })
.catch((error) => { .catch((error) => {
console.error('Error:', error); console.error('Error:', error);
phoneNumberUpdateCallback.innerHTML = "Error updating phone number"; phoneNumberUpdateCallback.textContent = "Error updating phone number";
phoneNumberUpdateCallback.style.display = "block"; phoneNumberUpdateCallback.style.display = "block";
setTimeout(function() { setTimeout(function() {
phoneNumberUpdateCallback.style.display = "none"; phoneNumberUpdateCallback.style.display = "none";

View File

@@ -12,7 +12,7 @@
</h2> </h2>
<div class="section-manage-files"> <div class="section-manage-files">
<div id="delete-all-files" class="delete-all-files"> <div id="delete-all-files" class="delete-all-files">
<button id="delete-all-files" type="submit" title="Remove all computer files from Khoj">🗑️ Delete all</button> <button id="delete-all-files-button" type="submit" title="Remove all computer files from Khoj">🗑️ Delete all</button>
</div> </div>
<div class="indexed-files"> <div class="indexed-files">
</div> </div>
@@ -56,7 +56,10 @@
if (data.length == 0) { if (data.length == 0) {
document.getElementById("delete-all-files").style.display = "none"; document.getElementById("delete-all-files").style.display = "none";
indexedFiles.innerHTML = "<div class='card-description'>No documents synced with Khoj</div>"; let noFilesElement = document.createElement("div");
noFilesElement.classList.add("card-description");
noFilesElement.textContent = "No documents synced with Khoj";
indexedFiles.appendChild(noFilesElement);
} else { } else {
document.getElementById("get-desktop-client").style.display = "none"; document.getElementById("get-desktop-client").style.display = "none";
document.getElementById("delete-all-files").style.display = "block"; document.getElementById("delete-all-files").style.display = "block";
@@ -86,14 +89,14 @@
let fileNameElement = document.createElement("div"); let fileNameElement = document.createElement("div");
fileNameElement.classList.add("content-name"); fileNameElement.classList.add("content-name");
fileNameElement.innerHTML = filename; fileNameElement.textContent = filename;
fileElement.appendChild(fileNameElement); fileElement.appendChild(fileNameElement);
let buttonContainer = document.createElement("div"); let buttonContainer = document.createElement("div");
buttonContainer.classList.add("remove-button-container"); buttonContainer.classList.add("remove-button-container");
let removeFileButton = document.createElement("button"); let removeFileButton = document.createElement("button");
removeFileButton.classList.add("remove-file-button"); removeFileButton.classList.add("remove-file-button");
removeFileButton.innerHTML = "🗑️"; removeFileButton.textContent = "🗑️";
removeFileButton.addEventListener("click", ((filename) => { removeFileButton.addEventListener("click", ((filename) => {
return () => { return () => {
removeFile(filename); removeFile(filename);
@@ -112,9 +115,13 @@
// Get all currently indexed files on page load // Get all currently indexed files on page load
getAllComputerFilenames(); getAllComputerFilenames();
let deleteAllComputerFilesButton = document.getElementById("delete-all-files"); let deleteAllComputerFilesButton = document.getElementById("delete-all-files-button");
deleteAllComputerFilesButton.addEventListener("click", function(event) { deleteAllComputerFilesButton.addEventListener("click", function(event) {
event.preventDefault(); event.preventDefault();
originalDeleteAllComputerFilesButtonText = deleteAllComputerFilesButton.textContent;
deleteAllComputerFilesButton.textContent = "🗑️ Deleting...";
deleteAllComputerFilesButton.disabled = true;
fetch('/api/config/data/content-source/computer', { fetch('/api/config/data/content-source/computer', {
method: 'DELETE', method: 'DELETE',
headers: { headers: {
@@ -122,11 +129,11 @@
} }
}) })
.then(response => response.json()) .then(response => response.json())
.then(data => { .finally(() => {
if (data.status == "ok") { getAllComputerFilenames();
getAllComputerFilenames(); deleteAllComputerFilesButton.textContent = originalDeleteAllComputerFilesButtonText;
} deleteAllComputerFilesButton.disabled = false;
}) });
}); });
</script> </script>
{% endblock %} {% endblock %}

View File

@@ -70,18 +70,50 @@
repo.classList.add("repo"); repo.classList.add("repo");
const id = Date.now(); const id = Date.now();
repo.id = "repo-card-" + id; repo.id = "repo-card-" + id;
repo.innerHTML = `
<label for="repo-owner">Repository Owner</label> // Create repo owner, name, branch elements
<input type="text" id="repo-owner" name="repo_owner"> let repoOwnerLabel = document.createElement("label");
<label for="repo-name">Repository Name</label> repoOwnerLabel.textContent = "Repository Owner";
<input type="text" id="repo-name" name="repo_name"> repoOwnerLabel.for = "repo-owner";
<label for="repo-branch">Repository Branch</label>
<input type="text" id="repo-branch" name="repo_branch"> let repoOwner = document.createElement("input");
<button type="button" repoOwner.type = "text";
class="remove-repo-button" repoOwner.id = "repo-owner-" + id;
onclick="remove_repo(${id})" repoOwner.name = "repo_owner";
id="remove-repo-button-${id}">Remove Repository</button>
`; let repoNameLabel = document.createElement("label");
repoNameLabel.textContent = "Repository Name";
repoNameLabel.for = "repo-name";
let repoName = document.createElement("input");
repoName.type = "text";
repoName.id = "repo-name-" + id;
repoName.name = "repo_name";
let repoBranchLabel = document.createElement("label");
repoBranchLabel.textContent = "Repository Branch";
repoBranchLabel.for = "repo-branch";
let repoBranch = document.createElement("input");
repoBranch.type = "text";
repoBranch.id = "repo-branch-" + id;
repoBranch.name = "repo_branch";
let removeRepoButton = document.createElement("button");
removeRepoButton.type = "button";
removeRepoButton.classList.add("remove-repo-button");
removeRepoButton.onclick = function() { remove_repo(id); };
removeRepoButton.id = "remove-repo-button-" + id;
removeRepoButton.textContent = "Remove Repository";
// Append elements to repo card
repo.append(
repoOwnerLabel, repoOwner,
repoNameLabel, repoName,
repoBranchLabel, repoBranch,
removeRepoButton
);
document.getElementById("repositories").appendChild(repo); document.getElementById("repositories").appendChild(repo);
}) })
@@ -95,7 +127,7 @@
const pat_token = document.getElementById("pat-token").value; const pat_token = document.getElementById("pat-token").value;
if (pat_token == "") { if (pat_token == "") {
document.getElementById("success").innerHTML = "❌ Please enter a Personal Access Token."; document.getElementById("success").textContent = "❌ Please enter a Personal Access Token.";
document.getElementById("success").style.display = "block"; document.getElementById("success").style.display = "block";
return; return;
} }
@@ -122,14 +154,14 @@
} }
if (repos.length == 0) { if (repos.length == 0) {
document.getElementById("success").innerHTML = "❌ Please add at least one repository."; document.getElementById("success").textContent = "❌ Please add at least one repository.";
document.getElementById("success").style.display = "block"; document.getElementById("success").style.display = "block";
return; return;
} }
const submitButton = document.getElementById("submit"); const submitButton = document.getElementById("submit");
submitButton.disabled = true; submitButton.disabled = true;
submitButton.innerHTML = "Saving..."; submitButton.textContent = "Saving...";
// Save Github config on server // Save Github config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
@@ -147,11 +179,11 @@
.then(response => response.json()) .then(response => response.json())
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) }) .then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
.catch(error => { .catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Github settings."; document.getElementById("success").textContent = "⚠️ Failed to save Github settings.";
document.getElementById("success").style.display = "block"; document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save settings"; submitButton.textContent = "⚠️ Failed to save settings";
setTimeout(function() { setTimeout(function() {
submitButton.innerHTML = "Save"; submitButton.textContent = "Save";
submitButton.disabled = false; submitButton.disabled = false;
}, 2000); }, 2000);
return; return;
@@ -163,18 +195,18 @@
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) }) .then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
.then(data => { .then(data => {
document.getElementById("success").style.display = "none"; document.getElementById("success").style.display = "none";
submitButton.innerHTML = "✅ Successfully updated"; submitButton.textContent = "✅ Successfully updated";
setTimeout(function() { setTimeout(function() {
submitButton.innerHTML = "Save"; submitButton.textContent = "Save";
submitButton.disabled = false; submitButton.disabled = false;
}, 2000); }, 2000);
}) })
.catch(error => { .catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Github content."; document.getElementById("success").textContent = "⚠️ Failed to save Github content.";
document.getElementById("success").style.display = "block"; document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save content"; submitButton.textContent = "⚠️ Failed to save content";
setTimeout(function() { setTimeout(function() {
submitButton.innerHTML = "Save"; submitButton.textContent = "Save";
submitButton.disabled = false; submitButton.disabled = false;
}, 2000); }, 2000);
}); });

View File

@@ -34,14 +34,14 @@
const token = document.getElementById("token").value; const token = document.getElementById("token").value;
if (token == "") { if (token == "") {
document.getElementById("success").innerHTML = "❌ Please enter a Notion Token."; document.getElementById("success").textContent = "❌ Please enter a Notion Token.";
document.getElementById("success").style.display = "block"; document.getElementById("success").style.display = "block";
return; return;
} }
const submitButton = document.getElementById("submit"); const submitButton = document.getElementById("submit");
submitButton.disabled = true; submitButton.disabled = true;
submitButton.innerHTML = "Syncing..."; submitButton.textContent = "Syncing...";
// Save Notion config on server // Save Notion config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
@@ -58,11 +58,11 @@
.then(response => response.json()) .then(response => response.json())
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) }) .then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
.catch(error => { .catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion settings."; document.getElementById("success").textContent = "⚠️ Failed to save Notion settings.";
document.getElementById("success").style.display = "block"; document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save settings"; submitButton.textContent = "⚠️ Failed to save settings";
setTimeout(function() { setTimeout(function() {
submitButton.innerHTML = "Save"; submitButton.textContent = "Save";
submitButton.disabled = false; submitButton.disabled = false;
}, 2000); }, 2000);
return; return;
@@ -74,18 +74,18 @@
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) }) .then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
.then(data => { .then(data => {
document.getElementById("success").style.display = "none"; document.getElementById("success").style.display = "none";
submitButton.innerHTML = "✅ Successfully updated"; submitButton.textContent = "✅ Successfully updated";
setTimeout(function() { setTimeout(function() {
submitButton.innerHTML = "Save"; submitButton.textContent = "Save";
submitButton.disabled = false; submitButton.disabled = false;
}, 2000); }, 2000);
}) })
.catch(error => { .catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion content."; document.getElementById("success").textContent = "⚠️ Failed to save Notion content.";
document.getElementById("success").style.display = "block"; document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save content"; submitButton.textContent = "⚠️ Failed to save content";
setTimeout(function() { setTimeout(function() {
submitButton.innerHTML = "Save"; submitButton.textContent = "Save";
submitButton.disabled = false; submitButton.disabled = false;
}, 2000); }, 2000);
}); });

View File

@@ -127,7 +127,7 @@ To get started, just start typing below. You can also type / to see a list of co
linkElement.textContent = title; linkElement.textContent = title;
let referenceButton = document.createElement('button'); let referenceButton = document.createElement('button');
referenceButton.innerHTML = linkElement.outerHTML; referenceButton.appendChild(linkElement);
referenceButton.id = `ref-${index}`; referenceButton.id = `ref-${index}`;
referenceButton.classList.add("reference-button"); referenceButton.classList.add("reference-button");
referenceButton.classList.add("collapsed"); referenceButton.classList.add("collapsed");
@@ -138,11 +138,12 @@ To get started, just start typing below. You can also type / to see a list of co
if (this.classList.contains("collapsed")) { if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed"); this.classList.remove("collapsed");
this.classList.add("expanded"); this.classList.add("expanded");
this.innerHTML = linkElement.outerHTML + `<br><br>${question + snippet}`; this.innerHTML = `${linkElement.outerHTML}<br><br>${question + snippet}`;
} else { } else {
this.classList.add("collapsed"); this.classList.add("collapsed");
this.classList.remove("expanded"); this.classList.remove("expanded");
this.innerHTML = linkElement.outerHTML; this.innerHTML = "";
this.appendChild(linkElement);
} }
}); });
@@ -296,7 +297,7 @@ To get started, just start typing below. You can also type / to see a list of co
} }
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText; referenceExpandButton.textContent = expandButtonText;
references.appendChild(referenceSection); references.appendChild(referenceSection);
@@ -447,7 +448,7 @@ To get started, just start typing below. You can also type / to see a list of co
let referenceExpandButton = document.createElement('button'); let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button"); referenceExpandButton.classList.add("reference-expand-button");
referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`; referenceExpandButton.textContent = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.addEventListener('click', function() { referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) { if (referenceSection.classList.contains("collapsed")) {
@@ -815,7 +816,7 @@ Learn more [here](https://khoj.dev).
let closeButton = document.createElement('button'); let closeButton = document.createElement('button');
closeButton.id = "close-button"; closeButton.id = "close-button";
closeButton.innerHTML = "Close"; closeButton.textContent = "Close";
closeButton.classList.add("close-button"); closeButton.classList.add("close-button");
closeButton.addEventListener('click', function() { closeButton.addEventListener('click', function() {
modal.remove(); modal.remove();

View File

@@ -0,0 +1,118 @@
import base64
import logging
import os
from datetime import datetime
from typing import Dict, List, Tuple
from rapidocr_onnxruntime import RapidOCR
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry
logger = logging.getLogger(__name__)
class ImageToEntries(TextToEntries):
def __init__(self):
super().__init__()
# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified image files
with timer("Extract entries from specified Image files", logger):
file_to_text_map, current_entries = ImageToEntries.extract_image_entries(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.IMAGE,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_image_entries(image_files) -> Tuple[Dict, List[Entry]]: # important function
"""Extract entries by page from specified image files"""
file_to_text_map = dict()
entries: List[str] = []
entry_to_location_map: List[Tuple[str, str]] = []
for image_file in image_files:
try:
loader = RapidOCR()
bytes = image_files[image_file]
# write the image to a temporary file
timestamp_now = datetime.utcnow().timestamp()
# use either png or jpg
if image_file.endswith(".png"):
tmp_file = f"tmp_image_file_{timestamp_now}.png"
elif image_file.endswith(".jpg") or image_file.endswith(".jpeg"):
tmp_file = f"tmp_image_file_{timestamp_now}.jpg"
with open(tmp_file, "wb") as f:
bytes = image_files[image_file]
f.write(bytes)
try:
image_entries_per_file = ""
result, _ = loader(tmp_file)
if result:
expanded_entries = [text[1] for text in result]
image_entries_per_file = " ".join(expanded_entries)
except ImportError:
logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.")
continue
entry_to_location_map.append((image_entries_per_file, image_file))
entries.extend([image_entries_per_file])
file_to_text_map[image_file] = image_entries_per_file
except Exception as e:
logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
finally:
if os.path.exists(tmp_file):
os.remove(tmp_file)
return file_to_text_map, ImageToEntries.convert_image_entries_to_maps(entries, dict(entry_to_location_map))
@staticmethod
def convert_image_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
"Convert each image entries into a dictionary"
entries = []
for parsed_entry in parsed_entries:
entry_filename = entry_to_file_map[parsed_entry]
# Append base filename to compiled entry for context to model
heading = f"{entry_filename}\n"
compiled_entry = f"{heading}{parsed_entry}"
entries.append(
Entry(
compiled=compiled_entry,
raw=parsed_entry,
heading=heading,
file=f"{entry_filename}",
)
)
logger.debug(f"Converted {len(parsed_entries)} image entries to dictionaries")
return entries

View File

@@ -146,7 +146,7 @@ class MarkdownToEntries(TextToEntries):
else: else:
entry_filename = str(Path(raw_filename)) entry_filename = str(Path(raw_filename))
heading = parsed_entry.splitlines()[0] if re.search("^#+\s", parsed_entry) else "" heading = parsed_entry.splitlines()[0] if re.search(r"^#+\s", parsed_entry) else ""
# Append base filename to compiled entry for context to model # Append base filename to compiled entry for context to model
# Increment heading level for heading entries and make filename as its top level heading # Increment heading level for heading entries and make filename as its top level heading
prefix = f"# {entry_filename}\n#" if heading else f"# {entry_filename}\n" prefix = f"# {entry_filename}\n#" if heading else f"# {entry_filename}\n"

View File

@@ -115,14 +115,20 @@ class OrgToEntries(TextToEntries):
return entries, entry_to_file_map return entries, entry_to_file_map
# Split this entry tree into sections by the next heading level in it # Split this entry tree into sections by the next heading level in it
# Increment heading level until able to split entry into sections # Increment heading level until able to split entry into sections or reach max heading level
# A successful split will result in at least 2 sections # A successful split will result in at least 2 sections
max_heading_level = 100
next_heading_level = len(ancestry) next_heading_level = len(ancestry)
sections: List[str] = [] sections: List[str] = []
while len(sections) < 2: while len(sections) < 2 and next_heading_level < max_heading_level:
next_heading_level += 1 next_heading_level += 1
sections = re.split(rf"(\n|^)(?=[*]{{{next_heading_level}}} .+\n?)", org_content, flags=re.MULTILINE) sections = re.split(rf"(\n|^)(?=[*]{{{next_heading_level}}} .+\n?)", org_content, flags=re.MULTILINE)
# If unable to split entry into sections, log error and skip indexing it
if next_heading_level == max_heading_level:
logger.error(f"Unable to split current entry chunk: {org_content_with_ancestry[:20]}. Skip indexing it.")
return entries, entry_to_file_map
# Recurse down each non-empty section after parsing its body, heading and ancestry # Recurse down each non-empty section after parsing its body, heading and ancestry
for section in sections: for section in sections:
# Skip empty sections # Skip empty sections
@@ -135,7 +141,7 @@ class OrgToEntries(TextToEntries):
# If first non-empty line is a heading with expected heading level # If first non-empty line is a heading with expected heading level
if re.search(rf"^\*{{{next_heading_level}}}\s", first_non_empty_line): if re.search(rf"^\*{{{next_heading_level}}}\s", first_non_empty_line):
# Extract the section body without the heading # Extract the section body without the heading
current_section_body = "\n".join(section.split(first_non_empty_line)[1:]) current_section_body = "\n".join(section.split(first_non_empty_line, 1)[1:])
# Parse the section heading into current section ancestry # Parse the section heading into current section ancestry
current_section_title = first_non_empty_line[next_heading_level:].strip() current_section_title = first_non_empty_line[next_heading_level:].strip()
current_ancestry[next_heading_level] = current_section_title current_ancestry[next_heading_level] = current_section_title

View File

@@ -124,7 +124,7 @@ class TextToEntries(ABC):
deletion_filenames: Set[str] = None, deletion_filenames: Set[str] = None,
user: KhojUser = None, user: KhojUser = None,
regenerate: bool = False, regenerate: bool = False,
file_to_text_map: dict[str, List[str]] = None, file_to_text_map: dict[str, str] = None,
): ):
with timer("Constructed current entry hashes in", logger): with timer("Constructed current entry hashes in", logger):
hashes_by_file = dict[str, set[str]]() hashes_by_file = dict[str, set[str]]()
@@ -137,7 +137,7 @@ class TextToEntries(ABC):
if regenerate: if regenerate:
with timer("Cleared existing dataset for regeneration in", logger): with timer("Cleared existing dataset for regeneration in", logger):
logger.debug(f"Deleting all entries for file type {file_type}") logger.debug(f"Deleting all entries for file type {file_type}")
num_deleted_entries = EntryAdapters.delete_all_entries_by_type(user, file_type) num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type=file_type)
hashes_to_process = set() hashes_to_process = set()
with timer("Identified entries to add to database in", logger): with timer("Identified entries to add to database in", logger):
@@ -192,16 +192,17 @@ class TextToEntries(ABC):
logger.debug(f"Added {len(added_entries)} {file_type} entries to database") logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
if file_to_text_map: if file_to_text_map:
# get the list of file_names using added_entries with timer("Indexed text of modified file in", logger):
filenames_to_update = [entry.file_path for entry in added_entries] # get the set of modified files from added_entries
# for each file_name in filenames_to_update, try getting the file object and updating raw_text and if it fails create a new file object modified_files = {entry.file_path for entry in added_entries}
for file_name in filenames_to_update: # create or update text of each updated file indexed on DB
raw_text = " ".join(file_to_text_map[file_name]) for modified_file in modified_files:
file_object = FileObjectAdapters.get_file_objects_by_name(user, file_name) raw_text = file_to_text_map[modified_file]
if file_object: file_object = FileObjectAdapters.get_file_object_by_name(user, modified_file)
FileObjectAdapters.update_raw_text(file_object, raw_text) if file_object:
else: FileObjectAdapters.update_raw_text(file_object, raw_text)
FileObjectAdapters.create_file_object(user, file_name, raw_text) else:
FileObjectAdapters.create_file_object(user, modified_file, raw_text)
new_dates = [] new_dates = []
with timer("Indexed dates from added entries in", logger): with timer("Indexed dates from added entries in", logger):

View File

@@ -99,15 +99,13 @@ def anthropic_llm_thread(
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
] ]
max_prompt_size = max_prompt_size or DEFAULT_MAX_TOKENS_ANTHROPIC
with client.messages.stream( with client.messages.stream(
messages=formatted_messages, messages=formatted_messages,
model=model_name, # type: ignore model=model_name, # type: ignore
temperature=temperature, temperature=temperature,
system=system_prompt, system=system_prompt,
timeout=20, timeout=20,
max_tokens=max_prompt_size, max_tokens=DEFAULT_MAX_TOKENS_ANTHROPIC,
**(model_kwargs or dict()), **(model_kwargs or dict()),
) as stream: ) as stream:
for text in stream.text_stream: for text in stream.text_stream:

View File

@@ -154,7 +154,7 @@ def converse(
completion_func(chat_response=prompts.no_online_results_found.format()) completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()]) return iter([prompts.no_online_results_found.format()])
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: if not is_none_or_empty(online_results):
conversation_primer = ( conversation_primer = (
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}" f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
) )

View File

@@ -1,5 +1,4 @@
import logging import logging
import os
from threading import Thread from threading import Thread
from typing import Dict from typing import Dict
@@ -40,7 +39,7 @@ def completion_with_backoff(
client: openai.OpenAI = openai_clients.get(client_key) client: openai.OpenAI = openai_clients.get(client_key)
if not client: if not client:
client = openai.OpenAI( client = openai.OpenAI(
api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), api_key=openai_api_key,
base_url=api_base_url, base_url=api_base_url,
) )
openai_clients[client_key] = client openai_clients[client_key] = client
@@ -102,7 +101,7 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
if client_key not in openai_clients: if client_key not in openai_clients:
client: openai.OpenAI = openai.OpenAI( client: openai.OpenAI = openai.OpenAI(
api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), api_key=openai_api_key,
base_url=api_base_url, base_url=api_base_url,
) )
openai_clients[client_key] = client openai_clients[client_key] = client

View File

@@ -121,7 +121,7 @@ User's Notes:
## Image Generation ## Image Generation
## -- ## --
image_generation_improve_prompt = PromptTemplate.from_template( image_generation_improve_prompt_dalle = PromptTemplate.from_template(
""" """
You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. You are provided with the following information to help you generate the prompt: You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. You are provided with the following information to help you generate the prompt:
@@ -143,6 +143,35 @@ Remember, now you are generating a prompt to improve the image generation. Add a
Improved Query:""" Improved Query:"""
) )
image_generation_improve_prompt_sd = PromptTemplate.from_template(
"""
You are a talented creator. Write 2-5 sentences with precise image composition, position details to create an image.
Use the provided context below to add specific, fine details to the image composition.
Retain any important information and follow any instructions from the original prompt.
Put any text to be rendered in the image within double quotes in your improved prompt.
You are provided with the following context to help enhance the original prompt:
Today's Date: {current_date}
User's Location: {location}
User's Notes:
{references}
Online References:
{online_results}
Conversation Log:
{chat_history}
Original Prompt: "{query}"
Now create an improved prompt using the context provided above to generate an image.
Retain any important information and follow any instructions from the original prompt.
Use the additional context from the user's notes, online references and conversation log to improve the image generation.
Improved Prompt:"""
)
## Online Search Conversation ## Online Search Conversation
## -- ## --
online_search_conversation = PromptTemplate.from_template( online_search_conversation = PromptTemplate.from_template(

View File

@@ -2,11 +2,11 @@ import asyncio
import json import json
import logging import logging
import os import os
import urllib.parse
from collections import defaultdict from collections import defaultdict
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import aiohttp import aiohttp
import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from markdownify import markdownify from markdownify import markdownify
@@ -23,6 +23,10 @@ logger = logging.getLogger(__name__)
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY") SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
SERPER_DEV_URL = "https://google.serper.dev/search" SERPER_DEV_URL = "https://google.serper.dev/search"
JINA_READER_API_URL = "https://r.jina.ai/"
JINA_SEARCH_API_URL = "https://s.jina.ai/"
JINA_API_KEY = os.getenv("JINA_API_KEY")
OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY") OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY")
OLOSTEP_API_URL = "https://agent.olostep.com/olostep-p2p-incomingAPI" OLOSTEP_API_URL = "https://agent.olostep.com/olostep-p2p-incomingAPI"
OLOSTEP_QUERY_PARAMS = { OLOSTEP_QUERY_PARAMS = {
@@ -50,9 +54,6 @@ async def search_online(
custom_filters: List[str] = [], custom_filters: List[str] = [],
): ):
query += " ".join(custom_filters) query += " ".join(custom_filters)
if not online_search_enabled():
logger.warn("SERPER_DEV_API_KEY is not set")
return {}
if not is_internet_connected(): if not is_internet_connected():
logger.warn("Cannot search online as not connected to internet") logger.warn("Cannot search online as not connected to internet")
return {} return {}
@@ -61,27 +62,35 @@ async def search_online(
subqueries = await generate_online_subqueries(query, conversation_history, location) subqueries = await generate_online_subqueries(query, conversation_history, location)
response_dict = {} response_dict = {}
for subquery in subqueries: if subqueries:
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
if send_status_func: if send_status_func:
await send_status_func(f"**🌐 Searching the Internet for**: {subquery}") subqueries_str = "\n- " + "\n- ".join(list(subqueries))
logger.info(f"🌐 Searching the Internet for '{subquery}'") await send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}")
response_dict[subquery] = search_with_google(subquery)
# Gather distinct web pages from organic search results of each subquery without an instant answer with timer(f"Internet searches for {list(subqueries)} took", logger):
webpage_links = { search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
organic["link"]: subquery search_tasks = [search_func(subquery) for subquery in subqueries]
search_results = await asyncio.gather(*search_tasks)
response_dict = {subquery: search_result for subquery, search_result in search_results}
# Gather distinct web page data from organic results of each subquery without an instant answer.
# Content of web pages is directly available when Jina is used for search.
webpages = {
(organic.get("link"), subquery, organic.get("content"))
for subquery in response_dict for subquery in response_dict
for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ] for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]
if "answerBox" not in response_dict[subquery] if "answerBox" not in response_dict[subquery]
} }
# Read, extract relevant info from the retrieved web pages # Read, extract relevant info from the retrieved web pages
if webpage_links: if webpages:
webpage_links = [link for link, _, _ in webpages]
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}") logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
if send_status_func: if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links)) webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}") await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}")
tasks = [read_webpage_and_extract_content(subquery, link) for link, subquery in webpage_links.items()] tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
# Collect extracted info from the retrieved web pages # Collect extracted info from the retrieved web pages
@@ -92,23 +101,24 @@ async def search_online(
return response_dict return response_dict
def search_with_google(subquery: str): async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
payload = json.dumps({"q": subquery}) payload = json.dumps({"q": query})
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"} headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
response = requests.request("POST", SERPER_DEV_URL, headers=headers, data=payload) async with aiohttp.ClientSession() as session:
async with session.post(SERPER_DEV_URL, headers=headers, data=payload) as response:
if response.status != 200:
logger.error(await response.text())
return query, {}
json_response = await response.json()
extraction_fields = ["organic", "answerBox", "peopleAlsoAsk", "knowledgeGraph"]
extracted_search_result = {
field: json_response[field]
for field in extraction_fields
if not is_none_or_empty(json_response.get(field))
}
if response.status_code != 200: return query, extracted_search_result
logger.error(response.text)
return {}
json_response = response.json()
extraction_fields = ["organic", "answerBox", "peopleAlsoAsk", "knowledgeGraph"]
extracted_search_result = {
field: json_response[field] for field in extraction_fields if not is_none_or_empty(json_response.get(field))
}
return extracted_search_result
async def read_webpages( async def read_webpages(
@@ -134,10 +144,13 @@ async def read_webpages(
return response return response
async def read_webpage_and_extract_content(subquery: str, url: str) -> Tuple[str, Union[None, str], str]: async def read_webpage_and_extract_content(
subquery: str, url: str, content: str = None
) -> Tuple[str, Union[None, str], str]:
try: try:
with timer(f"Reading web page at '{url}' took", logger): if is_none_or_empty(content):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_at_url(url) with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger): with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subquery, content) extracted_info = await extract_relevant_info(subquery, content)
return subquery, extracted_info, url return subquery, extracted_info, url
@@ -172,5 +185,41 @@ async def read_webpage_with_olostep(web_url: str) -> str:
return response_json["markdown_content"] return response_json["markdown_content"]
def online_search_enabled(): async def read_webpage_with_jina(web_url: str) -> str:
return SERPER_DEV_API_KEY is not None jina_reader_api_url = f"{JINA_READER_API_URL}/{web_url}"
headers = {"Accept": "application/json", "X-Timeout": "30"}
if JINA_API_KEY:
headers["Authorization"] = f"Bearer {JINA_API_KEY}"
async with aiohttp.ClientSession() as session:
async with session.get(jina_reader_api_url, headers=headers) as response:
response.raise_for_status()
response_json = await response.json()
return response_json["data"]["content"]
async def search_with_jina(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
encoded_query = urllib.parse.quote(query)
jina_search_api_url = f"{JINA_SEARCH_API_URL}/{encoded_query}"
headers = {"Accept": "application/json"}
if JINA_API_KEY:
headers["Authorization"] = f"Bearer {JINA_API_KEY}"
async with aiohttp.ClientSession() as session:
async with session.get(jina_search_api_url, headers=headers) as response:
if response.status != 200:
logger.error(await response.text())
return query, {}
response_json = await response.json()
parsed_response = [
{
"title": item["title"],
"content": item.get("content"),
# rename description -> snippet for consistency
"snippet": item["description"],
# rename url -> link for consistency
"link": item["url"],
}
for item in response_json["data"]
]
return query, {"organic": parsed_response}

View File

@@ -13,6 +13,7 @@ from starlette.authentication import requires
from starlette.websockets import WebSocketDisconnect from starlette.websockets import WebSocketDisconnect
from websockets import ConnectionClosedOK from websockets import ConnectionClosedOK
from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import ( from khoj.database.adapters import (
ConversationAdapters, ConversationAdapters,
EntryAdapters, EntryAdapters,
@@ -28,11 +29,7 @@ from khoj.processor.conversation.prompts import (
) )
from khoj.processor.conversation.utils import save_to_conversation_log from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import ( from khoj.processor.tools.online_search import read_webpages, search_online
online_search_enabled,
read_webpages,
search_online,
)
from khoj.routers.api import extract_references_and_questions from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ApiUserRateLimiter, ApiUserRateLimiter,
@@ -153,7 +150,17 @@ async def sendfeedback(request: Request, data: FeedbackData):
@api_chat.post("/speech") @api_chat.post("/speech")
@requires(["authenticated", "premium"]) @requires(["authenticated", "premium"])
async def text_to_speech(request: Request, common: CommonQueryParams, text: str): async def text_to_speech(
request: Request,
common: CommonQueryParams,
text: str,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=20, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=300, window=60 * 60 * 24, slug="chat_day")
),
) -> Response:
voice_model = await ConversationAdapters.aget_voice_model_config(request.user.object) voice_model = await ConversationAdapters.aget_voice_model_config(request.user.object)
params = {"text_to_speak": text} params = {"text_to_speak": text}
@@ -350,17 +357,19 @@ def duplicate_chat_history_public_conversation(
conversation_id: int, conversation_id: int,
): ):
user = request.user.object user = request.user.object
domain = request.headers.get("host")
scheme = request.url.scheme
# Throw unauthorized exception if domain not in ALLOWED_HOSTS
host_domain = domain.split(":")[0]
if host_domain not in ALLOWED_HOSTS:
raise HTTPException(status_code=401, detail="Unauthorized domain")
# Duplicate Conversation History to Public Conversation # Duplicate Conversation History to Public Conversation
conversation = ConversationAdapters.get_conversation_by_user(user, request.user.client_app, conversation_id) conversation = ConversationAdapters.get_conversation_by_user(user, request.user.client_app, conversation_id)
public_conversation = ConversationAdapters.make_public_conversation_copy(conversation) public_conversation = ConversationAdapters.make_public_conversation_copy(conversation)
public_conversation_url = PublicConversationAdapters.get_public_conversation_url(public_conversation) public_conversation_url = PublicConversationAdapters.get_public_conversation_url(public_conversation)
domain = request.headers.get("host")
scheme = request.url.scheme
update_telemetry_state( update_telemetry_state(
request=request, request=request,
telemetry_type="api", telemetry_type="api",
@@ -610,6 +619,7 @@ async def websocket_endpoint(
meta_log = conversation.conversation_log meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
if conversation_commands == [ConversationCommand.Default] or is_automated_task: if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task) conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
@@ -625,8 +635,18 @@ async def websocket_endpoint(
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd) await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
q = q.replace(f"/{cmd.value}", "").strip() q = q.replace(f"/{cmd.value}", "").strip()
if ConversationCommand.Summarize in conversation_commands: file_filters = conversation.file_filters if conversation else []
file_filters = conversation.file_filters # Skip trying to summarize if
if (
# summarization intent was inferred
ConversationCommand.Summarize in conversation_commands
# and not triggered via slash command
and not used_slash_summarize
# but we can't actually summarize
and len(file_filters) != 1
):
conversation_commands.remove(ConversationCommand.Summarize)
elif ConversationCommand.Summarize in conversation_commands:
response_log = "" response_log = ""
if len(file_filters) == 0: if len(file_filters) == 0:
response_log = "No files selected for summarization. Please add files using the section on the left." response_log = "No files selected for summarization. Please add files using the section on the left."
@@ -741,22 +761,16 @@ async def websocket_endpoint(
conversation_commands.remove(ConversationCommand.Notes) conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands: if ConversationCommand.Online in conversation_commands:
if not online_search_enabled(): try:
conversation_commands.remove(ConversationCommand.Online) online_results = await search_online(
# If online search is not enabled, try to read webpages directly defiltered_query, meta_log, location, send_status_update, custom_filters
if ConversationCommand.Webpage not in conversation_commands: )
conversation_commands.append(ConversationCommand.Webpage) except ValueError as e:
else: logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
try: await send_complete_llm_response(
online_results = await search_online( f"Error searching online: {e}. Attempting to respond without online results"
defiltered_query, meta_log, location, send_status_update, custom_filters )
) continue
except ValueError as e:
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
await send_complete_llm_response(
f"Error searching online: {e}. Attempting to respond without online results"
)
continue
if ConversationCommand.Webpage in conversation_commands: if ConversationCommand.Webpage in conversation_commands:
try: try:
@@ -1041,18 +1055,10 @@ async def chat(
conversation_commands.remove(ConversationCommand.Notes) conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands: if ConversationCommand.Online in conversation_commands:
if not online_search_enabled(): try:
conversation_commands.remove(ConversationCommand.Online) online_results = await search_online(defiltered_query, meta_log, location, custom_filters=_custom_filters)
# If online search is not enabled, try to read webpages directly except ValueError as e:
if ConversationCommand.Webpage not in conversation_commands: logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
conversation_commands.append(ConversationCommand.Webpage)
else:
try:
online_results = await search_online(
defiltered_query, meta_log, location, custom_filters=_custom_filters
)
except ValueError as e:
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
if ConversationCommand.Webpage in conversation_commands: if ConversationCommand.Webpage in conversation_commands:
try: try:

View File

@@ -183,7 +183,7 @@ async def remove_content_source_data(
raise ValueError(f"Invalid content source: {content_source}") raise ValueError(f"Invalid content source: {content_source}")
elif content_object != "Computer": elif content_object != "Computer":
await content_object.objects.filter(user=user).adelete() await content_object.objects.filter(user=user).adelete()
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source) await sync_to_async(EntryAdapters.delete_all_entries)(user, file_source=content_source)
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user) enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
return {"status": "ok"} return {"status": "ok"}
@@ -341,6 +341,35 @@ async def update_search_model(
return {"status": "ok"} return {"status": "ok"}
@api_config.post("/data/paint/model", status_code=200)
@requires(["authenticated"])
async def update_paint_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_paint_model",
client=client,
metadata={"paint_model": new_config.setting.model_name},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}
@api_config.get("/index/size", response_model=Dict[str, int]) @api_config.get("/index/size", response_model=Dict[str, int])
@requires(["authenticated"]) @requires(["authenticated"])
async def get_indexed_data_size(request: Request, common: CommonQueryParams): async def get_indexed_data_size(request: Request, common: CommonQueryParams):

View File

@@ -42,8 +42,12 @@ if not state.anonymous_mode:
from google.oauth2 import id_token from google.oauth2 import id_token
except ImportError: except ImportError:
missing_requirements += ["Install the Khoj production package with `pip install khoj-assistant[prod]`"] missing_requirements += ["Install the Khoj production package with `pip install khoj-assistant[prod]`"]
if not os.environ.get("GOOGLE_CLIENT_ID") or not os.environ.get("GOOGLE_CLIENT_SECRET"): if not os.environ.get("RESEND_API_KEY") and (
missing_requirements += ["Set your GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET as environment variables"] not os.environ.get("GOOGLE_CLIENT_ID") or not os.environ.get("GOOGLE_CLIENT_SECRET")
):
missing_requirements += [
"Set your RESEND_API_KEY or GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET as environment variables"
]
if missing_requirements: if missing_requirements:
requirements_string = "\n - " + "\n - ".join(missing_requirements) requirements_string = "\n - " + "\n - ".join(missing_requirements)
error_msg = f"🚨 Start Khoj with --anonymous-mode flag or to enable authentication:{requirements_string}" error_msg = f"🚨 Start Khoj with --anonymous-mode flag or to enable authentication:{requirements_string}"

View File

@@ -453,12 +453,14 @@ async def generate_better_image_prompt(
location_data: LocationData, location_data: LocationData,
note_references: List[Dict[str, Any]], note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None, online_results: Optional[dict] = None,
model_type: Optional[str] = None,
) -> str: ) -> str:
""" """
Generate a better image prompt from the given query Generate a better image prompt from the given query
""" """
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
if location_data: if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}" location = f"{location_data.city}, {location_data.region}, {location_data.country}"
@@ -477,21 +479,34 @@ async def generate_better_image_prompt(
elif online_results[result].get("webpages"): elif online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"] simplified_online_results[result] = online_results[result]["webpages"]
image_prompt = prompts.image_generation_improve_prompt.format( if model_type == TextToImageModelConfig.ModelType.OPENAI:
query=q, image_prompt = prompts.image_generation_improve_prompt_dalle.format(
chat_history=conversation_history, query=q,
location=location_prompt, chat_history=conversation_history,
current_date=today_date, location=location_prompt,
references=user_references, current_date=today_date,
online_results=simplified_online_results, references=user_references,
) online_results=simplified_online_results,
)
elif model_type == TextToImageModelConfig.ModelType.STABILITYAI:
image_prompt = prompts.image_generation_improve_prompt_sd.format(
query=q,
chat_history=conversation_history,
location=location_prompt,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
)
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config() summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
with timer("Chat actor: Generate contextual image prompt", logger): with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model) response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
return response.strip() return response
async def send_message_to_model_wrapper( async def send_message_to_model_wrapper(
@@ -747,74 +762,110 @@ async def text_to_image(
image_url = None image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config: if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error # If the user has not configured a text to image model, return an unsupported on server error
status_code = 501 status_code = 501
message = "Failed to generate image. Setup image generation on the server." message = "Failed to generate image. Setup image generation on the server."
return image_url or image, status_code, message, intent_type.value return image_url or image, status_code, message, intent_type.value
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
logger.info("Generating image with OpenAI") text2image_model = text_to_image_config.model_name
text2image_model = text_to_image_config.model_name chat_history = ""
chat_history = "" for chat in conversation_log.get("chat", [])[-4:]:
for chat in conversation_log.get("chat", [])[-4:]: if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]: chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"A: {chat['message']}\n"
chat_history += f"A: {chat['message']}\n" elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"): chat_history += f"Q: Query: {chat['intent']['query']}\n"
chat_history += f"Q: Query: {chat['intent']['query']}\n" chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n"
chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n"
try: with timer("Improve the original user query", logger):
with timer("Improve the original user query", logger): if send_status_func:
if send_status_func: await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
await send_status_func("**✍🏽 Enhancing the Painting Prompt**") improved_image_prompt = await generate_better_image_prompt(
improved_image_prompt = await generate_better_image_prompt( message,
message, chat_history,
chat_history, location_data=location_data,
location_data=location_data, note_references=references,
note_references=references, online_results=online_results,
online_results=online_results, model_type=text_to_image_config.model_type,
) )
with timer("Generate image with OpenAI", logger):
if send_status_func: if send_status_func:
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}") await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger):
if text_to_image_config.api_key:
api_key = text_to_image_config.api_key
elif text_to_image_config.openai_config:
api_key = text_to_image_config.openai_config.api_key
elif state.openai_client:
api_key = state.openai_client.api_key
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
try:
response = state.openai_client.images.generate( response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" prompt=improved_image_prompt,
model=text2image_model,
response_format="b64_json",
extra_headers=auth_header,
) )
image = response.data[0].b64_json image = response.data[0].b64_json
with timer("Convert image to webp", logger):
# Convert png to webp for faster loading
decoded_image = base64.b64decode(image) decoded_image = base64.b64decode(image)
image_io = io.BytesIO(decoded_image) except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
png_image = Image.open(image_io) if "content_policy_violation" in e.message:
webp_image_io = io.BytesIO() logger.error(f"Image Generation blocked by OpenAI: {e}")
png_image.save(webp_image_io, "WEBP") status_code = e.status_code # type: ignore
webp_image_bytes = webp_image_io.getvalue() message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
webp_image_io.close() return image_url or image, status_code, message, intent_type.value
image_io.close() else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value
with timer("Upload image to S3", logger): elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
image_url = upload_image(webp_image_bytes, user.uuid) with timer("Generate image with Stability AI", logger):
if image_url: try:
intent_type = ImageIntentType.TEXT_TO_IMAGE2 response = requests.post(
else: f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
image = base64.b64encode(webp_image_bytes).decode("utf-8") files={"none": ""},
data={
return image_url or image, status_code, improved_image_prompt, intent_type.value "prompt": improved_image_prompt,
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: "model": text2image_model,
if "content_policy_violation" in e.message: "mode": "text-to-image",
logger.error(f"Image Generation blocked by OpenAI: {e}") "output_format": "png",
status_code = e.status_code # type: ignore "seed": 1032622926,
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore "aspect_ratio": "1:1",
return image_url or image, status_code, message, intent_type.value },
else: )
decoded_image = response.content
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True) logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore message = f"Image generation failed with Stability AI error: {e}"
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value return image_url or image, status_code, message, intent_type.value
return image_url or image, status_code, response, intent_type.value
with timer("Convert image to webp", logger):
# Convert png to webp for faster loading
image_io = io.BytesIO(decoded_image)
png_image = Image.open(image_io)
webp_image_io = io.BytesIO()
png_image.save(webp_image_io, "WEBP")
webp_image_bytes = webp_image_io.getvalue()
webp_image_io.close()
image_io.close()
with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid)
if image_url:
intent_type = ImageIntentType.TEXT_TO_IMAGE2
else:
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")
return image_url or image, status_code, improved_image_prompt, intent_type.value
class ApiUserRateLimiter: class ApiUserRateLimiter:

View File

@@ -9,6 +9,7 @@ from starlette.authentication import requires
from khoj.database.models import GithubConfig, KhojUser, NotionConfig from khoj.database.models import GithubConfig, KhojUser, NotionConfig
from khoj.processor.content.docx.docx_to_entries import DocxToEntries from khoj.processor.content.docx.docx_to_entries import DocxToEntries
from khoj.processor.content.github.github_to_entries import GithubToEntries from khoj.processor.content.github.github_to_entries import GithubToEntries
from khoj.processor.content.images.image_to_entries import ImageToEntries
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.notion.notion_to_entries import NotionToEntries from khoj.processor.content.notion.notion_to_entries import NotionToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
@@ -41,6 +42,7 @@ class IndexerInput(BaseModel):
markdown: Optional[dict[str, str]] = None markdown: Optional[dict[str, str]] = None
pdf: Optional[dict[str, bytes]] = None pdf: Optional[dict[str, bytes]] = None
plaintext: Optional[dict[str, str]] = None plaintext: Optional[dict[str, str]] = None
image: Optional[dict[str, bytes]] = None
docx: Optional[dict[str, bytes]] = None docx: Optional[dict[str, bytes]] = None
@@ -65,7 +67,14 @@ async def update(
), ),
): ):
user = request.user.object user = request.user.object
index_files: Dict[str, Dict[str, str]] = {"org": {}, "markdown": {}, "pdf": {}, "plaintext": {}, "docx": {}} index_files: Dict[str, Dict[str, str]] = {
"org": {},
"markdown": {},
"pdf": {},
"plaintext": {},
"image": {},
"docx": {},
}
try: try:
logger.info(f"📬 Updating content index via API call by {client} client") logger.info(f"📬 Updating content index via API call by {client} client")
for file in files: for file in files:
@@ -81,6 +90,7 @@ async def update(
markdown=index_files["markdown"], markdown=index_files["markdown"],
pdf=index_files["pdf"], pdf=index_files["pdf"],
plaintext=index_files["plaintext"], plaintext=index_files["plaintext"],
image=index_files["image"],
docx=index_files["docx"], docx=index_files["docx"],
) )
@@ -133,6 +143,7 @@ async def update(
"num_markdown": len(index_files["markdown"]), "num_markdown": len(index_files["markdown"]),
"num_pdf": len(index_files["pdf"]), "num_pdf": len(index_files["pdf"]),
"num_plaintext": len(index_files["plaintext"]), "num_plaintext": len(index_files["plaintext"]),
"num_image": len(index_files["image"]),
"num_docx": len(index_files["docx"]), "num_docx": len(index_files["docx"]),
} }
@@ -300,6 +311,23 @@ def configure_content(
logger.error(f"🚨 Failed to setup Notion: {e}", exc_info=True) logger.error(f"🚨 Failed to setup Notion: {e}", exc_info=True)
success = False success = False
try:
# Initialize Image Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value) and files[
"image"
]:
logger.info("🖼️ Setting up search for images")
# Extract Entries, Generate Image Embeddings
text_search.setup(
ImageToEntries,
files.get("image"),
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
success = False
try: try:
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Docx.value) and files["docx"]: if (search_type == state.SearchType.All.value or search_type == state.SearchType.Docx.value) and files["docx"]:
logger.info("📄 Setting up search for docx") logger.info("📄 Setting up search for docx")

View File

@@ -261,6 +261,12 @@ def config_page(request: Request):
current_search_model_option = adapters.get_user_search_model_or_default(user) current_search_model_option = adapters.get_user_search_model_or_default(user)
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list()
for paint_model in paint_model_options:
all_paint_model_options.append({"model_name": paint_model.model_name, "id": paint_model.id})
notion_oauth_url = get_notion_auth_url(user) notion_oauth_url = get_notion_auth_url(user)
eleven_labs_enabled = is_eleven_labs_enabled() eleven_labs_enabled = is_eleven_labs_enabled()
@@ -283,10 +289,12 @@ def config_page(request: Request):
"anonymous_mode": state.anonymous_mode, "anonymous_mode": state.anonymous_mode,
"username": user.username, "username": user.username,
"given_name": given_name, "given_name": given_name,
"conversation_options": all_conversation_options,
"search_model_options": all_search_model_options, "search_model_options": all_search_model_options,
"selected_search_model_config": current_search_model_option.id, "selected_search_model_config": current_search_model_option.id,
"conversation_options": all_conversation_options,
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None, "selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
"paint_model_options": all_paint_model_options,
"selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None,
"user_photo": user_picture, "user_photo": user_picture,
"billing_enabled": state.billing_enabled, "billing_enabled": state.billing_enabled,
"subscription_state": user_subscription_state, "subscription_state": user_subscription_state,

View File

@@ -118,9 +118,9 @@ def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
elif file_type in ["application/msword", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"]: elif file_type in ["application/msword", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"]:
return "docx", encoding return "docx", encoding
elif file_type in ["image/jpeg"]: elif file_type in ["image/jpeg"]:
return "jpeg", encoding return "image", encoding
elif file_type in ["image/png"]: elif file_type in ["image/png"]:
return "png", encoding return "image", encoding
elif content_group in ["code", "text"]: elif content_group in ["code", "text"]:
return "plaintext", encoding return "plaintext", encoding
else: else:

View File

@@ -70,6 +70,7 @@ class ContentConfig(ConfigBase):
plaintext: Optional[TextContentConfig] = None plaintext: Optional[TextContentConfig] = None
github: Optional[GithubContentConfig] = None github: Optional[GithubContentConfig] = None
notion: Optional[NotionContentConfig] = None notion: Optional[NotionContentConfig] = None
image: Optional[TextContentConfig] = None
docx: Optional[TextContentConfig] = None docx: Optional[TextContentConfig] = None

BIN
tests/data/images/nasdaq.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

BIN
tests/data/images/testocr.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

View File

@@ -0,0 +1,21 @@
import os
from khoj.processor.content.images.image_to_entries import ImageToEntries
def test_png_to_jsonl():
with open("tests/data/images/testocr.png", "rb") as f:
image_bytes = f.read()
data = {"tests/data/images/testocr.png": image_bytes}
entries = ImageToEntries.extract_image_entries(image_files=data)
assert len(entries) == 2
assert "opencv-python" in entries[1][0].raw
def test_jpg_to_jsonl():
with open("tests/data/images/nasdaq.jpg", "rb") as f:
image_bytes = f.read()
data = {"tests/data/images/nasdaq.jpg": image_bytes}
entries = ImageToEntries.extract_image_entries(image_files=data)
assert len(entries) == 2
assert "investments" in entries[1][0].raw

View File

@@ -62,7 +62,6 @@ def test_offline_chat_with_no_chat_history_or_retrieved_content(client_offline_c
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("SERPER_DEV_API_KEY") is None, reason="requires SERPER_DEV_API_KEY")
@pytest.mark.chatquality @pytest.mark.chatquality
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_chat_with_online_content(client_offline_chat): def test_chat_with_online_content(client_offline_chat):
@@ -75,18 +74,18 @@ def test_chat_with_online_content(client_offline_chat):
response_message = response_message.split("### compiled references")[0] response_message = response_message.split("### compiled references")[0]
# Assert # Assert
expected_responses = ["http://www.paulgraham.com/greatwork.html"] expected_responses = [
"https://paulgraham.com/greatwork.html",
"https://www.paulgraham.com/greatwork.html",
"http://www.paulgraham.com/greatwork.html",
]
assert response.status_code == 200 assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), ( assert any(
"Expected links or serper not setup in response but got: " + response_message [expected_response in response_message for expected_response in expected_responses]
) ), f"Expected links: {expected_responses}. Actual response: {response_message}"
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(
os.getenv("SERPER_DEV_API_KEY") is None or os.getenv("OLOSTEP_API_KEY") is None,
reason="requires SERPER_DEV_API_KEY and OLOSTEP_API_KEY",
)
@pytest.mark.chatquality @pytest.mark.chatquality
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_chat_with_online_webpage_content(client_offline_chat): def test_chat_with_online_webpage_content(client_offline_chat):
@@ -101,9 +100,9 @@ def test_chat_with_online_webpage_content(client_offline_chat):
# Assert # Assert
expected_responses = ["185", "1871", "horse"] expected_responses = ["185", "1871", "horse"]
assert response.status_code == 200 assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), ( assert any(
"Expected links or serper not setup in response but got: " + response_message [expected_response in response_message for expected_response in expected_responses]
) ), f"Expected response with {expected_responses}. But actual response had: {response_message}"
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------

View File

@@ -61,7 +61,6 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("SERPER_DEV_API_KEY") is None, reason="requires SERPER_DEV_API_KEY")
@pytest.mark.chatquality @pytest.mark.chatquality
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_chat_with_online_content(chat_client): def test_chat_with_online_content(chat_client):
@@ -74,18 +73,18 @@ def test_chat_with_online_content(chat_client):
response_message = response_message.split("### compiled references")[0] response_message = response_message.split("### compiled references")[0]
# Assert # Assert
expected_responses = ["http://www.paulgraham.com/greatwork.html"] expected_responses = [
"https://paulgraham.com/greatwork.html",
"https://www.paulgraham.com/greatwork.html",
"http://www.paulgraham.com/greatwork.html",
]
assert response.status_code == 200 assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), ( assert any(
"Expected links or serper not setup in response but got: " + response_message [expected_response in response_message for expected_response in expected_responses]
) ), f"Expected links: {expected_responses}. Actual response: {response_message}"
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(
os.getenv("SERPER_DEV_API_KEY") is None or os.getenv("OLOSTEP_API_KEY") is None,
reason="requires SERPER_DEV_API_KEY and OLOSTEP_API_KEY",
)
@pytest.mark.chatquality @pytest.mark.chatquality
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_chat_with_online_webpage_content(chat_client): def test_chat_with_online_webpage_content(chat_client):
@@ -100,9 +99,9 @@ def test_chat_with_online_webpage_content(chat_client):
# Assert # Assert
expected_responses = ["185", "1871", "horse"] expected_responses = ["185", "1871", "horse"]
assert response.status_code == 200 assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), ( assert any(
"Expected links or serper not setup in response but got: " + response_message [expected_response in response_message for expected_response in expected_responses]
) ), f"Expected links: {expected_responses}. Actual response: {response_message}"
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------

View File

@@ -1,5 +1,6 @@
import os import os
import re import re
import time
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.text_to_entries import TextToEntries from khoj.processor.content.text_to_entries import TextToEntries
@@ -41,6 +42,35 @@ def test_configure_indexing_heading_only_entries(tmp_path):
assert is_none_or_empty(entries[1]) assert is_none_or_empty(entries[1])
def test_extract_entries_when_child_headings_have_same_prefix():
"""Extract org entries from entries having child headings with same prefix.
Prevents regressions like the one fixed in PR #840.
"""
# Arrange
tmp_path = "tests/data/org/same_prefix_headings.org"
entry: str = """
** 1
*** 1.1
**** 1.1.2
""".strip()
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Org files
start = time.time()
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=2)
end = time.time()
indexing_time = end - start
# Assert
explanation_msg = (
"It should not take more than 6 seconds to index. Entry extraction may have gone into an infinite loop."
)
assert indexing_time < 6 * len(entries), explanation_msg
def test_entry_split_when_exceeds_max_tokens(): def test_entry_split_when_exceeds_max_tokens():
"Ensure entries with compiled words exceeding max_tokens are split." "Ensure entries with compiled words exceeding max_tokens are split."
# Arrange # Arrange

View File

@@ -52,5 +52,6 @@
"1.12.1": "0.15.0", "1.12.1": "0.15.0",
"1.13.0": "0.15.0", "1.13.0": "0.15.0",
"1.14.0": "0.15.0", "1.14.0": "0.15.0",
"1.15.0": "0.15.0" "1.15.0": "0.15.0",
"1.16.0": "0.15.0"
} }