Support customization of the OpenAI base url in admin settings (#725)

- Allow self-hosted users to customize their open ai base url. This allows you to easily use a proxy service and extend support for other models.
- This also includes a migration that associates any existing openai chat model configuration with an openai processor configuration
- Make changing model a paid/subscriber feature
- Removes usage of langchain's OpenAI wrapper for better control over parsing input/output
This commit is contained in:
sabaimran
2024-04-27 05:54:35 -07:00
committed by GitHub
parent 49834e3b00
commit 2047b0c973
14 changed files with 219 additions and 100 deletions

View File

@@ -623,10 +623,6 @@ class ConversationAdapters:
def get_openai_conversation_config(): def get_openai_conversation_config():
return OpenAIProcessorConversationConfig.objects.filter().first() return OpenAIProcessorConversationConfig.objects.filter().first()
@staticmethod
async def aget_openai_conversation_config():
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
@staticmethod @staticmethod
def has_valid_openai_conversation_config(): def has_valid_openai_conversation_config():
return OpenAIProcessorConversationConfig.objects.filter().exists() return OpenAIProcessorConversationConfig.objects.filter().exists()
@@ -659,7 +655,7 @@ class ConversationAdapters:
@staticmethod @staticmethod
async def aget_default_conversation_config(): async def aget_default_conversation_config():
return await ChatModelOptions.objects.filter().afirst() return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
@staticmethod @staticmethod
def save_conversation( def save_conversation(
@@ -697,29 +693,15 @@ class ConversationAdapters:
user_conversation_config.setting = new_config user_conversation_config.setting = new_config
user_conversation_config.save() user_conversation_config.save()
@staticmethod
async def get_default_offline_llm():
return await ChatModelOptions.objects.filter(model_type="offline").afirst()
@staticmethod @staticmethod
async def aget_user_conversation_config(user: KhojUser): async def aget_user_conversation_config(user: KhojUser):
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() config = (
await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst()
)
if not config: if not config:
return None return None
return config.setting return config.setting
@staticmethod
async def has_openai_chat():
return await OpenAIProcessorConversationConfig.objects.filter().aexists()
@staticmethod
async def aget_default_openai_llm():
return await ChatModelOptions.objects.filter(model_type="openai").afirst()
@staticmethod
async def get_openai_chat_config():
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
@staticmethod @staticmethod
async def get_speech_to_text_config(): async def get_speech_to_text_config():
return await SpeechToTextModelOptions.objects.filter().afirst() return await SpeechToTextModelOptions.objects.filter().afirst()
@@ -744,7 +726,8 @@ class ConversationAdapters:
@staticmethod @staticmethod
def get_valid_conversation_config(user: KhojUser, conversation: Conversation): def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
if conversation.agent and conversation.agent.chat_model: agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
if agent and agent.chat_model:
conversation_config = conversation.agent.chat_model conversation_config = conversation.agent.chat_model
else: else:
conversation_config = ConversationAdapters.get_conversation_config(user) conversation_config = ConversationAdapters.get_conversation_config(user)
@@ -760,8 +743,7 @@ class ConversationAdapters:
return conversation_config return conversation_config
openai_chat_config = ConversationAdapters.get_openai_conversation_config() if conversation_config.model_type == "openai" and conversation_config.openai_config:
if openai_chat_config and conversation_config.model_type == "openai":
return conversation_config return conversation_config
else: else:

View File

@@ -0,0 +1,51 @@
# Generated by Django 4.2.10 on 2024-04-24 05:46
import django.db.models.deletion
from django.db import migrations, models
def attach_openai_config(apps, schema_editor):
OpenAIProcessorConversationConfig = apps.get_model("database", "OpenAIProcessorConversationConfig")
openai_processor_conversation_config = OpenAIProcessorConversationConfig.objects.first()
if openai_processor_conversation_config:
ChatModelOptions = apps.get_model("database", "ChatModelOptions")
for chat_model_option in ChatModelOptions.objects.all():
if chat_model_option.model_type == "openai":
chat_model_option.openai_config = openai_processor_conversation_config
chat_model_option.save()
def separate_openai_config(apps, schema_editor):
pass
class Migration(migrations.Migration):
dependencies = [
("database", "0036_delete_offlinechatprocessorconversationconfig"),
]
operations = [
migrations.AddField(
model_name="chatmodeloptions",
name="openai_config",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.openaiprocessorconversationconfig",
),
),
migrations.AddField(
model_name="openaiprocessorconversationconfig",
name="api_base_url",
field=models.URLField(blank=True, default=None, null=True),
),
migrations.AddField(
model_name="openaiprocessorconversationconfig",
name="name",
field=models.CharField(default="default", max_length=200),
preserve_default=False,
),
migrations.RunPython(attach_openai_config, reverse_code=separate_openai_config),
]

View File

@@ -0,0 +1,14 @@
# Generated by Django 4.2.10 on 2024-04-25 08:57
from typing import List
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0037_chatmodeloptions_openai_config_and_more"),
("database", "0037_searchmodelconfig_bi_encoder_docs_encode_config_and_more"),
]
operations: List[str] = []

View File

@@ -73,6 +73,12 @@ class Subscription(BaseModel):
renewal_date = models.DateTimeField(null=True, default=None, blank=True) renewal_date = models.DateTimeField(null=True, default=None, blank=True)
class OpenAIProcessorConversationConfig(BaseModel):
name = models.CharField(max_length=200)
api_key = models.CharField(max_length=200)
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
class ChatModelOptions(BaseModel): class ChatModelOptions(BaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
OPENAI = "openai" OPENAI = "openai"
@@ -82,6 +88,9 @@ class ChatModelOptions(BaseModel):
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF") chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
openai_config = models.ForeignKey(
OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True
)
class Agent(BaseModel): class Agent(BaseModel):
@@ -211,10 +220,6 @@ class TextToImageModelConfig(BaseModel):
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200)
class SpeechToTextModelOptions(BaseModel): class SpeechToTextModelOptions(BaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
OPENAI = "openai" OPENAI = "openai"

View File

@@ -191,9 +191,15 @@
</select> </select>
</div> </div>
<div class="card-action-row"> <div class="card-action-row">
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
<button id="save-model" class="card-button happy" onclick="updateChatModel()"> <button id="save-model" class="card-button happy" onclick="updateChatModel()">
Save Save
</button> </button>
{% else %}
<button id="save-model" class="card-button" disabled>
Subscribe to use different models
</button>
{% endif %}
</div> </div>
</div> </div>
</div> </div>

View File

@@ -121,14 +121,16 @@ def migrate_server_pg(args):
if openai.get("chat-model") is None: if openai.get("chat-model") is None:
openai["chat-model"] = "gpt-3.5-turbo" openai["chat-model"] = "gpt-3.5-turbo"
OpenAIProcessorConversationConfig.objects.create( openai_config = OpenAIProcessorConversationConfig.objects.create(
api_key=openai.get("api-key"), api_key=openai.get("api-key"), name="default"
) )
ChatModelOptions.objects.create( ChatModelOptions.objects.create(
chat_model=openai.get("chat-model"), chat_model=openai.get("chat-model"),
tokenizer=processor_conversation.get("tokenizer"), tokenizer=processor_conversation.get("tokenizer"),
max_prompt_size=processor_conversation.get("max-prompt-size"), max_prompt_size=processor_conversation.get("max-prompt-size"),
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModelOptions.ModelType.OPENAI,
openai_config=openai_config,
) )
save_config_to_file(raw_config, args.config_file) save_config_to_file(raw_config, args.config_file)

View File

@@ -23,6 +23,7 @@ def extract_questions(
model: Optional[str] = "gpt-4-turbo-preview", model: Optional[str] = "gpt-4-turbo-preview",
conversation_log={}, conversation_log={},
api_key=None, api_key=None,
api_base_url=None,
temperature=0, temperature=0,
max_tokens=100, max_tokens=100,
location_data: LocationData = None, location_data: LocationData = None,
@@ -64,12 +65,12 @@ def extract_questions(
# Get Response from GPT # Get Response from GPT
response = completion_with_backoff( response = completion_with_backoff(
messages=messages, messages=messages,
completion_kwargs={"temperature": temperature, "max_tokens": max_tokens}, model=model,
model_kwargs={ temperature=temperature,
"model_name": model, max_tokens=max_tokens,
"openai_api_key": api_key, api_base_url=api_base_url,
"model_kwargs": {"response_format": {"type": "json_object"}}, model_kwargs={"response_format": {"type": "json_object"}},
}, openai_api_key=api_key,
) )
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
@@ -89,7 +90,7 @@ def extract_questions(
return questions return questions
def send_message_to_model(messages, api_key, model, response_type="text"): def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None):
""" """
Send message to model Send message to model
""" """
@@ -97,11 +98,10 @@ def send_message_to_model(messages, api_key, model, response_type="text"):
# Get Response from GPT # Get Response from GPT
return completion_with_backoff( return completion_with_backoff(
messages=messages, messages=messages,
model_kwargs={ model=model,
"model_name": model, openai_api_key=api_key,
"openai_api_key": api_key, api_base_url=api_base_url,
"model_kwargs": {"response_format": {"type": response_type}}, model_kwargs={"response_format": {"type": response_type}},
},
) )
@@ -112,6 +112,7 @@ def converse(
conversation_log={}, conversation_log={},
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
temperature: float = 0.2, temperature: float = 0.2,
completion_func=None, completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
@@ -181,6 +182,7 @@ def converse(
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
openai_api_key=api_key, openai_api_key=api_key,
api_base_url=api_base_url,
completion_func=completion_func, completion_func=completion_func,
model_kwargs={"stop": ["Notes:\n["]}, model_kwargs={"stop": ["Notes:\n["]},
) )

View File

@@ -1,12 +1,9 @@
import logging import logging
import os import os
from threading import Thread from threading import Thread
from typing import Any from typing import Dict
import openai import openai
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_openai import ChatOpenAI
from tenacity import ( from tenacity import (
before_sleep_log, before_sleep_log,
retry, retry,
@@ -20,14 +17,7 @@ from khoj.processor.conversation.utils import ThreadedGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
openai_clients: Dict[str, openai.OpenAI] = {}
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
def __init__(self, gen: ThreadedGenerator):
super().__init__()
self.gen = gen
def on_llm_new_token(self, token: str, **kwargs) -> Any:
self.gen.send(token)
@retry( @retry(
@@ -43,13 +33,37 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
def completion_with_backoff(messages, model_kwargs={}, completion_kwargs={}) -> str: def completion_with_backoff(
if not "openai_api_key" in model_kwargs: messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, max_tokens=None
model_kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") ) -> str:
llm = ChatOpenAI(**model_kwargs, request_timeout=20, max_retries=1) client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI = openai_clients.get(client_key)
if not client:
client = openai.OpenAI(
api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
base_url=api_base_url,
)
openai_clients[client_key] = client
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
chat = client.chat.completions.create(
stream=True,
messages=formatted_messages, # type: ignore
model=model, # type: ignore
temperature=temperature,
timeout=20,
max_tokens=max_tokens,
**(model_kwargs or dict()),
)
aggregated_response = "" aggregated_response = ""
for chunk in llm.stream(messages, **completion_kwargs): for chunk in chat:
aggregated_response += chunk.content delta_chunk = chunk.choices[0].delta # type: ignore
if isinstance(delta_chunk, str):
aggregated_response += delta_chunk
elif delta_chunk.content:
aggregated_response += delta_chunk.content
return aggregated_response return aggregated_response
@@ -73,30 +87,45 @@ def chat_completion_with_backoff(
model_name, model_name,
temperature, temperature,
openai_api_key=None, openai_api_key=None,
api_base_url=None,
completion_func=None, completion_func=None,
model_kwargs=None, model_kwargs=None,
): ):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs)) t = Thread(
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
)
t.start() t.start()
return g return g
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None): def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
callback_handler = StreamingChatCallbackHandler(g) client_key = f"{openai_api_key}--{api_base_url}"
chat = ChatOpenAI( if client_key not in openai_clients:
streaming=True, client: openai.OpenAI = openai.OpenAI(
verbose=True, api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
callback_manager=BaseCallbackManager([callback_handler]), base_url=api_base_url,
model_name=model_name, # type: ignore )
openai_clients[client_key] = client
else:
client: openai.OpenAI = openai_clients[client_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
chat = client.chat.completions.create(
stream=True,
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature, temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), timeout=20,
model_kwargs=model_kwargs, **(model_kwargs or dict()),
request_timeout=20,
max_retries=1,
client=None,
) )
chat(messages=messages) for chunk in chat:
delta_chunk = chunk.choices[0].delta
if isinstance(delta_chunk, str):
g.send(delta_chunk)
elif delta_chunk.content:
g.send(delta_chunk.content)
g.close() g.close()

View File

@@ -14,6 +14,7 @@ from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ClientApplication, KhojUser from khoj.database.models import ClientApplication, KhojUser
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, merge_dicts from khoj.utils.helpers import is_none_or_empty, merge_dicts
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -186,19 +187,31 @@ def truncate_messages(
max_prompt_size, max_prompt_size,
model_name: str, model_name: str,
loaded_model: Optional[Llama] = None, loaded_model: Optional[Llama] = None,
tokenizer_name="hf-internal-testing/llama-tokenizer", tokenizer_name=None,
) -> list[ChatMessage]: ) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model""" """Truncate messages to fit within max prompt size supported by model"""
default_tokenizer = "hf-internal-testing/llama-tokenizer"
try: try:
if loaded_model: if loaded_model:
encoder = loaded_model.tokenizer() encoder = loaded_model.tokenizer()
elif model_name.startswith("gpt-"): elif model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
elif tokenizer_name:
if tokenizer_name in state.pretrained_tokenizers:
encoder = state.pretrained_tokenizers[tokenizer_name]
else:
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
state.pretrained_tokenizers[tokenizer_name] = encoder
else: else:
encoder = download_model(model_name).tokenizer() encoder = download_model(model_name).tokenizer()
except: except:
encoder = AutoTokenizer.from_pretrained(tokenizer_name) if default_tokenizer in state.pretrained_tokenizers:
encoder = state.pretrained_tokenizers[default_tokenizer]
else:
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
state.pretrained_tokenizers[default_tokenizer] = encoder
logger.warning( logger.warning(
f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
) )

View File

@@ -267,7 +267,6 @@ async def transcribe(
async def extract_references_and_questions( async def extract_references_and_questions(
request: Request, request: Request,
common: CommonQueryParams,
meta_log: dict, meta_log: dict,
q: str, q: str,
n: int, n: int,
@@ -303,14 +302,12 @@ async def extract_references_and_questions(
# Infer search queries from user message # Infer search queries from user message
with timer("Extracting search queries took", logger): with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled. # If we've reached here, either the user has enabled offline chat or the openai model is enabled.
conversation_config = await ConversationAdapters.aget_conversation_config(user) conversation_config = await ConversationAdapters.aget_default_conversation_config()
if conversation_config is None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
using_offline_chat = True using_offline_chat = True
default_offline_llm = await ConversationAdapters.get_default_offline_llm() chat_model = conversation_config.chat_model
chat_model = default_offline_llm.chat_model max_tokens = conversation_config.max_prompt_size
max_tokens = default_offline_llm.max_prompt_size
if state.offline_chat_processor_config is None: if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
@@ -324,11 +321,10 @@ async def extract_references_and_questions(
location_data=location_data, location_data=location_data,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=conversation_config.max_prompt_size,
) )
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config() openai_chat_config = conversation_config.openai_config
default_openai_llm = await ConversationAdapters.aget_default_openai_llm()
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key
chat_model = default_openai_llm.chat_model chat_model = conversation_config.chat_model
inferred_queries = extract_questions( inferred_queries = extract_questions(
defiltered_query, defiltered_query,
model=chat_model, model=chat_model,

View File

@@ -380,7 +380,7 @@ async def websocket_endpoint(
q = q.replace(f"/{cmd.value}", "").strip() q = q.replace(f"/{cmd.value}", "").strip()
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
websocket, None, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update websocket, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update
) )
if compiled_references: if compiled_references:
@@ -575,7 +575,7 @@ async def chat(
user_name = await aget_user_name(user) user_name = await aget_user_name(user)
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location request, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
) )
online_results: Dict[str, Dict] = {} online_results: Dict[str, Dict] = {}

View File

@@ -7,7 +7,7 @@ from asgiref.sync import sync_to_async
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import Response from fastapi.responses import Response
from starlette.authentication import requires from starlette.authentication import has_required_scope, requires
from khoj.database import adapters from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.adapters import ConversationAdapters, EntryAdapters
@@ -20,6 +20,7 @@ from khoj.database.models import (
LocalPdfConfig, LocalPdfConfig,
LocalPlaintextConfig, LocalPlaintextConfig,
NotionConfig, NotionConfig,
Subscription,
) )
from khoj.routers.helpers import CommonQueryParams, update_telemetry_state from khoj.routers.helpers import CommonQueryParams, update_telemetry_state
from khoj.utils import constants, state from khoj.utils import constants, state
@@ -236,6 +237,10 @@ async def update_chat_model(
client: Optional[str] = None, client: Optional[str] = None,
): ):
user = request.user.object user = request.user.object
subscribed = has_required_scope(request, ["premium"])
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id)) new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))

View File

@@ -70,13 +70,14 @@ def validate_conversation_config():
if default_config is None: if default_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
if default_config.model_type == "openai" and not ConversationAdapters.has_valid_openai_conversation_config(): if default_config.model_type == "openai" and not default_config.openai_config:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
async def is_ready_to_chat(user: KhojUser): async def is_ready_to_chat(user: KhojUser):
has_openai_config = await ConversationAdapters.has_openai_chat() user_conversation_config = (await ConversationAdapters.aget_user_conversation_config(user)) or (
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user) await ConversationAdapters.aget_default_conversation_config()
)
if user_conversation_config and user_conversation_config.model_type == "offline": if user_conversation_config and user_conversation_config.model_type == "offline":
chat_model = user_conversation_config.chat_model chat_model = user_conversation_config.chat_model
@@ -86,8 +87,14 @@ async def is_ready_to_chat(user: KhojUser):
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
return True return True
if not has_openai_config: if (
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.") user_conversation_config
and user_conversation_config.model_type == "openai"
and user_conversation_config.openai_config
):
return True
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
def update_telemetry_state( def update_telemetry_state(
@@ -407,8 +414,9 @@ async def send_message_to_model_wrapper(
) )
elif conversation_config.model_type == "openai": elif conversation_config.model_type == "openai":
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config() openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key
api_base_url = openai_chat_config.api_base_url
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=message, user_message=message,
system_message=system_message, system_message=system_message,
@@ -418,7 +426,11 @@ async def send_message_to_model_wrapper(
) )
openai_response = send_message_to_model( openai_response = send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
api_base_url=api_base_url,
) )
return openai_response return openai_response
@@ -480,7 +492,7 @@ def generate_chat_response(
) )
elif conversation_config.model_type == "openai": elif conversation_config.model_type == "openai":
openai_chat_config = ConversationAdapters.get_openai_conversation_config() openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
chat_response = converse( chat_response = converse(
@@ -490,6 +502,7 @@ def generate_chat_response(
conversation_log=meta_log, conversation_log=meta_log,
model=chat_model, model=chat_model,
api_key=api_key, api_key=api_key,
api_base_url=openai_chat_config.api_base_url,
completion_func=partial_completion, completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=conversation_config.max_prompt_size,

View File

@@ -2,7 +2,7 @@ import os
import threading import threading
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Any, Dict, List
from openai import OpenAI from openai import OpenAI
from whisper import Whisper from whisper import Whisper
@@ -34,6 +34,7 @@ khoj_version: str = None
device = get_device() device = get_device()
chat_on_gpu: bool = True chat_on_gpu: bool = True
anonymous_mode: bool = False anonymous_mode: bool = False
pretrained_tokenizers: Dict[str, Any] = dict()
billing_enabled: bool = ( billing_enabled: bool = (
os.getenv("STRIPE_API_KEY") is not None os.getenv("STRIPE_API_KEY") is not None
and os.getenv("STRIPE_SIGNING_SECRET") is not None and os.getenv("STRIPE_SIGNING_SECRET") is not None