mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Use global openai client for transcribe, image
This commit is contained in:
@@ -7,6 +7,7 @@ import requests
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
|
import openai
|
||||||
import schedule
|
import schedule
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||||
@@ -22,6 +23,7 @@ from starlette.authentication import (
|
|||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.database.models import KhojUser, Subscription
|
from khoj.database.models import KhojUser, Subscription
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
|
ConversationAdapters,
|
||||||
get_all_users,
|
get_all_users,
|
||||||
get_or_create_search_model,
|
get_or_create_search_model,
|
||||||
aget_user_subscription_state,
|
aget_user_subscription_state,
|
||||||
@@ -138,6 +140,10 @@ def configure_server(
|
|||||||
config = FullConfig()
|
config = FullConfig()
|
||||||
state.config = config
|
state.config = config
|
||||||
|
|
||||||
|
if ConversationAdapters.has_valid_openai_conversation_config():
|
||||||
|
openai_config = ConversationAdapters.get_openai_conversation_config()
|
||||||
|
state.openai_client = openai.OpenAI(api_key=openai_config.api_key)
|
||||||
|
|
||||||
# Initialize Search Models from Config and initialize content
|
# Initialize Search Models from Config and initialize content
|
||||||
try:
|
try:
|
||||||
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
|
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def extract_questions_offline(
|
|||||||
|
|
||||||
if use_history:
|
if use_history:
|
||||||
for chat in conversation_log.get("chat", [])[-4:]:
|
for chat in conversation_log.get("chat", [])[-4:]:
|
||||||
if chat["by"] == "khoj":
|
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image":
|
||||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||||
chat_history += f"A: {chat['message']}\n"
|
chat_history += f"A: {chat['message']}\n"
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ def extract_questions(
|
|||||||
[
|
[
|
||||||
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
|
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
|
||||||
for chat in conversation_log.get("chat", [])[-4:]
|
for chat in conversation_log.get("chat", [])[-4:]
|
||||||
if chat["by"] == "khoj"
|
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image"
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,11 +6,10 @@ from asgiref.sync import sync_to_async
|
|||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str:
|
async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str:
|
||||||
"""
|
"""
|
||||||
Transcribe audio file using Whisper model via OpenAI's API
|
Transcribe audio file using Whisper model via OpenAI's API
|
||||||
"""
|
"""
|
||||||
# Send the audio data to the Whisper API
|
# Send the audio data to the Whisper API
|
||||||
client = OpenAI(api_key=api_key)
|
|
||||||
response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file)
|
response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file)
|
||||||
return response["text"]
|
return response.text
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from starlette.authentication import requires
|
|||||||
from khoj.configure import configure_server
|
from khoj.configure import configure_server
|
||||||
from khoj.database import adapters
|
from khoj.database import adapters
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||||
from khoj.database.models import ChatModelOptions
|
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
GithubConfig,
|
GithubConfig,
|
||||||
@@ -624,17 +624,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
|
|||||||
|
|
||||||
# Send the audio data to the Whisper API
|
# Send the audio data to the Whisper API
|
||||||
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
|
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
|
||||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
|
||||||
if not speech_to_text_config:
|
if not speech_to_text_config:
|
||||||
# If the user has not configured a speech to text model, return an unsupported on server error
|
# If the user has not configured a speech to text model, return an unsupported on server error
|
||||||
status_code = 501
|
status_code = 501
|
||||||
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
|
||||||
api_key = openai_chat_config.api_key
|
|
||||||
speech2text_model = speech_to_text_config.model_name
|
speech2text_model = speech_to_text_config.model_name
|
||||||
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key)
|
user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client)
|
||||||
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
|
||||||
speech2text_model = speech_to_text_config.model_name
|
speech2text_model = speech_to_text_config.model_name
|
||||||
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model)
|
user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
|
||||||
finally:
|
finally:
|
||||||
# Close and Delete the temporary audio file
|
# Close and Delete the temporary audio file
|
||||||
audio_file.close()
|
audio_file.close()
|
||||||
@@ -793,7 +791,6 @@ async def extract_references_and_questions(
|
|||||||
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
||||||
if conversation_config is None:
|
if conversation_config is None:
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
|
||||||
if (
|
if (
|
||||||
offline_chat_config
|
offline_chat_config
|
||||||
and offline_chat_config.enabled
|
and offline_chat_config.enabled
|
||||||
@@ -810,7 +807,7 @@ async def extract_references_and_questions(
|
|||||||
inferred_queries = extract_questions_offline(
|
inferred_queries = extract_questions_offline(
|
||||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||||
)
|
)
|
||||||
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||||
openai_chat = await ConversationAdapters.get_openai_chat()
|
openai_chat = await ConversationAdapters.get_openai_chat()
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
|
|||||||
@@ -256,15 +256,15 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]:
|
|||||||
|
|
||||||
# Send the audio data to the Whisper API
|
# Send the audio data to the Whisper API
|
||||||
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
||||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
|
||||||
if not text_to_image_config:
|
if not text_to_image_config:
|
||||||
# If the user has not configured a text to image model, return an unsupported on server error
|
# If the user has not configured a text to image model, return an unsupported on server error
|
||||||
status_code = 501
|
status_code = 501
|
||||||
elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
client = openai.OpenAI(api_key=openai_chat_config.api_key)
|
|
||||||
text2image_model = text_to_image_config.model_name
|
text2image_model = text_to_image_config.model_name
|
||||||
try:
|
try:
|
||||||
response = client.images.generate(prompt=message, model=text2image_model, response_format="b64_json")
|
response = state.openai_client.images.generate(
|
||||||
|
prompt=message, model=text2image_model, response_format="b64_json"
|
||||||
|
)
|
||||||
image = response.data[0].b64_json
|
image = response.data[0].b64_json
|
||||||
except openai.OpenAIError as e:
|
except openai.OpenAIError as e:
|
||||||
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
|
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
|
||||||
|
|||||||
@@ -22,17 +22,9 @@ class BaseEncoder(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAI(BaseEncoder):
|
class OpenAI(BaseEncoder):
|
||||||
def __init__(self, model_name, device=None):
|
def __init__(self, model_name, client: openai.OpenAI, device=None):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
if (
|
self.openai_client = client
|
||||||
not state.processor_config
|
|
||||||
or not state.processor_config.conversation
|
|
||||||
or not state.processor_config.conversation.openai_model
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
|
|
||||||
)
|
|
||||||
self.openai_client = openai.OpenAI(api_key=state.processor_config.conversation.openai_model.api_key)
|
|
||||||
self.embedding_dimensions = None
|
self.embedding_dimensions = None
|
||||||
|
|
||||||
def encode(self, entries, device=None, **kwargs):
|
def encode(self, entries, device=None, **kwargs):
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
|
from collections import defaultdict
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
import threading
|
import threading
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from pathlib import Path
|
from openai import OpenAI
|
||||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
|
||||||
from whisper import Whisper
|
from whisper import Whisper
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
from khoj.utils import config as utils_config
|
from khoj.utils import config as utils_config
|
||||||
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import LRU, get_device
|
from khoj.utils.helpers import LRU, get_device
|
||||||
@@ -21,6 +22,7 @@ search_models = SearchModels()
|
|||||||
embeddings_model: EmbeddingsModel = None
|
embeddings_model: EmbeddingsModel = None
|
||||||
cross_encoder_model: CrossEncoderModel = None
|
cross_encoder_model: CrossEncoderModel = None
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
|
openai_client: OpenAI = None
|
||||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||||
whisper_model: Whisper = None
|
whisper_model: Whisper = None
|
||||||
config_file: Path = None
|
config_file: Path = None
|
||||||
|
|||||||
Reference in New Issue
Block a user