Rename Chat Model Options table to Chat Model as short & readable (#1003)

- Previous was incorrectly plural but was defining only a single model
- Rename chat model table field to name
- Update documentation
- Update references functions and variables to match new name
This commit is contained in:
Debanjum
2024-12-12 11:24:16 -08:00
committed by GitHub
parent 9be26e1bd2
commit 01bc6d35dc
26 changed files with 369 additions and 308 deletions

View File

@@ -36,7 +36,7 @@ from torch import Tensor
from khoj.database.models import (
Agent,
AiModelApi,
ChatModelOptions,
ChatModel,
ClientApplication,
Conversation,
Entry,
@@ -736,8 +736,8 @@ class AgentAdapters:
@staticmethod
def create_default_agent(user: KhojUser):
default_conversation_config = ConversationAdapters.get_default_conversation_config(user)
if default_conversation_config is None:
default_chat_model = ConversationAdapters.get_default_chat_model(user)
if default_chat_model is None:
logger.info("No default conversation config found, skipping default agent creation")
return None
default_personality = prompts.personality.format(current_date="placeholder", day_of_week="placeholder")
@@ -746,7 +746,7 @@ class AgentAdapters:
if agent:
agent.personality = default_personality
agent.chat_model = default_conversation_config
agent.chat_model = default_chat_model
agent.slug = AgentAdapters.DEFAULT_AGENT_SLUG
agent.name = AgentAdapters.DEFAULT_AGENT_NAME
agent.privacy_level = Agent.PrivacyLevel.PUBLIC
@@ -760,7 +760,7 @@ class AgentAdapters:
name=AgentAdapters.DEFAULT_AGENT_NAME,
privacy_level=Agent.PrivacyLevel.PUBLIC,
managed_by_admin=True,
chat_model=default_conversation_config,
chat_model=default_chat_model,
personality=default_personality,
slug=AgentAdapters.DEFAULT_AGENT_SLUG,
)
@@ -787,7 +787,7 @@ class AgentAdapters:
output_modes: List[str],
slug: Optional[str] = None,
):
chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst()
chat_model_option = await ChatModel.objects.filter(name=chat_model).afirst()
# Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug
agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create(
@@ -972,29 +972,29 @@ class ConversationAdapters:
@staticmethod
@require_valid_user
def has_any_conversation_config(user: KhojUser):
return ChatModelOptions.objects.filter(user=user).exists()
def has_any_chat_model(user: KhojUser):
return ChatModel.objects.filter(user=user).exists()
@staticmethod
def get_all_conversation_configs():
return ChatModelOptions.objects.all()
def get_all_chat_models():
return ChatModel.objects.all()
@staticmethod
async def aget_all_conversation_configs():
return await sync_to_async(list)(ChatModelOptions.objects.prefetch_related("ai_model_api").all())
async def aget_all_chat_models():
return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
@staticmethod
def get_vision_enabled_config():
conversation_configurations = ConversationAdapters.get_all_conversation_configs()
for config in conversation_configurations:
chat_models = ConversationAdapters.get_all_chat_models()
for config in chat_models:
if config.vision_enabled:
return config
return None
@staticmethod
async def aget_vision_enabled_config():
conversation_configurations = await ConversationAdapters.aget_all_conversation_configs()
for config in conversation_configurations:
chat_models = await ConversationAdapters.aget_all_chat_models()
for config in chat_models:
if config.vision_enabled:
return config
return None
@@ -1010,7 +1010,7 @@ class ConversationAdapters:
@staticmethod
@arequire_valid_user
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
config = await ChatModel.objects.filter(id=conversation_processor_config_id).afirst()
if not config:
return None
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
@@ -1026,24 +1026,24 @@ class ConversationAdapters:
return new_config
@staticmethod
def get_conversation_config(user: KhojUser):
def get_chat_model(user: KhojUser):
subscribed = is_user_subscribed(user)
if not subscribed:
return ConversationAdapters.get_default_conversation_config(user)
return ConversationAdapters.get_default_chat_model(user)
config = UserConversationConfig.objects.filter(user=user).first()
if config:
return config.setting
return ConversationAdapters.get_advanced_conversation_config(user)
return ConversationAdapters.get_advanced_chat_model(user)
@staticmethod
async def aget_conversation_config(user: KhojUser):
async def aget_chat_model(user: KhojUser):
subscribed = await ais_user_subscribed(user)
if not subscribed:
return await ConversationAdapters.aget_default_conversation_config(user)
return await ConversationAdapters.aget_default_chat_model(user)
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if config:
return config.setting
return ConversationAdapters.aget_advanced_conversation_config(user)
return ConversationAdapters.aget_advanced_chat_model(user)
@staticmethod
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
@@ -1064,7 +1064,7 @@ class ConversationAdapters:
return VoiceModelOption.objects.first()
@staticmethod
def get_default_conversation_config(user: KhojUser = None):
def get_default_chat_model(user: KhojUser = None):
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
# Get the server chat settings
server_chat_settings = ServerChatSettings.objects.first()
@@ -1084,10 +1084,10 @@ class ConversationAdapters:
return user_chat_settings.setting
# Get the first chat model if even the user chat settings are not set
return ChatModelOptions.objects.filter().first()
return ChatModel.objects.filter().first()
@staticmethod
async def aget_default_conversation_config(user: KhojUser = None):
async def aget_default_chat_model(user: KhojUser = None):
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
# Get the server chat settings
server_chat_settings: ServerChatSettings = (
@@ -1117,17 +1117,17 @@ class ConversationAdapters:
return user_chat_settings.setting
# Get the first chat model if even the user chat settings are not set
return await ChatModelOptions.objects.filter().prefetch_related("ai_model_api").afirst()
return await ChatModel.objects.filter().prefetch_related("ai_model_api").afirst()
@staticmethod
def get_advanced_conversation_config(user: KhojUser):
def get_advanced_chat_model(user: KhojUser):
server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
return server_chat_settings.chat_advanced
return ConversationAdapters.get_default_conversation_config(user)
return ConversationAdapters.get_default_chat_model(user)
@staticmethod
async def aget_advanced_conversation_config(user: KhojUser = None):
async def aget_advanced_chat_model(user: KhojUser = None):
server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter()
.prefetch_related("chat_advanced", "chat_advanced__ai_model_api")
@@ -1135,7 +1135,7 @@ class ConversationAdapters:
)
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
return server_chat_settings.chat_advanced
return await ConversationAdapters.aget_default_conversation_config(user)
return await ConversationAdapters.aget_default_chat_model(user)
@staticmethod
async def aget_server_webscraper():
@@ -1247,16 +1247,16 @@ class ConversationAdapters:
@staticmethod
def get_conversation_processor_options():
return ChatModelOptions.objects.all()
return ChatModel.objects.all()
@staticmethod
def set_conversation_processor_config(user: KhojUser, new_config: ChatModelOptions):
def set_user_chat_model(user: KhojUser, chat_model: ChatModel):
user_conversation_config, _ = UserConversationConfig.objects.get_or_create(user=user)
user_conversation_config.setting = new_config
user_conversation_config.setting = chat_model
user_conversation_config.save()
@staticmethod
async def aget_user_conversation_config(user: KhojUser):
async def aget_user_chat_model(user: KhojUser):
config = (
await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst()
)
@@ -1288,33 +1288,33 @@ class ConversationAdapters:
return random.sample(all_questions, max_results)
@staticmethod
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
def get_valid_chat_model(user: KhojUser, conversation: Conversation):
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
if agent and agent.chat_model:
conversation_config = conversation.agent.chat_model
chat_model = conversation.agent.chat_model
else:
conversation_config = ConversationAdapters.get_conversation_config(user)
chat_model = ConversationAdapters.get_chat_model(user)
if conversation_config is None:
conversation_config = ConversationAdapters.get_default_conversation_config()
if chat_model is None:
chat_model = ConversationAdapters.get_default_chat_model()
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
chat_model_name = chat_model.name
max_tokens = chat_model.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
return conversation_config
return chat_model
if (
conversation_config.model_type
chat_model.model_type
in [
ChatModelOptions.ModelType.ANTHROPIC,
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE,
ChatModel.ModelType.ANTHROPIC,
ChatModel.ModelType.OPENAI,
ChatModel.ModelType.GOOGLE,
]
) and conversation_config.ai_model_api:
return conversation_config
) and chat_model.ai_model_api:
return chat_model
else:
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")

