diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index a0fc1be9..00ad75f1 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -47,6 +47,8 @@ from khoj.database.models import ( UserConversationConfig, UserRequests, UserSearchModelConfig, + UserVoiceModelConfig, + VoiceModelOption, ) from khoj.processor.conversation import prompts from khoj.search_filter.date_filter import DateFilter @@ -705,6 +707,14 @@ class ConversationAdapters: new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) return new_config + @staticmethod + async def aset_user_voice_model(user: KhojUser, model_id: str): + config = await VoiceModelOption.objects.filter(model_id=model_id).afirst() + if not config: + return None + new_config = await UserVoiceModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) + return new_config + @staticmethod def get_conversation_config(user: KhojUser): config = UserConversationConfig.objects.filter(user=user).first() @@ -719,6 +729,24 @@ class ConversationAdapters: return None return config.setting + @staticmethod + async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: + voice_model_config = await UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() + if voice_model_config: + return voice_model_config.setting + return None + + @staticmethod + def get_voice_model_options(): + return VoiceModelOption.objects.all() + + @staticmethod + def get_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: + voice_model_config = UserVoiceModelConfig.objects.filter(user=user).prefetch_related("setting").first() + if voice_model_config: + return voice_model_config.setting + return None + @staticmethod def get_default_conversation_config(): server_chat_settings = ServerChatSettings.objects.first() diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 64a0a7fd..3bc0f76d 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -27,6 +27,7 @@ from khoj.database.models import ( Subscription, TextToImageModelConfig, UserSearchModelConfig, + VoiceModelOption, ) from khoj.utils.helpers import ImageIntentType @@ -99,6 +100,7 @@ admin.site.register(TextToImageModelConfig) admin.site.register(ClientApplication) admin.site.register(GithubConfig) admin.site.register(NotionConfig) +admin.site.register(VoiceModelOption) @admin.register(Agent) diff --git a/src/khoj/database/migrations/0048_voicemodeloption_uservoicemodelconfig.py b/src/khoj/database/migrations/0048_voicemodeloption_uservoicemodelconfig.py new file mode 100644 index 00000000..8f86c88a --- /dev/null +++ b/src/khoj/database/migrations/0048_voicemodeloption_uservoicemodelconfig.py @@ -0,0 +1,52 @@ +# Generated by Django 4.2.11 on 2024-06-21 04:18 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0047_alter_entry_file_type"), + ] + + operations = [ + migrations.CreateModel( + name="VoiceModelOption", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("model_id", models.CharField(max_length=200)), + ("name", models.CharField(max_length=200)), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="UserVoiceModelConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "setting", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="database.voicemodeloption", + ), + ), + ( + "user", + models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 52685471..fcda8c10 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -97,6 +97,11 @@ class ChatModelOptions(BaseModel): ) +class VoiceModelOption(BaseModel): + model_id = models.CharField(max_length=200) + name = models.CharField(max_length=200) + + class Agent(BaseModel): creator = models.ForeignKey( KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True @@ -248,6 +253,11 @@ class UserConversationConfig(BaseModel): setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True) +class UserVoiceModelConfig(BaseModel): + user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) + setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True) + + class UserSearchModelConfig(BaseModel): user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE) diff --git a/src/khoj/interface/web/assets/icons/speaker.svg b/src/khoj/interface/web/assets/icons/speaker.svg new file mode 100644 index 00000000..cfe4542c --- /dev/null +++ b/src/khoj/interface/web/assets/icons/speaker.svg @@ -0,0 +1,4 @@ + + + + diff --git a/src/khoj/interface/web/assets/icons/voice.svg b/src/khoj/interface/web/assets/icons/voice.svg new file mode 100644 index 00000000..e4e4649a --- /dev/null +++ b/src/khoj/interface/web/assets/icons/voice.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html index 9002d5c1..19ba1389 100644 --- a/src/khoj/interface/web/base_config.html +++ b/src/khoj/interface/web/base_config.html @@ -332,6 +332,7 @@ } select#search-models, + select#voice-models, select#chat-models { margin-bottom: 0; padding: 8px; diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 71ccea07..62c91f7d 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -422,8 +422,65 @@ To get started, just start typing below. You can also type / to see a list of co sendFeedback(userQuery, khojQuery, "Bad Response"); }; + // Only enable the speech feature if the user is subscribed + let speechButton = null; + + if ("{{ is_active }}" == "True") { + // Create a speech button icon to play the message out loud + speechButton = document.createElement('button'); + speechButton.classList.add("speech-button"); + speechButton.title = "Listen to Message"; + let speechIcon = document.createElement("img"); + speechIcon.src = "/static/assets/icons/speaker.svg"; + speechIcon.classList.add("speech-icon"); + speechButton.appendChild(speechIcon); + speechButton.addEventListener('click', function() { + // Replace the speaker with a loading icon. + let loader = document.createElement("span"); + loader.classList.add("loader"); + + speechButton.innerHTML = ""; + speechButton.appendChild(loader); + speechButton.disabled = true; + + const context = new (window.AudioContext || window.webkitAudioContext)(); + fetch(`/api/chat/speech?text=${encodeURIComponent(message)}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + }) + .then(response => response.arrayBuffer()) + .then(arrayBuffer => { + return context.decodeAudioData(arrayBuffer); + }) + .then(audioBuffer => { + const source = context.createBufferSource(); + source.buffer = audioBuffer; + source.connect(context.destination); + source.start(0); + source.onended = function() { + speechButton.innerHTML = ""; + speechButton.appendChild(speechIcon); + speechButton.disabled = false; + }; + }) + .catch(err => { + console.error("Error playing speech:", err); + speechButton.innerHTML = ""; + speechButton.appendChild(speechIcon); + speechButton.disabled = true; + }); + }); + } + + // Append buttons to parent element element.append(copyButton, thumbsDownButton, thumbsUpButton); + + if (speechButton) { + element.append(speechButton); + } } renderMathInElement(element, { @@ -2830,7 +2887,13 @@ To get started, just start typing below. You can also type / to see a list of co float: right; } - button.thumbs-up-button { + img.speech-icon { + width: 18px; + } + + button.thumbs-up-button, + button.thumbs-down-button, + button.speech-button { border-radius: 4px; background-color: var(--background-color); border: 1px solid var(--main-text-color); @@ -2843,19 +2906,6 @@ To get started, just start typing below. You can also type / to see a list of co margin-right: 4px; } - button.thumbs-down-button { - border-radius: 4px; - background-color: var(--background-color); - border: 1px solid var(--main-text-color); - text-align: center; - font-size: medium; - transition: all 0.5s; - cursor: pointer; - padding: 4px; - float: right; - margin-right:4px; - } - button.copy-button span { cursor: pointer; display: inline-block; @@ -2878,20 +2928,14 @@ To get started, just start typing below. You can also type / to see a list of co height: 18px; } - button.copy-button:hover { + button.copy-button:hover, + button.thumbs-up-button:hover, + button.thumbs-down-button:hover, + button.speech-button:hover { background-color: var(--primary-hover); color: #f5f5f5; } - button.thumbs-up-button:hover { - background-color: var(--primary-hover); - color: #f5f5f5; - } - - button.thumbs-down-button:hover { - background-color: var(--primary-hover); - color: #f5f5f5; - } pre { text-wrap: unset; @@ -3156,6 +3200,40 @@ To get started, just start typing below. You can also type / to see a list of co white-space: pre-wrap; } + .loader { + width: 18px; + height: 18px; + border: 3px solid #FFF; + border-radius: 50%; + display: inline-block; + position: relative; + box-sizing: border-box; + animation: rotation 1s linear infinite; + } + .loader::after { + content: ''; + box-sizing: border-box; + position: absolute; + left: 50%; + top: 50%; + transform: translate(-50%, -50%); + width: 18px; + height: 18px; + border-radius: 50%; + border: 3px solid transparent; + border-bottom-color: var(--flower); + } + + @keyframes rotation { + 0% { + transform: rotate(0deg); + } + 100% { + transform: rotate(360deg); + } + } + + .loading-spinner { display: inline-block; position: relative; diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 64841f97..2c3c98db 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -202,6 +202,34 @@ {% endif %} + {% if is_eleven_labs_enabled %} +
+
+ Voice configuration +

