weiqingy commented on code in PR #870:
URL: https://github.com/apache/flink-agents/pull/870#discussion_r3524288183


##########
integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java:
##########
@@ -133,6 +133,11 @@ public float[] embed(String text, Map<String, Object> 
parameters) {
 
         try {
             JsonNode result = MAPPER.readTree(response.body().asUtf8String());
+            JsonNode inputTokenCount = result.get("inputTextTokenCount");
+            if (inputTokenCount != null && inputTokenCount.isNumber()) {
+                long tokens = inputTokenCount.asLong();
+                recordTokenUsage(tokens, tokens);
+            }

Review Comment:
   This records into the current thread's `ThreadLocal`, which is right for the 
single-text path — but this same connection's `embed(List)` fans a batch of 
size > 1 out to `embedPool` workers (`CompletableFuture.supplyAsync(() -> 
embed(text, parameters), embedPool)`), so each per-text `recordTokenUsage` 
lands on a *worker* thread while `BaseEmbeddingModelSetup.embed(List)` consumes 
on the calling (mailbox) thread. For every Bedrock batch of size > 1 the 
consume then returns `null` and the tokens are silently dropped (exactly the 
RAG `add(List<Document>)` path #858 targets), while the workers' `ThreadLocal`s 
are never consumed and — since `recordTokenUsage` accumulates onto any existing 
value — grow across batches. The `size <= 1` path stays on the calling thread, 
so it's fine. This is the batch symptom of the side-channel root in the 
top-level comment; would carrying usage back with the result close it here?



##########
api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java:
##########
@@ -106,7 +107,10 @@ public float[] embed(String text) {
     public float[] embed(String text, Map<String, Object> parameters) {
         Map<String, Object> params = this.getParameters();
         params.putAll(parameters);
-        return getConnection().embed(text, params);
+        BaseEmbeddingModelConnection currentConnection = getConnection();
+        float[] embedding = currentConnection.embed(text, params);
+        recordTokenMetrics(currentConnection.consumeTokenUsage());

Review Comment:
   There's no `try/finally` around `embed()` and the consume, so if `embed()` 
throws after the connection already recorded usage, `consumeTokenUsage()` is 
skipped and the value persists on the thread. Bedrock hits this concretely: 
`BedrockEmbeddingModelConnection` records at line 139 *before* 
`result.get("embedding")` can return `null` and NPE at line 142, so a failed 
request leaves `{tokens, tokens}` in the `ThreadLocal`, and the next successful 
`embed()` on that thread folds those tokens in via the accumulate branch — 
inflating the metric with a plausible-looking number. The `List` overload at 
line 132 has the same shape, and `embedding_model.py` mirrors it. Would 
wrapping the consume in `finally` on both overloads (both languages) close the 
gap? The result-plus-usage change in the top-level comment would remove the 
need entirely, so worth deciding that direction first.



##########
api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelConnection.java:
##########
@@ -80,4 +81,31 @@ public ResourceType getResourceType() {
      *     embeddings. The length of each array is determined by the model 
itself.
      */
     public abstract List<float[]> embed(List<String> texts, Map<String, 
Object> parameters);
+
+    /**
+     * Record token usage observed during the current embedding request.
+     *
+     * <p>Connections should call this only when the provider response exposes 
usage. Providers that
+     * do not expose token usage should leave it unset.
+     */
+    protected void recordTokenUsage(Long promptTokens, Long totalTokens) {
+        if (promptTokens == null && totalTokens == null) {
+            return;
+        }
+
+        long prompt = promptTokens == null ? 0L : promptTokens;
+        long total = totalTokens == null ? prompt : totalTokens;
+        EmbeddingTokenUsage current = lastTokenUsage.get();
+        if (current != null) {

Review Comment:
   Because `consumeTokenUsage()` clears the slot after every `embed()`, in the 
happy path `current` is always `null` and this branch never runs — it only 
fires when a prior consume was skipped (the leak conditions above), and when it 
does it *adds to* the stale value rather than resetting, turning a dropped 
metric into a wrong one that's harder to spot. If the side channel stays, is 
overwrite (`set`, not `+=`) the safer default here than accumulate? I looked 
for a real multiple-record-per-consume case that would motivate the `+=` and 
couldn't find one — Bedrock's `size <= 1` loop runs at most once — but if there 
is one I've missed, it'd be worth a comment here since the branch reads as 
intentional.



##########
api/src/test/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetupTokenMetricsTest.java:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api.embedding.model;
+
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
+import org.apache.flink.agents.api.metrics.UpdatableGauge;
+import org.apache.flink.agents.api.resource.ResourceContext;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.apache.flink.metrics.Counter;
+import org.apache.flink.metrics.Histogram;
+import org.apache.flink.metrics.Meter;
+import org.apache.flink.metrics.SimpleCounter;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verifyNoInteractions;
+
+/** Test cases for embedding model token metrics. */
+class BaseEmbeddingModelSetupTokenMetricsTest {
+
+    private static class TestEmbeddingModelSetup extends 
BaseEmbeddingModelSetup {
+
+        TestEmbeddingModelSetup(BaseEmbeddingModelConnection connection) {
+            super(
+                    new ResourceDescriptor(
+                            TestEmbeddingModelSetup.class.getName(),
+                            Map.of("connection", "mock-connection", "model", 
"mock-model")),
+                    mock(ResourceContext.class));
+            this.connection = connection;
+        }
+
+        @Override
+        public Map<String, Object> getParameters() {
+            return new HashMap<>();
+        }
+    }
+
+    private static class TestEmbeddingModelConnection extends 
BaseEmbeddingModelConnection {
+
+        TestEmbeddingModelConnection() {
+            super(
+                    new ResourceDescriptor(
+                            TestEmbeddingModelConnection.class.getName(), 
Collections.emptyMap()),
+                    mock(ResourceContext.class));
+        }
+
+        @Override
+        public float[] embed(String text, Map<String, Object> parameters) {
+            recordTokenUsage(7L, 9L);
+            return new float[] {0.1f, 0.2f};
+        }
+
+        @Override
+        public List<float[]> embed(List<String> texts, Map<String, Object> 
parameters) {
+            recordTokenUsage(11L, 13L);
+            List<float[]> embeddings = new ArrayList<>();
+            for (String ignored : texts) {
+                embeddings.add(new float[] {0.1f, 0.2f});
+            }
+            return embeddings;
+        }
+    }
+
+    private static class TestEmbeddingModelConnectionWithoutUsage
+            extends BaseEmbeddingModelConnection {
+
+        TestEmbeddingModelConnectionWithoutUsage() {
+            super(
+                    new ResourceDescriptor(
+                            
TestEmbeddingModelConnectionWithoutUsage.class.getName(),
+                            Collections.emptyMap()),
+                    mock(ResourceContext.class));
+        }
+
+        @Override
+        public float[] embed(String text, Map<String, Object> parameters) {
+            return new float[] {0.1f, 0.2f};
+        }
+
+        @Override
+        public List<float[]> embed(List<String> texts, Map<String, Object> 
parameters) {
+            List<float[]> embeddings = new ArrayList<>();
+            for (String ignored : texts) {
+                embeddings.add(new float[] {0.1f, 0.2f});
+            }
+            return embeddings;
+        }
+    }
+
+    @Test
+    void testEmbeddingTokenMetricsAreRecordedWhenUsageIsReported() {
+        TestEmbeddingModelSetup setup =
+                new TestEmbeddingModelSetup(new 
TestEmbeddingModelConnection());
+        TestMetricGroup metricGroup = new TestMetricGroup();
+        setup.setMetricGroup(metricGroup);
+
+        assertArrayEquals(new float[] {0.1f, 0.2f}, setup.embed("hello"));
+
+        TestMetricGroup modelGroup =
+                (TestMetricGroup) metricGroup.getSubGroup("model", 
"mock-model");
+        assertEquals(7L, modelGroup.counters.get("promptTokens").getCount());
+        assertEquals(9L, modelGroup.counters.get("totalTokens").getCount());
+    }
+
+    @Test
+    void testEmbeddingTokenMetricsAreNoopWhenUsageIsAbsent() {
+        TestEmbeddingModelSetup setup =
+                new TestEmbeddingModelSetup(new 
TestEmbeddingModelConnectionWithoutUsage());
+        FlinkAgentsMetricGroup metricGroup = 
mock(FlinkAgentsMetricGroup.class);
+        setup.setMetricGroup(metricGroup);
+
+        setup.embed("hello");
+
+        verifyNoInteractions(metricGroup);
+    }
+
+    @Test
+    void testEmbeddingTokenMetricsAccumulateAcrossRequests() {

Review Comment:
   The mocks here are honest, so this isn't about the existing assertions — 
it's the failure modes that aren't reached. Despite its name, 
`testEmbeddingTokenMetricsAccumulateAcrossRequests` makes two separate 
`embed()` calls with a consume clearing between them, so it proves the Flink 
`Counter` is monotonic but never enters the `if (current != null)` branch in 
`BaseEmbeddingModelConnection`. The paths most likely to regress are untested: 
a record-then-throw followed by a successful `embed()` (asserting the second 
call's metric is *not* inflated), and the Bedrock batch path. Are those two 
cases worth adding? They're the ones that would catch the correctness issues 
above if the side channel stays.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to