Default to gemini 2.5 model series on init and for eval

This commit is contained in:
Debanjum
2025-08-22 18:44:26 -07:00
parent c53a70c997
commit 2823c84bb4
9 changed files with 17 additions and 15 deletions

View File

@@ -246,7 +246,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
if chat_provider == ChatModel.ModelType.OPENAI:
online_chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
elif chat_provider == ChatModel.ModelType.GOOGLE:
online_chat_model = ChatModelFactory(name="gemini-2.0-flash", model_type="google")
online_chat_model = ChatModelFactory(name="gemini-2.5-flash", model_type="google")
elif chat_provider == ChatModel.ModelType.ANTHROPIC:
online_chat_model = ChatModelFactory(name="claude-3-5-haiku-20241022", model_type="anthropic")
if online_chat_model:
@@ -355,7 +355,7 @@ End of file {i}.
if chat_provider == ChatModel.ModelType.OPENAI:
online_chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
elif chat_provider == ChatModel.ModelType.GOOGLE:
online_chat_model = ChatModelFactory(name="gemini-2.0-flash", model_type="google")
online_chat_model = ChatModelFactory(name="gemini-2.5-flash", model_type="google")
elif chat_provider == ChatModel.ModelType.ANTHROPIC:
online_chat_model = ChatModelFactory(name="claude-3-5-haiku-20241022", model_type="anthropic")

View File

@@ -34,10 +34,10 @@ logger = logging.getLogger(__name__)
KHOJ_URL = os.getenv("KHOJ_URL", "http://localhost:42110")
KHOJ_CHAT_API_URL = f"{KHOJ_URL}/api/chat"
KHOJ_API_KEY = os.getenv("KHOJ_API_KEY")
KHOJ_MODE = os.getenv("KHOJ_MODE", "default").lower() # E.g research, general, notes etc.
KHOJ_MODE = os.getenv("KHOJ_MODE", "default").lower() # E.g research, general, default etc.
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
GEMINI_EVAL_MODEL = os.getenv("GEMINI_EVAL_MODEL", "gemini-2.0-flash-001")
GEMINI_EVAL_MODEL = os.getenv("GEMINI_EVAL_MODEL", "gemini-2.5-flash")
LLM_SEED = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
SAMPLE_SIZE = os.getenv("SAMPLE_SIZE") # Number of examples to evaluate
@@ -636,7 +636,7 @@ def main():
response_evaluator = evaluate_response_with_mcq_match
elif args.dataset == "math500":
response_evaluator = partial(
evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-2.0-flash-001")
evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-2.5-flash-lite")
)
elif args.dataset == "frames_ir":
response_evaluator = evaluate_response_for_ir
@@ -696,7 +696,7 @@ def main():
if __name__ == "__main__":
"""
Evaluate Khoj on supported benchmarks.
Response are evaluated by GEMINI_EVAL_MODEL (default: gemini-pro-1.5-002).
Response are evaluated by GEMINI_EVAL_MODEL (default: gemini-2.5-flash).
Khoj should be running at KHOJ_URL (default: http://localhost:42110).
The Gemini judge model is accessed via the Gemini API with your GEMINI_API_KEY.

View File

@@ -232,7 +232,7 @@ class ChatModelFactory(factory.django.DjangoModelFactory):
max_prompt_size = 20000
tokenizer = None
name = "gemini-2.0-flash"
name = "gemini-2.5-flash"
model_type = get_chat_provider()
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)

View File

@@ -20,7 +20,7 @@ def create_test_automation(client: TestClient) -> str:
"""Helper function to create a test automation and return its ID."""
state.anonymous_mode = True
ChatModelFactory(
name="gemini-2.0-flash", model_type="google", ai_model_api=AiModelApiFactory(api_key=get_chat_api_key("google"))
name="gemini-2.5-flash", model_type="google", ai_model_api=AiModelApiFactory(api_key=get_chat_api_key("google"))
)
params = {
"q": "test automation",
@@ -37,7 +37,7 @@ def test_create_automation(client: TestClient):
# Arrange
state.anonymous_mode = True
ChatModelFactory(
name="gemini-2.0-flash", model_type="google", ai_model_api=AiModelApiFactory(api_key=get_chat_api_key("google"))
name="gemini-2.5-flash", model_type="google", ai_model_api=AiModelApiFactory(api_key=get_chat_api_key("google"))
)
params = {
"q": "test automation",