Files
khoj/tests/helpers.py
Debanjum b1f2737c9a Drop native offline chat support with llama-cpp-python
It is recommended to chat with open-source models by running an
open-source server like Ollama, Llama.cpp on your GPU powered machine
or use a commercial provider of open-source models like DeepInfra or
OpenRouter.

These chat model serving options provide a mature Openai compatible
API that already works with Khoj.

Directly using offline chat models only worked reasonably with pip
install on a machine with GPU. Docker setup of khoj had trouble with
accessing GPU. And without GPU access offline chat is too slow.

Deprecating support for an offline chat provider directly from within
Khoj will reduce code complexity and increase developement velocity.
Offline models are subsumed to use existing Openai ai model provider.
2025-07-31 18:25:32 -07:00

141 lines
3.8 KiB
Python

import os
from datetime import datetime
import factory
from django.utils.timezone import make_aware
from khoj.database.models import (
AiModelApi,
ChatMessageModel,
ChatModel,
Conversation,
KhojApiUser,
KhojUser,
ProcessLock,
SearchModelConfig,
Subscription,
UserConversationConfig,
)
from khoj.processor.conversation.utils import message_to_log
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.GOOGLE):
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
if provider and provider in ChatModel.ModelType:
return ChatModel.ModelType(provider)
elif os.getenv("OPENAI_API_KEY"):
return ChatModel.ModelType.OPENAI
elif os.getenv("GEMINI_API_KEY"):
return ChatModel.ModelType.GOOGLE
elif os.getenv("ANTHROPIC_API_KEY"):
return ChatModel.ModelType.ANTHROPIC
else:
return default
def get_chat_api_key(provider: ChatModel.ModelType = None):
provider = provider or get_chat_provider()
if provider == ChatModel.ModelType.OPENAI:
return os.getenv("OPENAI_API_KEY")
elif provider == ChatModel.ModelType.GOOGLE:
return os.getenv("GEMINI_API_KEY")
elif provider == ChatModel.ModelType.ANTHROPIC:
return os.getenv("ANTHROPIC_API_KEY")
else:
return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
def generate_chat_history(message_list):
# Generate conversation logs
chat_history: list[ChatMessageModel] = []
for user_message, chat_response, context in message_list:
message_to_log(
user_message,
chat_response,
{
"context": context,
"intent": {"type": "memory", "query": user_message, "inferred-queries": [user_message]},
},
chat_history=chat_history,
)
return chat_history
class UserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojUser
username = factory.Faker("name")
email = factory.Faker("email")
password = factory.Faker("password")
uuid = factory.Faker("uuid4")
class ApiUserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojApiUser
user = None
name = factory.Faker("name")
token = factory.Faker("password")
class AiModelApiFactory(factory.django.DjangoModelFactory):
class Meta:
model = AiModelApi
api_key = get_chat_api_key()
class ChatModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = ChatModel
max_prompt_size = 20000
tokenizer = None
name = "gemini-2.0-flash"
model_type = get_chat_provider()
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = UserConversationConfig
user = factory.SubFactory(UserFactory)
setting = factory.SubFactory(ChatModelFactory)
class ConversationFactory(factory.django.DjangoModelFactory):
class Meta:
model = Conversation
user = factory.SubFactory(UserFactory)
class SearchModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = SearchModelConfig
name = "default"
model_type = "text"
bi_encoder = "thenlper/gte-small"
cross_encoder = "mixedbread-ai/mxbai-rerank-xsmall-v1"
class SubscriptionFactory(factory.django.DjangoModelFactory):
class Meta:
model = Subscription
user = factory.SubFactory(UserFactory)
type = Subscription.Type.STANDARD
is_recurring = False
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
class ProcessLockFactory(factory.django.DjangoModelFactory):
class Meta:
model = ProcessLock
name = "test_lock"