diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 5b154862e..95ef0e358 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -24,8 +24,7 @@ import com.google.adk.agents.Callbacks.BeforeAgentCallback; import com.google.adk.events.Event; import com.google.adk.plugins.Plugin; -import com.google.adk.telemetry.Instrumentation; -import com.google.adk.telemetry.Instrumentation.AgentInvocation; +import com.google.adk.telemetry.Instrumentation.AgentInvocationTransformer; import com.google.adk.utils.AgentEnums.AgentOrigin; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -324,39 +323,30 @@ private Flowable run( InvocationContext parentContext, Function> runImplementation) { Context otelContext = Context.current(); - return Flowable.using( - () -> - Instrumentation.recordAgentInvocation( - createInvocationContext(parentContext), this, otelContext), - agentInvocation -> { - InvocationContext invocationContext = agentInvocation.getCtx(); - Flowable mainAndAfterEvents = - Flowable.defer(() -> runImplementation.apply(invocationContext)) - .concatWith( - Flowable.defer( - () -> - callCallback( - afterCallbacksToFunctions( - invocationContext.pluginManager(), afterAgentCallback), - invocationContext) - .toFlowable())); - - return callCallback( - beforeCallbacksToFunctions( - invocationContext.pluginManager(), beforeAgentCallback), - invocationContext) - .flatMapPublisher( - beforeEvent -> { - if (invocationContext.endInvocation()) { - return Flowable.just(beforeEvent); - } - return Flowable.just(beforeEvent).concatWith(mainAndAfterEvents); - }) - .switchIfEmpty(mainAndAfterEvents) - .doOnNext(agentInvocation::addEvent) - .doOnError(agentInvocation::setError); - }, - AgentInvocation::close); + InvocationContext invocationContext = createInvocationContext(parentContext); + Flowable mainAndAfterEvents = + Flowable.defer(() -> runImplementation.apply(invocationContext)) + .concatWith( + Flowable.defer( + () -> + callCallback( + afterCallbacksToFunctions( + invocationContext.pluginManager(), afterAgentCallback), + invocationContext) + .toFlowable())); + + return callCallback( + beforeCallbacksToFunctions(invocationContext.pluginManager(), beforeAgentCallback), + invocationContext) + .flatMapPublisher( + beforeEvent -> { + if (invocationContext.endInvocation()) { + return Flowable.just(beforeEvent); + } + return Flowable.just(beforeEvent).concatWith(mainAndAfterEvents); + }) + .switchIfEmpty(mainAndAfterEvents) + .compose(new AgentInvocationTransformer(invocationContext, this, otelContext)); } /** diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 8c60ebf76..a68160f4f 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -29,8 +29,7 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; import com.google.adk.events.ToolConfirmation; -import com.google.adk.telemetry.Instrumentation; -import com.google.adk.telemetry.Instrumentation.ToolExecution; +import com.google.adk.telemetry.Instrumentation.ToolExecutionTransformer; import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.FunctionTool; @@ -290,43 +289,42 @@ private static Function> getFunctionCallMapper( Context parentContext) { return functionCall -> Maybe.defer( - () -> { - BaseTool tool = tools.get(functionCall.name().get()); - ToolContext toolContext = - ToolContext.builder(invocationContext) - .functionCallId(functionCall.id().orElse("")) - .toolConfirmation( - functionCall.id().map(toolConfirmations::get).orElse(null)) - .build(); - - Map functionArgs = - functionCall.args().map(HashMap::new).orElse(new HashMap<>()); - - Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) - .switchIfEmpty( - Maybe.defer( - () -> - isLive - ? processFunctionLive( - invocationContext, - tool, - toolContext, - functionCall, - functionArgs) - : callTool(tool, functionArgs, toolContext)) - .compose(Tracing.withContext(parentContext))); - - return postProcessFunctionResult( + () -> { + BaseTool tool = tools.get(functionCall.name().get()); + ToolContext toolContext = + ToolContext.builder(invocationContext) + .functionCallId(functionCall.id().orElse("")) + .toolConfirmation(functionCall.id().map(toolConfirmations::get).orElse(null)) + .build(); + + Map functionArgs = + functionCall.args().map(HashMap::new).orElse(new HashMap<>()); + + Maybe> maybeFunctionResult = + maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) + .switchIfEmpty( + Maybe.defer( + () -> + isLive + ? processFunctionLive( + invocationContext, + tool, + toolContext, + functionCall, + functionArgs) + : callTool(tool, functionArgs, toolContext))); + + return processFunctionResult( maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, - isLive, - parentContext); - }) - .compose(Tracing.withContext(parentContext)); + isLive) + .compose( + new ToolExecutionTransformer( + tool, invocationContext.agent(), functionArgs, parentContext)); + }); } /** @@ -424,26 +422,6 @@ public static Set getLongRunningFunctionCalls( return longRunningFunctionCalls; } - private static Maybe postProcessFunctionResult( - Maybe> maybeFunctionResult, - InvocationContext invocationContext, - BaseTool tool, - Map functionArgs, - ToolContext toolContext, - boolean isLive, - Context parentContext) { - return Maybe.using( - () -> - Instrumentation.recordToolExecution( - tool, invocationContext.agent(), functionArgs, parentContext), - toolExecution -> - processFunctionResult( - maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive) - .doOnSuccess(event -> toolExecution.context().setFunctionResponseEvent(event)) - .doOnError(toolExecution::setError), - ToolExecution::close); - } - private static Maybe processFunctionResult( Maybe> maybeFunctionResult, InvocationContext invocationContext, diff --git a/core/src/main/java/com/google/adk/telemetry/Instrumentation.java b/core/src/main/java/com/google/adk/telemetry/Instrumentation.java index fd27878c9..03c26dbeb 100644 --- a/core/src/main/java/com/google/adk/telemetry/Instrumentation.java +++ b/core/src/main/java/com/google/adk/telemetry/Instrumentation.java @@ -24,6 +24,11 @@ import io.opentelemetry.api.trace.StatusCode; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.FlowableTransformer; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.MaybeSource; +import io.reactivex.rxjava3.core.MaybeTransformer; import java.time.Duration; import java.util.ArrayList; import java.util.Collections; @@ -31,6 +36,7 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import org.jspecify.annotations.Nullable; +import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -67,7 +73,7 @@ public void setFunctionResponseEvent(@Nullable Event functionResponseEvent) { public abstract static class ClosableTelemetryScope implements AutoCloseable { protected final long startTimeNanos; protected final Span span; - protected final Scope scope; + protected Scope scope; protected final TelemetryContext telemetryContext; protected @Nullable Throwable caughtError; protected final AtomicBoolean closed = new AtomicBoolean(false); @@ -80,6 +86,11 @@ public abstract static class ClosableTelemetryScope implements AutoCloseable { this.telemetryContext = new TelemetryContext(Context.current()); } + @SuppressWarnings("MustBeClosedChecker") + public void makeCurrent() { + this.scope = span.makeCurrent(); + } + public TelemetryContext context() { return telemetryContext; } @@ -136,10 +147,6 @@ public AgentInvocation(InvocationContext ctx, BaseAgent agent, Context parentCon Tracing.traceAgentInvocation(span, agent.name(), agent.description(), ctx); } - public InvocationContext getCtx() { - return ctx; - } - public void addEvent(Event event) { events.add(event); } @@ -203,24 +210,53 @@ protected void handleMetricsError(RuntimeException e) { } } - /** Creates an AgentInvocation context to record agent invocation telemetry. */ - public static AgentInvocation recordAgentInvocation(InvocationContext ctx, BaseAgent agent) { - return recordAgentInvocation(ctx, agent, Context.current()); - } + /** A transformer that manages an AgentInvocation telemetry scope for RxJava streams. */ + public static final class AgentInvocationTransformer + implements FlowableTransformer { + private final AgentInvocation agentInvocation; - public static AgentInvocation recordAgentInvocation( - InvocationContext ctx, BaseAgent agent, Context parentContext) { - return new AgentInvocation(ctx, agent, parentContext); - } + public AgentInvocationTransformer( + InvocationContext ctx, BaseAgent agent, Context parentContext) { + this.agentInvocation = new AgentInvocation(ctx, agent, parentContext); + } - /** Creates a ToolExecution context to record tool execution telemetry. */ - public static ToolExecution recordToolExecution( - BaseTool tool, BaseAgent agent, Map functionArgs) { - return recordToolExecution(tool, agent, functionArgs, Context.current()); + @Override + public Publisher apply(Flowable upstream) { + return Flowable.using( + () -> { + agentInvocation.makeCurrent(); + return agentInvocation; + }, + agentInvocation -> + upstream.doOnNext(agentInvocation::addEvent).doOnError(agentInvocation::setError), + AgentInvocation::close); + } } - public static ToolExecution recordToolExecution( - BaseTool tool, BaseAgent agent, Map functionArgs, Context parentContext) { - return new ToolExecution(tool, agent, functionArgs, parentContext); + /** A transformer that manages a ToolExecution telemetry scope for RxJava Maybe streams. */ + public static final class ToolExecutionTransformer implements MaybeTransformer { + private final BaseTool tool; + private final BaseAgent agent; + private final Map functionArgs; + private final Context parentContext; + + public ToolExecutionTransformer( + BaseTool tool, BaseAgent agent, Map functionArgs, Context parentContext) { + this.tool = tool; + this.agent = agent; + this.functionArgs = functionArgs; + this.parentContext = parentContext; + } + + @Override + public MaybeSource apply(Maybe upstream) { + return Maybe.using( + () -> new ToolExecution(tool, agent, functionArgs, parentContext), + toolExecution -> + upstream + .doOnSuccess(event -> toolExecution.context().setFunctionResponseEvent(event)) + .doOnError(toolExecution::setError), + ToolExecution::close); + } } } diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index f901484ea..c77dcac3f 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -467,26 +467,6 @@ public static TracerProvider trace(String spanName) { return new TracerProvider<>(spanName); } - /** - * Returns a transformer that traces an agent invocation. - * - * @param spanName The name of the span to create. - * @param agentName The name of the agent. - * @param agentDescription The description of the agent. - * @param invocationContext The invocation context. - * @param The type of the stream. - * @return A TracerProvider configured for agent invocation. - */ - public static TracerProvider traceAgent( - String spanName, - String agentName, - String agentDescription, - InvocationContext invocationContext) { - return new TracerProvider(spanName) - .configure( - span -> traceAgentInvocation(span, agentName, agentDescription, invocationContext)); - } - /** * A transformer that manages an OpenTelemetry span and scope for RxJava streams. * @@ -535,11 +515,11 @@ private Context getParentContext() { } private final class TracingLifecycle { - private Span span; - private Scope scope; + final Span span; + final Scope scope; @SuppressWarnings("MustBeClosedChecker") - void start() { + TracingLifecycle() { span = tracer.spanBuilder(spanName).setParent(getParentContext()).startSpan(); spanConfigurers.forEach(c -> c.accept(span)); scope = span.makeCurrent(); @@ -557,50 +537,40 @@ void end() { @Override public Publisher apply(Flowable upstream) { - return Flowable.defer( - () -> { - TracingLifecycle lifecycle = new TracingLifecycle(); - Flowable pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); - if (onSuccessConsumer != null) { - pipeline = pipeline.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t)); - } - return pipeline.doFinally(lifecycle::end); - }); + return Flowable.using( + TracingLifecycle::new, + lifecycle -> + onSuccessConsumer != null + ? upstream.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t)) + : upstream, + TracingLifecycle::end); } @Override public SingleSource apply(Single upstream) { - return Single.defer( - () -> { - TracingLifecycle lifecycle = new TracingLifecycle(); - Single pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); - if (onSuccessConsumer != null) { - pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); - } - return pipeline.doFinally(lifecycle::end); - }); + return Single.using( + TracingLifecycle::new, + lifecycle -> + onSuccessConsumer != null + ? upstream.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)) + : upstream, + TracingLifecycle::end); } @Override public MaybeSource apply(Maybe upstream) { - return Maybe.defer( - () -> { - TracingLifecycle lifecycle = new TracingLifecycle(); - Maybe pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); - if (onSuccessConsumer != null) { - pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); - } - return pipeline.doFinally(lifecycle::end); - }); + return Maybe.using( + TracingLifecycle::new, + lifecycle -> + onSuccessConsumer != null + ? upstream.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)) + : upstream, + TracingLifecycle::end); } @Override public CompletableSource apply(Completable upstream) { - return Completable.defer( - () -> { - TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); - }); + return Completable.using(TracingLifecycle::new, lifecycle -> upstream, TracingLifecycle::end); } } diff --git a/core/src/test/java/com/google/adk/telemetry/InstrumentationTest.java b/core/src/test/java/com/google/adk/telemetry/InstrumentationTest.java index d99a6878e..1b5ddbe8f 100644 --- a/core/src/test/java/com/google/adk/telemetry/InstrumentationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/InstrumentationTest.java @@ -17,6 +17,7 @@ package com.google.adk.telemetry; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; @@ -24,6 +25,8 @@ import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; +import com.google.adk.telemetry.Instrumentation.AgentInvocationTransformer; +import com.google.adk.telemetry.Instrumentation.ToolExecutionTransformer; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; @@ -33,11 +36,13 @@ import io.opentelemetry.api.common.Attributes; import io.opentelemetry.api.metrics.Meter; import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; import io.opentelemetry.sdk.metrics.data.HistogramPointData; import io.opentelemetry.sdk.metrics.data.MetricData; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.SpanData; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; import java.util.Map; @@ -116,11 +121,15 @@ public void tearDown() { @Test public void recordAgentInvocation_success() { - try (Instrumentation.AgentInvocation invocation = - Instrumentation.recordAgentInvocation(invocationContext, testAgent)) { - assertThat(invocation.context()).isNotNull(); - assertThat(invocation.context().otelContext()).isNotNull(); - } + Event testEvent = Event.builder().id("test-event").build(); + + List result = + Flowable.just(testEvent) + .compose( + new AgentInvocationTransformer(invocationContext, testAgent, Context.current())) + .toList() + .blockingGet(); + assertThat(result).containsExactly(testEvent); // Verify trace span List spans = openTelemetryRule.getSpans(); @@ -143,10 +152,14 @@ public void recordAgentInvocation_success() { @Test public void recordAgentInvocation_withError() { RuntimeException testException = new RuntimeException("test error"); - try (Instrumentation.AgentInvocation invocation = - Instrumentation.recordAgentInvocation(invocationContext, testAgent)) { - invocation.setError(testException); - } + assertThrows( + RuntimeException.class, + () -> + Flowable.error(testException) + .compose( + new AgentInvocationTransformer(invocationContext, testAgent, Context.current())) + .toList() + .blockingGet()); List spans = openTelemetryRule.getSpans(); assertThat(spans).hasSize(1); @@ -162,12 +175,15 @@ public void recordAgentInvocation_withError() { @Test public void recordToolExecution_success() { TestTool testTool = new TestTool(); - - try (Instrumentation.ToolExecution execution = - Instrumentation.recordToolExecution( - testTool, testAgent, ImmutableMap.of("arg1", "value1"))) { - assertThat(execution.context()).isNotNull(); - } + Event testEvent = Event.builder().id("test-event").build(); + + Event result = + Maybe.just(testEvent) + .compose( + new ToolExecutionTransformer( + testTool, testAgent, ImmutableMap.of("arg1", "value1"), Context.current())) + .blockingGet(); + assertThat(result).isNotNull(); List spans = openTelemetryRule.getSpans(); assertThat(spans).hasSize(1);