mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Run online, offine chat actor, director tests for any supported provider
- Previously online chat actors, director tests only worked with openai.
This change allows running them for any supported onlnie provider
including Google, Anthropic and Openai.
- Enable online/offline chat actor, director in two ways:
1. Explicitly setting KHOJ_TEST_CHAT_PROVIDER environment variable to
google, anthropic, openai, offline
2. Implicitly by the first API key found from openai, google or anthropic.
- Default offline chat provider to use Llama 3.1 3B for faster, lower
compute test runs
This commit is contained in:
@@ -17,6 +17,32 @@ from khoj.database.models import (
|
||||
)
|
||||
|
||||
|
||||
def get_chat_provider(default: ChatModelOptions.ModelType | None = ChatModelOptions.ModelType.OFFLINE):
|
||||
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
|
||||
if provider and provider in ChatModelOptions.ModelType:
|
||||
return ChatModelOptions.ModelType(provider)
|
||||
elif os.getenv("OPENAI_API_KEY"):
|
||||
return ChatModelOptions.ModelType.OPENAI
|
||||
elif os.getenv("GEMINI_API_KEY"):
|
||||
return ChatModelOptions.ModelType.GOOGLE
|
||||
elif os.getenv("ANTHROPIC_API_KEY"):
|
||||
return ChatModelOptions.ModelType.ANTHROPIC
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
def get_chat_api_key(provider: ChatModelOptions.ModelType = None):
|
||||
provider = provider or get_chat_provider()
|
||||
if provider == ChatModelOptions.ModelType.OPENAI:
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
elif provider == ChatModelOptions.ModelType.GOOGLE:
|
||||
return os.getenv("GEMINI_API_KEY")
|
||||
elif provider == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
return os.getenv("ANTHROPIC_API_KEY")
|
||||
else:
|
||||
return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
|
||||
class UserFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = KhojUser
|
||||
@@ -40,19 +66,19 @@ class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory
|
||||
class Meta:
|
||||
model = OpenAIProcessorConversationConfig
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
api_key = get_chat_api_key()
|
||||
|
||||
|
||||
class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = ChatModelOptions
|
||||
|
||||
max_prompt_size = 3500
|
||||
max_prompt_size = 20000
|
||||
tokenizer = None
|
||||
chat_model = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
|
||||
model_type = "offline"
|
||||
chat_model = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
|
||||
model_type = get_chat_provider()
|
||||
openai_config = factory.LazyAttribute(
|
||||
lambda obj: OpenAIProcessorConversationConfigFactory() if os.getenv("OPENAI_API_KEY") else None
|
||||
lambda obj: OpenAIProcessorConversationConfigFactory() if get_chat_api_key() else None
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user