mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Create speech to text API endpoint. Use OpenAI whisper for ASR
- Wrap audio transcription in try/catch and delete audio file after processing - Use configured speech to text model, else handle error
This commit is contained in:
@@ -1,13 +1,16 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from typing import Annotated, List, Optional, Union, Any
|
from typing import Annotated, List, Optional, Union, Any
|
||||||
|
import uuid
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, File
|
||||||
|
import openai
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
@@ -553,6 +556,54 @@ async def chat_options(
|
|||||||
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
|
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@api.post("/speak")
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def transcribe_audio(request: Request, common: CommonQueryParams, file: UploadFile = File(...)):
|
||||||
|
user: KhojUser = request.user.object
|
||||||
|
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
|
||||||
|
user_message: str = None
|
||||||
|
|
||||||
|
# Transcribe the audio from the request
|
||||||
|
try:
|
||||||
|
# Store the audio from the request in a temporary file
|
||||||
|
audio_data = await file.read()
|
||||||
|
with open(audio_filename, "wb") as audio_file_writer:
|
||||||
|
audio_file_writer.write(audio_data)
|
||||||
|
audio_file = open(audio_filename, "rb")
|
||||||
|
|
||||||
|
# Send the audio data to the Whisper API
|
||||||
|
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
|
||||||
|
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||||
|
if not openai_chat_config or not speech_to_text_config:
|
||||||
|
# If the user has not configured a speech to text model, return an unprocessable entity error
|
||||||
|
status_code = 422
|
||||||
|
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
|
api_key = openai_chat_config.api_key
|
||||||
|
speech2text_model = speech_to_text_config.model_name
|
||||||
|
response = await sync_to_async(openai.Audio.translate)(
|
||||||
|
model=speech2text_model, file=audio_file, api_key=api_key
|
||||||
|
)
|
||||||
|
user_message = response["text"]
|
||||||
|
finally:
|
||||||
|
# Close and Delete the temporary audio file
|
||||||
|
audio_file.close()
|
||||||
|
os.remove(audio_filename)
|
||||||
|
|
||||||
|
if user_message is None:
|
||||||
|
return Response(status_code=status_code or 500)
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="speech_to_text",
|
||||||
|
**common.__dict__,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the spoken text
|
||||||
|
content = json.dumps({"text": user_message})
|
||||||
|
return Response(content=content, media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
@api.get("/chat", response_class=Response)
|
@api.get("/chat", response_class=Response)
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def chat(
|
async def chat(
|
||||||
|
|||||||
Reference in New Issue
Block a user