+ Voice +

+
+
+ +
+
+ {% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %} + + {% else %} + + {% endif %} +
+
+ {% endif %} {% if not anonymous_mode or is_twilio_enabled %} @@ -363,6 +391,38 @@ }) } + function updateVoiceModel() { + const voiceModel = document.getElementById("voice-models").value; + const saveVoiceModelButton = document.getElementById("save-voice-model"); + saveVoiceModelButton.disabled = true; + saveVoiceModelButton.innerHTML = "Saving..."; + + fetch('/api/config/data/voice/model?id=' + voiceModel, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + } + }) + .then(response => response.json()) + .then(data => { + if (data.status == "ok") { + saveVoiceModelButton.innerHTML = "Save"; + saveVoiceModelButton.disabled = false; + + let notificationBanner = document.getElementById("notification-banner"); + notificationBanner.innerHTML = "Voice model has been updated!"; + notificationBanner.style.display = "block"; + setTimeout(function() { + notificationBanner.style.display = "none"; + }, 5000); + + } else { + saveVoiceModelButton.innerHTML = "Error"; + saveVoiceModelButton.disabled = false; + } + }) + } + function updateChatModel() { const chatModel = document.getElementById("chat-models").value; const saveModelButton = document.getElementById("save-model"); diff --git a/src/khoj/processor/speech/__init__.py b/src/khoj/processor/speech/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/processor/speech/text_to_speech.py b/src/khoj/processor/speech/text_to_speech.py new file mode 100644 index 00000000..3aa6bf72 --- /dev/null +++ b/src/khoj/processor/speech/text_to_speech.py @@ -0,0 +1,51 @@ +import json # Used for working with JSON data +import os + +import requests # Used for making HTTP requests +from bs4 import BeautifulSoup +from markdown_it import MarkdownIt + +# Define constants for the script +CHUNK_SIZE = 1024 # Size of chunks to read/write at a time +ELEVEN_LABS_API_KEY = os.getenv("ELEVEN_LABS_API_KEY", None) # Your API key for authentication +VOICE_ID = "RPEIZnKMqlQiZyZd1Dae" # ID of the voice model to use. MALE - Christopher - friendly guy next door. +ELEVEN_API_URL = "https://api.elevenlabs.io/v1/text-to-speech" # Base URL for the Text-to-Speech API + +markdown_renderer = MarkdownIt() + + +def is_eleven_labs_enabled(): + return ELEVEN_LABS_API_KEY is not None + + +def generate_text_to_speech( + text_to_speak: str, + voice_id: str = VOICE_ID, +): + if not is_eleven_labs_enabled(): + return "Eleven Labs API key is not set" + + # Convert the incoming text from markdown format to plain text + html = markdown_renderer.render(text_to_speak) + text = "".join(BeautifulSoup(html, features="lxml").findAll(text=True)) + + # Construct the URL for the Text-to-Speech API request + tts_url = f"{ELEVEN_API_URL}/{voice_id}/stream" + + # Set up headers for the API request, including the API key for authentication + headers = {"Accept": "application/json", "xi-api-key": ELEVEN_LABS_API_KEY} + + # Set up the data payload for the API request, including the text and voice settings + data = { + "text": text, + # "model_id": "eleven_multilingual_v2", + "voice_settings": {"stability": 0.5, "similarity_boost": 0.8, "style": 0.0, "use_speaker_boost": True}, + } + + # Make the POST request to the TTS API with headers and data, enabling streaming response + response = requests.post(tts_url, headers=headers, json=data, stream=True) + + if response.ok: + return response + else: + raise Exception(f"Failed to generate text-to-speech: {response.text}") diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index f88fba96..c9fbc28d 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -27,6 +27,7 @@ from khoj.processor.conversation.prompts import ( no_notes_found, ) from khoj.processor.conversation.utils import save_to_conversation_log +from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.tools.online_search import ( online_search_enabled, read_webpages, @@ -142,6 +143,20 @@ async def sendfeedback(request: Request, data: FeedbackData): await send_query_feedback(data.uquery, data.kquery, data.sentiment, user.email) +@api_chat.post("/speech") +@requires(["authenticated", "premium"]) +async def text_to_speech(request: Request, common: CommonQueryParams, text: str): + voice_model = await ConversationAdapters.aget_voice_model_config(request.user.object) + + params = {"text_to_speak": text} + + if voice_model: + params["voice_id"] = voice_model.model_id + + speech_stream = generate_text_to_speech(**params) + return StreamingResponse(speech_stream.iter_content(chunk_size=1024), media_type="audio/mpeg") + + @api_chat.get("/starters", response_class=Response) @requires(["authenticated"]) async def chat_starters( diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py index dd84e317..68757de6 100644 --- a/src/khoj/routers/api_config.py +++ b/src/khoj/routers/api_config.py @@ -258,6 +258,30 @@ async def update_chat_model( return {"status": "ok"} +@api_config.post("/data/voice/model", status_code=200) +@requires(["authenticated", "premium"]) +async def update_voice_model( + request: Request, + id: str, + client: Optional[str] = None, +): + user = request.user.object + + new_config = await ConversationAdapters.aset_user_voice_model(user, id) + + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_voice_model", + client=client, + ) + + if new_config is None: + return Response(status_code=404, content=json.dumps({"status": "error", "message": "Model not found"})) + + return Response(status_code=202, content=json.dumps({"status": "ok"})) + + @api_config.post("/data/search/model", status_code=200) @requires(["authenticated"]) async def update_search_model( diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 321353d9..6def7ba8 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -21,6 +21,7 @@ from khoj.database.adapters import ( get_user_subscription_state, ) from khoj.database.models import KhojUser +from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled from khoj.routers.helpers import get_next_url from khoj.routers.notion import get_notion_auth_url from khoj.routers.twilio import is_twilio_enabled @@ -252,6 +253,18 @@ def config_page(request: Request): notion_oauth_url = get_notion_auth_url(user) + eleven_labs_enabled = is_eleven_labs_enabled() + + voice_models = ConversationAdapters.get_voice_model_options() + voice_model_options = list() + for voice_model in voice_models: + voice_model_options.append({"name": voice_model.name, "id": voice_model.model_id}) + + if len(voice_model_options) == 0: + eleven_labs_enabled = False + + selected_voice_config = ConversationAdapters.get_voice_model_config(user) + return templates.TemplateResponse( "config.html", context={ @@ -272,6 +285,9 @@ def config_page(request: Request): "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, "is_twilio_enabled": is_twilio_enabled(), + "is_eleven_labs_enabled": eleven_labs_enabled, + "voice_model_options": voice_model_options, + "selected_voice_config": selected_voice_config.model_id if selected_voice_config else None, "phone_number": user.phone_number, "is_phone_number_verified": user.verified_phone_number, "khoj_version": state.khoj_version,