mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Return accuracy as decision to generalize across IR & standard scorers
This commit is contained in:
@@ -58,7 +58,7 @@ class Counter:
|
|||||||
# Track running metrics while evaluating
|
# Track running metrics while evaluating
|
||||||
running_cost = Counter()
|
running_cost = Counter()
|
||||||
running_true_count = Counter(0)
|
running_true_count = Counter(0)
|
||||||
running_false_count = Counter(0)
|
running_total_count = Counter(0)
|
||||||
|
|
||||||
|
|
||||||
def load_frames_dataset():
|
def load_frames_dataset():
|
||||||
@@ -259,7 +259,7 @@ def evaluate_response_with_mcq_match(
|
|||||||
explanation = f"Agent response {'matches' if decision else 'does not match'} ground truth {ground_truth}"
|
explanation = f"Agent response {'matches' if decision else 'does not match'} ground truth {ground_truth}"
|
||||||
|
|
||||||
# Return decision, explanation and cost in structured form
|
# Return decision, explanation and cost in structured form
|
||||||
return decision, explanation, 0.0
|
return float(decision), explanation, 0.0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in evaluation: {e}")
|
logger.error(f"Error in evaluation: {e}")
|
||||||
return None, f"Evaluation failed: {str(e)}", 0.0
|
return None, f"Evaluation failed: {str(e)}", 0.0
|
||||||
@@ -306,7 +306,7 @@ def evaluate_response_with_gemini(
|
|||||||
eval_response: dict[str, str] = json.loads(
|
eval_response: dict[str, str] = json.loads(
|
||||||
clean_json(response_json["candidates"][0]["content"]["parts"][0]["text"])
|
clean_json(response_json["candidates"][0]["content"]["parts"][0]["text"])
|
||||||
)
|
)
|
||||||
decision = str(eval_response.get("decision", "")).upper() == "TRUE"
|
decision = float(str(eval_response.get("decision", "")).upper() == "TRUE")
|
||||||
explanation = eval_response.get("explanation", "")
|
explanation = eval_response.get("explanation", "")
|
||||||
# Handle evaluation service errors
|
# Handle evaluation service errors
|
||||||
if "503 Service Error" in explanation:
|
if "503 Service Error" in explanation:
|
||||||
@@ -360,11 +360,12 @@ def process_batch(batch, batch_start, results, dataset_length, response_evaluato
|
|||||||
# Update running accuracy
|
# Update running accuracy
|
||||||
running_accuracy = 0.0
|
running_accuracy = 0.0
|
||||||
if decision is not None:
|
if decision is not None:
|
||||||
running_true_count.add(1) if decision == True else running_false_count.add(1)
|
running_true_count.add(decision)
|
||||||
running_accuracy = running_true_count.get() / (running_true_count.get() + running_false_count.get())
|
running_total_count.add(1)
|
||||||
|
running_accuracy = running_true_count.get() / running_total_count.get()
|
||||||
|
|
||||||
## Log results
|
## Log results
|
||||||
decision_color = {True: "green", None: "blue", False: "red"}[decision]
|
decision_color = {True: "green", None: "blue", False: "red"}[decision > 0.5]
|
||||||
colored_decision = color_text(str(decision), decision_color)
|
colored_decision = color_text(str(decision), decision_color)
|
||||||
result_to_print = f"""
|
result_to_print = f"""
|
||||||
---------
|
---------
|
||||||
@@ -466,12 +467,10 @@ def main():
|
|||||||
# Calculate metrics
|
# Calculate metrics
|
||||||
df = pd.DataFrame(results)
|
df = pd.DataFrame(results)
|
||||||
eval_df = df.dropna(subset=["evaluation_decision"]) # Exclude rows with missing evaluation decision
|
eval_df = df.dropna(subset=["evaluation_decision"]) # Exclude rows with missing evaluation decision
|
||||||
accuracy = (eval_df["evaluation_decision"] == True).mean()
|
accuracy = (eval_df["evaluation_decision"]).mean()
|
||||||
|
|
||||||
# Calculate accuracy by reasoning type
|
# Calculate accuracy by reasoning type
|
||||||
reasoning_type_accuracy = eval_df.groupby("reasoning_type")["evaluation_decision"].apply(
|
reasoning_type_accuracy = (eval_df.groupby("reasoning_type")["evaluation_decision"]).apply(lambda x: x.mean())
|
||||||
lambda x: (x == True).mean()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Collect summary
|
# Collect summary
|
||||||
colored_accuracy = color_text(f"{accuracy:.2%}", "blue")
|
colored_accuracy = color_text(f"{accuracy:.2%}", "blue")
|
||||||
|
|||||||
Reference in New Issue
Block a user