mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Use global openai client for transcribe, image
This commit is contained in:
@@ -7,6 +7,7 @@ import requests
|
||||
import os
|
||||
|
||||
# External Packages
|
||||
import openai
|
||||
import schedule
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
@@ -22,6 +23,7 @@ from starlette.authentication import (
|
||||
# Internal Packages
|
||||
from khoj.database.models import KhojUser, Subscription
|
||||
from khoj.database.adapters import (
|
||||
ConversationAdapters,
|
||||
get_all_users,
|
||||
get_or_create_search_model,
|
||||
aget_user_subscription_state,
|
||||
@@ -138,6 +140,10 @@ def configure_server(
|
||||
config = FullConfig()
|
||||
state.config = config
|
||||
|
||||
if ConversationAdapters.has_valid_openai_conversation_config():
|
||||
openai_config = ConversationAdapters.get_openai_conversation_config()
|
||||
state.openai_client = openai.OpenAI(api_key=openai_config.api_key)
|
||||
|
||||
# Initialize Search Models from Config and initialize content
|
||||
try:
|
||||
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
|
||||
|
||||
@@ -47,7 +47,7 @@ def extract_questions_offline(
|
||||
|
||||
if use_history:
|
||||
for chat in conversation_log.get("chat", [])[-4:]:
|
||||
if chat["by"] == "khoj":
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image":
|
||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||
chat_history += f"A: {chat['message']}\n"
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ def extract_questions(
|
||||
[
|
||||
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
|
||||
for chat in conversation_log.get("chat", [])[-4:]
|
||||
if chat["by"] == "khoj"
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -6,11 +6,10 @@ from asgiref.sync import sync_to_async
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str:
|
||||
async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str:
|
||||
"""
|
||||
Transcribe audio file using Whisper model via OpenAI's API
|
||||
"""
|
||||
# Send the audio data to the Whisper API
|
||||
client = OpenAI(api_key=api_key)
|
||||
response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file)
|
||||
return response["text"]
|
||||
return response.text
|
||||
|
||||
@@ -19,7 +19,7 @@ from starlette.authentication import requires
|
||||
from khoj.configure import configure_server
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.models import ChatModelOptions
|
||||
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import (
|
||||
GithubConfig,
|
||||
@@ -624,17 +624,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
|
||||
|
||||
# Send the audio data to the Whisper API
|
||||
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
if not speech_to_text_config:
|
||||
# If the user has not configured a speech to text model, return an unsupported on server error
|
||||
status_code = 501
|
||||
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
api_key = openai_chat_config.api_key
|
||||
elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
|
||||
speech2text_model = speech_to_text_config.model_name
|
||||
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key)
|
||||
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client)
|
||||
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
|
||||
speech2text_model = speech_to_text_config.model_name
|
||||
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model)
|
||||
user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
|
||||
finally:
|
||||
# Close and Delete the temporary audio file
|
||||
audio_file.close()
|
||||
@@ -793,7 +791,6 @@ async def extract_references_and_questions(
|
||||
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
||||
if conversation_config is None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||
if (
|
||||
offline_chat_config
|
||||
and offline_chat_config.enabled
|
||||
@@ -810,7 +807,7 @@ async def extract_references_and_questions(
|
||||
inferred_queries = extract_questions_offline(
|
||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||
)
|
||||
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
openai_chat = await ConversationAdapters.get_openai_chat()
|
||||
api_key = openai_chat_config.api_key
|
||||
|
||||
@@ -256,15 +256,15 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]:
|
||||
|
||||
# Send the audio data to the Whisper API
|
||||
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
if not text_to_image_config:
|
||||
# If the user has not configured a text to image model, return an unsupported on server error
|
||||
status_code = 501
|
||||
elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||
client = openai.OpenAI(api_key=openai_chat_config.api_key)
|
||||
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||
text2image_model = text_to_image_config.model_name
|
||||
try:
|
||||
response = client.images.generate(prompt=message, model=text2image_model, response_format="b64_json")
|
||||
response = state.openai_client.images.generate(
|
||||
prompt=message, model=text2image_model, response_format="b64_json"
|
||||
)
|
||||
image = response.data[0].b64_json
|
||||
except openai.OpenAIError as e:
|
||||
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
|
||||
|
||||
@@ -22,17 +22,9 @@ class BaseEncoder(ABC):
|
||||
|
||||
|
||||
class OpenAI(BaseEncoder):
|
||||
def __init__(self, model_name, device=None):
|
||||
def __init__(self, model_name, client: openai.OpenAI, device=None):
|
||||
self.model_name = model_name
|
||||
if (
|
||||
not state.processor_config
|
||||
or not state.processor_config.conversation
|
||||
or not state.processor_config.conversation.openai_model
|
||||
):
|
||||
raise Exception(
|
||||
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
|
||||
)
|
||||
self.openai_client = openai.OpenAI(api_key=state.processor_config.conversation.openai_model.api_key)
|
||||
self.openai_client = client
|
||||
self.embedding_dimensions = None
|
||||
|
||||
def encode(self, entries, device=None, **kwargs):
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
# Standard Packages
|
||||
from collections import defaultdict
|
||||
import os
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from typing import List, Dict
|
||||
from collections import defaultdict
|
||||
|
||||
# External Packages
|
||||
from pathlib import Path
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from openai import OpenAI
|
||||
from whisper import Whisper
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.utils import config as utils_config
|
||||
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import LRU, get_device
|
||||
@@ -21,6 +22,7 @@ search_models = SearchModels()
|
||||
embeddings_model: EmbeddingsModel = None
|
||||
cross_encoder_model: CrossEncoderModel = None
|
||||
content_index = ContentIndex()
|
||||
openai_client: OpenAI = None
|
||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||
whisper_model: Whisper = None
|
||||
config_file: Path = None
|
||||
|
||||
Reference in New Issue
Block a user