Add support for text to speech in chat responses (#821)

* Enable speech to text responses in khoj chat

- Current issue: reads out all the markdown formatting, plus waits for the whole result to be streamed before playing it

* Extract content from markdown-formatted text

* Add a loader for while you're waiting for Khoj's response

* Add user configuration option for chat model options, allow server side configuration for option list

* Join up APIs, views, admin pages to allow configuring custom voice models
This commit is contained in:
sabaimran
2024-06-20 23:00:28 -07:00
committed by GitHub
parent ff26b19d2b
commit b9966eb3d4
14 changed files with 373 additions and 24 deletions

View File

@@ -47,6 +47,8 @@ from khoj.database.models import (
UserConversationConfig,
UserRequests,
UserSearchModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
)
from khoj.processor.conversation import prompts
from khoj.search_filter.date_filter import DateFilter
@@ -705,6 +707,14 @@ class ConversationAdapters:
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config
@staticmethod
async def aset_user_voice_model(user: KhojUser, model_id: str):
config = await VoiceModelOption.objects.filter(model_id=model_id).afirst()
if not config:
return None
new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config
@staticmethod
def get_conversation_config(user: KhojUser):
config = UserConversationConfig.objects.filter(user=user).first()
@@ -719,6 +729,24 @@ class ConversationAdapters:
return None
return config.setting
@staticmethod
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if voice_model_config:
return voice_model_config.setting
return None
@staticmethod
def get_voice_model_options():
return VoiceModelOption.objects.all()
@staticmethod
def get_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
voice_model_config = UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").first()
if voice_model_config:
return voice_model_config.setting
return None
@staticmethod
def get_default_conversation_config():
server_chat_settings = ServerChatSettings.objects.first()

View File

@@ -27,6 +27,7 @@ from khoj.database.models import (
Subscription,
TextToImageModelConfig,
UserSearchModelConfig,
VoiceModelOption,
)
from khoj.utils.helpers import ImageIntentType
@@ -99,6 +100,7 @@ admin.site.register(TextToImageModelConfig)
admin.site.register(ClientApplication)
admin.site.register(GithubConfig)
admin.site.register(NotionConfig)
admin.site.register(VoiceModelOption)
@admin.register(Agent)

View File

@@ -0,0 +1,52 @@
# Generated by Django 4.2.11 on 2024-06-21 04:18
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0047_alter_entry_file_type"),
]
operations = [
migrations.CreateModel(
name="VoiceModelOption",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("model_id", models.CharField(max_length=200)),
("name", models.CharField(max_length=200)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="UserVoiceModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"setting",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.voicemodeloption",
),
),
(
"user",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
],
options={
"abstract": False,
},
),
]

View File

@@ -97,6 +97,11 @@ class ChatModelOptions(BaseModel):
)
class VoiceModelOption(BaseModel):
model_id = models.CharField(max_length=200)
name = models.CharField(max_length=200)
class Agent(BaseModel):
creator = models.ForeignKey(
KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True
@@ -248,6 +253,11 @@ class UserConversationConfig(BaseModel):
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
class UserVoiceModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
class UserSearchModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)

View File

@@ -0,0 +1,4 @@
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
<svg width="800px" height="800px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M19 6C20.5 7.5 21 10 21 12C21 14 20.5 16.5 19 18M16 8.99998C16.5 9.49998 17 10.5 17 12C17 13.5 16.5 14.5 16 15M3 10.5V13.5C3 14.6046 3.5 15.5 5.5 16C7.5 16.5 9 21 12 21C14 21 14 3 12 3C9 3 7.5 7.5 5.5 8C3.5 8.5 3 9.39543 3 10.5Z" stroke="#000000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 555 B

View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
<svg width="800px" height="800px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M6 11L6 13" stroke="#333333" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M9 9L9 15" stroke="#333333" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M15 9L15 15" stroke="#333333" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M18 11L18 13" stroke="#333333" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M12 11L12 13" stroke="#333333" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 758 B

View File

@@ -332,6 +332,7 @@
}
select#search-models,
select#voice-models,
select#chat-models {
margin-bottom: 0;
padding: 8px;

View File

