Files
khoj/tests/helpers.py
Debanjum 05d4e19cb8 Pass deep typed chat history for more ergonomic, readable, safe code
The chat dictionary is an artifact from earlier non-db chat history
storage. We've been ensuring new chat messages have valid type before
being written to DB for more than 6 months now.

Move to using the deeply typed chat history helps avoids null refs,
makes code more readable and easier to reason about.

Next Steps:
The current update entangles chat_history written to DB
with any virtual chat history message generated for intermediate
steps. The chat message type written to DB should be decoupled from
type that can be passed to AI model APIs (maybe?).

For now we've made the ChatMessage.message type looser to allow
for list[dict] type (apart from string). But later maybe a good idea
to decouple the chat_history recieved by send_message_to_model from
the chat_history saved to DB (which can then have its stricter type check)
2025-06-04 00:03:14 -07:00

138 lines
3.8 KiB
Python

import os
from datetime import datetime
import factory
from django.utils.timezone import make_aware
from khoj.database.models import (
AiModelApi,
ChatMessageModel,
ChatModel,
Conversation,
KhojApiUser,
KhojUser,
ProcessLock,
SearchModelConfig,
Subscription,
UserConversationConfig,
)
from khoj.processor.conversation.utils import message_to_log
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE):
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
if provider and provider in ChatModel.ModelType:
return ChatModel.ModelType(provider)
elif os.getenv("OPENAI_API_KEY"):
return ChatModel.ModelType.OPENAI
elif os.getenv("GEMINI_API_KEY"):
return ChatModel.ModelType.GOOGLE
elif os.getenv("ANTHROPIC_API_KEY"):
return ChatModel.ModelType.ANTHROPIC
else:
return default
def get_chat_api_key(provider: ChatModel.ModelType = None):
provider = provider or get_chat_provider()
if provider == ChatModel.ModelType.OPENAI:
return os.getenv("OPENAI_API_KEY")
elif provider == ChatModel.ModelType.GOOGLE:
return os.getenv("GEMINI_API_KEY")
elif provider == ChatModel.ModelType.ANTHROPIC:
return os.getenv("ANTHROPIC_API_KEY")
else:
return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
def generate_chat_history(message_list):
# Generate conversation logs
chat_history: list[ChatMessageModel] = []
for user_message, chat_response, context in message_list:
message_to_log(
user_message,
chat_response,
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
chat_history=chat_history,
)
return chat_history
class UserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojUser
username = factory.Faker("name")
email = factory.Faker("email")
password = factory.Faker("password")
uuid = factory.Faker("uuid4")
class ApiUserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojApiUser
user = None
name = factory.Faker("name")
token = factory.Faker("password")
class AiModelApiFactory(factory.django.DjangoModelFactory):
class Meta:
model = AiModelApi
api_key = get_chat_api_key()
class ChatModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = ChatModel
max_prompt_size = 20000
tokenizer = None
name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
model_type = get_chat_provider()
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = UserConversationConfig
user = factory.SubFactory(UserFactory)
setting = factory.SubFactory(ChatModelFactory)
class ConversationFactory(factory.django.DjangoModelFactory):
class Meta:
model = Conversation
user = factory.SubFactory(UserFactory)
class SearchModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = SearchModelConfig
name = "default"
model_type = "text"
bi_encoder = "thenlper/gte-small"
cross_encoder = "mixedbread-ai/mxbai-rerank-xsmall-v1"
class SubscriptionFactory(factory.django.DjangoModelFactory):
class Meta:
model = Subscription
user = factory.SubFactory(UserFactory)
type = Subscription.Type.STANDARD
is_recurring = False
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
class ProcessLockFactory(factory.django.DjangoModelFactory):
class Meta:
model = ProcessLock
name = "test_lock"