Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/google/adk/evaluation/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ class MatchType(Enum):
),
)

ignore_args: bool = Field(
default=False,
description=(
"When True, tool call matching checks only the tool name and "
"ignores argument values. Useful for non-deterministic "
"arguments such as timestamps, generated queries, or "
"other dynamic content."
),
)

@field_validator("match_type", mode="before")
@classmethod
def _coerce_match_type(cls, value: object) -> object:
Expand Down
37 changes: 28 additions & 9 deletions src/google/adk/evaluation/trajectory_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
)
self._threshold = criterion.threshold
self._match_type = criterion.match_type
self._ignore_args = criterion.ignore_args
except ValidationError as e:
expected_criterion_type_error = ValueError(
f"`{eval_metric.metric_name}` metric expects a criterion of type"
Expand All @@ -91,9 +92,11 @@ def __init__(
elif eval_metric:
self._threshold = eval_metric.threshold
self._match_type = ToolTrajectoryCriterion.MatchType.EXACT
self._ignore_args = False
else:
self._threshold = threshold
self._match_type = ToolTrajectoryCriterion.MatchType.EXACT
self._ignore_args = False

@override
def evaluate_invocations(
Expand Down Expand Up @@ -148,25 +151,42 @@ def _calculate_tool_use_accuracy(
tool_use_match_status = False
if self._match_type == ToolTrajectoryCriterion.MatchType.EXACT:
tool_use_match_status = self._are_tool_calls_exact_match(
actual_tool_uses, expected_tool_uses
actual_tool_uses, expected_tool_uses, self._ignore_args
)
elif self._match_type == ToolTrajectoryCriterion.MatchType.IN_ORDER:
tool_use_match_status = self._are_tool_calls_in_order_match(
actual_tool_uses, expected_tool_uses
actual_tool_uses, expected_tool_uses, self._ignore_args
)
elif self._match_type == ToolTrajectoryCriterion.MatchType.ANY_ORDER:
tool_use_match_status = self._are_tool_calls_any_order_match(
actual_tool_uses, expected_tool_uses
actual_tool_uses, expected_tool_uses, self._ignore_args
)
else:
raise ValueError(f"Unsupported match type {self._match_type}")

return 1.0 if tool_use_match_status else 0.0

def _calls_match(
self,
expected: genai_types.FunctionCall,
actual: genai_types.FunctionCall,
ignore_args: bool,
) -> bool:
"""Returns True if two tool calls are considered a match."""

if expected.name != actual.name:
return False

if ignore_args:
return True

return bool(expected.args == actual.args)

def _are_tool_calls_in_order_match(
self,
actual_tool_calls: list[genai_types.FunctionCall],
expected_tool_calls: list[genai_types.FunctionCall],
ignore_args: bool = False,
) -> bool:
"""Checks if expected tool calls appear in actual tool calls in order.

Expand All @@ -191,10 +211,7 @@ def _are_tool_calls_in_order_match(
try:
current_expected = next(expected_it)
for actual in actual_tool_calls:
if (
actual.name == current_expected.name
and actual.args == current_expected.args
):
if self._calls_match(current_expected, actual, ignore_args):
current_expected = next(expected_it)
except StopIteration:
return True
Expand All @@ -205,6 +222,7 @@ def _are_tool_calls_any_order_match(
self,
actual_tool_calls: list[genai_types.FunctionCall],
expected_tool_calls: list[genai_types.FunctionCall],
ignore_args: bool = False,
) -> bool:
"""Checks if expected tool calls appear in actual tool calls in any order.

Expand All @@ -229,7 +247,7 @@ def _are_tool_calls_any_order_match(
for expected in expected_tool_calls:
found = False
for i, actual in enumerate(actual_tool_calls_copy):
if actual.name == expected.name and actual.args == expected.args:
if self._calls_match(expected, actual, ignore_args):
actual_tool_calls_copy.pop(i)
found = True
break
Expand All @@ -241,6 +259,7 @@ def _are_tool_calls_exact_match(
self,
actual_tool_calls: list[genai_types.FunctionCall],
expected_tool_calls: list[genai_types.FunctionCall],
ignore_args: bool = False,
) -> bool:
"""Checks if actual tool calls exactly match expected tool calls.

Expand All @@ -260,7 +279,7 @@ def _are_tool_calls_exact_match(
return False

for actual, expected in zip(actual_tool_calls, expected_tool_calls):
if actual.name != expected.name or actual.args != expected.args:
if not self._calls_match(expected, actual, ignore_args):
return False

return True
Expand Down
69 changes: 69 additions & 0 deletions tests/unittests/evaluation/test_trajectory_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,72 @@ def test_evaluate_invocations_no_invocations(evaluator: TrajectoryEvaluator):
assert result.overall_score is None
assert result.overall_eval_status == EvalStatus.NOT_EVALUATED
assert not result.per_invocation_results


def test_exact_match_ignore_args():
"""Tests EXACT match when args differ but names are the same."""
evaluator = TrajectoryEvaluator()

actual = [genai_types.FunctionCall(name="search", args={"query": "cats"})]
expected = [genai_types.FunctionCall(name="search", args={"query": "dogs"})]

assert evaluator._are_tool_calls_exact_match(
actual, expected, ignore_args=True
)


def test_in_order_match_ignore_args():
"""Tests IN_ORDER match when args differ."""
evaluator = TrajectoryEvaluator()

actual = [
genai_types.FunctionCall(name="search", args={"query": "cats"}),
genai_types.FunctionCall(name="lookup", args={"id": "123"}),
]

expected = [
genai_types.FunctionCall(name="search", args={"query": "dogs"}),
genai_types.FunctionCall(name="lookup", args={"id": "999"}),
]

assert evaluator._are_tool_calls_in_order_match(
actual, expected, ignore_args=True
)


def test_any_order_match_ignore_args():
"""Tests ANY_ORDER match when args differ and order differs."""
evaluator = TrajectoryEvaluator()

actual = [
genai_types.FunctionCall(name="lookup", args={"id": "123"}),
genai_types.FunctionCall(name="search", args={"query": "cats"}),
]

expected = [
genai_types.FunctionCall(name="search", args={"query": "dogs"}),
genai_types.FunctionCall(name="lookup", args={"id": "999"}),
]

assert evaluator._are_tool_calls_any_order_match(
actual, expected, ignore_args=True
)


def test_trajectory_evaluator_loads_ignore_args_from_criterion():
"""Tests ignore_args configuration is propagated from criterion."""
criterion = ToolTrajectoryCriterion(
threshold=0.5,
match_type=ToolTrajectoryCriterion.MatchType.EXACT,
ignore_args=True,
)

metric = EvalMetric(
metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value,
threshold=0.5,
criterion=criterion,
)

evaluator = TrajectoryEvaluator(eval_metric=metric)

assert evaluator._ignore_args is True