Use global openai client for transcribe, image

This commit is contained in:
Debanjum Singh Solanky
2023-12-05 02:40:28 -05:00
parent 162b219f2b
commit 408b7413e9
8 changed files with 27 additions and 31 deletions

View File

@@ -7,6 +7,7 @@ import requests
import os import os
# External Packages # External Packages
import openai
import schedule import schedule
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.authentication import AuthenticationMiddleware
@@ -22,6 +23,7 @@ from starlette.authentication import (
# Internal Packages # Internal Packages
from khoj.database.models import KhojUser, Subscription from khoj.database.models import KhojUser, Subscription
from khoj.database.adapters import ( from khoj.database.adapters import (
ConversationAdapters,
get_all_users, get_all_users,
get_or_create_search_model, get_or_create_search_model,
aget_user_subscription_state, aget_user_subscription_state,
@@ -138,6 +140,10 @@ def configure_server(
config = FullConfig() config = FullConfig()
state.config = config state.config = config
if ConversationAdapters.has_valid_openai_conversation_config():
openai_config = ConversationAdapters.get_openai_conversation_config()
state.openai_client = openai.OpenAI(api_key=openai_config.api_key)
# Initialize Search Models from Config and initialize content # Initialize Search Models from Config and initialize content
try: try:
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder) state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)

View File

@@ -47,7 +47,7 @@ def extract_questions_offline(
if use_history: if use_history:
for chat in conversation_log.get("chat", [])[-4:]: for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj": if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image":
chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n" chat_history += f"A: {chat['message']}\n"

View File

@@ -41,7 +41,7 @@ def extract_questions(
[ [
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n' f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:] for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj" if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image"
] ]
) )

View File

@@ -6,11 +6,10 @@ from asgiref.sync import sync_to_async
from openai import OpenAI from openai import OpenAI
async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str: async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str:
""" """
Transcribe audio file using Whisper model via OpenAI's API Transcribe audio file using Whisper model via OpenAI's API
""" """
# Send the audio data to the Whisper API # Send the audio data to the Whisper API
client = OpenAI(api_key=api_key)
response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file) response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file)
return response["text"] return response.text

View File

@@ -19,7 +19,7 @@ from starlette.authentication import requires
from khoj.configure import configure_server from khoj.configure import configure_server
from khoj.database import adapters from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import ChatModelOptions from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
from khoj.database.models import Entry as DbEntry from khoj.database.models import Entry as DbEntry
from khoj.database.models import ( from khoj.database.models import (
GithubConfig, GithubConfig,
@@ -624,17 +624,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
# Send the audio data to the Whisper API # Send the audio data to the Whisper API
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config() speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
if not speech_to_text_config: if not speech_to_text_config:
# If the user has not configured a speech to text model, return an unsupported on server error # If the user has not configured a speech to text model, return an unsupported on server error
status_code = 501 status_code = 501
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI: elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
api_key = openai_chat_config.api_key
speech2text_model = speech_to_text_config.model_name speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key) user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client)
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE: elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
speech2text_model = speech_to_text_config.model_name speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model) user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
finally: finally:
# Close and Delete the temporary audio file # Close and Delete the temporary audio file
audio_file.close() audio_file.close()
@@ -793,7 +791,6 @@ async def extract_references_and_questions(
conversation_config = await ConversationAdapters.aget_conversation_config(user) conversation_config = await ConversationAdapters.aget_conversation_config(user)
if conversation_config is None: if conversation_config is None:
conversation_config = await ConversationAdapters.aget_default_conversation_config() conversation_config = await ConversationAdapters.aget_default_conversation_config()
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
if ( if (
offline_chat_config offline_chat_config
and offline_chat_config.enabled and offline_chat_config.enabled
@@ -810,7 +807,7 @@ async def extract_references_and_questions(
inferred_queries = extract_questions_offline( inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
) )
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config() openai_chat_config = await ConversationAdapters.get_openai_chat_config()
openai_chat = await ConversationAdapters.get_openai_chat() openai_chat = await ConversationAdapters.get_openai_chat()
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key

View File

@@ -256,15 +256,15 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]:
# Send the audio data to the Whisper API # Send the audio data to the Whisper API
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
if not text_to_image_config: if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error # If the user has not configured a text to image model, return an unsupported on server error
status_code = 501 status_code = 501
elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
client = openai.OpenAI(api_key=openai_chat_config.api_key)
text2image_model = text_to_image_config.model_name text2image_model = text_to_image_config.model_name
try: try:
response = client.images.generate(prompt=message, model=text2image_model, response_format="b64_json") response = state.openai_client.images.generate(
prompt=message, model=text2image_model, response_format="b64_json"
)
image = response.data[0].b64_json image = response.data[0].b64_json
except openai.OpenAIError as e: except openai.OpenAIError as e:
logger.error(f"Image Generation failed with {e.http_status}: {e.error}") logger.error(f"Image Generation failed with {e.http_status}: {e.error}")

View File

@@ -22,17 +22,9 @@ class BaseEncoder(ABC):
class OpenAI(BaseEncoder): class OpenAI(BaseEncoder):
def __init__(self, model_name, device=None): def __init__(self, model_name, client: openai.OpenAI, device=None):
self.model_name = model_name self.model_name = model_name
if ( self.openai_client = client
not state.processor_config
or not state.processor_config.conversation
or not state.processor_config.conversation.openai_model
):
raise Exception(
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
)
self.openai_client = openai.OpenAI(api_key=state.processor_config.conversation.openai_model.api_key)
self.embedding_dimensions = None self.embedding_dimensions = None
def encode(self, entries, device=None, **kwargs): def encode(self, entries, device=None, **kwargs):

View File

@@ -1,15 +1,16 @@
# Standard Packages # Standard Packages
from collections import defaultdict
import os import os
from pathlib import Path
import threading import threading
from typing import List, Dict from typing import List, Dict
from collections import defaultdict
# External Packages # External Packages
from pathlib import Path from openai import OpenAI
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from whisper import Whisper from whisper import Whisper
# Internal Packages # Internal Packages
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.utils import config as utils_config from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU, get_device from khoj.utils.helpers import LRU, get_device
@@ -21,6 +22,7 @@ search_models = SearchModels()
embeddings_model: EmbeddingsModel = None embeddings_model: EmbeddingsModel = None
cross_encoder_model: CrossEncoderModel = None cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex() content_index = ContentIndex()
openai_client: OpenAI = None
gpt4all_processor_config: GPT4AllProcessorModel = None gpt4all_processor_config: GPT4AllProcessorModel = None
whisper_model: Whisper = None whisper_model: Whisper = None
config_file: Path = None config_file: Path = None