View File

@@ -16,7 +16,7 @@ from unfold import admin as unfold_admin
from khoj.database.models import (
Agent,
AiModelApi,
ChatModelOptions,
ChatModel,
ClientApplication,
Conversation,
Entry,
@@ -212,15 +212,15 @@ class KhojUserSubscription(unfold_admin.ModelAdmin):
list_filter = ("type",)
@admin.register(ChatModelOptions)
class ChatModelOptionsAdmin(unfold_admin.ModelAdmin):
@admin.register(ChatModel)
class ChatModelAdmin(unfold_admin.ModelAdmin):
list_display = (
"id",
"chat_model",
"name",
"ai_model_api",
"max_prompt_size",
)
search_fields = ("id", "chat_model", "ai_model_api__name")
search_fields = ("id", "name", "ai_model_api__name")
@admin.register(TextToImageModelConfig)
@@ -385,7 +385,7 @@ class UserConversationConfigAdmin(unfold_admin.ModelAdmin):
"get_chat_model",
"get_subscription_type",
)
search_fields = ("id", "user__email", "setting__chat_model", "user__subscription__type")
search_fields = ("id", "user__email", "setting__name", "user__subscription__type")
ordering = ("-updated_at",)
def get_user_email(self, obj):
@@ -395,10 +395,10 @@ class UserConversationConfigAdmin(unfold_admin.ModelAdmin):
get_user_email.admin_order_field = "user__email" # type: ignore
def get_chat_model(self, obj):
return obj.setting.chat_model if obj.setting else None
return obj.setting.name if obj.setting else None
get_chat_model.short_description = "Chat Model" # type: ignore
get_chat_model.admin_order_field = "setting__chat_model" # type: ignore
get_chat_model.admin_order_field = "setting__name" # type: ignore
def get_subscription_type(self, obj):
if hasattr(obj.user, "subscription"):

