Rename Chat Model Options table to Chat Model as short & readable (#1003)

- Previous was incorrectly plural but was defining only a single model
- Rename chat model table field to name
- Update documentation
- Update references functions and variables to match new name
This commit is contained in:
Debanjum
2024-12-12 11:24:16 -08:00
committed by GitHub
parent 9be26e1bd2
commit 01bc6d35dc
26 changed files with 369 additions and 308 deletions

View File

@@ -13,7 +13,7 @@ from khoj.configure import (
)
from khoj.database.models import (
Agent,
ChatModelOptions,
ChatModel,
GithubConfig,
GithubRepoConfig,
KhojApiUser,
@@ -35,7 +35,7 @@ from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig
from tests.helpers import (
AiModelApiFactory,
ChatModelOptionsFactory,
ChatModelFactory,
ProcessLockFactory,
SubscriptionFactory,
UserConversationProcessorConfigFactory,
@@ -184,14 +184,14 @@ def api_user4(default_user4):
@pytest.mark.django_db
@pytest.fixture
def default_openai_chat_model_option():
chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
return chat_model
@pytest.mark.django_db
@pytest.fixture
def offline_agent():
chat_model = ChatModelOptionsFactory()
chat_model = ChatModelFactory()
return Agent.objects.create(
name="Accountant",
chat_model=chat_model,
@@ -202,7 +202,7 @@ def offline_agent():
@pytest.mark.django_db
@pytest.fixture
def openai_agent():
chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
return Agent.objects.create(
name="Accountant",
chat_model=chat_model,
@@ -311,13 +311,13 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
# Initialize Processor from Config
chat_provider = get_chat_provider()
online_chat_model: ChatModelOptionsFactory = None
if chat_provider == ChatModelOptions.ModelType.OPENAI:
online_chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
elif chat_provider == ChatModelOptions.ModelType.GOOGLE:
online_chat_model = ChatModelOptionsFactory(chat_model="gemini-1.5-flash", model_type="google")
elif chat_provider == ChatModelOptions.ModelType.ANTHROPIC:
online_chat_model = ChatModelOptionsFactory(chat_model="claude-3-5-haiku-20241022", model_type="anthropic")
online_chat_model: ChatModelFactory = None
if chat_provider == ChatModel.ModelType.OPENAI:
online_chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
elif chat_provider == ChatModel.ModelType.GOOGLE:
online_chat_model = ChatModelFactory(name="gemini-1.5-flash", model_type="google")
elif chat_provider == ChatModel.ModelType.ANTHROPIC:
online_chat_model = ChatModelFactory(name="claude-3-5-haiku-20241022", model_type="anthropic")
if online_chat_model:
online_chat_model.ai_model_api = AiModelApiFactory(api_key=get_chat_api_key(chat_provider))
UserConversationProcessorConfigFactory(user=user, setting=online_chat_model)
@@ -394,8 +394,8 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
configure_content(default_user2, all_files)
# Initialize Processor from Config
ChatModelOptionsFactory(
chat_model="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF",
ChatModelFactory(
name="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF",
tokenizer=None,
max_prompt_size=None,
model_type="offline",

View File

@@ -6,7 +6,7 @@ from django.utils.timezone import make_aware
from khoj.database.models import (
AiModelApi,
ChatModelOptions,
ChatModel,
Conversation,
KhojApiUser,
KhojUser,
@@ -18,27 +18,27 @@ from khoj.database.models import (
from khoj.processor.conversation.utils import message_to_log
def get_chat_provider(default: ChatModelOptions.ModelType | None = ChatModelOptions.ModelType.OFFLINE):
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE):
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
if provider and provider in ChatModelOptions.ModelType:
return ChatModelOptions.ModelType(provider)
if provider and provider in ChatModel.ModelType:
return ChatModel.ModelType(provider)
elif os.getenv("OPENAI_API_KEY"):
return ChatModelOptions.ModelType.OPENAI
return ChatModel.ModelType.OPENAI
elif os.getenv("GEMINI_API_KEY"):
return ChatModelOptions.ModelType.GOOGLE
return ChatModel.ModelType.GOOGLE
elif os.getenv("ANTHROPIC_API_KEY"):
return ChatModelOptions.ModelType.ANTHROPIC
return ChatModel.ModelType.ANTHROPIC
else:
return default
def get_chat_api_key(provider: ChatModelOptions.ModelType = None):
def get_chat_api_key(provider: ChatModel.ModelType = None):
provider = provider or get_chat_provider()
if provider == ChatModelOptions.ModelType.OPENAI:
if provider == ChatModel.ModelType.OPENAI:
return os.getenv("OPENAI_API_KEY")
elif provider == ChatModelOptions.ModelType.GOOGLE:
elif provider == ChatModel.ModelType.GOOGLE:
return os.getenv("GEMINI_API_KEY")
elif provider == ChatModelOptions.ModelType.ANTHROPIC:
elif provider == ChatModel.ModelType.ANTHROPIC:
return os.getenv("ANTHROPIC_API_KEY")
else:
return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
@@ -83,13 +83,13 @@ class AiModelApiFactory(factory.django.DjangoModelFactory):
api_key = get_chat_api_key()
class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
class ChatModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = ChatModelOptions
model = ChatModel
max_prompt_size = 20000
tokenizer = None
chat_model = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
model_type = get_chat_provider()
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
@@ -99,7 +99,7 @@ class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
model = UserConversationConfig
user = factory.SubFactory(UserFactory)
setting = factory.SubFactory(ChatModelOptionsFactory)
setting = factory.SubFactory(ChatModelFactory)
class ConversationFactory(factory.django.DjangoModelFactory):

View File

@@ -5,14 +5,14 @@ import pytest
from asgiref.sync import sync_to_async
from khoj.database.adapters import AgentAdapters
from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
from khoj.routers.api import execute_search
from khoj.utils.helpers import get_absolute_path
from tests.helpers import ChatModelOptionsFactory
from tests.helpers import ChatModelFactory
def test_create_default_agent(default_user: KhojUser):
ChatModelOptionsFactory()
ChatModelFactory()
agent = AgentAdapters.create_default_agent(default_user)
assert agent is not None
@@ -24,7 +24,7 @@ def test_create_default_agent(default_user: KhojUser):
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModelOptions):
async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModel):
new_agent = await AgentAdapters.aupdate_agent(
default_user,
"Test Agent",
@@ -32,7 +32,7 @@ async def test_create_or_update_agent(default_user: KhojUser, default_openai_cha
Agent.PrivacyLevel.PRIVATE,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[],
[],
[],
@@ -46,7 +46,7 @@ async def test_create_or_update_agent(default_user: KhojUser, default_openai_cha
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_create_or_update_agent_with_knowledge_base(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent(
@@ -56,7 +56,7 @@ async def test_create_or_update_agent_with_knowledge_base(
Agent.PrivacyLevel.PRIVATE,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename],
[],
[],
@@ -78,7 +78,7 @@ async def test_create_or_update_agent_with_knowledge_base(
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_create_or_update_agent_with_knowledge_base_and_search(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent(
@@ -88,7 +88,7 @@ async def test_create_or_update_agent_with_knowledge_base_and_search(
Agent.PrivacyLevel.PRIVATE,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename],
[],
[],
@@ -102,7 +102,7 @@ async def test_create_or_update_agent_with_knowledge_base_and_search(
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_agent_with_knowledge_base_and_search_not_creator(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent(
@@ -112,7 +112,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator(
Agent.PrivacyLevel.PUBLIC,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename],
[],
[],
@@ -126,7 +126,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator(
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent(
@@ -136,7 +136,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
Agent.PrivacyLevel.PRIVATE,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename],
[],
[],
@@ -150,7 +150,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_agent_with_knowledge_base_and_search_not_creator_and_private_accessible_to_none(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent(
@@ -160,7 +160,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private_acce
Agent.PrivacyLevel.PRIVATE,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename],
[],
[],
@@ -174,7 +174,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private_acce
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_multiple_agents_with_knowledge_base_and_users(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent(
@@ -184,7 +184,7 @@ async def test_multiple_agents_with_knowledge_base_and_users(
Agent.PrivacyLevel.PUBLIC,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename],
[],
[],
@@ -198,7 +198,7 @@ async def test_multiple_agents_with_knowledge_base_and_users(
Agent.PrivacyLevel.PUBLIC,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename2],
[],
[],

View File

@@ -2,12 +2,12 @@ from datetime import datetime
import pytest
from khoj.database.models import ChatModelOptions
from khoj.database.models import ChatModel
from khoj.routers.helpers import aget_data_sources_and_output_format
from khoj.utils.helpers import ConversationCommand
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
pytestmark = pytest.mark.skipif(
SKIP_TESTS,
reason="Disable in CI to avoid long test runs.",

View File

@@ -4,12 +4,12 @@ import pytest
from faker import Faker
from freezegun import freeze_time
from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from tests.helpers import ConversationFactory, get_chat_provider
SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
pytestmark = pytest.mark.skipif(
SKIP_TESTS,
reason="Disable in CI to avoid long test runs.",