From dc0bc5bcca59fd9dea672d54598f14907e05721a Mon Sep 17 00:00:00 2001 From: Debanjum Date: Fri, 3 Jan 2025 13:50:49 +0700 Subject: [PATCH] Evaluate information retrieval quality using eval script - Encode article urls in filename indexed in Khoj KB Makes it easier for humans to compare, trace retrieval performance by looking at logs than using content hash (which was previously explored) --- tests/evals/eval.py | 188 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 186 insertions(+), 2 deletions(-) diff --git a/tests/evals/eval.py b/tests/evals/eval.py index ff5d9cf0..20a6051e 100644 --- a/tests/evals/eval.py +++ b/tests/evals/eval.py @@ -1,10 +1,13 @@ import argparse +import base64 import concurrent.futures +import hashlib import json import logging import os import re import time +import uuid from datetime import datetime from functools import partial from io import StringIO @@ -14,9 +17,16 @@ from typing import Any, Dict import pandas as pd import requests +import yaml from datasets import Dataset, load_dataset +from tqdm import tqdm -from khoj.utils.helpers import get_cost_of_chat_message, is_none_or_empty, timer +from khoj.utils.helpers import ( + batcher, + get_cost_of_chat_message, + is_none_or_empty, + timer, +) # Configure root logger logging.basicConfig(level=logging.INFO, format="%(message)s") @@ -61,6 +71,116 @@ running_true_count = Counter(0) running_total_count = Counter(0) +def get_article_filename(article: dict[str, str]) -> str: + """Create a unique filename for a Wikipedia article""" + # Construct filename from frames prompt ids associated with each article and url + encoded_url = base64.urlsafe_b64encode(article["link"].encode()).decode() + return "-".join(map(str, article["frames_prompt_id"])) + f"_{encoded_url}.txt" + + +def extract_prompt_ids_from_filename(filename: str) -> set[int]: + """Extract frames prompt id from a indexed file name""" + return set(map(int, filename.split("_", 1)[0].split("-"))) + + +def extract_article_url_from_filename(filename: str) -> set[int]: + """Decode URL from filename""" + encoded_url = filename.split("_", 1)[1].rsplit(".", 1)[0] + return base64.urlsafe_b64decode(encoded_url).decode() + + +def get_articles_by_prompt_id(prompt_id: int): + """Get all Wikipedia articles relevant to a specific FRAMES prompt ID""" + try: + # Load dataset + dataset = load_dataset("parasail-ai/frames-benchmark-wikipedia") + + # Filter function to check if prompt_id exists in sequence + def has_prompt_id(example): + return prompt_id in example["frames_prompt_id"] + + # Filter dataset and return matching rows + filtered_dataset = dataset["train"].filter(has_prompt_id) + return filtered_dataset + + except Exception as e: + logger.error(f"Error filtering dataset for prompt {prompt_id}: {e}") + return None + + +def load_frames_kb(): + """ + Load Wikipedia articles used as Knowledge Base by the FRAMES benchmark dataset from HuggingFace + + FRAMES is a benchmark dataset to evaluate retrieval and answering capabilities of agents. + It contains ~800 requiring multi-hop retrieval and reasoning across various topics from Wikipedia. + + ### Data Fields + - link: The link to the Wikipedia article + - text: The text content of the Wikipedia article + - frames_prompt_id: The list of FRAMES prompt ids for which this article is relevant + """ + try: + dataset_name = "parasail-ai/frames-benchmark-wikipedia" + dataset = load_dataset(dataset_name) + return dataset["train"] + + except Exception as e: + logger.error(f"Error loading {dataset_name} dataset: {e}") + return None + + +def index_frames_kb(): + """Index Wikipedia articles from FRAMES dataset into Khoj""" + try: + # Load dataset + dataset = load_frames_kb() + dataset_files = set(map(get_article_filename, dataset)) + + # Get indexed files from Khoj API + headers = {"Authorization": f"Bearer {KHOJ_API_KEY}"} if KHOJ_API_KEY else {} + try: + response = requests.get(f"{KHOJ_URL}/api/content/computer", headers=headers) + response.raise_for_status() + indexed_files = set(response.json()) + except requests.exceptions.RequestException as e: + logger.error(f"Failed to get indexed files: {e}") + return False + + # Find missing files to index + missing_files = dataset_files - indexed_files + filtered_dataset = [ + article + for article in dataset + if get_article_filename(article) in missing_files and not is_none_or_empty(article["text"]) + ] + if not filtered_dataset: + return True + logger.info(f"Found {len(filtered_dataset)} files to index") + + # Process Wikipedia articles from FRAMES knowledge base in batches + batch_size = 300 + total_batches = len(filtered_dataset) // batch_size + 1 + for batch in tqdm(batcher(filtered_dataset, batch_size), total=total_batches, desc="Indexing FRAMES KB"): + # Create files batch to index + files = [] + for article in batch: + filename = get_article_filename(article) + files.append(("files", (filename, article["text"], "text/plaintext"))) + # Send files batch to index + try: + response = requests.patch(f"{KHOJ_URL}/api/content?client=eval", headers=headers, files=files) + response.raise_for_status() + time.sleep(SLEEP_SECONDS) # Rate limiting + except Exception as e: + logger.error(f"Failed to index batch: {e}") + return False + return True + except Exception as e: + logger.error(f"Failed to index KB: {e}") + return False + + def load_frames_dataset(): """ Load the Google FRAMES benchmark dataset from HuggingFace @@ -248,6 +368,62 @@ def get_agent_response(prompt: str) -> Dict[str, Any]: return {"response": "", "usage": {}, "references": {}} +def calculate_precision_recall(numerator: int, denominator: int) -> float: + """Calculate precision and recall from numerator and denominator""" + if numerator == 0 and denominator == 0: + return 1.0 + elif numerator > 0 and denominator == 0: + return 0.0 + else: + return numerator / denominator + + +def calculate_fi(precision: float, recall: float) -> float: + """Calculate F1 score from precision and recall""" + return 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0 + + +def evaluate_response_for_ir( + query: str, agent_response: str, ground_truth: int, agent_references: dict = {} +) -> tuple[bool | None, str, float]: + """Evaluate Khoj response against benchmark ground truth using string matching""" + try: + # Extract answer from agent response + referenced_files: list[dict[str, str]] = agent_references.get("context", []) + count_of_correct_articles_used_by_agent: int = 0 + # Count how many of the expected articles the agent actually retrieved from the KB + unique_file_refs = {file["file"] for file in referenced_files} + referenced_articles = list(map(extract_article_url_from_filename, unique_file_refs)) + for file in unique_file_refs: + frames_ids_for_articles_used_by_agent = extract_prompt_ids_from_filename(file) + count_of_correct_articles_used_by_agent += int(ground_truth in frames_ids_for_articles_used_by_agent) + + articles = get_articles_by_prompt_id(ground_truth) + precision = calculate_precision_recall(count_of_correct_articles_used_by_agent, len(unique_file_refs)) + recall = calculate_precision_recall(count_of_correct_articles_used_by_agent, len(articles)) + f1 = calculate_fi(precision, recall) + + explanation = ( + f"Information Retrieval F1 Score: {f1:.2%} Recall: {recall:.2%}, Precision: {precision:.2%}.\n" + f"{count_of_correct_articles_used_by_agent} of {len(articles)} correct from {len(unique_file_refs)} total retrievals for {ground_truth}.\n" + f"Queries:\n{yaml.dump(sorted([r['query'] for r in referenced_files]))}\n" + f"Expected Articles for {ground_truth}:\n{yaml.dump(sorted([a['link'] for a in articles]))}\n" + f"Retrieved Articles for {ground_truth}:\n{yaml.dump(referenced_articles)}\n" + ) + + # Truncate referenced files for logging + truncated_refs = [ + {k: v[:200] + "..." if len(v) > 200 else v for k, v in ref.items()} for ref in referenced_files + ] + logger.info(f"Retrieved Article Details:\n{yaml.dump(truncated_refs, sort_keys=False)}\n") + + # Return decision, explanation and cost in structured form + return recall, explanation, 0.0 + except Exception as e: + logger.error(f"Error in IR evaluation: {e}") + return None, f"Evaluation failed: {str(e)}", 0.0 + + def evaluate_response_with_mcq_match( query: str, agent_response: str, ground_truth: str, agent_references: dict = {} ) -> tuple[bool | None, str, float]: @@ -417,7 +593,7 @@ def parse_args(): "--dataset", "-d", default="frames", - choices=["frames", "simpleqa", "gpqa", "math500"], + choices=["frames", "frames_ir", "simpleqa", "gpqa", "math500"], help="Dataset to use for evaluation (default: frames)", ) return parser.parse_args() @@ -438,6 +614,12 @@ def main(): dataset = load_gpqa_dataset() elif args.dataset == "math500": dataset = load_math500_dataset() + elif args.dataset == "frames_ir": + indexed = index_frames_kb() + if indexed: + dataset = load_frames_dataset() + # Rename the index field, 'Unnamed: 0' to 'Answer' for IR evaluation + dataset["Answer"] = dataset["Unnamed: 0"] if dataset is None: return @@ -450,6 +632,8 @@ def main(): response_evaluator = partial( evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-flash-002") ) + elif args.dataset == "frames_ir": + response_evaluator = evaluate_response_for_ir else: response_evaluator = evaluate_response_with_gemini