zxs1633079383 commented on code in PR #870:
URL: https://github.com/apache/flink-agents/pull/870#discussion_r3524307244


##########
integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java:
##########
@@ -133,6 +133,11 @@ public float[] embed(String text, Map<String, Object> 
parameters) {
 
         try {
             JsonNode result = MAPPER.readTree(response.body().asUtf8String());
+            JsonNode inputTokenCount = result.get("inputTextTokenCount");
+            if (inputTokenCount != null && inputTokenCount.isNumber()) {
+                long tokens = inputTokenCount.asLong();
+                recordTokenUsage(tokens, tokens);
+            }

Review Comment:
   Resolved in e023346 by removing the ThreadLocal side channel. 
`BedrockEmbeddingModelConnection.embedWithUsage(List)` now collects 
`EmbeddingResult<float[]>` from the worker futures and merges token usage on 
the caller path. 
`BedrockEmbeddingModelTest.testBatchEmbeddingAggregatesTokenUsage` covers the 
batch size > 1 path.



##########
api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java:
##########
@@ -106,7 +107,10 @@ public float[] embed(String text) {
     public float[] embed(String text, Map<String, Object> parameters) {
         Map<String, Object> params = this.getParameters();
         params.putAll(parameters);
-        return getConnection().embed(text, params);
+        BaseEmbeddingModelConnection currentConnection = getConnection();
+        float[] embedding = currentConnection.embed(text, params);
+        recordTokenMetrics(currentConnection.consumeTokenUsage());

Review Comment:
   Resolved via the result-plus-usage path rather than `try/finally`. 
`BaseEmbeddingModelSetup` records token metrics only after 
`embedWithUsage(...)` returns a successful result, so a provider exception has 
no side-channel slot to leak into the next call. Added Java and Python no-leak 
regression tests.



##########
api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelConnection.java:
##########
@@ -80,4 +81,31 @@ public ResourceType getResourceType() {
      *     embeddings. The length of each array is determined by the model 
itself.
      */
     public abstract List<float[]> embed(List<String> texts, Map<String, 
Object> parameters);
+
+    /**
+     * Record token usage observed during the current embedding request.
+     *
+     * <p>Connections should call this only when the provider response exposes 
usage. Providers that
+     * do not expose token usage should leave it unset.
+     */
+    protected void recordTokenUsage(Long promptTokens, Long totalTokens) {
+        if (promptTokens == null && totalTokens == null) {
+            return;
+        }
+
+        long prompt = promptTokens == null ? 0L : promptTokens;
+        long total = totalTokens == null ? prompt : totalTokens;
+        EmbeddingTokenUsage current = lastTokenUsage.get();
+        if (current != null) {

Review Comment:
   Removed with the side channel. There is no per-thread accumulate branch now; 
provider usage is returned with the embedding result, and metric accumulation 
only happens in the Flink counters after setup receives that result.



##########
api/src/test/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetupTokenMetricsTest.java:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api.embedding.model;
+
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
+import org.apache.flink.agents.api.metrics.UpdatableGauge;
+import org.apache.flink.agents.api.resource.ResourceContext;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.apache.flink.metrics.Counter;
+import org.apache.flink.metrics.Histogram;
+import org.apache.flink.metrics.Meter;
+import org.apache.flink.metrics.SimpleCounter;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verifyNoInteractions;
+
+/** Test cases for embedding model token metrics. */
+class BaseEmbeddingModelSetupTokenMetricsTest {
+
+    private static class TestEmbeddingModelSetup extends 
BaseEmbeddingModelSetup {
+
+        TestEmbeddingModelSetup(BaseEmbeddingModelConnection connection) {
+            super(
+                    new ResourceDescriptor(
+                            TestEmbeddingModelSetup.class.getName(),
+                            Map.of("connection", "mock-connection", "model", 
"mock-model")),
+                    mock(ResourceContext.class));
+            this.connection = connection;
+        }
+
+        @Override
+        public Map<String, Object> getParameters() {
+            return new HashMap<>();
+        }
+    }
+
+    private static class TestEmbeddingModelConnection extends 
BaseEmbeddingModelConnection {
+
+        TestEmbeddingModelConnection() {
+            super(
+                    new ResourceDescriptor(
+                            TestEmbeddingModelConnection.class.getName(), 
Collections.emptyMap()),
+                    mock(ResourceContext.class));
+        }
+
+        @Override
+        public float[] embed(String text, Map<String, Object> parameters) {
+            recordTokenUsage(7L, 9L);
+            return new float[] {0.1f, 0.2f};
+        }
+
+        @Override
+        public List<float[]> embed(List<String> texts, Map<String, Object> 
parameters) {
+            recordTokenUsage(11L, 13L);
+            List<float[]> embeddings = new ArrayList<>();
+            for (String ignored : texts) {
+                embeddings.add(new float[] {0.1f, 0.2f});
+            }
+            return embeddings;
+        }
+    }
+
+    private static class TestEmbeddingModelConnectionWithoutUsage
+            extends BaseEmbeddingModelConnection {
+
+        TestEmbeddingModelConnectionWithoutUsage() {
+            super(
+                    new ResourceDescriptor(
+                            
TestEmbeddingModelConnectionWithoutUsage.class.getName(),
+                            Collections.emptyMap()),
+                    mock(ResourceContext.class));
+        }
+
+        @Override
+        public float[] embed(String text, Map<String, Object> parameters) {
+            return new float[] {0.1f, 0.2f};
+        }
+
+        @Override
+        public List<float[]> embed(List<String> texts, Map<String, Object> 
parameters) {
+            List<float[]> embeddings = new ArrayList<>();
+            for (String ignored : texts) {
+                embeddings.add(new float[] {0.1f, 0.2f});
+            }
+            return embeddings;
+        }
+    }
+
+    @Test
+    void testEmbeddingTokenMetricsAreRecordedWhenUsageIsReported() {
+        TestEmbeddingModelSetup setup =
+                new TestEmbeddingModelSetup(new 
TestEmbeddingModelConnection());
+        TestMetricGroup metricGroup = new TestMetricGroup();
+        setup.setMetricGroup(metricGroup);
+
+        assertArrayEquals(new float[] {0.1f, 0.2f}, setup.embed("hello"));
+
+        TestMetricGroup modelGroup =
+                (TestMetricGroup) metricGroup.getSubGroup("model", 
"mock-model");
+        assertEquals(7L, modelGroup.counters.get("promptTokens").getCount());
+        assertEquals(9L, modelGroup.counters.get("totalTokens").getCount());
+    }
+
+    @Test
+    void testEmbeddingTokenMetricsAreNoopWhenUsageIsAbsent() {
+        TestEmbeddingModelSetup setup =
+                new TestEmbeddingModelSetup(new 
TestEmbeddingModelConnectionWithoutUsage());
+        FlinkAgentsMetricGroup metricGroup = 
mock(FlinkAgentsMetricGroup.class);
+        setup.setMetricGroup(metricGroup);
+
+        setup.embed("hello");
+
+        verifyNoInteractions(metricGroup);
+    }
+
+    @Test
+    void testEmbeddingTokenMetricsAccumulateAcrossRequests() {

Review Comment:
   Added the two regression cases called out here: prior provider failure 
followed by success does not inflate metrics in Java/Python setup tests, and 
Bedrock batch size > 1 aggregates usage from worker results. The existing 
cross-request test now only proves Flink counter monotonicity.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to