mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Add GPQA (diamond) dataset to eval
This commit is contained in:
@@ -3,6 +3,7 @@ import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
@@ -24,13 +25,10 @@ logger = logging.getLogger(__name__)
|
||||
KHOJ_URL = os.getenv("KHOJ_URL", "http://localhost:42110")
|
||||
KHOJ_CHAT_API_URL = f"{KHOJ_URL}/api/chat"
|
||||
KHOJ_API_KEY = os.getenv("KHOJ_API_KEY")
|
||||
KHOJ_MODE = os.getenv("KHOJ_MODE", "default") # E.g research, general, notes etc.
|
||||
KHOJ_MODE = os.getenv("KHOJ_MODE", "default").lower() # E.g research, general, notes etc.
|
||||
|
||||
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
||||
GEMINI_EVAL_MODEL = os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-pro-002")
|
||||
GEMINI_API_URL = (
|
||||
f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_EVAL_MODEL}:generateContent?key={GEMINI_API_KEY}"
|
||||
)
|
||||
|
||||
SAMPLE_SIZE = os.getenv("SAMPLE_SIZE") # Number of examples to evaluate
|
||||
RANDOMIZE = os.getenv("RANDOMIZE", "false").lower() == "true" # Randomize examples
|
||||
@@ -128,6 +126,75 @@ def load_simpleqa_dataset():
|
||||
return None
|
||||
|
||||
|
||||
def load_gpqa_dataset():
|
||||
"""
|
||||
Load the Google GPQA benchmark dataset from HuggingFace
|
||||
|
||||
GPQA is a benchmark dataset to evaluate retrieval and answering capabilities of agents.
|
||||
It contains ~800 requiring multi-hop retrieval and reasoning across various topics.
|
||||
|
||||
### Data Fields
|
||||
- Prompt: The question to be answered
|
||||
- Answer: The ground truth answer
|
||||
- reasoning_types: The type of reasoning required to answer the question
|
||||
"""
|
||||
import random
|
||||
|
||||
def format_multiple_choice_question(row: Dict) -> tuple[str, str]:
|
||||
"""
|
||||
Create GPQA multi-choice prompt from shuffled answer choices and question.
|
||||
Refer: https://github.com/openai/simple-evals/blob/a8e85cc8a5dea497d915f870895250e07f9cc737/common.py#L12
|
||||
|
||||
Returns formatted prompt and correct answer letter.
|
||||
"""
|
||||
# Gather choices
|
||||
choices = [
|
||||
row["Incorrect Answer 1"],
|
||||
row["Incorrect Answer 2"],
|
||||
row["Incorrect Answer 3"],
|
||||
row["Correct Answer"],
|
||||
]
|
||||
# Shuffle choices
|
||||
random.shuffle(choices)
|
||||
|
||||
# Get correct answer letter
|
||||
correct_index = choices.index(row["Correct Answer"])
|
||||
correct_letter = "ABCD"[correct_index]
|
||||
|
||||
prompt = f"""
|
||||
Answer the following multiple choice question. Answer should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
||||
|
||||
{row["Question"]}
|
||||
|
||||
A) {choices[0]}
|
||||
B) {choices[1]}
|
||||
C) {choices[2]}
|
||||
D) {choices[3]}
|
||||
""".strip()
|
||||
|
||||
return prompt, correct_letter
|
||||
|
||||
try:
|
||||
dataset = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train")
|
||||
|
||||
# Create multi-choice q&a prompt from choices and correct answer
|
||||
prompts_and_answers = [format_multiple_choice_question(row) for row in dataset]
|
||||
|
||||
# Normalize dataset to FRAMES format
|
||||
dataset = dataset.rename_columns({"Subdomain": "reasoning_types"})
|
||||
dataset = dataset.add_column("Prompt", [p[0] for p in prompts_and_answers])
|
||||
dataset = dataset.add_column("Answer", [p[1] for p in prompts_and_answers])
|
||||
|
||||
# Sample and shuffle dataset if configured
|
||||
dataset = dataset.shuffle() if RANDOMIZE else dataset
|
||||
dataset = dataset[: int(SAMPLE_SIZE)] if SAMPLE_SIZE else dataset
|
||||
|
||||
return dataset
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading dataset: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_agent_response(prompt: str) -> Dict[str, Any]:
|
||||
"""Get response from the Khoj API"""
|
||||
# Set headers
|
||||
@@ -152,7 +219,30 @@ def get_agent_response(prompt: str) -> Dict[str, Any]:
|
||||
return {"response": "", "usage": {}}
|
||||
|
||||
|
||||
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tuple[bool | None, str, float]:
|
||||
def evaluate_response_with_mcq_match(
|
||||
query: str, agent_response: str, ground_truth: str
|
||||
) -> tuple[bool | None, str, float]:
|
||||
"""Evaluate Khoj response against benchmark ground truth using string matching"""
|
||||
try:
|
||||
# Extract answer from agent response
|
||||
answer_pattern_multichoice = r"(?i)Answer\s*:\s*([A-D])"
|
||||
match = re.search(answer_pattern_multichoice, agent_response)
|
||||
extracted_answer = match.group(1) if match else None
|
||||
|
||||
# Check if extracted answer matches ground truth
|
||||
decision = extracted_answer == ground_truth
|
||||
explanation = f"Agent response {'matches' if decision else 'does not match'} ground truth {ground_truth}"
|
||||
|
||||
# Return decision, explanation and cost in structured form
|
||||
return decision, explanation, 0.0
|
||||
except Exception as e:
|
||||
logger.error(f"Error in evaluation: {e}")
|
||||
return None, f"Evaluation failed: {str(e)}", 0.0
|
||||
|
||||
|
||||
def evaluate_response_with_gemini(
|
||||
query: str, agent_response: str, ground_truth: str, eval_model=GEMINI_EVAL_MODEL
|
||||
) -> tuple[bool | None, str, float]:
|
||||
"""Evaluate Khoj response against benchmark ground truth using Gemini"""
|
||||
evaluation_prompt = f"""
|
||||
Compare the following agent response with the ground truth answer.
|
||||
@@ -166,10 +256,13 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tup
|
||||
Provide your evaluation in the following json format:
|
||||
{"explanation:" "[How you made the decision?)", "decision:" "(TRUE if response contains key information, FALSE otherwise)"}
|
||||
"""
|
||||
gemini_api_url = (
|
||||
f"https://generativelanguage.googleapis.com/v1beta/models/{eval_model}:generateContent?key={GEMINI_API_KEY}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
GEMINI_API_URL,
|
||||
gemini_api_url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={
|
||||
"contents": [{"parts": [{"text": evaluation_prompt}]}],
|
||||
@@ -182,7 +275,7 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tup
|
||||
# Update cost of evaluation
|
||||
input_tokens = response_json["usageMetadata"]["promptTokenCount"]
|
||||
ouput_tokens = response_json["usageMetadata"]["candidatesTokenCount"]
|
||||
cost = get_cost_of_chat_message(GEMINI_EVAL_MODEL, input_tokens, ouput_tokens)
|
||||
cost = get_cost_of_chat_message(eval_model, input_tokens, ouput_tokens)
|
||||
|
||||
# Parse evaluation response
|
||||
eval_response: dict[str, str] = json.loads(
|
||||
@@ -200,7 +293,7 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tup
|
||||
return None, f"Evaluation failed: {str(e)}", 0.0
|
||||
|
||||
|
||||
def process_batch(batch, batch_start, results, dataset_length):
|
||||
def process_batch(batch, batch_start, results, dataset_length, response_evaluator):
|
||||
global running_cost
|
||||
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
|
||||
current_index = batch_start + idx
|
||||
@@ -219,7 +312,7 @@ def process_batch(batch, batch_start, results, dataset_length):
|
||||
decision = None
|
||||
explanation = "Agent response is empty. This maybe due to a service error."
|
||||
else:
|
||||
decision, explanation, eval_cost = evaluate_response(prompt, agent_response, answer)
|
||||
decision, explanation, eval_cost = response_evaluator(prompt, agent_response, answer)
|
||||
|
||||
# Store results
|
||||
results.append(
|
||||
@@ -292,7 +385,7 @@ def parse_args():
|
||||
"--dataset",
|
||||
"-d",
|
||||
default="frames",
|
||||
choices=["frames", "simpleqa"],
|
||||
choices=["frames", "simpleqa", "gpqa"],
|
||||
help="Dataset to use for evaluation (default: frames)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
@@ -309,12 +402,18 @@ def main():
|
||||
dataset = load_frames_dataset()
|
||||
elif args.dataset == "simpleqa":
|
||||
dataset = load_simpleqa_dataset()
|
||||
elif args.dataset == "gpqa":
|
||||
dataset = load_gpqa_dataset()
|
||||
if dataset is None:
|
||||
return
|
||||
|
||||
# Initialize variables
|
||||
results = []
|
||||
dataset_length = len(dataset["Prompt"])
|
||||
if args.dataset == "gpqa":
|
||||
response_evaluator = evaluate_response_with_mcq_match
|
||||
else:
|
||||
response_evaluator = evaluate_response_with_gemini
|
||||
|
||||
# Process examples in batches
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
@@ -326,7 +425,9 @@ def main():
|
||||
dataset["Answer"][i : i + BATCH_SIZE],
|
||||
dataset["reasoning_types"][i : i + BATCH_SIZE],
|
||||
)
|
||||
futures.append(executor.submit(process_batch, batch, batch_start, results, dataset_length))
|
||||
futures.append(
|
||||
executor.submit(process_batch, batch, batch_start, results, dataset_length, response_evaluator)
|
||||
)
|
||||
|
||||
# Wait for all futures to complete
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
Reference in New Issue
Block a user