This is an automated email from the ASF dual-hosted git repository.
mawiesne pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/main by this push:
new 889ceab12 OPENNLP-1845: Fix numerically unstable softmax in
DocumentCategorizerDL (#1085)
889ceab12 is described below
commit 889ceab12a932d26214cbec524d684589a370ca1
Author: Kristian Rickert <[email protected]>
AuthorDate: Tue Jun 16 05:34:31 2026 -0400
OPENNLP-1845: Fix numerically unstable softmax in DocumentCategorizerDL
(#1085)
* Fix DocumentCategorizerDL softmax and error result handling
* Fail loudly instead of 0 result
---
.../opennlp/dl/doccat/DocumentCategorizerDL.java | 190 ++++++++++++---------
.../dl/doccat/DocumentCategorizerDLTest.java | 125 ++++++++++++++
.../dl/doccat/DocumentCategorizerDLEval.java | 37 ++++
3 files changed, 276 insertions(+), 76 deletions(-)
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
index cf01631bf..e357c48f2 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
@@ -22,6 +22,7 @@ import java.io.IOException;
import java.nio.LongBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
@@ -37,8 +38,6 @@ import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
import opennlp.dl.AbstractDL;
import opennlp.dl.InferenceOptions;
@@ -63,8 +62,6 @@ import opennlp.tools.doccat.DocumentCategorizer;
*/
public class DocumentCategorizerDL extends AbstractDL implements
DocumentCategorizer {
- private static final Logger logger =
LoggerFactory.getLogger(DocumentCategorizerDL.class);
-
/** Classification models are commonly uncased, so lower casing is the
default. */
private static final boolean LOWER_CASE_DEFAULT = true;
@@ -72,6 +69,19 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
private final ClassificationScoringStrategy classificationScoringStrategy;
private final InferenceOptions inferenceOptions;
+ DocumentCategorizerDL(OrtEnvironment env, OrtSession session, Map<String,
Integer> vocab,
+ Map<Integer, String> categories,
+ ClassificationScoringStrategy
classificationScoringStrategy,
+ InferenceOptions inferenceOptions) {
+ this.env = env;
+ this.session = session;
+ this.vocab = vocab;
+ this.tokenizer = createTokenizer(vocab, resolveLowerCase(inferenceOptions,
LOWER_CASE_DEFAULT));
+ this.categories = categories;
+ this.classificationScoringStrategy = classificationScoringStrategy;
+ this.inferenceOptions = inferenceOptions;
+ }
+
/**
* Instantiates a {@link DocumentCategorizer document categorizer} using
ONNX models.
*
@@ -141,68 +151,74 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
}
+ /**
+ * Categorizes the document, failing loudly rather than returning an invalid
distribution:
+ * malformed input is rejected with {@link IllegalArgumentException}, and
any failure executing
+ * the model is surfaced as an {@link IllegalStateException} (cause
preserved).
+ *
+ * @param strings The document to categorize; {@code strings[0]} is
classified.
+ * @return The per-category probabilities.
+ * @throws IllegalArgumentException If {@code strings} is {@code null} or
empty.
+ * @throws IllegalStateException If inference fails or the model returns
an unexpected output.
+ */
@Override
public double[] categorize(String[] strings) {
- try {
+ if (strings == null || strings.length == 0) {
+ throw new IllegalArgumentException("strings must contain at least one
document to categorize");
+ }
+
+ final List<Tokens> tokens = tokenize(strings[0]);
+
+ final List<double[]> scores = new LinkedList<>();
+ for (final Tokens t : tokens) {
+ scores.add(softmax(infer(t)));
+ }
+
+ return classificationScoringStrategy.score(scores);
+ }
+
+ /**
+ * Runs the model on one token window and returns its raw per-category
logits. A failure executing
+ * the model (an {@link OrtException} or any runtime fault) is wrapped as an
+ * {@link IllegalStateException}; an unexpected output shape is its own loud
failure.
+ */
+ private float[] infer(final Tokens t) {
- final List<Tokens> tokens = tokenize(strings[0]);
-
- final List<double[]> scores = new LinkedList<>();
-
- for (final Tokens t : tokens) {
-
- final Map<String, OnnxTensor> inputs = new HashMap<>();
-
- final Object output;
- try {
- inputs.put(INPUT_IDS, OnnxTensor.createTensor(env,
- LongBuffer.wrap(t.ids()), new long[] {1, t.ids().length}));
-
- if (inferenceOptions.isIncludeAttentionMask()) {
- inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
- LongBuffer.wrap(t.mask()), new long[] {1, t.mask().length}));
- }
-
- if (inferenceOptions.isIncludeTokenTypeIds()) {
- inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
- LongBuffer.wrap(t.types()), new long[] {1, t.types().length}));
- }
-
- // The outputs from the model. Some models return a 2D array (e.g.
BERT),
- // while others return a 1D array (e.g. RoBERTa).
- try (OrtSession.Result result = session.run(inputs)) {
- // getValue() copies the tensor into Java arrays, so the result
can be closed safely.
- output = result.get(0).getValue();
- }
- } finally {
- inputs.values().forEach(OnnxTensor::close);
- }
-
- final float[] rawScores;
- if (output instanceof float[][] v) {
- rawScores = v[0];
- } else if (output instanceof float[] v) {
- rawScores = v;
- } else {
- throw new IllegalStateException(
- "Unexpected model output type: " + output.getClass().getName());
- }
-
- // Keep track of all scores.
- final double[] categoryScoresForTokens = softmax(rawScores);
- scores.add(categoryScoresForTokens);
+ final Map<String, OnnxTensor> inputs = new HashMap<>();
+ final Object output;
+ try {
+ inputs.put(INPUT_IDS, OnnxTensor.createTensor(env,
+ LongBuffer.wrap(t.ids()), new long[] {1, t.ids().length}));
+ if (inferenceOptions.isIncludeAttentionMask()) {
+ inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
+ LongBuffer.wrap(t.mask()), new long[] {1, t.mask().length}));
}
- return classificationScoringStrategy.score(scores);
+ if (inferenceOptions.isIncludeTokenTypeIds()) {
+ inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
+ LongBuffer.wrap(t.types()), new long[] {1, t.types().length}));
+ }
- } catch (Exception ex) {
- logger.error("Unload to perform document classification inference", ex);
+ // getValue() copies the tensor into Java arrays, so the result can be
closed safely.
+ try (OrtSession.Result result = session.run(inputs)) {
+ output = result.get(0).getValue();
+ }
+ } catch (OrtException | RuntimeException ex) {
+ throw new IllegalStateException("Unable to perform document
classification inference", ex);
+ } finally {
+ inputs.values().forEach(OnnxTensor::close);
}
- return new double[] {};
-
+ // Some models return a 2D array (e.g. BERT), others a 1D array (e.g.
RoBERTa). A different
+ // shape is a model-contract violation, surfaced on its own rather than as
"inference failed".
+ if (output instanceof float[][] v) {
+ return v[0];
+ } else if (output instanceof float[] v) {
+ return v;
+ }
+ throw new IllegalStateException("Unexpected model output type: " +
output.getClass().getName());
}
@Override
@@ -298,23 +314,13 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
// Split the input text into 200 word chunks with 50 overlapping between
chunks.
final String[] whitespaceTokenized = text.split("\\s+");
- for (int start = 0; start < whitespaceTokenized.length;
- start = start + inferenceOptions.getDocumentSplitSize()) {
-
- // 200 word length chunk
- // Check the end do don't go past and get a
StringIndexOutOfBoundsException
- int end = start + inferenceOptions.getDocumentSplitSize();
- if (end > whitespaceTokenized.length) {
- end = whitespaceTokenized.length;
- }
+ for (final int[] range : chunkRanges(whitespaceTokenized.length,
+ inferenceOptions.getDocumentSplitSize(),
inferenceOptions.getSplitOverlapSize())) {
- // The group is that subsection of string.
- final String group = String.join(" ",
Arrays.copyOfRange(whitespaceTokenized, start, end));
+ // The group is that subsection of the input.
+ final String group =
+ String.join(" ", Arrays.copyOfRange(whitespaceTokenized, range[0],
range[1]));
- // We want to overlap each chunk by 50 words so scoot back 50 words for
the next iteration.
- start = start - inferenceOptions.getSplitOverlapSize();
-
- // Now we can tokenize the group and continue.
final String[] tokens = tokenizer.tokenize(group);
final long[] ids = tokenIds(tokens, vocab);
@@ -333,6 +339,32 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
}
+ /**
+ * Computes the {@code [start, end)} word-index ranges the input is split
into: chunks of
+ * {@code splitSize} words overlapping by {@code overlapSize}. The loop
always advances by
+ * at least one word, so a misconfigured {@code overlapSize >= splitSize}
can neither stall
+ * the loop nor produce negative indices.
+ *
+ * @param length The number of whitespace-separated words.
+ * @param splitSize The chunk size in words.
+ * @param overlapSize The overlap between consecutive chunks in words.
+ * @return The ordered list of {@code [start, end)} ranges; empty when
{@code length == 0}.
+ */
+ static List<int[]> chunkRanges(final int length, final int splitSize, final
int overlapSize) {
+ final List<int[]> ranges = new ArrayList<>();
+ int start = 0;
+ while (start < length) {
+ final int end = Math.min(start + splitSize, length);
+ ranges.add(new int[] {start, end});
+ if (end == length) {
+ break;
+ }
+ // Overlap by overlapSize words, but always move forward by at least one.
+ start = Math.max(end - overlapSize, start + 1);
+ }
+ return ranges;
+ }
+
/**
* Maps tokens to their vocabulary ids.
*
@@ -366,21 +398,27 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
* @param input An array of values.
* @return The output array.
*/
- private double[] softmax(final float[] input) {
+ static double[] softmax(final float[] input) {
+
+ // Subtract the maximum before exponentiating (numerically stable
softmax): exp() of a
+ // large logit otherwise overflows to +Infinity, yielding NaN scores.
Mathematically
+ // identical to the naive form. Results are kept in double precision
throughout.
+ double max = Double.NEGATIVE_INFINITY;
+ for (final float value : input) {
+ max = Math.max(max, value);
+ }
final double[] t = new double[input.length];
double sum = 0.0;
-
for (int x = 0; x < input.length; x++) {
- double val = Math.exp(input[x]);
+ final double val = Math.exp(input[x] - max);
sum += val;
t[x] = val;
}
final double[] output = new double[input.length];
-
for (int x = 0; x < output.length; x++) {
- output[x] = (float) (t[x] / sum);
+ output[x] = t[x] / sum;
}
return output;
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
index a6bab39f6..e962e4bee 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
@@ -18,13 +18,18 @@
package opennlp.dl.doccat;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;
+import opennlp.dl.InferenceOptions;
+import opennlp.dl.doccat.scoring.AverageClassificationScoringStrategy;
import opennlp.tools.tokenize.WordpieceTokenizer;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -40,6 +45,47 @@ public class DocumentCategorizerDLTest {
return vocab;
}
+ private static Map<Integer, String> categories() {
+ final Map<Integer, String> categories = new HashMap<>();
+ categories.put(0, "negative");
+ categories.put(1, "positive");
+ return categories;
+ }
+
+ private static DocumentCategorizerDL categorizerWithoutSession() {
+ return new DocumentCategorizerDL(null, null, vocab(), categories(),
+ new AverageClassificationScoringStrategy(), new InferenceOptions());
+ }
+
+ @Test
+ void testCategorizeFailsLoudlyWhenInferenceFails() {
+ final IllegalStateException e = assertThrows(IllegalStateException.class,
() ->
+ categorizerWithoutSession().categorize(new String[] {"hello world"}));
+
+ assertTrue(e.getMessage().contains("document classification inference"));
+ assertTrue(e.getCause() instanceof RuntimeException);
+ }
+
+ @Test
+ void testScoreMapsFailLoudlyWhenInferenceFails() {
+ final DocumentCategorizerDL categorizer = categorizerWithoutSession();
+
+ assertThrows(IllegalStateException.class, () ->
+ categorizer.scoreMap(new String[] {"hello world"}));
+ assertThrows(IllegalStateException.class, () ->
+ categorizer.sortedScoreMap(new String[] {"hello world"}));
+ }
+
+ @Test
+ void testCategorizeRejectsMalformedInput() {
+ // A caller-side input bug is distinguished from an inference failure: it
is rejected up front
+ // with IllegalArgumentException, not wrapped as "document classification
inference" failure.
+ final DocumentCategorizerDL categorizer = categorizerWithoutSession();
+
+ assertThrows(IllegalArgumentException.class, () ->
categorizer.categorize(null));
+ assertThrows(IllegalArgumentException.class, () ->
categorizer.categorize(new String[0]));
+ }
+
@Test
void testTokenIdsMapsTokensToVocabularyIds() {
final long[] ids = DocumentCategorizerDL.tokenIds(
@@ -57,4 +103,83 @@ public class DocumentCategorizerDLTest {
assertTrue(e.getMessage().contains("missing"),
"the error message should name the missing token: " + e.getMessage());
}
+
+ @Test
+ void testSoftmaxIsUniformForEqualLogitsAndSumsToOne() {
+ final double[] out = DocumentCategorizerDL.softmax(new float[] {0f, 0f,
0f});
+
+ assertEquals(3, out.length);
+ for (final double p : out) {
+ assertEquals(1.0 / 3.0, p, 1e-12);
+ }
+ assertEquals(1.0, out[0] + out[1] + out[2], 1e-12);
+ }
+
+ @Test
+ void testSoftmaxIsNumericallyStableForLargeLogits() {
+ // The naive exp(logit) form overflows to +Infinity here and yields NaN;
subtracting
+ // the maximum keeps every value finite and the distribution uniform.
+ final double[] out = DocumentCategorizerDL.softmax(new float[] {1000f,
1000f, 1000f});
+
+ double sum = 0.0;
+ for (final double p : out) {
+ assertFalse(Double.isNaN(p) || Double.isInfinite(p),
+ "softmax must stay finite for large logits");
+ assertEquals(1.0 / 3.0, p, 1e-9);
+ sum += p;
+ }
+ assertEquals(1.0, sum, 1e-12);
+ }
+
+ @Test
+ void testSoftmaxMatchesReferenceDistribution() {
+ // Reference (numpy): softmax([1,2,3]) = [0.09003057, 0.24472847,
0.66524096].
+ final double[] out = DocumentCategorizerDL.softmax(new float[] {1f, 2f,
3f});
+
+ assertEquals(0.09003057, out[0], 1e-6);
+ assertEquals(0.24472847, out[1], 1e-6);
+ assertEquals(0.66524096, out[2], 1e-6);
+ }
+
+ @Test
+ void testChunkRangesSplitsWithOverlap() {
+ // 210 words, 200-word chunks overlapping by 50 -> [0,200), [150,210).
+ final List<int[]> ranges = DocumentCategorizerDL.chunkRanges(210, 200, 50);
+
+ assertEquals(2, ranges.size());
+ assertArrayEquals(new int[] {0, 200}, ranges.get(0));
+ assertArrayEquals(new int[] {150, 210}, ranges.get(1));
+ }
+
+ @Test
+ void testChunkRangesSingleChunkWhenShorterThanSplit() {
+ final List<int[]> ranges = DocumentCategorizerDL.chunkRanges(30, 200, 50);
+
+ assertEquals(1, ranges.size());
+ assertArrayEquals(new int[] {0, 30}, ranges.get(0));
+ }
+
+ @Test
+ void testChunkRangesEmptyForZeroLength() {
+ assertTrue(DocumentCategorizerDL.chunkRanges(0, 200, 50).isEmpty());
+ }
+
+ @Test
+ void testChunkRangesAlwaysProgressesForInvalidOverlap() {
+ // overlap == split would stall forever, and overlap > split would make
the start index
+ // negative, without the forward-progress guard.
+ for (final int[] cfg : new int[][] {{10, 5, 5}, {8, 3, 10}, {7, 4, 100}}) {
+ final int length = cfg[0];
+ final List<int[]> ranges = DocumentCategorizerDL.chunkRanges(length,
cfg[1], cfg[2]);
+
+ int previousStart = -1;
+ for (final int[] range : ranges) {
+ assertTrue(range[0] >= 0, "start must never be negative: " + range[0]);
+ assertTrue(range[1] >= range[0], "end must be >= start");
+ assertTrue(range[0] > previousStart, "each chunk must advance the
start index");
+ previousStart = range[0];
+ }
+ assertEquals(length, ranges.get(ranges.size() - 1)[1], "last chunk must
reach the end");
+ }
+ }
}
diff --git
a/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
b/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
index e045ec90e..d19443b5b 100644
---
a/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
+++
b/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
@@ -34,6 +34,7 @@ import org.slf4j.LoggerFactory;
import opennlp.dl.InferenceOptions;
import opennlp.dl.doccat.scoring.AverageClassificationScoringStrategy;
import opennlp.tools.eval.AbstractEvalTest;
+import opennlp.tools.tokenize.WordpieceTokenizer;
public class DocumentCategorizerDLEval extends AbstractEvalTest {
@@ -91,6 +92,27 @@ public class DocumentCategorizerDLEval extends
AbstractEvalTest {
}
+ @Test
+ public void categorizeFailsLoudlyOnFailure() throws Exception {
+
+ try (final DocumentCategorizerDL documentCategorizerDL =
+ categorizerWithoutSession()) {
+
+ // Empty input drives categorize() down its failure path (strings[0]
throws) before any
+ // inference; it must fail loudly rather than return an invalid all-zero
distribution.
+ final IllegalStateException e =
Assertions.assertThrows(IllegalStateException.class, () ->
+ documentCategorizerDL.categorize(new String[0]));
+ Assertions.assertTrue(e.getMessage().contains("document classification
inference"));
+
+ // The dependent API must not mask that inference failure with all-zero
scores.
+ Assertions.assertThrows(IllegalStateException.class, () ->
+ documentCategorizerDL.scoreMap(new String[0]));
+ Assertions.assertThrows(IllegalStateException.class, () ->
+ documentCategorizerDL.sortedScoreMap(new String[0]));
+ }
+
+ }
+
@Test
public void categorizeWithAutomaticLabels() throws Exception {
@@ -309,4 +331,19 @@ public class DocumentCategorizerDLEval extends
AbstractEvalTest {
}
+ private DocumentCategorizerDL categorizerWithoutSession() {
+ return new DocumentCategorizerDL(null, null, vocab(), getCategories(),
+ new AverageClassificationScoringStrategy(), new InferenceOptions());
+ }
+
+ private Map<String, Integer> vocab() {
+ final Map<String, Integer> vocab = new HashMap<>();
+ vocab.put(WordpieceTokenizer.BERT_CLS_TOKEN, 0);
+ vocab.put(WordpieceTokenizer.BERT_SEP_TOKEN, 1);
+ vocab.put(WordpieceTokenizer.BERT_UNK_TOKEN, 2);
+ vocab.put("hello", 3);
+ vocab.put("world", 4);
+ return vocab;
+ }
+
}