diff --git a/tests/evals/eval.py b/tests/evals/eval.py index 629df91e..ff5d9cf0 100644 --- a/tests/evals/eval.py +++ b/tests/evals/eval.py @@ -238,14 +238,18 @@ def get_agent_response(prompt: str) -> Dict[str, Any]: ) response.raise_for_status() response_json = response.json() - return {"response": response_json.get("response", ""), "usage": response_json.get("usage", {})} + return { + "response": response_json.get("response", ""), + "usage": response_json.get("usage", {}), + "references": response_json.get("references", {}), + } except Exception as e: logger.error(f"Error getting agent response: {e}") - return {"response": "", "usage": {}} + return {"response": "", "usage": {}, "references": {}} def evaluate_response_with_mcq_match( - query: str, agent_response: str, ground_truth: str + query: str, agent_response: str, ground_truth: str, agent_references: dict = {} ) -> tuple[bool | None, str, float]: """Evaluate Khoj response against benchmark ground truth using string matching""" try: @@ -266,7 +270,7 @@ def evaluate_response_with_mcq_match( def evaluate_response_with_gemini( - query: str, agent_response: str, ground_truth: str, eval_model=GEMINI_EVAL_MODEL + query: str, agent_response: str, ground_truth: str, agent_references: dict = {}, eval_model=GEMINI_EVAL_MODEL ) -> tuple[bool | None, str, float]: """Evaluate Khoj response against benchmark ground truth using Gemini""" evaluation_prompt = f""" @@ -331,13 +335,14 @@ def process_batch(batch, batch_start, results, dataset_length, response_evaluato response = get_agent_response(prompt) agent_response = response["response"] agent_usage = response["usage"] + agent_references = response["references"] # Evaluate response if is_none_or_empty(agent_response): decision = None explanation = "Agent response is empty. This maybe due to a service error." else: - decision, explanation, eval_cost = response_evaluator(prompt, agent_response, answer) + decision, explanation, eval_cost = response_evaluator(prompt, agent_response, answer, agent_references) # Store results results.append( @@ -350,6 +355,7 @@ def process_batch(batch, batch_start, results, dataset_length, response_evaluato "evaluation_explanation": explanation, "reasoning_type": reasoning_type, "usage": agent_usage, + "references": agent_references, } )