mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
It is recommended to chat with open-source models by running an open-source server like Ollama, Llama.cpp on your GPU powered machine or use a commercial provider of open-source models like DeepInfra or OpenRouter. These chat model serving options provide a mature Openai compatible API that already works with Khoj. Directly using offline chat models only worked reasonably with pip install on a machine with GPU. Docker setup of khoj had trouble with accessing GPU. And without GPU access offline chat is too slow. Deprecating support for an offline chat provider directly from within Khoj will reduce code complexity and increase developement velocity. Offline models are subsumed to use existing Openai ai model provider.
141 lines
3.8 KiB
Python
141 lines
3.8 KiB
Python
import os
|
|
from datetime import datetime
|
|
|
|
import factory
|
|
from django.utils.timezone import make_aware
|
|
|
|
from khoj.database.models import (
|
|
AiModelApi,
|
|
ChatMessageModel,
|
|
ChatModel,
|
|
Conversation,
|
|
KhojApiUser,
|
|
KhojUser,
|
|
ProcessLock,
|
|
SearchModelConfig,
|
|
Subscription,
|
|
UserConversationConfig,
|
|
)
|
|
from khoj.processor.conversation.utils import message_to_log
|
|
|
|
|
|
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.GOOGLE):
|
|
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
|
|
if provider and provider in ChatModel.ModelType:
|
|
return ChatModel.ModelType(provider)
|
|
elif os.getenv("OPENAI_API_KEY"):
|
|
return ChatModel.ModelType.OPENAI
|
|
elif os.getenv("GEMINI_API_KEY"):
|
|
return ChatModel.ModelType.GOOGLE
|
|
elif os.getenv("ANTHROPIC_API_KEY"):
|
|
return ChatModel.ModelType.ANTHROPIC
|
|
else:
|
|
return default
|
|
|
|
|
|
def get_chat_api_key(provider: ChatModel.ModelType = None):
|
|
provider = provider or get_chat_provider()
|
|
if provider == ChatModel.ModelType.OPENAI:
|
|
return os.getenv("OPENAI_API_KEY")
|
|
elif provider == ChatModel.ModelType.GOOGLE:
|
|
return os.getenv("GEMINI_API_KEY")
|
|
elif provider == ChatModel.ModelType.ANTHROPIC:
|
|
return os.getenv("ANTHROPIC_API_KEY")
|
|
else:
|
|
return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
|
|
|
|
|
|
def generate_chat_history(message_list):
|
|
# Generate conversation logs
|
|
chat_history: list[ChatMessageModel] = []
|
|
for user_message, chat_response, context in message_list:
|
|
message_to_log(
|
|
user_message,
|
|
chat_response,
|
|
{
|
|
"context": context,
|
|
"intent": {"type": "memory", "query": user_message, "inferred-queries": [user_message]},
|
|
},
|
|
chat_history=chat_history,
|
|
)
|
|
return chat_history
|
|
|
|
|
|
class UserFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = KhojUser
|
|
|
|
username = factory.Faker("name")
|
|
email = factory.Faker("email")
|
|
password = factory.Faker("password")
|
|
uuid = factory.Faker("uuid4")
|
|
|
|
|
|
class ApiUserFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = KhojApiUser
|
|
|
|
user = None
|
|
name = factory.Faker("name")
|
|
token = factory.Faker("password")
|
|
|
|
|
|
class AiModelApiFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = AiModelApi
|
|
|
|
api_key = get_chat_api_key()
|
|
|
|
|
|
class ChatModelFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = ChatModel
|
|
|
|
max_prompt_size = 20000
|
|
tokenizer = None
|
|
name = "gemini-2.0-flash"
|
|
model_type = get_chat_provider()
|
|
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
|
|
|
|
|
|
class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = UserConversationConfig
|
|
|
|
user = factory.SubFactory(UserFactory)
|
|
setting = factory.SubFactory(ChatModelFactory)
|
|
|
|
|
|
class ConversationFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = Conversation
|
|
|
|
user = factory.SubFactory(UserFactory)
|
|
|
|
|
|
class SearchModelFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = SearchModelConfig
|
|
|
|
name = "default"
|
|
model_type = "text"
|
|
bi_encoder = "thenlper/gte-small"
|
|
cross_encoder = "mixedbread-ai/mxbai-rerank-xsmall-v1"
|
|
|
|
|
|
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = Subscription
|
|
|
|
user = factory.SubFactory(UserFactory)
|
|
type = Subscription.Type.STANDARD
|
|
is_recurring = False
|
|
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
|
|
|
|
|
|
class ProcessLockFactory(factory.django.DjangoModelFactory):
|
|
class Meta:
|
|
model = ProcessLock
|
|
|
|
name = "test_lock"
|