Merge branch 'master' of github.com:khoj-ai/khoj into features/allow-multi-outputs-in-chat

This commit is contained in:
sabaimran
2024-11-29 14:12:03 -08:00
38 changed files with 437 additions and 180 deletions

View File

@@ -3,8 +3,10 @@ import concurrent.futures
import json
import logging
import os
import re
import time
from datetime import datetime
from functools import partial
from io import StringIO
from textwrap import dedent
from threading import Lock
@@ -24,13 +26,10 @@ logger = logging.getLogger(__name__)
KHOJ_URL = os.getenv("KHOJ_URL", "http://localhost:42110")
KHOJ_CHAT_API_URL = f"{KHOJ_URL}/api/chat"
KHOJ_API_KEY = os.getenv("KHOJ_API_KEY")
KHOJ_MODE = os.getenv("KHOJ_MODE", "default") # E.g research, general, notes etc.
KHOJ_MODE = os.getenv("KHOJ_MODE", "default").lower() # E.g research, general, notes etc.
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
GEMINI_EVAL_MODEL = os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-pro-002")
GEMINI_API_URL = (
f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_EVAL_MODEL}:generateContent?key={GEMINI_API_KEY}"
)
SAMPLE_SIZE = os.getenv("SAMPLE_SIZE") # Number of examples to evaluate
RANDOMIZE = os.getenv("RANDOMIZE", "false").lower() == "true" # Randomize examples
@@ -128,6 +127,99 @@ def load_simpleqa_dataset():
return None
def load_gpqa_dataset():
"""
Load the Google GPQA benchmark dataset from HuggingFace
GPQA is a benchmark dataset to evaluate retrieval and answering capabilities of agents.
It contains ~800 requiring multi-hop retrieval and reasoning across various topics.
### Data Fields
- Prompt: The question to be answered
- Answer: The ground truth answer
- reasoning_types: The type of reasoning required to answer the question
"""
import random
def format_multiple_choice_question(row: Dict) -> tuple[str, str]:
"""
Create GPQA multi-choice prompt from shuffled answer choices and question.
Refer: https://github.com/openai/simple-evals/blob/a8e85cc8a5dea497d915f870895250e07f9cc737/common.py#L12
Returns formatted prompt and correct answer letter.
"""
# Gather choices
choices = [
row["Incorrect Answer 1"],
row["Incorrect Answer 2"],
row["Incorrect Answer 3"],
row["Correct Answer"],
]
# Shuffle choices
random.shuffle(choices)
# Get correct answer letter
correct_index = choices.index(row["Correct Answer"])
correct_letter = "ABCD"[correct_index]
prompt = f"""
Answer the following multiple choice question. Answer should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
{row["Question"]}
A) {choices[0]}
B) {choices[1]}
C) {choices[2]}
D) {choices[3]}
""".strip()
return prompt, correct_letter
try:
dataset = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train")
# Create multi-choice q&a prompt from choices and correct answer
prompts_and_answers = [format_multiple_choice_question(row) for row in dataset]
# Normalize dataset to FRAMES format
dataset = dataset.rename_columns({"Subdomain": "reasoning_types"})
dataset = dataset.add_column("Prompt", [p[0] for p in prompts_and_answers])
dataset = dataset.add_column("Answer", [p[1] for p in prompts_and_answers])
# Sample and shuffle dataset if configured
dataset = dataset.shuffle() if RANDOMIZE else dataset
dataset = dataset[: int(SAMPLE_SIZE)] if SAMPLE_SIZE else dataset
return dataset
except Exception as e:
logger.error(f"Error loading dataset: {e}")
return None
def load_math500_dataset():
"""
Load and format the MATH500 dataset to match the evaluation script's structure.
Args:
sample_size (int, optional): Number of samples to include. Defaults to None (use full dataset).
randomize (bool, optional): Whether to randomize the dataset. Defaults to False.
Returns:
Dataset: Formatted HuggingFace Dataset.
"""
try:
# Load the MATH500 dataset from HuggingFace
dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
dataset = dataset.rename_columns({"problem": "Prompt", "answer": "Answer", "subject": "reasoning_types"})
dataset = dataset.shuffle() if RANDOMIZE else dataset
dataset = dataset.select(range(int(SAMPLE_SIZE))) if SAMPLE_SIZE else dataset
return dataset
except Exception as e:
print(f"Error loading and formatting MATH500 dataset: {e}")
return None
def get_agent_response(prompt: str) -> Dict[str, Any]:
"""Get response from the Khoj API"""
# Set headers
@@ -152,7 +244,30 @@ def get_agent_response(prompt: str) -> Dict[str, Any]:
return {"response": "", "usage": {}}
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tuple[bool | None, str, float]:
def evaluate_response_with_mcq_match(
query: str, agent_response: str, ground_truth: str
) -> tuple[bool | None, str, float]:
"""Evaluate Khoj response against benchmark ground truth using string matching"""
try:
# Extract answer from agent response
answer_pattern_multichoice = r"(?i)Answer\s*:\s*([A-D])"
match = re.search(answer_pattern_multichoice, agent_response)
extracted_answer = match.group(1) if match else None
# Check if extracted answer matches ground truth
decision = extracted_answer == ground_truth
explanation = f"Agent response {'matches' if decision else 'does not match'} ground truth {ground_truth}"
# Return decision, explanation and cost in structured form
return decision, explanation, 0.0
except Exception as e:
logger.error(f"Error in evaluation: {e}")
return None, f"Evaluation failed: {str(e)}", 0.0
def evaluate_response_with_gemini(
query: str, agent_response: str, ground_truth: str, eval_model=GEMINI_EVAL_MODEL
) -> tuple[bool | None, str, float]:
"""Evaluate Khoj response against benchmark ground truth using Gemini"""
evaluation_prompt = f"""
Compare the following agent response with the ground truth answer.
@@ -166,10 +281,13 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tup
Provide your evaluation in the following json format:
{"explanation:" "[How you made the decision?)", "decision:" "(TRUE if response contains key information, FALSE otherwise)"}
"""
gemini_api_url = (
f"https://generativelanguage.googleapis.com/v1beta/models/{eval_model}:generateContent?key={GEMINI_API_KEY}"
)
try:
response = requests.post(
GEMINI_API_URL,
gemini_api_url,
headers={"Content-Type": "application/json"},
json={
"contents": [{"parts": [{"text": evaluation_prompt}]}],
@@ -182,7 +300,7 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tup
# Update cost of evaluation
input_tokens = response_json["usageMetadata"]["promptTokenCount"]
ouput_tokens = response_json["usageMetadata"]["candidatesTokenCount"]
cost = get_cost_of_chat_message(GEMINI_EVAL_MODEL, input_tokens, ouput_tokens)
cost = get_cost_of_chat_message(eval_model, input_tokens, ouput_tokens)
# Parse evaluation response
eval_response: dict[str, str] = json.loads(
@@ -200,7 +318,7 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tup
return None, f"Evaluation failed: {str(e)}", 0.0
def process_batch(batch, batch_start, results, dataset_length):
def process_batch(batch, batch_start, results, dataset_length, response_evaluator):
global running_cost
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
current_index = batch_start + idx
@@ -219,7 +337,7 @@ def process_batch(batch, batch_start, results, dataset_length):
decision = None
explanation = "Agent response is empty. This maybe due to a service error."
else:
decision, explanation, eval_cost = evaluate_response(prompt, agent_response, answer)
decision, explanation, eval_cost = response_evaluator(prompt, agent_response, answer)
# Store results
results.append(
@@ -292,7 +410,7 @@ def parse_args():
"--dataset",
"-d",
default="frames",
choices=["frames", "simpleqa"],
choices=["frames", "simpleqa", "gpqa", "math500"],
help="Dataset to use for evaluation (default: frames)",
)
return parser.parse_args()
@@ -309,12 +427,24 @@ def main():
dataset = load_frames_dataset()
elif args.dataset == "simpleqa":
dataset = load_simpleqa_dataset()
elif args.dataset == "gpqa":
dataset = load_gpqa_dataset()
elif args.dataset == "math500":
dataset = load_math500_dataset()
if dataset is None:
return
# Initialize variables
results = []
dataset_length = len(dataset["Prompt"])
if args.dataset == "gpqa":
response_evaluator = evaluate_response_with_mcq_match
elif args.dataset == "math500":
response_evaluator = partial(
evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-flash-002")
)
else:
response_evaluator = evaluate_response_with_gemini
# Process examples in batches
with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -326,7 +456,9 @@ def main():
dataset["Answer"][i : i + BATCH_SIZE],
dataset["reasoning_types"][i : i + BATCH_SIZE],
)
futures.append(executor.submit(process_batch, batch, batch_start, results, dataset_length))
futures.append(
executor.submit(process_batch, batch, batch_start, results, dataset_length, response_evaluator)
)
# Wait for all futures to complete
concurrent.futures.wait(futures)

View File

@@ -104,6 +104,18 @@ class TestTruncateMessage:
assert truncated_chat_history[0] != copy_big_chat_message
def test_load_complex_raw_json_string():
# Arrange
raw_json = r"""{"key": "value with unescaped " and unescaped \' and escaped \" and escaped \\'"}"""
expeced_json = {"key": "value with unescaped \" and unescaped \\' and escaped \" and escaped \\'"}
# Act
parsed_json = utils.load_complex_json(raw_json)
# Assert
assert parsed_json == expeced_json
def generate_content(count):
return " ".join([f"{index}" for index, _ in enumerate(range(count))])