diff --git a/tests/evals/eval.py b/tests/evals/eval.py index ff5d9cf0..20a6051e 100644 --- a/tests/evals/eval.py +++ b/tests/evals/eval.py @@ -1,10 +1,13 @@ import argparse +import base64 import concurrent.futures +import hashlib import json import logging import os import re import time +import uuid from datetime import datetime from functools import partial from io import StringIO @@ -14,9 +17,16 @@ from typing import Any, Dict import pandas as pd import requests +import yaml from datasets import Dataset, load_dataset +from tqdm import tqdm -from khoj.utils.helpers import get_cost_of_chat_message, is_none_or_empty, timer +from khoj.utils.helpers import ( + batcher, + get_cost_of_chat_message, + is_none_or_empty, + timer, +) # Configure root logger logging.basicConfig(level=logging.INFO, format="%(message)s") @@ -61,6 +71,116 @@ running_true_count = Counter(0) running_total_count = Counter(0) +def get_article_filename(article: dict[str, str]) -> str: + """Create a unique filename for a Wikipedia article""" + # Construct filename from frames prompt ids associated with each article and url + encoded_url = base64.urlsafe_b64encode(article["link"].encode()).decode() + return "-".join(map(str, article["frames_prompt_id"])) + f"_{encoded_url}.txt" + + +def extract_prompt_ids_from_filename(filename: str) -> set[int]: + """Extract frames prompt id from a indexed file name""" + return set(map(int, filename.split("_", 1)[0].split("-"))) + + +def extract_article_url_from_filename(filename: str) -> set[int]: + """Decode URL from filename""" + encoded_url = filename.split("_", 1)[1].rsplit(".", 1)[0] + return base64.urlsafe_b64decode(encoded_url).decode() + + +def get_articles_by_prompt_id(prompt_id: int): + """Get all Wikipedia articles relevant to a specific FRAMES prompt ID""" + try: + # Load dataset + dataset = load_dataset("parasail-ai/frames-benchmark-wikipedia") + + # Filter function to check if prompt_id exists in sequence + def has_prompt_id(example): + return prompt_id in example["frames_prompt_id"] + + # Filter dataset and return matching rows + filtered_dataset = dataset["train"].filter(has_prompt_id) + return filtered_dataset + + except Exception as e: + logger.error(f"Error filtering dataset for prompt {prompt_id}: {e}") + return None + + +def load_frames_kb(): + """ + Load Wikipedia articles used as Knowledge Base by the FRAMES benchmark dataset from HuggingFace + + FRAMES is a benchmark dataset to evaluate retrieval and answering capabilities of agents. + It contains ~800 requiring multi-hop retrieval and reasoning across various topics from Wikipedia. + + ### Data Fields + - link: The link to the Wikipedia article + - text: The text content of the Wikipedia article + - frames_prompt_id: The list of FRAMES prompt ids for which this article is relevant + """ + try: + dataset_name = "parasail-ai/frames-benchmark-wikipedia" + dataset = load_dataset(dataset_name) + return dataset["train"] + + except Exception as e: + logger.error(f"Error loading {dataset_name} dataset: {e}") + return None + + +def index_frames_kb(): + """Index Wikipedia articles from FRAMES dataset into Khoj""" + try: + # Load dataset + dataset = load_frames_kb() + dataset_files = set(map(get_article_filename, dataset)) + + # Get indexed files from Khoj API + headers = {"Authorization": f"Bearer {KHOJ_API_KEY}"} if KHOJ_API_KEY else {} + try: + response = requests.get(f"{KHOJ_URL}/api/content/computer", headers=headers) + response.raise_for_status() + indexed_files = set(response.json()) + except requests.exceptions.RequestException as e: + logger.error(f"Failed to get indexed files: {e}") + return False + + # Find missing files to index + missing_files = dataset_files - indexed_files + filtered_dataset = [ + article + for article in dataset + if get_article_filename(article) in missing_files and not is_none_or_empty(article["text"]) + ] + if not filtered_dataset: + return True + logger.info(f"Found {len(filtered_dataset)} files to index") + + # Process Wikipedia articles from FRAMES knowledge base in batches + batch_size = 300 + total_batches = len(filtered_dataset) // batch_size + 1 + for batch in tqdm(batcher(filtered_dataset, batch_size), total=total_batches, desc="Indexing FRAMES KB"): + # Create files batch to index + files = [] + for article in batch: + filename = get_article_filename(article) + files.append(("files", (filename, article["text"], "text/plaintext"))) + # Send files batch to index + try: + response = requests.patch(f"{KHOJ_URL}/api/content?client=eval", headers=headers, files=files) + response.raise_for_status() + time.sleep(SLEEP_SECONDS) # Rate limiting + except Exception as e: + logger.error(f"Failed to index batch: {e}") + return False + return True + except Exception as e: + logger.error(f"Failed to index KB: {e}") + return False + + def load_frames_dataset(): """ Load the Google FRAMES benchmark dataset from HuggingFace @@ -248,6 +368,62 @@ def get_agent_response(prompt: str) -> Dict[str, Any]: return {"response": "", "usage": {}, "references": {}} +def calculate_precision_recall(numerator: int, denominator: int) -> float: + """Calculate precision and recall from numerator and denominator""" + if numerator == 0 and denominator == 0: + return 1.0 + elif numerator > 0 and denominator == 0: + return 0.0 + else: + return numerator / denominator + + +def calculate_fi(precision: float, recall: float) -> float: + """Calculate F1 score from precision and recall""" + return 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0 + + +def evaluate_response_for_ir( + query: str, agent_response: str, ground_truth: int, agent_references: dict = {} +) -> tuple[bool | None, str, float]: + """Evaluate Khoj response against benchmark ground truth using string matching""" + try: + # Extract answer from agent response + referenced_files: list[dict[str, str]] = agent_references.get("context", []) + count_of_correct_articles_used_by_agent: int = 0 + # Count how many of the expected articles the agent actually retrieved from the KB + unique_file_refs = {file["file"] for file in referenced_files} + referenced_articles = list(map(extract_article_url_from_filename, unique_file_refs)) + for file in unique_file_refs: + frames_ids_for_articles_used_by_agent = extract_prompt_ids_from_filename(file) + count_of_correct_articles_used_by_agent += int(ground_truth in frames_ids_for_articles_used_by_agent) + + articles = get_articles_by_prompt_id(ground_truth) + precision = calculate_precision_recall(count_of_correct_articles_used_by_agent, len(unique_file_refs)) + recall = calculate_precision_recall(count_of_correct_articles_used_by_agent, len(articles)) + f1 = calculate_fi(precision, recall) + + explanation = ( + f"Information Retrieval F1 Score: {f1:.2%} Recall: {recall:.2%}, Precision: {precision:.2%}.\n" + f"{count_of_correct_articles_used_by_agent} of {len(articles)} correct from {len(unique_file_refs)} total retrievals for {ground_truth}.\n" + f"Queries:\n{yaml.dump(sorted([r['query'] for r in referenced_files]))}\n" + f"Expected Articles for {ground_truth}:\n{yaml.dump(sorted([a['link'] for a in articles]))}\n" + f"Retrieved Articles for {ground_truth}:\n{yaml.dump(referenced_articles)}\n" + ) + + # Truncate referenced files for logging + truncated_refs = [ + {k: v[:200] + "..." if len(v) > 200 else v for k, v in ref.items()} for ref in referenced_files + ] + logger.info(f"Retrieved Article Details:\n{yaml.dump(truncated_refs, sort_keys=False)}\n") + + # Return decision, explanation and cost in structured form + return recall, explanation, 0.0 + except Exception as e: + logger.error(f"Error in IR evaluation: {e}") + return None, f"Evaluation failed: {str(e)}", 0.0 + + def evaluate_response_with_mcq_match( query: str, agent_response: str, ground_truth: str, agent_references: dict = {} ) -> tuple[bool | None, str, float]: @@ -417,7 +593,7 @@ def parse_args(): "--dataset", "-d", default="frames", - choices=["frames", "simpleqa", "gpqa", "math500"], + choices=["frames", "frames_ir", "simpleqa", "gpqa", "math500"], help="Dataset to use for evaluation (default: frames)", ) return parser.parse_args() @@ -438,6 +614,12 @@ def main(): dataset = load_gpqa_dataset() elif args.dataset == "math500": dataset = load_math500_dataset() + elif args.dataset == "frames_ir": + indexed = index_frames_kb() + if indexed: + dataset = load_frames_dataset() + # Rename the index field, 'Unnamed: 0' to 'Answer' for IR evaluation + dataset["Answer"] = dataset["Unnamed: 0"] if dataset is None: return @@ -450,6 +632,8 @@ def main(): response_evaluator = partial( evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-flash-002") ) + elif args.dataset == "frames_ir": + response_evaluator = evaluate_response_for_ir else: response_evaluator = evaluate_response_with_gemini