View File

@@ -0,0 +1,62 @@
# Generated by Django 5.0.9 on 2024-12-09 04:21
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0076_rename_openaiprocessorconversationconfig_aimodelapi_and_more"),
]
operations = [
migrations.RenameModel(
old_name="ChatModelOptions",
new_name="ChatModel",
),
migrations.RenameField(
model_name="chatmodel",
old_name="chat_model",
new_name="name",
),
migrations.AlterField(
model_name="agent",
name="chat_model",
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.chatmodel"),
),
migrations.AlterField(
model_name="serverchatsettings",
name="chat_advanced",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="chat_advanced",
to="database.chatmodel",
),
),
migrations.AlterField(
model_name="serverchatsettings",
name="chat_default",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="chat_default",
to="database.chatmodel",
),
),
migrations.AlterField(
model_name="userconversationconfig",
name="setting",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.chatmodel",
),
),
]

View File

@@ -193,7 +193,7 @@ class AiModelApi(DbBaseModel):
return self.name
class ChatModelOptions(DbBaseModel):
class ChatModel(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
@@ -203,13 +203,13 @@ class ChatModelOptions(DbBaseModel):
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
vision_enabled = models.BooleanField(default=False)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
def __str__(self):
return self.chat_model
return self.name
class VoiceModelOption(DbBaseModel):
@@ -297,7 +297,7 @@ class Agent(DbBaseModel):
models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list, null=True, blank=True
)
managed_by_admin = models.BooleanField(default=False)
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
chat_model = models.ForeignKey(ChatModel, on_delete=models.CASCADE)
slug = models.CharField(max_length=200, unique=True)
style_color = models.CharField(max_length=200, choices=StyleColorTypes.choices, default=StyleColorTypes.BLUE)
style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHTBULB)
@@ -438,10 +438,10 @@ class WebScraper(DbBaseModel):
class ServerChatSettings(DbBaseModel):
chat_default = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
)
chat_advanced = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
)
web_scraper = models.ForeignKey(
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
@@ -563,7 +563,7 @@ class SpeechToTextModelOptions(DbBaseModel):
class UserConversationConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
setting = models.ForeignKey(ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True)
class UserVoiceModelConfig(DbBaseModel):

View File

@@ -60,7 +60,7 @@ import logging
from packaging import version
from khoj.database.models import AiModelApi, ChatModelOptions, SearchModelConfig
from khoj.database.models import AiModelApi, ChatModel, SearchModelConfig
from khoj.utils.yaml import load_config_from_file, save_config_to_file
logger = logging.getLogger(__name__)
@@ -98,11 +98,11 @@ def migrate_server_pg(args):
if "offline-chat" in raw_config["processor"]["conversation"]:
offline_chat = raw_config["processor"]["conversation"]["offline-chat"]
ChatModelOptions.objects.create(
chat_model=offline_chat.get("chat-model"),
ChatModel.objects.create(
name=offline_chat.get("chat-model"),
tokenizer=processor_conversation.get("tokenizer"),
max_prompt_size=processor_conversation.get("max-prompt-size"),
model_type=ChatModelOptions.ModelType.OFFLINE,
model_type=ChatModel.ModelType.OFFLINE,
)
if (
@@ -119,11 +119,11 @@ def migrate_server_pg(args):
openai_model_api = AiModelApi.objects.create(api_key=openai.get("api-key"), name="default")
ChatModelOptions.objects.create(
chat_model=openai.get("chat-model"),
ChatModel.objects.create(
name=openai.get("chat-model"),
tokenizer=processor_conversation.get("tokenizer"),
max_prompt_size=processor_conversation.get("max-prompt-size"),
model_type=ChatModelOptions.ModelType.OPENAI,
model_type=ChatModel.ModelType.OPENAI,
ai_model_api=openai_model_api,
)

View File

@@ -5,7 +5,7 @@ from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
@@ -85,7 +85,7 @@ def extract_questions_anthropic(
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
model_type=ChatModel.ModelType.ANTHROPIC,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
@@ -218,7 +218,7 @@ def converse_anthropic(
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
model_type=ChatModel.ModelType.ANTHROPIC,
query_files=query_files,
generated_files=generated_files,
generated_asset_results=generated_asset_results,

View File

@@ -5,7 +5,7 @@ from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
format_messages_for_gemini,
@@ -86,7 +86,7 @@ def extract_questions_gemini(
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.GOOGLE,
model_type=ChatModel.ModelType.GOOGLE,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
@@ -229,7 +229,7 @@ def converse_gemini(
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE,
model_type=ChatModel.ModelType.GOOGLE,
query_files=query_files,
generated_files=generated_files,
generated_asset_results=generated_asset_results,

View File

@@ -9,7 +9,7 @@ import pyjson5
from langchain.schema import ChatMessage
from llama_cpp import Llama
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
@@ -96,7 +96,7 @@ def extract_questions_offline(
model_name=model,
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
model_type=ChatModelOptions.ModelType.OFFLINE,
model_type=ChatModel.ModelType.OFFLINE,
query_files=query_files,
)
@@ -232,7 +232,7 @@ def converse_offline(
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.OFFLINE,
model_type=ChatModel.ModelType.OFFLINE,
query_files=query_files,
generated_files=generated_files,
generated_asset_results=generated_asset_results,

View File

@@ -5,7 +5,7 @@ from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
@@ -83,7 +83,7 @@ def extract_questions(
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.OPENAI,
model_type=ChatModel.ModelType.OPENAI,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
@@ -220,7 +220,7 @@ def converse_openai(
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI,
model_type=ChatModel.ModelType.OPENAI,
query_files=query_files,
generated_files=generated_files,
generated_asset_results=generated_asset_results,

View File

@@ -24,7 +24,7 @@ from llama_cpp.llama import Llama
from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
from khoj.database.models import ChatModel, ClientApplication, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.search_filter.base_filter import BaseFilter
@@ -330,9 +330,9 @@ def construct_structured_message(
Format messages into appropriate multimedia format for supported chat model types
"""
if model_type in [
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE,
ChatModelOptions.ModelType.ANTHROPIC,
ChatModel.ModelType.OPENAI,
ChatModel.ModelType.GOOGLE,
ChatModel.ModelType.ANTHROPIC,
]:
if not attached_file_context and not (vision_enabled and images):
return message

View File

@@ -28,12 +28,7 @@ from khoj.database.adapters import (
get_default_search_model,
get_user_photo,
)
from khoj.database.models import (
Agent,
ChatModelOptions,
KhojUser,
SpeechToTextModelOptions,
)
from khoj.database.models import Agent, ChatModel, KhojUser, SpeechToTextModelOptions
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import (
extract_questions_anthropic,
@@ -404,15 +399,15 @@ async def extract_references_and_questions(
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
vision_enabled = conversation_config.vision_enabled
chat_model = await ConversationAdapters.aget_default_chat_model(user)
vision_enabled = chat_model.vision_enabled
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
using_offline_chat = True
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
chat_model_name = chat_model.name
max_tokens = chat_model.max_prompt_size
if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
@@ -424,18 +419,18 @@ async def extract_references_and_questions(
should_extract_questions=True,
location_data=location_data,
user=user,
max_prompt_size=conversation_config.max_prompt_size,
max_prompt_size=chat_model.max_prompt_size,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
api_key = conversation_config.ai_model_api.api_key
base_url = conversation_config.ai_model_api.api_base_url
chat_model = conversation_config.chat_model
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
api_key = chat_model.ai_model_api.api_key
base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions(
defiltered_query,
model=chat_model,
model=chat_model_name,
api_key=api_key,
api_base_url=base_url,
conversation_log=meta_log,
@@ -447,13 +442,13 @@ async def extract_references_and_questions(
query_files=query_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.ai_model_api.api_key
chat_model = conversation_config.chat_model
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
chat_model_name = chat_model.name
inferred_queries = extract_questions_anthropic(
defiltered_query,
query_images=query_images,
model=chat_model,
model=chat_model_name,
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
@@ -463,17 +458,17 @@ async def extract_references_and_questions(
query_files=query_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.ai_model_api.api_key
chat_model = conversation_config.chat_model
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
chat_model_name = chat_model.name
inferred_queries = extract_questions_gemini(
defiltered_query,
query_images=query_images,
model=chat_model,
model=chat_model_name,
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
max_tokens=conversation_config.max_prompt_size,
max_tokens=chat_model.max_prompt_size,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,

View File

@@ -62,7 +62,7 @@ async def all_agents(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"chat_model": agent.chat_model.name,
"files": file_names,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
@@ -150,7 +150,7 @@ async def get_agent(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"chat_model": agent.chat_model.name,
"files": file_names,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
@@ -225,7 +225,7 @@ async def create_agent(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"chat_model": agent.chat_model.name,
"files": body.files,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
@@ -286,7 +286,7 @@ async def update_agent(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"chat_model": agent.chat_model.name,
"files": body.files,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,

View File

@@ -58,7 +58,7 @@ from khoj.routers.helpers import (
is_ready_to_chat,
read_chat_stream,
update_telemetry_state,
validate_conversation_config,
validate_chat_model,
)
from khoj.routers.research import (
InformationCollectionIteration,
@@ -205,7 +205,7 @@ def chat_history(
n: Optional[int] = None,
):
user = request.user.object
validate_conversation_config(user)
validate_chat_model(user)
# Load Conversation History
conversation = ConversationAdapters.get_conversation_by_user(
@@ -898,10 +898,10 @@ async def chat(
custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
if not q:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
model_type = conversation_config.model_type
chat_model = await ConversationAdapters.aget_user_chat_model(user)
if chat_model == None:
chat_model = await ConversationAdapters.aget_default_chat_model(user)
model_type = chat_model.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
async for result in send_llm_response(formatted_help, tracer.get("usage")):
yield result

View File

@@ -24,7 +24,7 @@ def get_chat_model_options(
all_conversation_options = list()
for conversation_option in conversation_options:
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
all_conversation_options.append({"chat_model": conversation_option.name, "id": conversation_option.id})
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
@@ -37,12 +37,12 @@ def get_user_chat_model(
):
user = request.user.object
chat_model = ConversationAdapters.get_conversation_config(user)
chat_model = ConversationAdapters.get_chat_model(user)
if chat_model is None:
chat_model = ConversationAdapters.get_default_conversation_config(user)
chat_model = ConversationAdapters.get_default_chat_model(user)
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.name}))
@api_model.post("/chat", status_code=200)

View File

@@ -56,7 +56,7 @@ from khoj.database.adapters import (
)
from khoj.database.models import (
Agent,
ChatModelOptions,
ChatModel,
ClientApplication,
Conversation,
GithubConfig,
@@ -133,40 +133,40 @@ def is_query_empty(query: str) -> bool:
return is_none_or_empty(query.strip())
def validate_conversation_config(user: KhojUser):
default_config = ConversationAdapters.get_default_conversation_config(user)
def validate_chat_model(user: KhojUser):
default_chat_model = ConversationAdapters.get_default_chat_model(user)
if default_config is None:
if default_chat_model is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
if default_config.model_type == "openai" and not default_config.ai_model_api:
if default_chat_model.model_type == "openai" and not default_chat_model.ai_model_api:
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
async def is_ready_to_chat(user: KhojUser):
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if user_conversation_config == None:
user_conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
user_chat_model = await ConversationAdapters.aget_user_chat_model(user)
if user_chat_model == None:
user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
chat_model = user_conversation_config.chat_model
max_tokens = user_conversation_config.max_prompt_size
if user_chat_model and user_chat_model.model_type == ChatModel.ModelType.OFFLINE:
chat_model_name = user_chat_model.name
max_tokens = user_chat_model.max_prompt_size
if state.offline_chat_processor_config is None:
logger.info("Loading Offline Chat Model...")
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
return True
if (
user_conversation_config
user_chat_model
and (
user_conversation_config.model_type
user_chat_model.model_type
in [
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.ANTHROPIC,
ChatModelOptions.ModelType.GOOGLE,
ChatModel.ModelType.OPENAI,
ChatModel.ModelType.ANTHROPIC,
ChatModel.ModelType.GOOGLE,
]
)
and user_conversation_config.ai_model_api
and user_chat_model.ai_model_api
):
return True
@@ -942,120 +942,124 @@ async def send_message_to_model_wrapper(
query_files: str = None,
tracer: dict = {},
):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user)
vision_available = chat_model.vision_enabled
if not vision_available and query_images:
logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.")
logger.warning(f"Vision is not enabled for default model: {chat_model.name}.")
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
conversation_config = vision_enabled_config
chat_model = vision_enabled_config
vision_available = True
if vision_available and query_images:
logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.")
logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.")
subscribed = await ais_user_subscribed(user)
chat_model = conversation_config.chat_model
chat_model_name = chat_model.name
max_tokens = (
conversation_config.subscribed_max_prompt_size
if subscribed and conversation_config.subscribed_max_prompt_size
else conversation_config.max_prompt_size
chat_model.subscribed_max_prompt_size
if subscribed and chat_model.subscribed_max_prompt_size
else chat_model.max_prompt_size
)
tokenizer = conversation_config.tokenizer
model_type = conversation_config.model_type
vision_available = conversation_config.vision_enabled
tokenizer = chat_model.tokenizer
model_type = chat_model.model_type
vision_available = chat_model.vision_enabled
if model_type == ChatModelOptions.ModelType.OFFLINE:
if model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
loaded_model=loaded_model,
tokenizer_name=tokenizer,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return send_message_to_model_offline(
messages=truncated_messages,
loaded_model=loaded_model,
model=chat_model,
model=chat_model_name,
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.ai_model_api
elif model_type == ChatModel.ModelType.OPENAI:
openai_chat_config = chat_model.ai_model_api
api_key = openai_chat_config.api_key
api_base_url = openai_chat_config.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
query_images=query_images,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
model=chat_model_name,
response_type=response_type,
api_base_url=api_base_url,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.ai_model_api.api_key
elif model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
query_images=query_images,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.ai_model_api.api_key
elif model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
query_images=query_images,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return gemini_send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
messages=truncated_messages,
api_key=api_key,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@@ -1069,99 +1073,99 @@ def send_message_to_model_wrapper_sync(
query_files: str = "",
tracer: dict = {},
):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
chat_model: ChatModel = ConversationAdapters.get_default_chat_model(user)
if conversation_config is None:
if chat_model is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
vision_available = conversation_config.vision_enabled
chat_model_name = chat_model.name
max_tokens = chat_model.max_prompt_size
vision_available = chat_model.vision_enabled
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
loaded_model=loaded_model,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return send_message_to_model_offline(
messages=truncated_messages,
loaded_model=loaded_model,
model=chat_model,
model=chat_model_name,
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
api_key = conversation_config.ai_model_api.api_key
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
openai_response = send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
)
return openai_response
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.ai_model_api.api_key
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.ai_model_api.api_key
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return gemini_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
)
@@ -1229,15 +1233,15 @@ def generate_chat_response(
online_results = {}
code_results = {}
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
vision_available = conversation_config.vision_enabled
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation)
vision_available = chat_model.vision_enabled
if not vision_available and query_images:
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
if vision_enabled_config:
conversation_config = vision_enabled_config
chat_model = vision_enabled_config
vision_available = True
if conversation_config.model_type == "offline":
if chat_model.model_type == "offline":
loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline(
user_query=query_to_run,
@@ -1247,9 +1251,9 @@ def generate_chat_response(
conversation_log=meta_log,
completion_func=partial_completion,
conversation_commands=conversation_commands,
model=conversation_config.chat_model,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
model=chat_model.name,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
@@ -1259,10 +1263,10 @@ def generate_chat_response(
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.ai_model_api
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
openai_chat_config = chat_model.ai_model_api
api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model
chat_model_name = chat_model.name
chat_response = converse_openai(
compiled_references,
query_to_run,
@@ -1270,13 +1274,13 @@ def generate_chat_response(
online_results=online_results,
code_results=code_results,
conversation_log=meta_log,
model=chat_model,
model=chat_model_name,
api_key=api_key,
api_base_url=openai_chat_config.api_base_url,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
@@ -1288,8 +1292,8 @@ def generate_chat_response(
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.ai_model_api.api_key
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
chat_response = converse_anthropic(
compiled_references,
query_to_run,
@@ -1297,12 +1301,12 @@ def generate_chat_response(
online_results=online_results,
code_results=code_results,
conversation_log=meta_log,
model=conversation_config.chat_model,
model=chat_model.name,
api_key=api_key,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
@@ -1313,20 +1317,20 @@ def generate_chat_response(
program_execution_context=program_execution_context,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.ai_model_api.api_key
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
chat_response = converse_gemini(
compiled_references,
query_to_run,
online_results,
code_results,
meta_log,
model=conversation_config.chat_model,
model=chat_model.name,
api_key=api_key,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
@@ -1339,7 +1343,7 @@ def generate_chat_response(
tracer=tracer,
)
metadata.update({"chat_model": conversation_config.chat_model})
metadata.update({"chat_model": chat_model.name})
except Exception as e:
logger.error(e, exc_info=True)
@@ -1939,13 +1943,13 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
current_notion_config = get_user_notion_config(user)
notion_token = current_notion_config.token if current_notion_config else ""
selected_chat_model_config = ConversationAdapters.get_conversation_config(
selected_chat_model_config = ConversationAdapters.get_chat_model(
user
) or ConversationAdapters.get_default_conversation_config(user)
) or ConversationAdapters.get_default_chat_model(user)
chat_models = ConversationAdapters.get_conversation_processor_options().all()
chat_model_options = list()
for chat_model in chat_models:
chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id})
chat_model_options.append({"name": chat_model.name, "id": chat_model.id})
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()

View File

@@ -7,7 +7,7 @@ import openai
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import (
AiModelApi,
ChatModelOptions,
ChatModel,
KhojUser,
SpeechToTextModelOptions,
TextToImageModelConfig,
@@ -63,7 +63,7 @@ def initialization(interactive: bool = True):
# Set up OpenAI's online chat models
openai_configured, openai_provider = _setup_chat_model_provider(
ChatModelOptions.ModelType.OPENAI,
ChatModel.ModelType.OPENAI,
default_chat_models,
default_api_key=openai_api_key,
api_base_url=openai_api_base,
@@ -105,7 +105,7 @@ def initialization(interactive: bool = True):
# Set up Google's Gemini online chat models
_setup_chat_model_provider(
ChatModelOptions.ModelType.GOOGLE,
ChatModel.ModelType.GOOGLE,
default_gemini_chat_models,
default_api_key=os.getenv("GEMINI_API_KEY"),
vision_enabled=True,
@@ -116,7 +116,7 @@ def initialization(interactive: bool = True):
# Set up Anthropic's online chat models
_setup_chat_model_provider(
ChatModelOptions.ModelType.ANTHROPIC,
ChatModel.ModelType.ANTHROPIC,
default_anthropic_chat_models,
default_api_key=os.getenv("ANTHROPIC_API_KEY"),
vision_enabled=True,
@@ -126,7 +126,7 @@ def initialization(interactive: bool = True):
# Set up offline chat models
_setup_chat_model_provider(
ChatModelOptions.ModelType.OFFLINE,
ChatModel.ModelType.OFFLINE,
default_offline_chat_models,
default_api_key=None,
vision_enabled=False,
@@ -135,9 +135,9 @@ def initialization(interactive: bool = True):
)
# Explicitly set default chat model
chat_models_configured = ChatModelOptions.objects.count()
chat_models_configured = ChatModel.objects.count()
if chat_models_configured > 0:
default_chat_model_name = ChatModelOptions.objects.first().chat_model
default_chat_model_name = ChatModel.objects.first().name
# If there are multiple chat models, ask the user to choose the default chat model
if chat_models_configured > 1 and interactive:
user_chat_model_name = input(
@@ -147,7 +147,7 @@ def initialization(interactive: bool = True):
user_chat_model_name = None
# If the user's choice is valid, set it as the default chat model
if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists():
if user_chat_model_name and ChatModel.objects.filter(name=user_chat_model_name).exists():
default_chat_model_name = user_chat_model_name
logger.info("🗣️ Chat model configuration complete")
@@ -171,7 +171,7 @@ def initialization(interactive: bool = True):
logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")
def _setup_chat_model_provider(
model_type: ChatModelOptions.ModelType,
model_type: ChatModel.ModelType,
default_chat_models: list,
default_api_key: str,
interactive: bool,
@@ -226,7 +226,7 @@ def initialization(interactive: bool = True):
"ai_model_api": ai_model_api,
}
ChatModelOptions.objects.create(**chat_model_options)
ChatModel.objects.create(**chat_model_options)
logger.info(f"🗣️ {provider_name} chat model configuration complete")
return True, ai_model_api
@@ -250,16 +250,16 @@ def initialization(interactive: bool = True):
available_models = [model.id for model in openai_client.models.list()]
# Get existing chat model options for this config
existing_models = ChatModelOptions.objects.filter(
ai_model_api=config, model_type=ChatModelOptions.ModelType.OPENAI
existing_models = ChatModel.objects.filter(
ai_model_api=config, model_type=ChatModel.ModelType.OPENAI
)
# Add new models
for model in available_models:
if not existing_models.filter(chat_model=model).exists():
ChatModelOptions.objects.create(
chat_model=model,
model_type=ChatModelOptions.ModelType.OPENAI,
if not existing_models.filter(name=model).exists():
ChatModel.objects.create(
name=model,
model_type=ChatModel.ModelType.OPENAI,
max_prompt_size=model_to_prompt_size.get(model),
vision_enabled=model in default_openai_chat_models,
tokenizer=model_to_tokenizer.get(model),
@@ -284,7 +284,7 @@ def initialization(interactive: bool = True):
except Exception as e:
logger.error(f"🚨 Failed to create admin user: {e}", exc_info=True)
chat_config = ConversationAdapters.get_default_conversation_config()
chat_config = ConversationAdapters.get_default_chat_model()
if admin_user is None and chat_config is None:
while True:
try: