Use global openai client for transcribe, image

This commit is contained in:
Debanjum Singh Solanky
2023-12-05 02:40:28 -05:00
parent 162b219f2b
commit 408b7413e9
8 changed files with 27 additions and 31 deletions

View File

@@ -7,6 +7,7 @@ import requests
import os
# External Packages
import openai
import schedule
from starlette.middleware.sessions import SessionMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
@@ -22,6 +23,7 @@ from starlette.authentication import (
# Internal Packages
from khoj.database.models import KhojUser, Subscription
from khoj.database.adapters import (
ConversationAdapters,
get_all_users,
get_or_create_search_model,
aget_user_subscription_state,
@@ -138,6 +140,10 @@ def configure_server(
config = FullConfig()
state.config = config
if ConversationAdapters.has_valid_openai_conversation_config():
openai_config = ConversationAdapters.get_openai_conversation_config()
state.openai_client = openai.OpenAI(api_key=openai_config.api_key)
# Initialize Search Models from Config and initialize content
try:
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)

View File

@@ -47,7 +47,7 @@ def extract_questions_offline(
if use_history:
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj":
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image":
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"

View File

@@ -41,7 +41,7 @@ def extract_questions(
[
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj"
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image"
]
)

View File

@@ -6,11 +6,10 @@ from asgiref.sync import sync_to_async
from openai import OpenAI
async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str:
async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str:
"""
Transcribe audio file using Whisper model via OpenAI's API
"""
# Send the audio data to the Whisper API
client = OpenAI(api_key=api_key)
response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file)
return response["text"]
return response.text

View File

@@ -19,7 +19,7 @@ from starlette.authentication import requires
from khoj.configure import configure_server
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import ChatModelOptions
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
from khoj.database.models import Entry as DbEntry
from khoj.database.models import (
GithubConfig,
@@ -624,17 +624,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
# Send the audio data to the Whisper API
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
if not speech_to_text_config:
# If the user has not configured a speech to text model, return an unsupported on server error
status_code = 501
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
api_key = openai_chat_config.api_key
elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key)
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE:
user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client)
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model)
user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
finally:
# Close and Delete the temporary audio file
audio_file.close()
@@ -793,7 +791,6 @@ async def extract_references_and_questions(
conversation_config = await ConversationAdapters.aget_conversation_config(user)
if conversation_config is None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
if (
offline_chat_config
and offline_chat_config.enabled
@@ -810,7 +807,7 @@ async def extract_references_and_questions(
inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
)
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
openai_chat = await ConversationAdapters.get_openai_chat()
api_key = openai_chat_config.api_key

View File

@@ -256,15 +256,15 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]:
# Send the audio data to the Whisper API
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
client = openai.OpenAI(api_key=openai_chat_config.api_key)
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
text2image_model = text_to_image_config.model_name
try:
response = client.images.generate(prompt=message, model=text2image_model, response_format="b64_json")
response = state.openai_client.images.generate(
prompt=message, model=text2image_model, response_format="b64_json"
)
image = response.data[0].b64_json
except openai.OpenAIError as e:
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")

View File

@@ -22,17 +22,9 @@ class BaseEncoder(ABC):
class OpenAI(BaseEncoder):
def __init__(self, model_name, device=None):
def __init__(self, model_name, client: openai.OpenAI, device=None):
self.model_name = model_name
if (
not state.processor_config
or not state.processor_config.conversation
or not state.processor_config.conversation.openai_model
):
raise Exception(
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
)
self.openai_client = openai.OpenAI(api_key=state.processor_config.conversation.openai_model.api_key)
self.openai_client = client
self.embedding_dimensions = None
def encode(self, entries, device=None, **kwargs):

View File

@@ -1,15 +1,16 @@
# Standard Packages
from collections import defaultdict
import os
from pathlib import Path
import threading
from typing import List, Dict
from collections import defaultdict
# External Packages
from pathlib import Path
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from openai import OpenAI
from whisper import Whisper
# Internal Packages
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU, get_device
@@ -21,6 +22,7 @@ search_models = SearchModels()
embeddings_model: EmbeddingsModel = None
cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex()
openai_client: OpenAI = None
gpt4all_processor_config: GPT4AllProcessorModel = None
whisper_model: Whisper = None
config_file: Path = None