Speak to Khoj via Desktop, Web or Obsidian Client (#566)

- Create speech to text API endpoint
- Use OpenAI Whisper for ASR offline (by downloading Whisper model) or online (via OpenAI API)
- Add speech to text model configuration to Database
- Speak to Khoj from the Web, Desktop or Obsidian client
This commit is contained in:
Debanjum
2023-11-26 14:32:11 -08:00
committed by GitHub
24 changed files with 518 additions and 48 deletions

View File

@@ -75,6 +75,7 @@ dependencies = [
"tzdata == 2023.3", "tzdata == 2023.3",
"rapidocr-onnxruntime == 1.3.8", "rapidocr-onnxruntime == 1.3.8",
"stripe == 7.3.0", "stripe == 7.3.0",
"openai-whisper >= 20231117",
] ]
dynamic = ["version"] dynamic = ["version"]

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 384 512"><!--! Font Awesome Pro 6.4.2 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc. --><path d="M192 0C139 0 96 43 96 96V256c0 53 43 96 96 96s96-43 96-96V96c0-53-43-96-96-96zM64 216c0-13.3-10.7-24-24-24s-24 10.7-24 24v40c0 89.1 66.2 162.7 152 174.4V464H120c-13.3 0-24 10.7-24 24s10.7 24 24 24h72 72c13.3 0 24-10.7 24-24s-10.7-24-24-24H216V430.4c85.8-11.7 152-85.3 152-174.4V216c0-13.3-10.7-24-24-24s-24 10.7-24 24v40c0 70.7-57.3 128-128 128s-128-57.3-128-128V216z"/></svg>

After

Width:  |  Height:  |  Size: 616 B

View File

