mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Add MATH500 dataset to eval
Evaluate simpler MATH500 responses with gemini 1.5 flash This improves both the speed and cost of running this eval
This commit is contained in:
5
.github/workflows/run_evals.yml
vendored
5
.github/workflows/run_evals.yml
vendored
@@ -26,6 +26,7 @@ on:
|
|||||||
- frames
|
- frames
|
||||||
- simpleqa
|
- simpleqa
|
||||||
- gpqa
|
- gpqa
|
||||||
|
- math500
|
||||||
sample_size:
|
sample_size:
|
||||||
description: 'Number of samples to evaluate'
|
description: 'Number of samples to evaluate'
|
||||||
required: false
|
required: false
|
||||||
@@ -96,8 +97,8 @@ jobs:
|
|||||||
KHOJ_URL: "http://localhost:42110"
|
KHOJ_URL: "http://localhost:42110"
|
||||||
KHOJ_LLM_SEED: "42"
|
KHOJ_LLM_SEED: "42"
|
||||||
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||||
SERPER_DEV_API_KEY: ${{ secrets.SERPER_DEV_API_KEY }}
|
SERPER_DEV_API_KEY: ${{ matrix.dataset != 'math500' && secrets.SERPER_DEV_API_KEY }}
|
||||||
OLOSTEP_API_KEY: ${{ secrets.OLOSTEP_API_KEY }}
|
OLOSTEP_API_KEY: ${{ matrix.dataset != 'math500' && secrets.OLOSTEP_API_KEY }}
|
||||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
KHOJ_ADMIN_EMAIL: khoj
|
KHOJ_ADMIN_EMAIL: khoj
|
||||||
KHOJ_ADMIN_PASSWORD: khoj
|
KHOJ_ADMIN_PASSWORD: khoj
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from functools import partial
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
@@ -195,6 +196,30 @@ D) {choices[3]}
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_math500_dataset():
|
||||||
|
"""
|
||||||
|
Load and format the MATH500 dataset to match the evaluation script's structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_size (int, optional): Number of samples to include. Defaults to None (use full dataset).
|
||||||
|
randomize (bool, optional): Whether to randomize the dataset. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset: Formatted HuggingFace Dataset.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load the MATH500 dataset from HuggingFace
|
||||||
|
dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
|
||||||
|
dataset = dataset.rename_columns({"problem": "Prompt", "answer": "Answer", "subject": "reasoning_types"})
|
||||||
|
dataset = dataset.shuffle() if RANDOMIZE else dataset
|
||||||
|
dataset = dataset.select(range(int(SAMPLE_SIZE))) if SAMPLE_SIZE else dataset
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading and formatting MATH500 dataset: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_agent_response(prompt: str) -> Dict[str, Any]:
|
def get_agent_response(prompt: str) -> Dict[str, Any]:
|
||||||
"""Get response from the Khoj API"""
|
"""Get response from the Khoj API"""
|
||||||
# Set headers
|
# Set headers
|
||||||
@@ -385,7 +410,7 @@ def parse_args():
|
|||||||
"--dataset",
|
"--dataset",
|
||||||
"-d",
|
"-d",
|
||||||
default="frames",
|
default="frames",
|
||||||
choices=["frames", "simpleqa", "gpqa"],
|
choices=["frames", "simpleqa", "gpqa", "math500"],
|
||||||
help="Dataset to use for evaluation (default: frames)",
|
help="Dataset to use for evaluation (default: frames)",
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
@@ -404,6 +429,8 @@ def main():
|
|||||||
dataset = load_simpleqa_dataset()
|
dataset = load_simpleqa_dataset()
|
||||||
elif args.dataset == "gpqa":
|
elif args.dataset == "gpqa":
|
||||||
dataset = load_gpqa_dataset()
|
dataset = load_gpqa_dataset()
|
||||||
|
elif args.dataset == "math500":
|
||||||
|
dataset = load_math500_dataset()
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -412,6 +439,10 @@ def main():
|
|||||||
dataset_length = len(dataset["Prompt"])
|
dataset_length = len(dataset["Prompt"])
|
||||||
if args.dataset == "gpqa":
|
if args.dataset == "gpqa":
|
||||||
response_evaluator = evaluate_response_with_mcq_match
|
response_evaluator = evaluate_response_with_mcq_match
|
||||||
|
elif args.dataset == "math500":
|
||||||
|
response_evaluator = partial(
|
||||||
|
evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-flash-002")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
response_evaluator = evaluate_response_with_gemini
|
response_evaluator = evaluate_response_with_gemini
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user