Evaluate information retrieval quality using eval script

- Encode article urls in filename indexed in Khoj KB
  Makes it easier for humans to compare, trace retrieval performance
  by looking at logs than using content hash (which was previously
  explored)
This commit is contained in:
Debanjum
2025-01-03 13:50:49 +07:00
parent daeba66c0d
commit dc0bc5bcca

View File

@@ -1,10 +1,13 @@
import argparse
import base64
import concurrent.futures
import hashlib
import json
import logging
import os
import re
import time
import uuid
from datetime import datetime
from functools import partial
from io import StringIO
@@ -14,9 +17,16 @@ from typing import Any, Dict
import pandas as pd
import requests
import yaml
from datasets import Dataset, load_dataset
from tqdm import tqdm
from khoj.utils.helpers import get_cost_of_chat_message, is_none_or_empty, timer
from khoj.utils.helpers import (
batcher,
get_cost_of_chat_message,
is_none_or_empty,
timer,
)
# Configure root logger
logging.basicConfig(level=logging.INFO, format="%(message)s")
@@ -61,6 +71,116 @@ running_true_count = Counter(0)
running_total_count = Counter(0)
def get_article_filename(article: dict[str, str]) -> str:
"""Create a unique filename for a Wikipedia article"""
# Construct filename from frames prompt ids associated with each article and url
encoded_url = base64.urlsafe_b64encode(article["link"].encode()).decode()
return "-".join(map(str, article["frames_prompt_id"])) + f"_{encoded_url}.txt"
def extract_prompt_ids_from_filename(filename: str) -> set[int]:
"""Extract frames prompt id from a indexed file name"""
return set(map(int, filename.split("_", 1)[0].split("-")))
def extract_article_url_from_filename(filename: str) -> set[int]:
"""Decode URL from filename"""
encoded_url = filename.split("_", 1)[1].rsplit(".", 1)[0]
return base64.urlsafe_b64decode(encoded_url).decode()
def get_articles_by_prompt_id(prompt_id: int):
"""Get all Wikipedia articles relevant to a specific FRAMES prompt ID"""
try:
# Load dataset
dataset = load_dataset("parasail-ai/frames-benchmark-wikipedia")
# Filter function to check if prompt_id exists in sequence
def has_prompt_id(example):
return prompt_id in example["frames_prompt_id"]
# Filter dataset and return matching rows
filtered_dataset = dataset["train"].filter(has_prompt_id)
return filtered_dataset
except Exception as e:
logger.error(f"Error filtering dataset for prompt {prompt_id}: {e}")
return None
def load_frames_kb():
"""
Load Wikipedia articles used as Knowledge Base by the FRAMES benchmark dataset from HuggingFace
FRAMES is a benchmark dataset to evaluate retrieval and answering capabilities of agents.
It contains ~800 requiring multi-hop retrieval and reasoning across various topics from Wikipedia.
### Data Fields
- link: The link to the Wikipedia article
- text: The text content of the Wikipedia article
- frames_prompt_id: The list of FRAMES prompt ids for which this article is relevant
"""
try:
dataset_name = "parasail-ai/frames-benchmark-wikipedia"
dataset = load_dataset(dataset_name)
return dataset["train"]
except Exception as e:
logger.error(f"Error loading {dataset_name} dataset: {e}")
return None
def index_frames_kb():
"""Index Wikipedia articles from FRAMES dataset into Khoj"""
try:
# Load dataset
dataset = load_frames_kb()
dataset_files = set(map(get_article_filename, dataset))
# Get indexed files from Khoj API
headers = {"Authorization": f"Bearer {KHOJ_API_KEY}"} if KHOJ_API_KEY else {}
try:
response = requests.get(f"{KHOJ_URL}/api/content/computer", headers=headers)
response.raise_for_status()
indexed_files = set(response.json())
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get indexed files: {e}")
return False
# Find missing files to index
missing_files = dataset_files - indexed_files
filtered_dataset = [
article
for article in dataset
if get_article_filename(article) in missing_files and not is_none_or_empty(article["text"])
]
if not filtered_dataset:
return True
logger.info(f"Found {len(filtered_dataset)} files to index")
# Process Wikipedia articles from FRAMES knowledge base in batches
batch_size = 300
total_batches = len(filtered_dataset) // batch_size + 1
for batch in tqdm(batcher(filtered_dataset, batch_size), total=total_batches, desc="Indexing FRAMES KB"):
# Create files batch to index
files = []
for article in batch:
filename = get_article_filename(article)
files.append(("files", (filename, article["text"], "text/plaintext")))
# Send files batch to index
try:
response = requests.patch(f"{KHOJ_URL}/api/content?client=eval", headers=headers, files=files)
response.raise_for_status()
time.sleep(SLEEP_SECONDS) # Rate limiting
except Exception as e:
logger.error(f"Failed to index batch: {e}")
return False
return True
except Exception as e:
logger.error(f"Failed to index KB: {e}")
return False
def load_frames_dataset():
"""
Load the Google FRAMES benchmark dataset from HuggingFace
@@ -248,6 +368,62 @@ def get_agent_response(prompt: str) -> Dict[str, Any]:
return {"response": "", "usage": {}, "references": {}}
def calculate_precision_recall(numerator: int, denominator: int) -> float:
"""Calculate precision and recall from numerator and denominator"""
if numerator == 0 and denominator == 0:
return 1.0
elif numerator > 0 and denominator == 0:
return 0.0
else:
return numerator / denominator
def calculate_fi(precision: float, recall: float) -> float:
"""Calculate F1 score from precision and recall"""
return 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0
def evaluate_response_for_ir(
query: str, agent_response: str, ground_truth: int, agent_references: dict = {}
) -> tuple[bool | None, str, float]:
"""Evaluate Khoj response against benchmark ground truth using string matching"""
try:
# Extract answer from agent response
referenced_files: list[dict[str, str]] = agent_references.get("context", [])
count_of_correct_articles_used_by_agent: int = 0
# Count how many of the expected articles the agent actually retrieved from the KB
unique_file_refs = {file["file"] for file in referenced_files}
referenced_articles = list(map(extract_article_url_from_filename, unique_file_refs))
for file in unique_file_refs:
frames_ids_for_articles_used_by_agent = extract_prompt_ids_from_filename(file)
count_of_correct_articles_used_by_agent += int(ground_truth in frames_ids_for_articles_used_by_agent)
articles = get_articles_by_prompt_id(ground_truth)
precision = calculate_precision_recall(count_of_correct_articles_used_by_agent, len(unique_file_refs))
recall = calculate_precision_recall(count_of_correct_articles_used_by_agent, len(articles))
f1 = calculate_fi(precision, recall)
explanation = (
f"Information Retrieval F1 Score: {f1:.2%} Recall: {recall:.2%}, Precision: {precision:.2%}.\n"
f"{count_of_correct_articles_used_by_agent} of {len(articles)} correct from {len(unique_file_refs)} total retrievals for {ground_truth}.\n"
f"Queries:\n{yaml.dump(sorted([r['query'] for r in referenced_files]))}\n"
f"Expected Articles for {ground_truth}:\n{yaml.dump(sorted([a['link'] for a in articles]))}\n"
f"Retrieved Articles for {ground_truth}:\n{yaml.dump(referenced_articles)}\n"
)
# Truncate referenced files for logging
truncated_refs = [
{k: v[:200] + "..." if len(v) > 200 else v for k, v in ref.items()} for ref in referenced_files
]
logger.info(f"Retrieved Article Details:\n{yaml.dump(truncated_refs, sort_keys=False)}\n")
# Return decision, explanation and cost in structured form
return recall, explanation, 0.0
except Exception as e:
logger.error(f"Error in IR evaluation: {e}")
return None, f"Evaluation failed: {str(e)}", 0.0
def evaluate_response_with_mcq_match(
query: str, agent_response: str, ground_truth: str, agent_references: dict = {}
) -> tuple[bool | None, str, float]:
@@ -417,7 +593,7 @@ def parse_args():
"--dataset",
"-d",
default="frames",
choices=["frames", "simpleqa", "gpqa", "math500"],
choices=["frames", "frames_ir", "simpleqa", "gpqa", "math500"],
help="Dataset to use for evaluation (default: frames)",
)
return parser.parse_args()
@@ -438,6 +614,12 @@ def main():
dataset = load_gpqa_dataset()
elif args.dataset == "math500":
dataset = load_math500_dataset()
elif args.dataset == "frames_ir":
indexed = index_frames_kb()
if indexed:
dataset = load_frames_dataset()
# Rename the index field, 'Unnamed: 0' to 'Answer' for IR evaluation
dataset["Answer"] = dataset["Unnamed: 0"]
if dataset is None:
return
@@ -450,6 +632,8 @@ def main():
response_evaluator = partial(
evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-flash-002")
)
elif args.dataset == "frames_ir":
response_evaluator = evaluate_response_for_ir
else:
response_evaluator = evaluate_response_with_gemini