@@ -0,0 +1,37 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
viewBox="0 0 384 512"
version="1.1"
id="svg1"
sodipodi:docname="stop-solid.svg"
inkscape:version="1.3 (0e150ed, 2023-07-21)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<defs
id="defs1" />
<sodipodi:namedview
id="namedview1"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:zoom="0.4609375"
inkscape:cx="192"
inkscape:cy="256"
inkscape:window-width="1312"
inkscape:window-height="449"
inkscape:window-x="0"
inkscape:window-y="88"
inkscape:window-maximized="0"
inkscape:current-layer="svg1" />
<!--! Font Awesome Pro 6.4.2 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc. -->
<path
d="M0 128C0 92.7 28.7 64 64 64H320c35.3 0 64 28.7 64 64V384c0 35.3-28.7 64-64 64H64c-35.3 0-64-28.7-64-64V128z"
id="path1"
style="fill:#aa0000" />
</svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@@ -516,6 +516,18 @@
} }
} }
function flashStatusInChatInput(message) {
// Get chat input element and original placeholder
let chatInput = document.getElementById("chat-input");
let originalPlaceholder = chatInput.placeholder;
// Set placeholder to message
chatInput.placeholder = message;
// Reset placeholder after 2 seconds
setTimeout(() => {
chatInput.placeholder = originalPlaceholder;
}, 2000);
}
async function clearConversationHistory() { async function clearConversationHistory() {
let chatInput = document.getElementById("chat-input"); let chatInput = document.getElementById("chat-input");
let originalPlaceholder = chatInput.placeholder; let originalPlaceholder = chatInput.placeholder;
@@ -530,17 +542,71 @@
.then(data => { .then(data => {
chatBody.innerHTML = ""; chatBody.innerHTML = "";
loadChat(); loadChat();
chatInput.placeholder = "Cleared conversation history"; flashStatusInChatInput("🗑 Cleared conversation history");
}) })
.catch(err => { .catch(err => {
chatInput.placeholder = "Failed to clear conversation history"; flashStatusInChatInput("⛔️ Failed to clear conversation history");
}) })
.finally(() => {
setTimeout(() => {
chatInput.placeholder = originalPlaceholder;
}, 2000);
});
} }
let mediaRecorder;
async function speechToText() {
const speakButtonImg = document.getElementById('speak-button-img');
const chatInput = document.getElementById('chat-input');
const hostURL = await window.hostURLAPI.getURL();
let url = `${hostURL}/api/transcribe?client=desktop`;
const khojToken = await window.tokenAPI.getToken();
const headers = { 'Authorization': `Bearer ${khojToken}` };
const sendToServer = (audioBlob) => {
const formData = new FormData();
formData.append('file', audioBlob);
fetch(url, { method: 'POST', body: formData, headers})
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => { chatInput.value += data.text; })
.catch(err => {
err.status == 422
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
: flashStatusInChatInput("⛔️ Failed to transcribe audio")
});
};
const handleRecording = (stream) => {
const audioChunks = [];
const recordingConfig = { mimeType: 'audio/webm' };
mediaRecorder = new MediaRecorder(stream, recordingConfig);
mediaRecorder.addEventListener("dataavailable", function(event) {
if (event.data.size > 0) audioChunks.push(event.data);
});
mediaRecorder.addEventListener("stop", function() {
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
sendToServer(audioBlob);
});
mediaRecorder.start();
speakButtonImg.src = './assets/icons/stop-solid.svg';
speakButtonImg.alt = 'Stop Transcription';
};
// Toggle recording
if (!mediaRecorder || mediaRecorder.state === 'inactive') {
navigator.mediaDevices
.getUserMedia({ audio: true })
.then(handleRecording)
.catch((e) => {
flashStatusInChatInput("⛔️ Failed to access microphone");
});
} else if (mediaRecorder.state === 'recording') {
mediaRecorder.stop();
speakButtonImg.src = './assets/icons/microphone-solid.svg';
speakButtonImg.alt = 'Transcribe';
}
}
</script> </script>
<body> <body>
<div id="khoj-empty-container" class="khoj-empty-container"> <div id="khoj-empty-container" class="khoj-empty-container">
@@ -569,8 +635,11 @@
<div id="chat-tooltip" style="display: none;"></div> <div id="chat-tooltip" style="display: none;"></div>
<div id="input-row"> <div id="input-row">
<textarea id="chat-input" class="option" oninput="onChatInput()" onkeydown=incrementalChat(event) autofocus="autofocus" placeholder="Type / to see a list of commands, or just type your questions and hit enter."></textarea> <textarea id="chat-input" class="option" oninput="onChatInput()" onkeydown=incrementalChat(event) autofocus="autofocus" placeholder="Type / to see a list of commands, or just type your questions and hit enter."></textarea>
<button class="input-row-button" onclick="clearConversationHistory()"> <button id="speak-button" class="input-row-button" onclick="speechToText()">
<img class="input-rown-button-img" src="./assets/icons/trash-solid.svg" alt="Clear Chat History"></img> <img id="speak-button-img" class="input-row-button-img" src="./assets/icons/microphone-solid.svg" alt="Transcribe"></img>
</button>
<button id="clear-chat" class="input-row-button" onclick="clearConversationHistory()">
<img class="input-row-button-img" src="./assets/icons/trash-solid.svg" alt="Clear Chat History"></img>
</button> </button>
</div> </div>
</div> </div>
@@ -620,7 +689,6 @@
.chat-message.you { .chat-message.you {
margin-right: auto; margin-right: auto;
text-align: right; text-align: right;
white-space: pre-line;
} }
/* basic style chat message text */ /* basic style chat message text */
.chat-message-text { .chat-message-text {
@@ -637,7 +705,6 @@
color: var(--primary-inverse); color: var(--primary-inverse);
background: var(--primary); background: var(--primary);
margin-left: auto; margin-left: auto;
white-space: pre-line;
} }
/* Spinner symbol when the chat message is loading */ /* Spinner symbol when the chat message is loading */
.spinner { .spinner {
@@ -694,7 +761,7 @@
} }
#input-row { #input-row {
display: grid; display: grid;
grid-template-columns: auto 32px; grid-template-columns: auto 32px 32px;
grid-column-gap: 10px; grid-column-gap: 10px;
grid-row-gap: 10px; grid-row-gap: 10px;
background: #f9fafc background: #f9fafc

View File

