mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Add GPQA (diamond) dataset to eval
This commit is contained in:
2
.github/workflows/run_evals.yml
vendored
2
.github/workflows/run_evals.yml
vendored
@@ -25,6 +25,7 @@ on:
|
|||||||
options:
|
options:
|
||||||
- frames
|
- frames
|
||||||
- simpleqa
|
- simpleqa
|
||||||
|
- gpqa
|
||||||
sample_size:
|
sample_size:
|
||||||
description: 'Number of samples to evaluate'
|
description: 'Number of samples to evaluate'
|
||||||
required: false
|
required: false
|
||||||
@@ -97,6 +98,7 @@ jobs:
|
|||||||
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||||
SERPER_DEV_API_KEY: ${{ secrets.SERPER_DEV_API_KEY }}
|
SERPER_DEV_API_KEY: ${{ secrets.SERPER_DEV_API_KEY }}
|
||||||
OLOSTEP_API_KEY: ${{ secrets.OLOSTEP_API_KEY }}
|
OLOSTEP_API_KEY: ${{ secrets.OLOSTEP_API_KEY }}
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
KHOJ_ADMIN_EMAIL: khoj
|
KHOJ_ADMIN_EMAIL: khoj
|
||||||
KHOJ_ADMIN_PASSWORD: khoj
|
KHOJ_ADMIN_PASSWORD: khoj
|
||||||
POSTGRES_HOST: localhost
|
POSTGRES_HOST: localhost
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import concurrent.futures
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
@@ -24,13 +25,10 @@ logger = logging.getLogger(__name__)
|
|||||||
KHOJ_URL = os.getenv("KHOJ_URL", "http://localhost:42110")
|
KHOJ_URL = os.getenv("KHOJ_URL", "http://localhost:42110")
|
||||||
KHOJ_CHAT_API_URL = f"{KHOJ_URL}/api/chat"
|
KHOJ_CHAT_API_URL = f"{KHOJ_URL}/api/chat"
|
||||||
KHOJ_API_KEY = os.getenv("KHOJ_API_KEY")
|
KHOJ_API_KEY = os.getenv("KHOJ_API_KEY")
|
||||||
KHOJ_MODE = os.getenv("KHOJ_MODE", "default") # E.g research, general, notes etc.
|
KHOJ_MODE = os.getenv("KHOJ_MODE", "default").lower() # E.g research, general, notes etc.
|
||||||
|
|
||||||
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
||||||
GEMINI_EVAL_MODEL = os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-pro-002")
|
GEMINI_EVAL_MODEL = os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-pro-002")
|
||||||
GEMINI_API_URL = (
|
|
||||||
f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_EVAL_MODEL}:generateContent?key={GEMINI_API_KEY}"
|
|
||||||
)
|
|
||||||
|
|
||||||
SAMPLE_SIZE = os.getenv("SAMPLE_SIZE") # Number of examples to evaluate
|
SAMPLE_SIZE = os.getenv("SAMPLE_SIZE") # Number of examples to evaluate
|
||||||
RANDOMIZE = os.getenv("RANDOMIZE", "false").lower() == "true" # Randomize examples
|
RANDOMIZE = os.getenv("RANDOMIZE", "false").lower() == "true" # Randomize examples
|
||||||
@@ -128,6 +126,75 @@ def load_simpleqa_dataset():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_gpqa_dataset():
|
||||||
|
"""
|
||||||
|
Load the Google GPQA benchmark dataset from HuggingFace
|
||||||
|
|
||||||
|
GPQA is a benchmark dataset to evaluate retrieval and answering capabilities of agents.
|
||||||
|
It contains ~800 requiring multi-hop retrieval and reasoning across various topics.
|
||||||
|
|
||||||
|
### Data Fields
|
||||||
|
- Prompt: The question to be answered
|
||||||
|
- Answer: The ground truth answer
|
||||||
|
- reasoning_types: The type of reasoning required to answer the question
|
||||||
|
"""
|
||||||
|
import random
|
||||||
|
|
||||||
|
def format_multiple_choice_question(row: Dict) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Create GPQA multi-choice prompt from shuffled answer choices and question.
|
||||||
|
Refer: https://github.com/openai/simple-evals/blob/a8e85cc8a5dea497d915f870895250e07f9cc737/common.py#L12
|
||||||
|
|
||||||
|
Returns formatted prompt and correct answer letter.
|
||||||
|
"""
|
||||||
|
# Gather choices
|
||||||
|
choices = [
|
||||||
|
row["Incorrect Answer 1"],
|
||||||
|
row["Incorrect Answer 2"],
|
||||||
|
row["Incorrect Answer 3"],
|
||||||
|
row["Correct Answer"],
|
||||||
|
]
|
||||||
|
# Shuffle choices
|
||||||
|
random.shuffle(choices)
|
||||||
|
|
||||||
|
# Get correct answer letter
|
||||||
|
correct_index = choices.index(row["Correct Answer"])
|
||||||
|
correct_letter = "ABCD"[correct_index]
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
Answer the following multiple choice question. Answer should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
||||||
|
|
||||||
|
{row["Question"]}
|
||||||
|
|
||||||
|
A) {choices[0]}
|
||||||
|
B) {choices[1]}
|
||||||
|
C) {choices[2]}
|
||||||
|
D) {choices[3]}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
return prompt, correct_letter
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataset = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train")
|
||||||
|
|
||||||
|
# Create multi-choice q&a prompt from choices and correct answer
|
||||||
|
prompts_and_answers = [format_multiple_choice_question(row) for row in dataset]
|
||||||
|
|
||||||
|
# Normalize dataset to FRAMES format
|
||||||
|
dataset = dataset.rename_columns({"Subdomain": "reasoning_types"})
|
||||||
|
dataset = dataset.add_column("Prompt", [p[0] for p in prompts_and_answers])
|
||||||
|
dataset = dataset.add_column("Answer", [p[1] for p in prompts_and_answers])
|
||||||
|
|
||||||
|
# Sample and shuffle dataset if configured
|
||||||
|
dataset = dataset.shuffle() if RANDOMIZE else dataset
|
||||||
|
dataset = dataset[: int(SAMPLE_SIZE)] if SAMPLE_SIZE else dataset
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading dataset: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_agent_response(prompt: str) -> Dict[str, Any]:
|
def get_agent_response(prompt: str) -> Dict[str, Any]:
|
||||||
"""Get response from the Khoj API"""
|
"""Get response from the Khoj API"""
|
||||||
# Set headers
|
# Set headers
|
||||||
@@ -152,7 +219,30 @@ def get_agent_response(prompt: str) -> Dict[str, Any]:
|
|||||||
return {"response": "", "usage": {}}
|
return {"response": "", "usage": {}}
|
||||||
|
|
||||||
|
|
||||||
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tuple[bool | None, str, float]:
|
def evaluate_response_with_mcq_match(
|
||||||
|
query: str, agent_response: str, ground_truth: str
|
||||||
|
) -> tuple[bool | None, str, float]:
|
||||||
|
"""Evaluate Khoj response against benchmark ground truth using string matching"""
|
||||||
|
try:
|
||||||
|
# Extract answer from agent response
|
||||||
|
answer_pattern_multichoice = r"(?i)Answer\s*:\s*([A-D])"
|
||||||
|
match = re.search(answer_pattern_multichoice, agent_response)
|
||||||
|
extracted_answer = match.group(1) if match else None
|
||||||
|
|
||||||
|
# Check if extracted answer matches ground truth
|
||||||
|
decision = extracted_answer == ground_truth
|
||||||
|
explanation = f"Agent response {'matches' if decision else 'does not match'} ground truth {ground_truth}"
|
||||||
|
|
||||||
|
# Return decision, explanation and cost in structured form
|
||||||
|
return decision, explanation, 0.0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in evaluation: {e}")
|
||||||
|
return None, f"Evaluation failed: {str(e)}", 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_response_with_gemini(
|
||||||
|
query: str, agent_response: str, ground_truth: str, eval_model=GEMINI_EVAL_MODEL
|
||||||
|
) -> tuple[bool | None, str, float]:
|
||||||
"""Evaluate Khoj response against benchmark ground truth using Gemini"""
|
"""Evaluate Khoj response against benchmark ground truth using Gemini"""
|
||||||
evaluation_prompt = f"""
|
evaluation_prompt = f"""
|
||||||
Compare the following agent response with the ground truth answer.
|
Compare the following agent response with the ground truth answer.
|
||||||
@@ -166,10 +256,13 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tup
|
|||||||
Provide your evaluation in the following json format:
|
Provide your evaluation in the following json format:
|
||||||
{"explanation:" "[How you made the decision?)", "decision:" "(TRUE if response contains key information, FALSE otherwise)"}
|
{"explanation:" "[How you made the decision?)", "decision:" "(TRUE if response contains key information, FALSE otherwise)"}
|
||||||
"""
|
"""
|
||||||
|
gemini_api_url = (
|
||||||
|
f"https://generativelanguage.googleapis.com/v1beta/models/{eval_model}:generateContent?key={GEMINI_API_KEY}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
GEMINI_API_URL,
|
gemini_api_url,
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
json={
|
json={
|
||||||
"contents": [{"parts": [{"text": evaluation_prompt}]}],
|
"contents": [{"parts": [{"text": evaluation_prompt}]}],
|
||||||
@@ -182,7 +275,7 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tup
|
|||||||
# Update cost of evaluation
|
# Update cost of evaluation
|
||||||
input_tokens = response_json["usageMetadata"]["promptTokenCount"]
|
input_tokens = response_json["usageMetadata"]["promptTokenCount"]
|
||||||
ouput_tokens = response_json["usageMetadata"]["candidatesTokenCount"]
|
ouput_tokens = response_json["usageMetadata"]["candidatesTokenCount"]
|
||||||
cost = get_cost_of_chat_message(GEMINI_EVAL_MODEL, input_tokens, ouput_tokens)
|
cost = get_cost_of_chat_message(eval_model, input_tokens, ouput_tokens)
|
||||||
|
|
||||||
# Parse evaluation response
|
# Parse evaluation response
|
||||||
eval_response: dict[str, str] = json.loads(
|
eval_response: dict[str, str] = json.loads(
|
||||||
@@ -200,7 +293,7 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tup
|
|||||||
return None, f"Evaluation failed: {str(e)}", 0.0
|
return None, f"Evaluation failed: {str(e)}", 0.0
|
||||||
|
|
||||||
|
|
||||||
def process_batch(batch, batch_start, results, dataset_length):
|
def process_batch(batch, batch_start, results, dataset_length, response_evaluator):
|
||||||
global running_cost
|
global running_cost
|
||||||
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
|
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
|
||||||
current_index = batch_start + idx
|
current_index = batch_start + idx
|
||||||
@@ -219,7 +312,7 @@ def process_batch(batch, batch_start, results, dataset_length):
|
|||||||
decision = None
|
decision = None
|
||||||
explanation = "Agent response is empty. This maybe due to a service error."
|
explanation = "Agent response is empty. This maybe due to a service error."
|
||||||
else:
|
else:
|
||||||
decision, explanation, eval_cost = evaluate_response(prompt, agent_response, answer)
|
decision, explanation, eval_cost = response_evaluator(prompt, agent_response, answer)
|
||||||
|
|
||||||
# Store results
|
# Store results
|
||||||
results.append(
|
results.append(
|
||||||
@@ -292,7 +385,7 @@ def parse_args():
|
|||||||
"--dataset",
|
"--dataset",
|
||||||
"-d",
|
"-d",
|
||||||
default="frames",
|
default="frames",
|
||||||
choices=["frames", "simpleqa"],
|
choices=["frames", "simpleqa", "gpqa"],
|
||||||
help="Dataset to use for evaluation (default: frames)",
|
help="Dataset to use for evaluation (default: frames)",
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
@@ -309,12 +402,18 @@ def main():
|
|||||||
dataset = load_frames_dataset()
|
dataset = load_frames_dataset()
|
||||||
elif args.dataset == "simpleqa":
|
elif args.dataset == "simpleqa":
|
||||||
dataset = load_simpleqa_dataset()
|
dataset = load_simpleqa_dataset()
|
||||||
|
elif args.dataset == "gpqa":
|
||||||
|
dataset = load_gpqa_dataset()
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Initialize variables
|
# Initialize variables
|
||||||
results = []
|
results = []
|
||||||
dataset_length = len(dataset["Prompt"])
|
dataset_length = len(dataset["Prompt"])
|
||||||
|
if args.dataset == "gpqa":
|
||||||
|
response_evaluator = evaluate_response_with_mcq_match
|
||||||
|
else:
|
||||||
|
response_evaluator = evaluate_response_with_gemini
|
||||||
|
|
||||||
# Process examples in batches
|
# Process examples in batches
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
@@ -326,7 +425,9 @@ def main():
|
|||||||
dataset["Answer"][i : i + BATCH_SIZE],
|
dataset["Answer"][i : i + BATCH_SIZE],
|
||||||
dataset["reasoning_types"][i : i + BATCH_SIZE],
|
dataset["reasoning_types"][i : i + BATCH_SIZE],
|
||||||
)
|
)
|
||||||
futures.append(executor.submit(process_batch, batch, batch_start, results, dataset_length))
|
futures.append(
|
||||||
|
executor.submit(process_batch, batch, batch_start, results, dataset_length, response_evaluator)
|
||||||
|
)
|
||||||
|
|
||||||
# Wait for all futures to complete
|
# Wait for all futures to complete
|
||||||
concurrent.futures.wait(futures)
|
concurrent.futures.wait(futures)
|
||||||
|
|||||||
Reference in New Issue
Block a user