Add MATH500 dataset to eval

Evaluate simpler MATH500 responses with gemini 1.5 flash

This improves both the speed and cost of running this eval
This commit is contained in:
Debanjum
2024-11-27 16:29:15 -08:00
parent 22aef9bf53
commit 29e801c381
2 changed files with 35 additions and 3 deletions

View File

@@ -26,6 +26,7 @@ on:
- frames - frames
- simpleqa - simpleqa
- gpqa - gpqa
- math500
sample_size: sample_size:
description: 'Number of samples to evaluate' description: 'Number of samples to evaluate'
required: false required: false
@@ -96,8 +97,8 @@ jobs:
KHOJ_URL: "http://localhost:42110" KHOJ_URL: "http://localhost:42110"
KHOJ_LLM_SEED: "42" KHOJ_LLM_SEED: "42"
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
SERPER_DEV_API_KEY: ${{ secrets.SERPER_DEV_API_KEY }} SERPER_DEV_API_KEY: ${{ matrix.dataset != 'math500' && secrets.SERPER_DEV_API_KEY }}
OLOSTEP_API_KEY: ${{ secrets.OLOSTEP_API_KEY }} OLOSTEP_API_KEY: ${{ matrix.dataset != 'math500' && secrets.OLOSTEP_API_KEY }}
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
KHOJ_ADMIN_EMAIL: khoj KHOJ_ADMIN_EMAIL: khoj
KHOJ_ADMIN_PASSWORD: khoj KHOJ_ADMIN_PASSWORD: khoj

View File

@@ -6,6 +6,7 @@ import os
import re import re
import time import time
from datetime import datetime from datetime import datetime
from functools import partial
from io import StringIO from io import StringIO
from textwrap import dedent from textwrap import dedent
from threading import Lock from threading import Lock
@@ -195,6 +196,30 @@ D) {choices[3]}
return None return None
def load_math500_dataset():
"""
Load and format the MATH500 dataset to match the evaluation script's structure.
Args:
sample_size (int, optional): Number of samples to include. Defaults to None (use full dataset).
randomize (bool, optional): Whether to randomize the dataset. Defaults to False.
Returns:
Dataset: Formatted HuggingFace Dataset.
"""
try:
# Load the MATH500 dataset from HuggingFace
dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
dataset = dataset.rename_columns({"problem": "Prompt", "answer": "Answer", "subject": "reasoning_types"})
dataset = dataset.shuffle() if RANDOMIZE else dataset
dataset = dataset.select(range(int(SAMPLE_SIZE))) if SAMPLE_SIZE else dataset
return dataset
except Exception as e:
print(f"Error loading and formatting MATH500 dataset: {e}")
return None
def get_agent_response(prompt: str) -> Dict[str, Any]: def get_agent_response(prompt: str) -> Dict[str, Any]:
"""Get response from the Khoj API""" """Get response from the Khoj API"""
# Set headers # Set headers
@@ -385,7 +410,7 @@ def parse_args():
"--dataset", "--dataset",
"-d", "-d",
default="frames", default="frames",
choices=["frames", "simpleqa", "gpqa"], choices=["frames", "simpleqa", "gpqa", "math500"],
help="Dataset to use for evaluation (default: frames)", help="Dataset to use for evaluation (default: frames)",
) )
return parser.parse_args() return parser.parse_args()
@@ -404,6 +429,8 @@ def main():
dataset = load_simpleqa_dataset() dataset = load_simpleqa_dataset()
elif args.dataset == "gpqa": elif args.dataset == "gpqa":
dataset = load_gpqa_dataset() dataset = load_gpqa_dataset()
elif args.dataset == "math500":
dataset = load_math500_dataset()
if dataset is None: if dataset is None:
return return
@@ -412,6 +439,10 @@ def main():
dataset_length = len(dataset["Prompt"]) dataset_length = len(dataset["Prompt"])
if args.dataset == "gpqa": if args.dataset == "gpqa":
response_evaluator = evaluate_response_with_mcq_match response_evaluator = evaluate_response_with_mcq_match
elif args.dataset == "math500":
response_evaluator = partial(
evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-flash-002")
)
else: else:
response_evaluator = evaluate_response_with_gemini response_evaluator = evaluate_response_with_gemini