weiqingy commented on code in PR #870:
URL: https://github.com/apache/flink-agents/pull/870#discussion_r3524288183
##########
integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java:
##########
@@ -133,6 +133,11 @@ public float[] embed(String text, Map<String, Object>
parameters) {
try {
JsonNode result = MAPPER.readTree(response.body().asUtf8String());
+ JsonNode inputTokenCount = result.get("inputTextTokenCount");
+ if (inputTokenCount != null && inputTokenCount.isNumber()) {
+ long tokens = inputTokenCount.asLong();
+ recordTokenUsage(tokens, tokens);
+ }
Review Comment:
This records into the current thread's `ThreadLocal`, which is right for the
single-text path — but this same connection's `embed(List)` fans a batch of
size > 1 out to `embedPool` workers (`CompletableFuture.supplyAsync(() ->
embed(text, parameters), embedPool)`), so each per-text `recordTokenUsage`
lands on a *worker* thread while `BaseEmbeddingModelSetup.embed(List)` consumes
on the calling (mailbox) thread. For every Bedrock batch of size > 1 the
consume then returns `null` and the tokens are silently dropped (exactly the
RAG `add(List<Document>)` path #858 targets), while the workers' `ThreadLocal`s
are never consumed and — since `recordTokenUsage` accumulates onto any existing
value — grow across batches. The `size <= 1` path stays on the calling thread,
so it's fine. This is the batch symptom of the side-channel root in the
top-level comment; would carrying usage back with the result close it here?
##########
api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java:
##########
@@ -106,7 +107,10 @@ public float[] embed(String text) {
public float[] embed(String text, Map<String, Object> parameters) {
Map<String, Object> params = this.getParameters();
params.putAll(parameters);
- return getConnection().embed(text, params);
+ BaseEmbeddingModelConnection currentConnection = getConnection();
+ float[] embedding = currentConnection.embed(text, params);
+ recordTokenMetrics(currentConnection.consumeTokenUsage());
Review Comment:
There's no `try/finally` around `embed()` and the consume, so if `embed()`
throws after the connection already recorded usage, `consumeTokenUsage()` is
skipped and the value persists on the thread. Bedrock hits this concretely:
`BedrockEmbeddingModelConnection` records at line 139 *before*
`result.get("embedding")` can return `null` and NPE at line 142, so a failed
request leaves `{tokens, tokens}` in the `ThreadLocal`, and the next successful
`embed()` on that thread folds those tokens in via the accumulate branch —
inflating the metric with a plausible-looking number. The `List` overload at
line 132 has the same shape, and `embedding_model.py` mirrors it. Would
wrapping the consume in `finally` on both overloads (both languages) close the
gap? The result-plus-usage change in the top-level comment would remove the
need entirely, so worth deciding that direction first.
##########
api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelConnection.java:
##########
@@ -80,4 +81,31 @@ public ResourceType getResourceType() {
* embeddings. The length of each array is determined by the model
itself.
*/
public abstract List<float[]> embed(List<String> texts, Map<String,
Object> parameters);
+
+ /**
+ * Record token usage observed during the current embedding request.
+ *
+ * <p>Connections should call this only when the provider response exposes
usage. Providers that
+ * do not expose token usage should leave it unset.
+ */
+ protected void recordTokenUsage(Long promptTokens, Long totalTokens) {
+ if (promptTokens == null && totalTokens == null) {
+ return;
+ }
+
+ long prompt = promptTokens == null ? 0L : promptTokens;
+ long total = totalTokens == null ? prompt : totalTokens;
+ EmbeddingTokenUsage current = lastTokenUsage.get();
+ if (current != null) {
Review Comment:
Because `consumeTokenUsage()` clears the slot after every `embed()`, in the
happy path `current` is always `null` and this branch never runs — it only
fires when a prior consume was skipped (the leak conditions above), and when it
does it *adds to* the stale value rather than resetting, turning a dropped
metric into a wrong one that's harder to spot. If the side channel stays, is
overwrite (`set`, not `+=`) the safer default here than accumulate? I looked
for a real multiple-record-per-consume case that would motivate the `+=` and
couldn't find one — Bedrock's `size <= 1` loop runs at most once — but if there
is one I've missed, it'd be worth a comment here since the branch reads as
intentional.
##########
api/src/test/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetupTokenMetricsTest.java:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api.embedding.model;
+
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
+import org.apache.flink.agents.api.metrics.UpdatableGauge;
+import org.apache.flink.agents.api.resource.ResourceContext;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.apache.flink.metrics.Counter;
+import org.apache.flink.metrics.Histogram;
+import org.apache.flink.metrics.Meter;
+import org.apache.flink.metrics.SimpleCounter;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verifyNoInteractions;
+
+/** Test cases for embedding model token metrics. */
+class BaseEmbeddingModelSetupTokenMetricsTest {
+
+ private static class TestEmbeddingModelSetup extends
BaseEmbeddingModelSetup {
+
+ TestEmbeddingModelSetup(BaseEmbeddingModelConnection connection) {
+ super(
+ new ResourceDescriptor(
+ TestEmbeddingModelSetup.class.getName(),
+ Map.of("connection", "mock-connection", "model",
"mock-model")),
+ mock(ResourceContext.class));
+ this.connection = connection;
+ }
+
+ @Override
+ public Map<String, Object> getParameters() {
+ return new HashMap<>();
+ }
+ }
+
+ private static class TestEmbeddingModelConnection extends
BaseEmbeddingModelConnection {
+
+ TestEmbeddingModelConnection() {
+ super(
+ new ResourceDescriptor(
+ TestEmbeddingModelConnection.class.getName(),
Collections.emptyMap()),
+ mock(ResourceContext.class));
+ }
+
+ @Override
+ public float[] embed(String text, Map<String, Object> parameters) {
+ recordTokenUsage(7L, 9L);
+ return new float[] {0.1f, 0.2f};
+ }
+
+ @Override
+ public List<float[]> embed(List<String> texts, Map<String, Object>
parameters) {
+ recordTokenUsage(11L, 13L);
+ List<float[]> embeddings = new ArrayList<>();
+ for (String ignored : texts) {
+ embeddings.add(new float[] {0.1f, 0.2f});
+ }
+ return embeddings;
+ }
+ }
+
+ private static class TestEmbeddingModelConnectionWithoutUsage
+ extends BaseEmbeddingModelConnection {
+
+ TestEmbeddingModelConnectionWithoutUsage() {
+ super(
+ new ResourceDescriptor(
+
TestEmbeddingModelConnectionWithoutUsage.class.getName(),
+ Collections.emptyMap()),
+ mock(ResourceContext.class));
+ }
+
+ @Override
+ public float[] embed(String text, Map<String, Object> parameters) {
+ return new float[] {0.1f, 0.2f};
+ }
+
+ @Override
+ public List<float[]> embed(List<String> texts, Map<String, Object>
parameters) {
+ List<float[]> embeddings = new ArrayList<>();
+ for (String ignored : texts) {
+ embeddings.add(new float[] {0.1f, 0.2f});
+ }
+ return embeddings;
+ }
+ }
+
+ @Test
+ void testEmbeddingTokenMetricsAreRecordedWhenUsageIsReported() {
+ TestEmbeddingModelSetup setup =
+ new TestEmbeddingModelSetup(new
TestEmbeddingModelConnection());
+ TestMetricGroup metricGroup = new TestMetricGroup();
+ setup.setMetricGroup(metricGroup);
+
+ assertArrayEquals(new float[] {0.1f, 0.2f}, setup.embed("hello"));
+
+ TestMetricGroup modelGroup =
+ (TestMetricGroup) metricGroup.getSubGroup("model",
"mock-model");
+ assertEquals(7L, modelGroup.counters.get("promptTokens").getCount());
+ assertEquals(9L, modelGroup.counters.get("totalTokens").getCount());
+ }
+
+ @Test
+ void testEmbeddingTokenMetricsAreNoopWhenUsageIsAbsent() {
+ TestEmbeddingModelSetup setup =
+ new TestEmbeddingModelSetup(new
TestEmbeddingModelConnectionWithoutUsage());
+ FlinkAgentsMetricGroup metricGroup =
mock(FlinkAgentsMetricGroup.class);
+ setup.setMetricGroup(metricGroup);
+
+ setup.embed("hello");
+
+ verifyNoInteractions(metricGroup);
+ }
+
+ @Test
+ void testEmbeddingTokenMetricsAccumulateAcrossRequests() {
Review Comment:
The mocks here are honest, so this isn't about the existing assertions —
it's the failure modes that aren't reached. Despite its name,
`testEmbeddingTokenMetricsAccumulateAcrossRequests` makes two separate
`embed()` calls with a consume clearing between them, so it proves the Flink
`Counter` is monotonic but never enters the `if (current != null)` branch in
`BaseEmbeddingModelConnection`. The paths most likely to regress are untested:
a record-then-throw followed by a successful `embed()` (asserting the second
call's metric is *not* inflated), and the Bedrock batch path. Are those two
cases worth adding? They're the ones that would catch the correctness issues
above if the side channel stays.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]