Add GPQA (diamond) dataset to eval

This commit is contained in:
Debanjum
2024-11-27 16:30:20 -08:00
parent f1190ccf32
commit 22aef9bf53
2 changed files with 114 additions and 11 deletions

View File

@@ -3,6 +3,7 @@ import concurrent.futures
import json
import logging
import os
import re
import time
from datetime import datetime
from io import StringIO
@@ -24,13 +25,10 @@ logger = logging.getLogger(__name__)
KHOJ_URL = os.getenv("KHOJ_URL", "http://localhost:42110")
KHOJ_CHAT_API_URL = f"{KHOJ_URL}/api/chat"
KHOJ_API_KEY = os.getenv("KHOJ_API_KEY")
KHOJ_MODE = os.getenv("KHOJ_MODE", "default") # E.g research, general, notes etc.
KHOJ_MODE = os.getenv("KHOJ_MODE", "default").lower() # E.g research, general, notes etc.
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
GEMINI_EVAL_MODEL = os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-pro-002")
GEMINI_API_URL = (
f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_EVAL_MODEL}:generateContent?key={GEMINI_API_KEY}"
)
SAMPLE_SIZE = os.getenv("SAMPLE_SIZE") # Number of examples to evaluate
RANDOMIZE = os.getenv("RANDOMIZE", "false").lower() == "true" # Randomize examples
@@ -128,6 +126,75 @@ def load_simpleqa_dataset():
return None
def load_gpqa_dataset():
"""
Load the Google GPQA benchmark dataset from HuggingFace
GPQA is a benchmark dataset to evaluate retrieval and answering capabilities of agents.
It contains ~800 requiring multi-hop retrieval and reasoning across various topics.
### Data Fields
- Prompt: The question to be answered
- Answer: The ground truth answer
- reasoning_types: The type of reasoning required to answer the question
"""
import random
def format_multiple_choice_question(row: Dict) -> tuple[str, str]:
"""
Create GPQA multi-choice prompt from shuffled answer choices and question.
Refer: https://github.com/openai/simple-evals/blob/a8e85cc8a5dea497d915f870895250e07f9cc737/common.py#L12
Returns formatted prompt and correct answer letter.
"""
# Gather choices
choices = [
row["Incorrect Answer 1"],
row["Incorrect Answer 2"],
row["Incorrect Answer 3"],
row["Correct Answer"],
]
# Shuffle choices
random.shuffle(choices)
# Get correct answer letter
correct_index = choices.index(row["Correct Answer"])
correct_letter = "ABCD"[correct_index]
prompt = f"""
Answer the following multiple choice question. Answer should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
{row["Question"]}
A) {choices[0]}
B) {choices[1]}
C) {choices[2]}
D) {choices[3]}
""".strip()
return prompt, correct_letter
try:
dataset = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train")
# Create multi-choice q&a prompt from choices and correct answer
prompts_and_answers = [format_multiple_choice_question(row) for row in dataset]
# Normalize dataset to FRAMES format
dataset = dataset.rename_columns({"Subdomain": "reasoning_types"})
dataset = dataset.add_column("Prompt", [p[0] for p in prompts_and_answers])
dataset = dataset.add_column("Answer", [p[1] for p in prompts_and_answers])
# Sample and shuffle dataset if configured
dataset = dataset.shuffle() if RANDOMIZE else dataset
dataset = dataset[: int(SAMPLE_SIZE)] if SAMPLE_SIZE else dataset
return dataset
except Exception as e:
logger.error(f"Error loading dataset: {e}")
return None
def get_agent_response(prompt: str) -> Dict[str, Any]:
"""Get response from the Khoj API"""
# Set headers
@@ -152,7 +219,30 @@ def get_agent_response(prompt: str) -> Dict[str, Any]:
return {"response": "", "usage": {}}
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tuple[bool | None, str, float]:
def evaluate_response_with_mcq_match(
query: str, agent_response: str, ground_truth: str
) -> tuple[bool | None, str, float]:
"""Evaluate Khoj response against benchmark ground truth using string matching"""
try:
# Extract answer from agent response
answer_pattern_multichoice = r"(?i)Answer\s*:\s*([A-D])"
match = re.search(answer_pattern_multichoice, agent_response)
extracted_answer = match.group(1) if match else None
# Check if extracted answer matches ground truth
decision = extracted_answer == ground_truth
explanation = f"Agent response {'matches' if decision else 'does not match'} ground truth {ground_truth}"
# Return decision, explanation and cost in structured form
return decision, explanation, 0.0
except Exception as e:
logger.error(f"Error in evaluation: {e}")
return None, f"Evaluation failed: {str(e)}", 0.0
def evaluate_response_with_gemini(
query: str, agent_response: str, ground_truth: str, eval_model=GEMINI_EVAL_MODEL
) -> tuple[bool | None, str, float]:
"""Evaluate Khoj response against benchmark ground truth using Gemini"""
evaluation_prompt = f"""
Compare the following agent response with the ground truth answer.
@@ -166,10 +256,13 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tup
Provide your evaluation in the following json format:
{"explanation:" "[How you made the decision?)", "decision:" "(TRUE if response contains key information, FALSE otherwise)"}
"""
gemini_api_url = (
f"https://generativelanguage.googleapis.com/v1beta/models/{eval_model}:generateContent?key={GEMINI_API_KEY}"
)
try:
response = requests.post(
GEMINI_API_URL,
gemini_api_url,
headers={"Content-Type": "application/json"},
json={
"contents": [{"parts": [{"text": evaluation_prompt}]}],
@@ -182,7 +275,7 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tup
# Update cost of evaluation
input_tokens = response_json["usageMetadata"]["promptTokenCount"]
ouput_tokens = response_json["usageMetadata"]["candidatesTokenCount"]
cost = get_cost_of_chat_message(GEMINI_EVAL_MODEL, input_tokens, ouput_tokens)
cost = get_cost_of_chat_message(eval_model, input_tokens, ouput_tokens)
# Parse evaluation response
eval_response: dict[str, str] = json.loads(
@@ -200,7 +293,7 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tup
return None, f"Evaluation failed: {str(e)}", 0.0
def process_batch(batch, batch_start, results, dataset_length):
def process_batch(batch, batch_start, results, dataset_length, response_evaluator):
global running_cost
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
current_index = batch_start + idx
@@ -219,7 +312,7 @@ def process_batch(batch, batch_start, results, dataset_length):
decision = None
explanation = "Agent response is empty. This maybe due to a service error."
else:
decision, explanation, eval_cost = evaluate_response(prompt, agent_response, answer)
decision, explanation, eval_cost = response_evaluator(prompt, agent_response, answer)
# Store results
results.append(
@@ -292,7 +385,7 @@ def parse_args():
"--dataset",
"-d",
default="frames",
choices=["frames", "simpleqa"],
choices=["frames", "simpleqa", "gpqa"],
help="Dataset to use for evaluation (default: frames)",
)
return parser.parse_args()
@@ -309,12 +402,18 @@ def main():
dataset = load_frames_dataset()
elif args.dataset == "simpleqa":
dataset = load_simpleqa_dataset()
elif args.dataset == "gpqa":
dataset = load_gpqa_dataset()
if dataset is None:
return
# Initialize variables
results = []
dataset_length = len(dataset["Prompt"])
if args.dataset == "gpqa":
response_evaluator = evaluate_response_with_mcq_match
else:
response_evaluator = evaluate_response_with_gemini
# Process examples in batches
with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -326,7 +425,9 @@ def main():
dataset["Answer"][i : i + BATCH_SIZE],
dataset["reasoning_types"][i : i + BATCH_SIZE],
)
futures.append(executor.submit(process_batch, batch, batch_start, results, dataset_length))
futures.append(
executor.submit(process_batch, batch, batch_start, results, dataset_length, response_evaluator)
)
# Wait for all futures to complete
concurrent.futures.wait(futures)