diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 89ffdc54..8e595dae 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1288,7 +1288,7 @@ class ConversationAdapters: @staticmethod async def get_speech_to_text_config(): - return await SpeechToTextModelOptions.objects.filter().afirst() + return await SpeechToTextModelOptions.objects.filter().prefetch_related("ai_model_api").afirst() @staticmethod @arequire_valid_user diff --git a/src/khoj/database/migrations/0080_speechtotextmodeloptions_ai_model_api.py b/src/khoj/database/migrations/0080_speechtotextmodeloptions_ai_model_api.py new file mode 100644 index 00000000..5ed66e16 --- /dev/null +++ b/src/khoj/database/migrations/0080_speechtotextmodeloptions_ai_model_api.py @@ -0,0 +1,24 @@ +# Generated by Django 5.0.10 on 2025-01-15 11:05 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0079_searchmodelconfig_embeddings_inference_endpoint_type"), + ] + + operations = [ + migrations.AddField( + model_name="speechtotextmodeloptions", + name="ai_model_api", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="database.aimodelapi", + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 29dc5e58..3e693e15 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -566,6 +566,7 @@ class SpeechToTextModelOptions(DbBaseModel): model_name = models.CharField(max_length=200, default="base") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) + ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) def __str__(self): return f"{self.model_name} - {self.model_type}" diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index d2df38a3..44375bca 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -9,6 +9,7 @@ import uuid from typing import Any, Callable, List, Optional, Set, Union import cron_descriptor +import openai import pytz from apscheduler.job import Job from apscheduler.triggers.cron import CronTrigger @@ -264,12 +265,21 @@ async def transcribe( if not speech_to_text_config: # If the user has not configured a speech to text model, return an unsupported on server error status_code = 501 - elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI: - speech2text_model = speech_to_text_config.model_name - user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client) elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE: speech2text_model = speech_to_text_config.model_name user_message = await transcribe_audio_offline(audio_filename, speech2text_model) + elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI: + speech2text_model = speech_to_text_config.model_name + if speech_to_text_config.ai_model_api: + api_key = speech_to_text_config.ai_model_api.api_key + api_base_url = speech_to_text_config.ai_model_api.api_base_url + openai_client = openai.OpenAI(api_key=api_key, base_url=api_base_url) + elif state.openai_client: + openai_client = state.openai_client + if openai_client: + user_message = await transcribe_audio(audio_file, speech2text_model, client=openai_client) + else: + status_code = 501 finally: # Close and Delete the temporary audio file audio_file.close()