diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 19e7d403..717ad859 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -7,6 +7,7 @@ import requests import os # External Packages +import openai import schedule from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -22,6 +23,7 @@ from starlette.authentication import ( # Internal Packages from khoj.database.models import KhojUser, Subscription from khoj.database.adapters import ( + ConversationAdapters, get_all_users, get_or_create_search_model, aget_user_subscription_state, @@ -138,6 +140,10 @@ def configure_server( config = FullConfig() state.config = config + if ConversationAdapters.has_valid_openai_conversation_config(): + openai_config = ConversationAdapters.get_openai_conversation_config() + state.openai_client = openai.OpenAI(api_key=openai_config.api_key) + # Initialize Search Models from Config and initialize content try: state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 41b1844b..211e3cac 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -47,7 +47,7 @@ def extract_questions_offline( if use_history: for chat in conversation_log.get("chat", [])[-4:]: - if chat["by"] == "khoj": + if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image": chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"A: {chat['message']}\n" diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index dc708ab7..7bebc26c 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -41,7 +41,7 @@ def extract_questions( [ f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n' for chat in conversation_log.get("chat", [])[-4:] - if chat["by"] == "khoj" + if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image" ] ) diff --git a/src/khoj/processor/conversation/openai/whisper.py b/src/khoj/processor/conversation/openai/whisper.py index 351319b7..bd0e66df 100644 --- a/src/khoj/processor/conversation/openai/whisper.py +++ b/src/khoj/processor/conversation/openai/whisper.py @@ -6,11 +6,10 @@ from asgiref.sync import sync_to_async from openai import OpenAI -async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str: +async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str: """ Transcribe audio file using Whisper model via OpenAI's API """ # Send the audio data to the Whisper API - client = OpenAI(api_key=api_key) response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file) - return response["text"] + return response.text diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 4d2c80a2..7efd8bfd 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -19,7 +19,7 @@ from starlette.authentication import requires from khoj.configure import configure_server from khoj.database import adapters from khoj.database.adapters import ConversationAdapters, EntryAdapters -from khoj.database.models import ChatModelOptions +from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions from khoj.database.models import Entry as DbEntry from khoj.database.models import ( GithubConfig, @@ -624,17 +624,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi # Send the audio data to the Whisper API speech_to_text_config = await ConversationAdapters.get_speech_to_text_config() - openai_chat_config = await ConversationAdapters.get_openai_chat_config() if not speech_to_text_config: # If the user has not configured a speech to text model, return an unsupported on server error status_code = 501 - elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI: - api_key = openai_chat_config.api_key + elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI: speech2text_model = speech_to_text_config.model_name - user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key) - elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE: + user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client) + elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE: speech2text_model = speech_to_text_config.model_name - user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model) + user_message = await transcribe_audio_offline(audio_filename, speech2text_model) finally: # Close and Delete the temporary audio file audio_file.close() @@ -793,7 +791,6 @@ async def extract_references_and_questions( conversation_config = await ConversationAdapters.aget_conversation_config(user) if conversation_config is None: conversation_config = await ConversationAdapters.aget_default_conversation_config() - openai_chat_config = await ConversationAdapters.aget_openai_conversation_config() if ( offline_chat_config and offline_chat_config.enabled @@ -810,7 +807,7 @@ async def extract_references_and_questions( inferred_queries = extract_questions_offline( defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False ) - elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: + elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = await ConversationAdapters.get_openai_chat_config() openai_chat = await ConversationAdapters.get_openai_chat() api_key = openai_chat_config.api_key diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a780eb20..4e883f35 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -256,15 +256,15 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]: # Send the audio data to the Whisper API text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() - openai_chat_config = await ConversationAdapters.get_openai_chat_config() if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 - elif openai_chat_config and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: - client = openai.OpenAI(api_key=openai_chat_config.api_key) + elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: text2image_model = text_to_image_config.model_name try: - response = client.images.generate(prompt=message, model=text2image_model, response_format="b64_json") + response = state.openai_client.images.generate( + prompt=message, model=text2image_model, response_format="b64_json" + ) image = response.data[0].b64_json except openai.OpenAIError as e: logger.error(f"Image Generation failed with {e.http_status}: {e.error}") diff --git a/src/khoj/utils/models.py b/src/khoj/utils/models.py index 37d09418..e1298b08 100644 --- a/src/khoj/utils/models.py +++ b/src/khoj/utils/models.py @@ -22,17 +22,9 @@ class BaseEncoder(ABC): class OpenAI(BaseEncoder): - def __init__(self, model_name, device=None): + def __init__(self, model_name, client: openai.OpenAI, device=None): self.model_name = model_name - if ( - not state.processor_config - or not state.processor_config.conversation - or not state.processor_config.conversation.openai_model - ): - raise Exception( - f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}" - ) - self.openai_client = openai.OpenAI(api_key=state.processor_config.conversation.openai_model.api_key) + self.openai_client = client self.embedding_dimensions = None def encode(self, entries, device=None, **kwargs): diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index b54cf4b3..d5358868 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -1,15 +1,16 @@ # Standard Packages +from collections import defaultdict import os +from pathlib import Path import threading from typing import List, Dict -from collections import defaultdict # External Packages -from pathlib import Path -from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel +from openai import OpenAI from whisper import Whisper # Internal Packages +from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.utils import config as utils_config from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel from khoj.utils.helpers import LRU, get_device @@ -21,6 +22,7 @@ search_models = SearchModels() embeddings_model: EmbeddingsModel = None cross_encoder_model: CrossEncoderModel = None content_index = ContentIndex() +openai_client: OpenAI = None gpt4all_processor_config: GPT4AllProcessorModel = None whisper_model: Whisper = None config_file: Path = None