@@ -1,4 +1,4 @@
import { App, Modal, request, setIcon } from 'obsidian'; import { App, Modal, RequestUrlParam, request, requestUrl, setIcon } from 'obsidian';
import { KhojSetting } from 'src/settings'; import { KhojSetting } from 'src/settings';
import fetch from "node-fetch"; import fetch from "node-fetch";
@@ -51,6 +51,16 @@ export class KhojChatModal extends Modal {
}) })
chatInput.addEventListener('change', (event) => { this.result = (<HTMLInputElement>event.target).value }); chatInput.addEventListener('change', (event) => { this.result = (<HTMLInputElement>event.target).value });
let transcribe = inputRow.createEl("button", {
text: "Transcribe",
attr: {
id: "khoj-transcribe",
class: "khoj-transcribe khoj-input-row-button",
},
})
transcribe.addEventListener('click', async (_) => { await this.speechToText() });
setIcon(transcribe, "mic");
let clearChat = inputRow.createEl("button", { let clearChat = inputRow.createEl("button", {
text: "Clear History", text: "Clear History",
attr: { attr: {
@@ -205,9 +215,19 @@ export class KhojChatModal extends Modal {
} }
} }
async clearConversationHistory() { flashStatusInChatInput(message: string) {
// Get chat input element and original placeholder
let chatInput = <HTMLInputElement>this.contentEl.getElementsByClassName("khoj-chat-input")[0]; let chatInput = <HTMLInputElement>this.contentEl.getElementsByClassName("khoj-chat-input")[0];
let originalPlaceholder = chatInput.placeholder; let originalPlaceholder = chatInput.placeholder;
// Set placeholder to message
chatInput.placeholder = message;
// Reset placeholder after 2 seconds
setTimeout(() => {
chatInput.placeholder = originalPlaceholder;
}, 2000);
}
async clearConversationHistory() {
let chatBody = this.contentEl.getElementsByClassName("khoj-chat-body")[0]; let chatBody = this.contentEl.getElementsByClassName("khoj-chat-body")[0];
let response = await request({ let response = await request({
@@ -224,15 +244,84 @@ export class KhojChatModal extends Modal {
// If conversation history is cleared successfully, clear chat logs from modal // If conversation history is cleared successfully, clear chat logs from modal
chatBody.innerHTML = ""; chatBody.innerHTML = "";
await this.getChatHistory(); await this.getChatHistory();
chatInput.placeholder = result.message; this.flashStatusInChatInput(result.message);
} }
} catch (err) { } catch (err) {
chatInput.placeholder = "Failed to clear conversation history"; this.flashStatusInChatInput("Failed to clear conversation history");
} finally { }
// Reset to original placeholder text after some time }
setTimeout(() => {
chatInput.placeholder = originalPlaceholder; mediaRecorder: MediaRecorder | undefined;
}, 2000); async speechToText() {
const transcribeButton = <HTMLButtonElement>this.contentEl.getElementsByClassName("khoj-transcribe")[0];
const chatInput = <HTMLInputElement>this.contentEl.getElementsByClassName("khoj-chat-input")[0];
const generateRequestBody = async (audioBlob: Blob, boundary_string: string) => {
const boundary = `------${boundary_string}`;
const chunks: ArrayBuffer[] = [];
chunks.push(new TextEncoder().encode(`${boundary}\r\n`));
chunks.push(new TextEncoder().encode(`Content-Disposition: form-data; name="file"; filename="blob"\r\nContent-Type: "application/octet-stream"\r\n\r\n`));
chunks.push(await audioBlob.arrayBuffer());
chunks.push(new TextEncoder().encode('\r\n'));
await Promise.all(chunks);
chunks.push(new TextEncoder().encode(`${boundary}--\r\n`));
return await new Blob(chunks).arrayBuffer();
};
const sendToServer = async (audioBlob: Blob) => {
const boundary_string = `Boundary${Math.random().toString(36).slice(2)}`;
const requestBody = await generateRequestBody(audioBlob, boundary_string);
const response = await requestUrl({
url: `${this.setting.khojUrl}/api/transcribe?client=obsidian`,
method: 'POST',
headers: { "Authorization": `Bearer ${this.setting.khojApiKey}` },
contentType: `multipart/form-data; boundary=----${boundary_string}`,
body: requestBody,
});
// Parse response from Khoj backend
if (response.status === 200) {
console.log(response);
chatInput.value += response.json.text;
} else if (response.status === 422) {
throw new Error("⛔️ Failed to transcribe audio");
} else {
throw new Error("⛔️ Configure speech-to-text model on server.");
}
};
const handleRecording = (stream: MediaStream) => {
const audioChunks: Blob[] = [];
const recordingConfig = { mimeType: 'audio/webm' };
this.mediaRecorder = new MediaRecorder(stream, recordingConfig);
this.mediaRecorder.addEventListener("dataavailable", function(event) {
if (event.data.size > 0) audioChunks.push(event.data);
});
this.mediaRecorder.addEventListener("stop", async function() {
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
await sendToServer(audioBlob);
});
this.mediaRecorder.start();
setIcon(transcribeButton, "mic-off");
};
// Toggle recording
if (!this.mediaRecorder || this.mediaRecorder.state === 'inactive') {
navigator.mediaDevices
.getUserMedia({ audio: true })
.then(handleRecording)
.catch((e) => {
this.flashStatusInChatInput("⛔️ Failed to access microphone");
});
} else if (this.mediaRecorder.state === 'recording') {
this.mediaRecorder.stop();
setIcon(transcribeButton, "mic");
} }
} }
} }

View File

@@ -112,7 +112,7 @@ If your plugin does not need CSS, delete this file.
} }
.khoj-input-row { .khoj-input-row {
display: grid; display: grid;
grid-template-columns: auto 32px; grid-template-columns: auto 32px 32px;
grid-column-gap: 10px; grid-column-gap: 10px;
grid-row-gap: 10px; grid-row-gap: 10px;
background: var(--background-primary); background: var(--background-primary);

View File

@@ -26,6 +26,7 @@ from khoj.database.models import (
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
SearchModelConfig, SearchModelConfig,
SpeechToTextModelOptions,
Subscription, Subscription,
UserConversationConfig, UserConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
@@ -344,6 +345,10 @@ class ConversationAdapters:
async def get_openai_chat_config(): async def get_openai_chat_config():
return await OpenAIProcessorConversationConfig.objects.filter().afirst() return await OpenAIProcessorConversationConfig.objects.filter().afirst()
@staticmethod
async def get_speech_to_text_config():
return await SpeechToTextModelOptions.objects.filter().afirst()
@staticmethod @staticmethod
async def aget_conversation_starters(user: KhojUser): async def aget_conversation_starters(user: KhojUser):
all_questions = [] all_questions = []

View File

@@ -9,6 +9,7 @@ from khoj.database.models import (
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
SearchModelConfig, SearchModelConfig,
SpeechToTextModelOptions,
Subscription, Subscription,
ReflectiveQuestion, ReflectiveQuestion,
) )
@@ -16,6 +17,7 @@ from khoj.database.models import (
admin.site.register(KhojUser, UserAdmin) admin.site.register(KhojUser, UserAdmin)
admin.site.register(ChatModelOptions) admin.site.register(ChatModelOptions)
admin.site.register(SpeechToTextModelOptions)
admin.site.register(OpenAIProcessorConversationConfig) admin.site.register(OpenAIProcessorConversationConfig)
admin.site.register(OfflineChatProcessorConversationConfig) admin.site.register(OfflineChatProcessorConversationConfig)
admin.site.register(SearchModelConfig) admin.site.register(SearchModelConfig)

View File

@@ -0,0 +1,42 @@
# Generated by Django 4.2.7 on 2023-11-26 13:54
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0020_reflectivequestion"),
]
operations = [
migrations.CreateModel(
name="SpeechToTextModelOptions",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("model_name", models.CharField(default="base", max_length=200)),
(
"model_type",
models.CharField(
choices=[("openai", "Openai"), ("offline", "Offline")], default="offline", max_length=200
),
),
],
options={
"abstract": False,
},
),
migrations.AlterField(
model_name="chatmodeloptions",
name="chat_model",
field=models.CharField(default="mistral-7b-instruct-v0.1.Q4_0.gguf", max_length=200),
),
migrations.AlterField(
model_name="chatmodeloptions",
name="model_type",
field=models.CharField(
choices=[("openai", "Openai"), ("offline", "Offline")], default="offline", max_length=200
),
),
]

View File

@@ -120,6 +120,15 @@ class OfflineChatProcessorConversationConfig(BaseModel):
enabled = models.BooleanField(default=False) enabled = models.BooleanField(default=False)
class SpeechToTextModelOptions(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
model_name = models.CharField(max_length=200, default="base")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
class ChatModelOptions(BaseModel): class ChatModelOptions(BaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
OPENAI = "openai" OPENAI = "openai"
@@ -127,8 +136,8 @@ class ChatModelOptions(BaseModel):
max_prompt_size = models.IntegerField(default=None, null=True, blank=True) max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default=None, null=True, blank=True) chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
class UserConversationConfig(BaseModel): class UserConversationConfig(BaseModel):

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 384 512"><!--! Font Awesome Pro 6.4.2 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc. --><path d="M192 0C139 0 96 43 96 96V256c0 53 43 96 96 96s96-43 96-96V96c0-53-43-96-96-96zM64 216c0-13.3-10.7-24-24-24s-24 10.7-24 24v40c0 89.1 66.2 162.7 152 174.4V464H120c-13.3 0-24 10.7-24 24s10.7 24 24 24h72 72c13.3 0 24-10.7 24-24s-10.7-24-24-24H216V430.4c85.8-11.7 152-85.3 152-174.4V216c0-13.3-10.7-24-24-24s-24 10.7-24 24v40c0 70.7-57.3 128-128 128s-128-57.3-128-128V216z"/></svg>

After

Width:  |  Height:  |  Size: 616 B

View File

@@ -0,0 +1,37 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
viewBox="0 0 384 512"
version="1.1"
id="svg1"
sodipodi:docname="stop-solid.svg"
inkscape:version="1.3 (0e150ed, 2023-07-21)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<defs
id="defs1" />
<sodipodi:namedview
id="namedview1"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:zoom="0.4609375"
inkscape:cx="192"
inkscape:cy="256"
inkscape:window-width="1312"
inkscape:window-height="449"
inkscape:window-x="0"
inkscape:window-y="88"
inkscape:window-maximized="0"
inkscape:current-layer="svg1" />
<!--! Font Awesome Pro 6.4.2 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc. -->
<path
d="M0 128C0 92.7 28.7 64 64 64H320c35.3 0 64 28.7 64 64V384c0 35.3-28.7 64-64 64H64c-35.3 0-64-28.7-64-64V128z"
id="path1"
style="fill:#aa0000" />
</svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@@ -543,6 +543,18 @@ To get started, just start typing below. You can also type / to see a list of co
} }
} }
function flashStatusInChatInput(message) {
// Get chat input element and original placeholder
let chatInput = document.getElementById("chat-input");
let originalPlaceholder = chatInput.placeholder;
// Set placeholder to message
chatInput.placeholder = message;
// Reset placeholder after 2 seconds
setTimeout(() => {
chatInput.placeholder = originalPlaceholder;
}, 2000);
}
function clearConversationHistory() { function clearConversationHistory() {
let chatInput = document.getElementById("chat-input"); let chatInput = document.getElementById("chat-input");
let originalPlaceholder = chatInput.placeholder; let originalPlaceholder = chatInput.placeholder;
@@ -553,17 +565,65 @@ To get started, just start typing below. You can also type / to see a list of co
.then(data => { .then(data => {
chatBody.innerHTML = ""; chatBody.innerHTML = "";
loadChat(); loadChat();
chatInput.placeholder = "Cleared conversation history"; flashStatusInChatInput("🗑 Cleared conversation history");
}) })
.catch(err => { .catch(err => {
chatInput.placeholder = "Failed to clear conversation history"; flashStatusInChatInput("⛔️ Failed to clear conversation history");
})
.finally(() => {
setTimeout(() => {
chatInput.placeholder = originalPlaceholder;
}, 2000);
}); });
} }
let mediaRecorder;
function speechToText() {
const speakButtonImg = document.getElementById('speak-button-img');
const chatInput = document.getElementById('chat-input');
const sendToServer = (audioBlob) => {
const formData = new FormData();
formData.append('file', audioBlob);
fetch('/api/transcribe?client=web', { method: 'POST', body: formData })
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => { chatInput.value += data.text; })
.catch(err => {
err.status == 422
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
: flashStatusInChatInput("⛔️ Failed to transcribe audio")
});
};
const handleRecording = (stream) => {
const audioChunks = [];
const recordingConfig = { mimeType: 'audio/webm' };
mediaRecorder = new MediaRecorder(stream, recordingConfig);
mediaRecorder.addEventListener("dataavailable", function(event) {
if (event.data.size > 0) audioChunks.push(event.data);
});
mediaRecorder.addEventListener("stop", function() {
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
sendToServer(audioBlob);
});
mediaRecorder.start();
speakButtonImg.src = '/static/assets/icons/stop-solid.svg';
speakButtonImg.alt = 'Stop Transcription';
};
// Toggle recording
if (!mediaRecorder || mediaRecorder.state === 'inactive') {
navigator.mediaDevices
.getUserMedia({ audio: true })
.then(handleRecording)
.catch((e) => {
flashStatusInChatInput("⛔️ Failed to access microphone");
});
} else if (mediaRecorder.state === 'recording') {
mediaRecorder.stop();
speakButtonImg.src = '/static/assets/icons/microphone-solid.svg';
speakButtonImg.alt = 'Transcribe';
}
}
</script> </script>
<body> <body>
<div id="khoj-empty-container" class="khoj-empty-container"> <div id="khoj-empty-container" class="khoj-empty-container">
@@ -584,8 +644,11 @@ To get started, just start typing below. You can also type / to see a list of co
<div id="chat-tooltip" style="display: none;"></div> <div id="chat-tooltip" style="display: none;"></div>
<div id="input-row"> <div id="input-row">
<textarea id="chat-input" class="option" oninput="onChatInput()" onkeydown=incrementalChat(event) autofocus="autofocus" placeholder="Type / to see a list of commands, or just type your questions and hit enter."></textarea> <textarea id="chat-input" class="option" oninput="onChatInput()" onkeydown=incrementalChat(event) autofocus="autofocus" placeholder="Type / to see a list of commands, or just type your questions and hit enter."></textarea>
<button id="speak-button" class="input-row-button" onclick="speechToText()">
<img id="speak-button-img" class="input-row-button-img" src="/static/assets/icons/microphone-solid.svg" alt="Transcribe"></img>
</button>
<button class="input-row-button" onclick="clearConversationHistory()"> <button class="input-row-button" onclick="clearConversationHistory()">
<img class="input-rown-button-img" src="/static/assets/icons/trash-solid.svg" alt="Clear Chat History"></img> <img class="input-row-button-img" src="/static/assets/icons/trash-solid.svg" alt="Clear Chat History"></img>
</button> </button>
</div> </div>
</div> </div>
@@ -749,7 +812,6 @@ To get started, just start typing below. You can also type / to see a list of co
.chat-message.you { .chat-message.you {
margin-right: auto; margin-right: auto;
text-align: right; text-align: right;
white-space: pre-line;
} }
/* basic style chat message text */ /* basic style chat message text */
.chat-message-text { .chat-message-text {
@@ -766,7 +828,6 @@ To get started, just start typing below. You can also type / to see a list of co
color: var(--primary-inverse); color: var(--primary-inverse);
background: var(--primary); background: var(--primary);
margin-left: auto; margin-left: auto;
white-space: pre-line;
} }
/* Spinner symbol when the chat message is loading */ /* Spinner symbol when the chat message is loading */
.spinner { .spinner {
@@ -815,6 +876,7 @@ To get started, just start typing below. You can also type / to see a list of co
#chat-footer { #chat-footer {
padding: 0; padding: 0;
margin: 8px;
display: grid; display: grid;
grid-template-columns: minmax(70px, 100%); grid-template-columns: minmax(70px, 100%);
grid-column-gap: 10px; grid-column-gap: 10px;
@@ -822,7 +884,7 @@ To get started, just start typing below. You can also type / to see a list of co
} }
#input-row { #input-row {
display: grid; display: grid;
grid-template-columns: auto 32px; grid-template-columns: auto 32px 32px;
grid-column-gap: 10px; grid-column-gap: 10px;
grid-row-gap: 10px; grid-row-gap: 10px;
background: #f9fafc background: #f9fafc