@@ -422,8 +422,65 @@ To get started, just start typing below. You can also type / to see a list of co
sendFeedback(userQuery, khojQuery, "Bad Response");
};
// Only enable the speech feature if the user is subscribed
let speechButton = null;
if ("{{ is_active }}" == "True") {
// Create a speech button icon to play the message out loud
speechButton = document.createElement('button');
speechButton.classList.add("speech-button");
speechButton.title = "Listen to Message";
let speechIcon = document.createElement("img");
speechIcon.src = "/static/assets/icons/speaker.svg";
speechIcon.classList.add("speech-icon");
speechButton.appendChild(speechIcon);
speechButton.addEventListener('click', function() {
// Replace the speaker with a loading icon.
let loader = document.createElement("span");
loader.classList.add("loader");
speechButton.innerHTML = "";
speechButton.appendChild(loader);
speechButton.disabled = true;
const context = new (window.AudioContext || window.webkitAudioContext)();
fetch(`/api/chat/speech?text=${encodeURIComponent(message)}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
})
.then(response => response.arrayBuffer())
.then(arrayBuffer => {
return context.decodeAudioData(arrayBuffer);
})
.then(audioBuffer => {
const source = context.createBufferSource();
source.buffer = audioBuffer;
source.connect(context.destination);
source.start(0);
source.onended = function() {
speechButton.innerHTML = "";
speechButton.appendChild(speechIcon);
speechButton.disabled = false;
};
})
.catch(err => {
console.error("Error playing speech:", err);
speechButton.innerHTML = "";
speechButton.appendChild(speechIcon);
speechButton.disabled = true;
});
});
}
// Append buttons to parent element
element.append(copyButton, thumbsDownButton, thumbsUpButton);
if (speechButton) {
element.append(speechButton);
}
}
renderMathInElement(element, {
@@ -2830,7 +2887,13 @@ To get started, just start typing below. You can also type / to see a list of co
float: right;
}
button.thumbs-up-button {
img.speech-icon {
width: 18px;
}
button.thumbs-up-button,
button.thumbs-down-button,
button.speech-button {
border-radius: 4px;
background-color: var(--background-color);
border: 1px solid var(--main-text-color);
@@ -2843,19 +2906,6 @@ To get started, just start typing below. You can also type / to see a list of co
margin-right: 4px;
}
button.thumbs-down-button {
border-radius: 4px;
background-color: var(--background-color);
border: 1px solid var(--main-text-color);
text-align: center;
font-size: medium;
transition: all 0.5s;
cursor: pointer;
padding: 4px;
float: right;
margin-right:4px;
}
button.copy-button span {
cursor: pointer;
display: inline-block;
@@ -2878,20 +2928,14 @@ To get started, just start typing below. You can also type / to see a list of co
height: 18px;
}
button.copy-button:hover {
button.copy-button:hover,
button.thumbs-up-button:hover,
button.thumbs-down-button:hover,
button.speech-button:hover {
background-color: var(--primary-hover);
color: #f5f5f5;
}
button.thumbs-up-button:hover {
background-color: var(--primary-hover);
color: #f5f5f5;
}
button.thumbs-down-button:hover {
background-color: var(--primary-hover);
color: #f5f5f5;
}
pre {
text-wrap: unset;
@@ -3156,6 +3200,40 @@ To get started, just start typing below. You can also type / to see a list of co
white-space: pre-wrap;
}
.loader {
width: 18px;
height: 18px;
border: 3px solid #FFF;
border-radius: 50%;
display: inline-block;
position: relative;
box-sizing: border-box;
animation: rotation 1s linear infinite;
}
.loader::after {
content: '';
box-sizing: border-box;
position: absolute;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
width: 18px;
height: 18px;
border-radius: 50%;
border: 3px solid transparent;
border-bottom-color: var(--flower);
}
@keyframes rotation {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
.loading-spinner {
display: inline-block;
position: relative;

View File

@@ -202,6 +202,34 @@
{% endif %}
</div>
</div>
{% if is_eleven_labs_enabled %}
<div class="card">
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/voice.svg" alt="Voice configuration">
<h3 class="card-title">
<span>Voice</span>
</h3>
</div>
<div class="card-description-row">
<select id="voice-models">
{% for option in voice_model_options %}
<option value="{{ option.id }}" {% if option.id == selected_voice_config %}selected{% endif %}>{{ option.name }}</option>
{% endfor %}
</select>
</div>
<div class="card-action-row">
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
<button id="save-voice-model" class="card-button happy" onclick="updateVoiceModel()">
Save
</button>
{% else %}
<button id="save-voice-model" class="card-button" disabled>
You must be subscribed to use this feature
</button>
{% endif %}
</div>
</div>
{% endif %}
</div>
</div>
{% if not anonymous_mode or is_twilio_enabled %}
@@ -363,6 +391,38 @@
})
}
function updateVoiceModel() {
const voiceModel = document.getElementById("voice-models").value;
const saveVoiceModelButton = document.getElementById("save-voice-model");
saveVoiceModelButton.disabled = true;
saveVoiceModelButton.innerHTML = "Saving...";
fetch('/api/config/data/voice/model?id=' + voiceModel, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
}
})
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
saveVoiceModelButton.innerHTML = "Save";
saveVoiceModelButton.disabled = false;
let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Voice model has been updated!";
notificationBanner.style.display = "block";
setTimeout(function() {
notificationBanner.style.display = "none";
}, 5000);
} else {
saveVoiceModelButton.innerHTML = "Error";
saveVoiceModelButton.disabled = false;
}
})
}
function updateChatModel() {
const chatModel = document.getElementById("chat-models").value;
const saveModelButton = document.getElementById("save-model");

View File

View File

@@ -0,0 +1,51 @@
import json # Used for working with JSON data
import os
import requests # Used for making HTTP requests
from bs4 import BeautifulSoup
from markdown_it import MarkdownIt
# Define constants for the script
CHUNK_SIZE = 1024 # Size of chunks to read/write at a time
ELEVEN_LABS_API_KEY = os.getenv("ELEVEN_LABS_API_KEY", None) # Your API key for authentication
VOICE_ID = "RPEIZnKMqlQiZyZd1Dae" # ID of the voice model to use. MALE - Christopher - friendly guy next door.
ELEVEN_API_URL = "https://api.elevenlabs.io/v1/text-to-speech" # Base URL for the Text-to-Speech API
markdown_renderer = MarkdownIt()
def is_eleven_labs_enabled():
return ELEVEN_LABS_API_KEY is not None
def generate_text_to_speech(
text_to_speak: str,
voice_id: str = VOICE_ID,
):
if not is_eleven_labs_enabled():
return "Eleven Labs API key is not set"
# Convert the incoming text from markdown format to plain text
html = markdown_renderer.render(text_to_speak)
text = "".join(BeautifulSoup(html, features="lxml").findAll(text=True))
# Construct the URL for the Text-to-Speech API request
tts_url = f"{ELEVEN_API_URL}/{voice_id}/stream"
# Set up headers for the API request, including the API key for authentication
headers = {"Accept": "application/json", "xi-api-key": ELEVEN_LABS_API_KEY}
# Set up the data payload for the API request, including the text and voice settings
data = {
"text": text,
# "model_id": "eleven_multilingual_v2",
"voice_settings": {"stability": 0.5, "similarity_boost": 0.8, "style": 0.0, "use_speaker_boost": True},
}
# Make the POST request to the TTS API with headers and data, enabling streaming response
response = requests.post(tts_url, headers=headers, json=data, stream=True)
if response.ok:
return response
else:
raise Exception(f"Failed to generate text-to-speech: {response.text}")

View File

@@ -27,6 +27,7 @@ from khoj.processor.conversation.prompts import (
no_notes_found,
)
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import (
online_search_enabled,
read_webpages,
@@ -142,6 +143,20 @@ async def sendfeedback(request: Request, data: FeedbackData):
await send_query_feedback(data.uquery, data.kquery, data.sentiment, user.email)
@api_chat.post("/speech")
@requires(["authenticated", "premium"])
async def text_to_speech(request: Request, common: CommonQueryParams, text: str):
voice_model = await ConversationAdapters.aget_voice_model_config(request.user.object)
params = {"text_to_speak": text}
if voice_model:
params["voice_id"] = voice_model.model_id
speech_stream = generate_text_to_speech(**params)
return StreamingResponse(speech_stream.iter_content(chunk_size=1024), media_type="audio/mpeg")
@api_chat.get("/starters", response_class=Response)
@requires(["authenticated"])
async def chat_starters(

View File

@@ -258,6 +258,30 @@ async def update_chat_model(
return {"status": "ok"}
@api_config.post("/data/voice/model", status_code=200)
@requires(["authenticated", "premium"])
async def update_voice_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
new_config = await ConversationAdapters.aset_user_voice_model(user, id)
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_voice_model",
client=client,
)
if new_config is None:
return Response(status_code=404, content=json.dumps({"status": "error", "message": "Model not found"}))
return Response(status_code=202, content=json.dumps({"status": "ok"}))
@api_config.post("/data/search/model", status_code=200)
@requires(["authenticated"])
async def update_search_model(

View File

@@ -21,6 +21,7 @@ from khoj.database.adapters import (
get_user_subscription_state,
)
from khoj.database.models import KhojUser
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
from khoj.routers.helpers import get_next_url
from khoj.routers.notion import get_notion_auth_url
from khoj.routers.twilio import is_twilio_enabled
@@ -252,6 +253,18 @@ def config_page(request: Request):
notion_oauth_url = get_notion_auth_url(user)
eleven_labs_enabled = is_eleven_labs_enabled()
voice_models = ConversationAdapters.get_voice_model_options()
voice_model_options = list()
for voice_model in voice_models:
voice_model_options.append({"name": voice_model.name, "id": voice_model.model_id})
if len(voice_model_options) == 0:
eleven_labs_enabled = False
selected_voice_config = ConversationAdapters.get_voice_model_config(user)
return templates.TemplateResponse(
"config.html",
context={
@@ -272,6 +285,9 @@ def config_page(request: Request):
"is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents,
"is_twilio_enabled": is_twilio_enabled(),
"is_eleven_labs_enabled": eleven_labs_enabled,
"voice_model_options": voice_model_options,
"selected_voice_config": selected_voice_config.model_id if selected_voice_config else None,
"phone_number": user.phone_number,
"is_phone_number_verified": user.verified_phone_number,
"khoj_version": state.khoj_version,