mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
The chat dictionary is an artifact from earlier non-db chat history storage. We've been ensuring new chat messages have valid type before being written to DB for more than 6 months now. Move to using the deeply typed chat history helps avoids null refs, makes code more readable and easier to reason about. Next Steps: The current update entangles chat_history written to DB with any virtual chat history message generated for intermediate steps. The chat message type written to DB should be decoupled from type that can be passed to AI model APIs (maybe?). For now we've made the ChatMessage.message type looser to allow for list[dict] type (apart from string). But later maybe a good idea to decouple the chat_history recieved by send_message_to_model from the chat_history saved to DB (which can then have its stricter type check)
138 lines
3.8 KiB
Python
138 lines
3.8 KiB
Python
import os
|
|
from datetime import datetime
|
|
|
|
import factory
|
|
from django.utils.timezone import make_aware
|
|
|
|
from khoj.database.models import (
|
|
AiModelApi,
|
|
ChatMessageModel,
|
|
ChatModel,
|
|
Conversation,
|
|
KhojApiUser,
|
|
KhojUser,
|
|
ProcessLock,
|
|
SearchModelConfig,
|
|
Subscription,
|
|
UserConversationConfig,
|
|
)
|
|
from khoj.processor.conversation.utils import message_to_log
|
|
|
|
|
|
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE):
|
|
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
|
|
if provider and provider in ChatModel.ModelType:
|
|
return ChatModel.ModelType(provider)
|
|
elif os.getenv("OPENAI_API_KEY"):
|
|
return ChatModel.ModelType.OPENAI
|
|
elif os.getenv("GEMINI_API_KEY"):
|
|
return ChatModel.ModelType.GOOGLE
|
|
elif os.getenv("ANTHROPIC_API_KEY"):
|
|
return ChatModel.ModelType.ANTHROPIC
|
|
else:
|
|
return default
|
|
|
|
|
|
def get_chat_api_key(provider: ChatModel.ModelType = None):
|
|
provider = provider or get_chat_provider()
|
|
if provider == ChatModel.ModelType.OPENAI:
|
|
return os.getenv("OPENAI_API_KEY")
|
|
elif provider == ChatModel.ModelType.GOOGLE:
|
|
return os.getenv("GEMINI_API_KEY")
|
|
elif provider == ChatModel.ModelType.ANTHROPIC:
|
|
return os.getenv("ANTHROPIC_API_KEY")
|
|
else:
|
|
return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
|
|
|
|
|
|
def generate_chat_history(message_list):
|
|
# Generate conversation logs
|
|
chat_history: list[ChatMessageModel] = []
|
|
for user_message, chat_response, context in message_list:
|
|
message_to_log(
|
|
user_message,
|
|
chat_response,
|
|
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
|
chat_history=chat_history,
|
|
)
|
|
return chat_history
|
|
|
|
|
|
class UserFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = KhojUser
|
|
|
|
username = factory.Faker("name")
|
|
email = factory.Faker("email")
|
|
password = factory.Faker("password")
|
|
uuid = factory.Faker("uuid4")
|
|
|
|
|
|
class ApiUserFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = KhojApiUser
|
|
|
|
user = None
|
|
name = factory.Faker("name")
|
|
token = factory.Faker("password")
|
|
|
|
|
|
class AiModelApiFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = AiModelApi
|
|
|
|
api_key = get_chat_api_key()
|
|
|
|
|
|
class ChatModelFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = ChatModel
|
|
|
|
max_prompt_size = 20000
|
|
tokenizer = None
|
|
name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
|
|
model_type = get_chat_provider()
|
|
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
|
|
|
|
|
|
class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = UserConversationConfig
|
|
|
|
user = factory.SubFactory(UserFactory)
|
|
setting = factory.SubFactory(ChatModelFactory)
|
|
|
|
|
|
class ConversationFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = Conversation
|
|
|
|
user = factory.SubFactory(UserFactory)
|
|
|
|
|
|
class SearchModelFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = SearchModelConfig
|
|
|
|
name = "default"
|
|
model_type = "text"
|
|
bi_encoder = "thenlper/gte-small"
|
|
cross_encoder = "mixedbread-ai/mxbai-rerank-xsmall-v1"
|
|
|
|
|
|
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = Subscription
|
|
|
|
user = factory.SubFactory(UserFactory)
|
|
type = Subscription.Type.STANDARD
|
|
is_recurring = False
|
|
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
|
|
|
|
|
|
class ProcessLockFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = ProcessLock
|
|
|
|
name = "test_lock"
|