View File

@@ -0,0 +1,17 @@
# External Packages
from asgiref.sync import sync_to_async
import whisper
# Internal Packages
from khoj.utils import state
async def transcribe_audio_offline(audio_filename: str, model: str) -> str:
"""
Transcribe audio file offline using Whisper
"""
# Send the audio data to the Whisper API
if not state.whisper_model:
state.whisper_model = whisper.load_model(model)
response = await sync_to_async(state.whisper_model.transcribe)(audio_filename)
return response["text"]

View File

@@ -0,0 +1,15 @@
# Standard Packages
from io import BufferedReader
# External Packages
from asgiref.sync import sync_to_async
import openai
async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str:
"""
Transcribe audio file using Whisper model via OpenAI's API
"""
# Send the audio data to the Whisper API
response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key)
return response["text"]

View File

@@ -3,13 +3,14 @@ import concurrent.futures
import json import json
import logging import logging
import math import math
import os
import time import time
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import uuid
from asgiref.sync import sync_to_async
# External Packages # External Packages
from fastapi import APIRouter, Depends, Header, HTTPException, Request from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires from starlette.authentication import requires
@@ -29,8 +30,10 @@ from khoj.database.models import (
LocalPlaintextConfig, LocalPlaintextConfig,
NotionConfig, NotionConfig,
) )
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline from khoj.processor.conversation.offline.chat_model import extract_questions_offline
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.tools.online_search import search_with_google from khoj.processor.tools.online_search import search_with_google
from khoj.routers.helpers import ( from khoj.routers.helpers import (
@@ -585,6 +588,59 @@ async def chat_options(
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200) return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
@api.post("/transcribe")
@requires(["authenticated"])
async def transcribe(request: Request, common: CommonQueryParams, file: UploadFile = File(...)):
user: KhojUser = request.user.object
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
user_message: str = None
# If the file is too large, return an unprocessable entity error
if file.size > 10 * 1024 * 1024:
logger.warning(f"Audio file too large to transcribe. Audio file size: {file.size}. Exceeds 10Mb limit.")
return Response(content="Audio size larger than 10Mb limit", status_code=422)
# Transcribe the audio from the request
try:
# Store the audio from the request in a temporary file
audio_data = await file.read()
with open(audio_filename, "wb") as audio_file_writer:
audio_file_writer.write(audio_data)
audio_file = open(audio_filename, "rb")
# Send the audio data to the Whisper API
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
if not speech_to_text_config:
# If the user has not configured a speech to text model, return an unprocessable entity error
status_code = 422
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
api_key = openai_chat_config.api_key
speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key)
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE:
speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model)
finally:
# Close and Delete the temporary audio file
audio_file.close()
os.remove(audio_filename)
if user_message is None:
return Response(status_code=status_code or 500)
update_telemetry_state(
request=request,
telemetry_type="api",
api="transcribe",
**common.__dict__,
)
# Return the spoken text
content = json.dumps({"text": user_message})
return Response(content=content, media_type="application/json", status_code=200)
@api.get("/chat", response_class=Response) @api.get("/chat", response_class=Response)
@requires(["authenticated"]) @requires(["authenticated"])
async def chat( async def chat(

View File

@@ -15,7 +15,7 @@ from fastapi import Depends, Header, HTTPException, Request
from khoj.database.adapters import ConversationAdapters from khoj.database.adapters import ConversationAdapters
from khoj.database.models import KhojUser, Subscription from khoj.database.models import KhojUser, Subscription
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log

View File

@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, List, Optional, Union, Any
import torch import torch
# Internal Packages # Internal Packages
from khoj.processor.conversation.gpt4all.utils import download_model from khoj.processor.conversation.offline.utils import download_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -80,7 +80,7 @@ class GPT4AllProcessorConfig:
class GPT4AllProcessorModel: class GPT4AllProcessorModel:
def __init__( def __init__(
self, self,
chat_model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin", chat_model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
): ):
self.chat_model = chat_model self.chat_model = chat_model
self.loaded_model = None self.loaded_model = None

View File

@@ -6,6 +6,7 @@ from khoj.database.models import (
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
ChatModelOptions, ChatModelOptions,
SpeechToTextModelOptions,
) )
from khoj.utils.constants import default_offline_chat_model, default_online_chat_model from khoj.utils.constants import default_offline_chat_model, default_online_chat_model
@@ -73,10 +74,9 @@ def initialization():
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
logger.warning("Offline models are not supported on this device.") logger.warning("Offline models are not supported on this device.")
use_openai_model = input("Use OpenAI chat model? (y/n): ") use_openai_model = input("Use OpenAI models? (y/n): ")
if use_openai_model == "y": if use_openai_model == "y":
logger.info("🗣️ Setting up OpenAI chat model") logger.info("🗣️ Setting up your OpenAI configuration")
api_key = input("Enter your OpenAI API key: ") api_key = input("Enter your OpenAI API key: ")
OpenAIProcessorConversationConfig.objects.create(api_key=api_key) OpenAIProcessorConversationConfig.objects.create(api_key=api_key)
@@ -94,7 +94,34 @@ def initialization():
chat_model=openai_chat_model, model_type=ChatModelOptions.ModelType.OPENAI, max_prompt_size=max_tokens chat_model=openai_chat_model, model_type=ChatModelOptions.ModelType.OPENAI, max_prompt_size=max_tokens
) )
logger.info("🗣️ Chat model configuration complete") default_speech2text_model = "whisper-1"
openai_speech2text_model = input(
f"Enter the OpenAI speech to text model you want to use (default: {default_speech2text_model}): "
)
openai_speech2text_model = openai_speech2text_model or default_speech2text_model
SpeechToTextModelOptions.objects.create(
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
)
if use_offline_model == "y" or use_openai_model == "y":
logger.info("🗣️ Chat model configuration complete")
use_offline_speech2text_model = input("Use offline speech to text model? (y/n): ")
if use_offline_speech2text_model == "y":
logger.info("🗣️ Setting up offline speech to text model")
# Delete any existing speech to text model options. There can only be one.
SpeechToTextModelOptions.objects.all().delete()
default_offline_speech2text_model = "base"
offline_speech2text_model = input(
f"Enter the Whisper model to use Offline (default: {default_offline_speech2text_model}): "
)
offline_speech2text_model = offline_speech2text_model or default_offline_speech2text_model
SpeechToTextModelOptions.objects.create(
model_name=offline_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OFFLINE
)
logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")
admin_user = KhojUser.objects.filter(is_staff=True).first() admin_user = KhojUser.objects.filter(is_staff=True).first()
if admin_user is None: if admin_user is None:

View File

@@ -7,6 +7,7 @@ from collections import defaultdict
# External Packages # External Packages
from pathlib import Path from pathlib import Path
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from whisper import Whisper
# Internal Packages # Internal Packages
from khoj.utils import config as utils_config from khoj.utils import config as utils_config
@@ -21,6 +22,7 @@ embeddings_model: EmbeddingsModel = None
cross_encoder_model: CrossEncoderModel = None cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex() content_index = ContentIndex()
gpt4all_processor_config: GPT4AllProcessorModel = None gpt4all_processor_config: GPT4AllProcessorModel = None
whisper_model: Whisper = None
config_file: Path = None config_file: Path = None
verbose: int = 0 verbose: int = 0
host: str = None host: str = None

View File

@@ -19,8 +19,8 @@ except ModuleNotFoundError as e:
print("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") print("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
# Internal Packages # Internal Packages
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, extract_questions_offline, filter_questions from khoj.processor.conversation.offline.chat_model import converse_offline, extract_questions_offline, filter_questions
from khoj.processor.conversation.gpt4all.utils import download_model from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log