Add MATH500 dataset to eval

Evaluate simpler MATH500 responses with gemini 1.5 flash

This improves both the speed and cost of running this eval
This commit is contained in:
Debanjum
2024-11-27 16:29:15 -08:00
parent 22aef9bf53
commit 29e801c381
2 changed files with 35 additions and 3 deletions

View File

@@ -26,6 +26,7 @@ on:
- frames
- simpleqa
- gpqa
- math500
sample_size:
description: 'Number of samples to evaluate'
required: false
@@ -96,8 +97,8 @@ jobs:
KHOJ_URL: "http://localhost:42110"
KHOJ_LLM_SEED: "42"
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
SERPER_DEV_API_KEY: ${{ secrets.SERPER_DEV_API_KEY }}
OLOSTEP_API_KEY: ${{ secrets.OLOSTEP_API_KEY }}
SERPER_DEV_API_KEY: ${{ matrix.dataset != 'math500' && secrets.SERPER_DEV_API_KEY }}
OLOSTEP_API_KEY: ${{ matrix.dataset != 'math500' && secrets.OLOSTEP_API_KEY }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
KHOJ_ADMIN_EMAIL: khoj
KHOJ_ADMIN_PASSWORD: khoj

View File

@@ -6,6 +6,7 @@ import os
import re
import time
from datetime import datetime
from functools import partial
from io import StringIO
from textwrap import dedent
from threading import Lock
@@ -195,6 +196,30 @@ D) {choices[3]}
return None
def load_math500_dataset():
"""
Load and format the MATH500 dataset to match the evaluation script's structure.
Args:
sample_size (int, optional): Number of samples to include. Defaults to None (use full dataset).
randomize (bool, optional): Whether to randomize the dataset. Defaults to False.
Returns:
Dataset: Formatted HuggingFace Dataset.
"""
try:
# Load the MATH500 dataset from HuggingFace
dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
dataset = dataset.rename_columns({"problem": "Prompt", "answer": "Answer", "subject": "reasoning_types"})
dataset = dataset.shuffle() if RANDOMIZE else dataset
dataset = dataset.select(range(int(SAMPLE_SIZE))) if SAMPLE_SIZE else dataset
return dataset
except Exception as e:
print(f"Error loading and formatting MATH500 dataset: {e}")
return None
def get_agent_response(prompt: str) -> Dict[str, Any]:
"""Get response from the Khoj API"""
# Set headers
@@ -385,7 +410,7 @@ def parse_args():
"--dataset",
"-d",
default="frames",
choices=["frames", "simpleqa", "gpqa"],
choices=["frames", "simpleqa", "gpqa", "math500"],
help="Dataset to use for evaluation (default: frames)",
)
return parser.parse_args()
@@ -404,6 +429,8 @@ def main():
dataset = load_simpleqa_dataset()
elif args.dataset == "gpqa":
dataset = load_gpqa_dataset()
elif args.dataset == "math500":
dataset = load_math500_dataset()
if dataset is None:
return
@@ -412,6 +439,10 @@ def main():
dataset_length = len(dataset["Prompt"])
if args.dataset == "gpqa":
response_evaluator = evaluate_response_with_mcq_match
elif args.dataset == "math500":
response_evaluator = partial(
evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-flash-002")
)
else:
response_evaluator = evaluate_response_with_gemini