This is an automated email from the ASF dual-hosted git repository.
mawiesne pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/main by this push:
new 9ebcaaa23 OPENNLP-1838: Adopt BertTokenizer in opennlp-dl components
(#1075)
9ebcaaa23 is described below
commit 9ebcaaa2367ebf0725f4c489c877fdd9da4d1ed7
Author: Kristian Rickert <[email protected]>
AuthorDate: Sat Jun 13 15:16:30 2026 -0400
OPENNLP-1838: Adopt BertTokenizer in opennlp-dl components (#1075)
---
.../src/main/java/opennlp/dl/AbstractDL.java | 73 +++++++++++-
.../src/main/java/opennlp/dl/InferenceOptions.java | 25 ++++
.../opennlp/dl/doccat/DocumentCategorizerDL.java | 13 ++-
.../java/opennlp/dl/namefinder/NameFinderDL.java | 12 +-
.../java/opennlp/dl/vectors/SentenceVectorsDL.java | 34 +++++-
.../test/java/opennlp/dl/CreateTokenizerTest.java | 130 +++++++++++++++++++++
.../dl/doccat/DocumentCategorizerDLEval.java | 52 ++++-----
.../opennlp/dl/vectors/SentenceVectorsDLEval.java | 7 ++
8 files changed, 307 insertions(+), 39 deletions(-)
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
index d46d68a6f..38137a057 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
@@ -33,6 +33,7 @@ import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
+import opennlp.tools.tokenize.BertTokenizer;
import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;
@@ -109,19 +110,81 @@ public abstract class AbstractDL implements AutoCloseable
{
WordpieceTokenizer.ROBERTA_CLS_TOKEN)
&& vocab.containsKey(
WordpieceTokenizer.ROBERTA_SEP_TOKEN)) {
- final String unk = vocab.containsKey(
- WordpieceTokenizer.ROBERTA_UNK_TOKEN)
- ? WordpieceTokenizer.ROBERTA_UNK_TOKEN
- : WordpieceTokenizer.BERT_UNK_TOKEN;
return new WordpieceTokenizer(
vocab.keySet(),
WordpieceTokenizer.ROBERTA_CLS_TOKEN,
WordpieceTokenizer.ROBERTA_SEP_TOKEN,
- unk);
+ resolveUnknownToken(vocab));
}
return new WordpieceTokenizer(vocab.keySet());
}
+ /**
+ * Creates a {@link BertTokenizer} that performs the full BERT tokenization
+ * pipeline: basic tokenization (text normalization) followed by wordpiece.
+ * The special tokens are selected based on the vocabulary: if it contains
+ * RoBERTa-style tokens, those are used, otherwise the BERT defaults.
+ *
+ * @param vocab The vocabulary map.
+ * @param lowerCase {@code true} for uncased models (lower casing and accent
+ * stripping), {@code false} for cased models.
+ * @return A configured {@link BertTokenizer}.
+ * @throws IllegalArgumentException Thrown if a RoBERTa-style vocabulary
+ * contains no supported unknown token.
+ */
+ protected BertTokenizer createTokenizer(
+ final Map<String, Integer> vocab, final boolean lowerCase) {
+ if (vocab.containsKey(
+ WordpieceTokenizer.ROBERTA_CLS_TOKEN)
+ && vocab.containsKey(
+ WordpieceTokenizer.ROBERTA_SEP_TOKEN)) {
+ return new BertTokenizer(
+ vocab.keySet(),
+ lowerCase,
+ WordpieceTokenizer.ROBERTA_CLS_TOKEN,
+ WordpieceTokenizer.ROBERTA_SEP_TOKEN,
+ resolveUnknownToken(vocab));
+ }
+ return new BertTokenizer(vocab.keySet(), lowerCase);
+ }
+
+ /**
+ * Resolves the unknown token of a RoBERTa-style vocabulary. The RoBERTa
+ * token {@link WordpieceTokenizer#ROBERTA_UNK_TOKEN} is preferred;
vocabularies
+ * mixing conventions may instead contain {@link
WordpieceTokenizer#BERT_UNK_TOKEN}.
+ * An unknown token that is absent from the vocabulary must never be
selected, as
+ * the tokenizer would emit tokens that later fail the token-to-id mapping.
+ *
+ * @param vocab The vocabulary map.
+ * @return The unknown token present in the vocabulary.
+ * @throws IllegalArgumentException Thrown if the vocabulary contains neither
+ * supported unknown token.
+ */
+ private static String resolveUnknownToken(final Map<String, Integer> vocab) {
+ if (vocab.containsKey(WordpieceTokenizer.ROBERTA_UNK_TOKEN)) {
+ return WordpieceTokenizer.ROBERTA_UNK_TOKEN;
+ }
+ if (vocab.containsKey(WordpieceTokenizer.BERT_UNK_TOKEN)) {
+ return WordpieceTokenizer.BERT_UNK_TOKEN;
+ }
+ throw new IllegalArgumentException(
+ "The vocabulary contains neither '" +
WordpieceTokenizer.ROBERTA_UNK_TOKEN
+ + "' nor '" + WordpieceTokenizer.BERT_UNK_TOKEN + "' as an unknown
token.");
+ }
+
+ /**
+ * Resolves the effective lower casing behavior from the
+ * given {@link InferenceOptions}.
+ *
+ * @param options The {@link InferenceOptions} to consult.
+ * @param componentDefault The default to apply if the option is not set.
+ * @return The effective lower casing behavior.
+ */
+ protected static boolean resolveLowerCase(
+ final InferenceOptions options, final boolean componentDefault) {
+ return options.getLowerCase() != null ? options.getLowerCase() :
componentDefault;
+ }
+
private Map<String, Integer> loadJsonVocab(final String json) {
final Map<String, Integer> vocab = new HashMap<>();
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java
index 606c8bc02..344c5846d 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java
@@ -25,6 +25,7 @@ public class InferenceOptions {
private int gpuDeviceId = 0;
private int documentSplitSize = 250;
private int splitOverlapSize = 50;
+ private Boolean lowerCase;
public boolean isIncludeAttentionMask() {
return includeAttentionMask;
@@ -74,4 +75,28 @@ public class InferenceOptions {
this.splitOverlapSize = splitOverlapSize;
}
+ /**
+ * Returns whether tokenization should lower case the input text and strip
+ * accents, as required by uncased models.
+ *
+ * @return {@code Boolean.TRUE} for uncased models, {@code Boolean.FALSE} for
+ * cased models, or {@code null} if not set, in which case each component
+ * applies the default that matches its commonly used models.
+ */
+ public Boolean getLowerCase() {
+ return lowerCase;
+ }
+
+ /**
+ * Sets whether tokenization should lower case the input text and strip
+ * accents. Set {@code true} for uncased models and {@code false} for cased
+ * models. If not set, each component applies the default that matches its
+ * commonly used models.
+ *
+ * @param lowerCase Whether to lower case the input text during tokenization.
+ */
+ public void setLowerCase(boolean lowerCase) {
+ this.lowerCase = lowerCase;
+ }
+
}
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
index f02dd875b..cf01631bf 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
@@ -51,6 +51,12 @@ import opennlp.tools.doccat.DocumentCategorizer;
* An implementation of {@link DocumentCategorizer} that performs document
classification
* using ONNX models.
*
+ * <p>Tokenization performs BERT basic tokenization (text normalization)
+ * before wordpiece, see {@link opennlp.tools.tokenize.BertTokenizer}. Input
+ * text is lower cased and accent stripped by default, matching the uncased
+ * models commonly used for classification. For cased models, set
+ * {@link InferenceOptions#setLowerCase(boolean)} to {@code false}.</p>
+ *
* @see DocumentCategorizer
* @see InferenceOptions
* @see ClassificationScoringStrategy
@@ -59,6 +65,9 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
private static final Logger logger =
LoggerFactory.getLogger(DocumentCategorizerDL.class);
+ /** Classification models are commonly uncased, so lower casing is the
default. */
+ private static final boolean LOWER_CASE_DEFAULT = true;
+
private final Map<Integer, String> categories;
private final ClassificationScoringStrategy classificationScoringStrategy;
private final InferenceOptions inferenceOptions;
@@ -90,7 +99,7 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
this.session = env.createSession(model.getPath(), sessionOptions);
this.vocab = loadVocab(vocabulary);
- this.tokenizer = createTokenizer(vocab);
+ this.tokenizer = createTokenizer(vocab, resolveLowerCase(inferenceOptions,
LOWER_CASE_DEFAULT));
this.categories = categories;
this.classificationScoringStrategy = classificationScoringStrategy;
this.inferenceOptions = inferenceOptions;
@@ -125,7 +134,7 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
this.session = env.createSession(model.getPath(), sessionOptions);
this.vocab = loadVocab(vocabulary);
- this.tokenizer = createTokenizer(vocab);
+ this.tokenizer = createTokenizer(vocab, resolveLowerCase(inferenceOptions,
LOWER_CASE_DEFAULT));
this.categories = readCategoriesFromFile(config);
this.classificationScoringStrategy = classificationScoringStrategy;
this.inferenceOptions = inferenceOptions;
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
index f7373700e..d2adee0b3 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
@@ -44,6 +44,13 @@ import opennlp.tools.util.Span;
/**
* An implementation of {@link TokenNameFinder} that uses ONNX models.
*
+ * <p>Tokenization performs BERT basic tokenization (text normalization)
+ * before wordpiece, see {@link opennlp.tools.tokenize.BertTokenizer}. Input
+ * text is <b>not</b> lower cased by default, because named entity recognition
+ * models are commonly cased: capitalization is a strong signal for entity
+ * boundaries. For uncased models, set
+ * {@link InferenceOptions#setLowerCase(boolean)} to {@code true}.</p>
+ *
* @see TokenNameFinder
* @see InferenceOptions
*/
@@ -53,6 +60,9 @@ public class NameFinderDL extends AbstractDL implements
TokenNameFinder {
public static final String B_PER = "B-PER";
public static final String SEPARATOR = "[SEP]";
+ /** NER models are commonly cased, so lower casing is off by default. */
+ private static final boolean LOWER_CASE_DEFAULT = false;
+
private static final String CHARS_TO_REPLACE = "##";
private final SentenceDetector sentenceDetector;
@@ -103,7 +113,7 @@ public class NameFinderDL extends AbstractDL implements
TokenNameFinder {
this.session = env.createSession(model.getPath(), sessionOptions);
this.ids2Labels = ids2Labels;
this.vocab = loadVocab(vocabulary);
- this.tokenizer = createTokenizer(vocab);
+ this.tokenizer = createTokenizer(vocab, resolveLowerCase(inferenceOptions,
LOWER_CASE_DEFAULT));
this.inferenceOptions = inferenceOptions;
this.sentenceDetector = sentenceDetector;
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
index 06ba5cbd6..c7b7fda86 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
@@ -45,14 +45,20 @@ import opennlp.tools.tokenize.Tokenizer;
* <p><b>Release note (OpenNLP 3.0.0):</b> prior releases sent an
* all-zero {@code attention_mask} and all-one {@code token_type_ids},
* so the encoder attended to nothing and the output vectors were
- * incorrect. Output vectors change with the corrected encoding; any
- * embeddings persisted from the previous behavior are not comparable
- * with the corrected output and must be re-embedded.</p>
+ * incorrect. Additionally, tokenization now performs BERT basic
+ * tokenization (lower casing and accent stripping by default, see
+ * {@link opennlp.tools.tokenize.BertTokenizer}) before wordpiece.
+ * Output vectors change with the corrected encoding and tokenization;
+ * any embeddings persisted from the previous behavior are not
+ * comparable with the corrected output and must be re-embedded.</p>
*/
public class SentenceVectorsDL extends AbstractDL {
/**
- * Instantiates a {@link SentenceVectorsDL sentence vector generator} using
ONNX models.
+ * Instantiates a {@link SentenceVectorsDL sentence vector generator} for an
+ * uncased model. Input text is lower cased and accent stripped during
+ * tokenization, as required by uncased models such as the
+ * sentence-transformers MiniLM family.
*
* @param model The file name of a sentence vectors ONNX model.
* @param vocabulary The file name of the vocabulary file for the model.
@@ -63,10 +69,28 @@ public class SentenceVectorsDL extends AbstractDL {
public SentenceVectorsDL(final File model, final File vocabulary)
throws OrtException, IOException {
+ this(model, vocabulary, true);
+
+ }
+
+ /**
+ * Instantiates a {@link SentenceVectorsDL sentence vector generator} using
ONNX models.
+ *
+ * @param model The file name of a sentence vectors ONNX model.
+ * @param vocabulary The file name of the vocabulary file for the model.
+ * @param lowerCase {@code true} for uncased models (lower casing and accent
+ * stripping during tokenization), {@code false} for cased models.
+ *
+ * @throws OrtException Thrown if the {@code model} cannot be loaded.
+ * @throws IOException Thrown if errors occurred loading the {@code model}
or {@code vocabulary}.
+ */
+ public SentenceVectorsDL(final File model, final File vocabulary, final
boolean lowerCase)
+ throws OrtException, IOException {
+
env = OrtEnvironment.getEnvironment();
session = env.createSession(model.getPath(), new
OrtSession.SessionOptions());
vocab = loadVocab(vocabulary);
- tokenizer = createTokenizer(vocab);
+ tokenizer = createTokenizer(vocab, lowerCase);
}
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/CreateTokenizerTest.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/CreateTokenizerTest.java
new file mode 100644
index 000000000..a373cb159
--- /dev/null
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/CreateTokenizerTest.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.jupiter.api.Test;
+
+import opennlp.tools.tokenize.BertTokenizer;
+import opennlp.tools.tokenize.WordpieceTokenizer;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class CreateTokenizerTest {
+
+ /** A concrete subclass to access the protected factory methods. */
+ private static class TestDL extends AbstractDL {
+ }
+
+ private static Map<String, Integer> bertVocab() {
+ final Map<String, Integer> vocab = new HashMap<>();
+ vocab.put(WordpieceTokenizer.BERT_CLS_TOKEN, 0);
+ vocab.put(WordpieceTokenizer.BERT_SEP_TOKEN, 1);
+ vocab.put(WordpieceTokenizer.BERT_UNK_TOKEN, 2);
+ vocab.put("hello", 3);
+ vocab.put("world", 4);
+ return vocab;
+ }
+
+ private static Map<String, Integer> robertaVocab() {
+ final Map<String, Integer> vocab = new HashMap<>();
+ vocab.put(WordpieceTokenizer.ROBERTA_CLS_TOKEN, 0);
+ vocab.put(WordpieceTokenizer.ROBERTA_SEP_TOKEN, 1);
+ vocab.put(WordpieceTokenizer.ROBERTA_UNK_TOKEN, 2);
+ vocab.put("hello", 3);
+ return vocab;
+ }
+
+ @Test
+ void testCreatesLowerCasingBertTokenizer() {
+ final BertTokenizer tokenizer = new TestDL().createTokenizer(bertVocab(),
true);
+
+ // Capitalized input must be lower cased before the wordpiece lookup.
+ assertArrayEquals(new String[] {
+ WordpieceTokenizer.BERT_CLS_TOKEN, "hello", "world",
WordpieceTokenizer.BERT_SEP_TOKEN},
+ tokenizer.tokenize("Hello World"));
+ }
+
+ @Test
+ void testCreatesCasePreservingBertTokenizer() {
+ final BertTokenizer tokenizer = new TestDL().createTokenizer(bertVocab(),
false);
+
+ // Without lower casing, capitalized words miss the lowercase-only
vocabulary.
+ assertArrayEquals(new String[] {
+ WordpieceTokenizer.BERT_CLS_TOKEN, WordpieceTokenizer.BERT_UNK_TOKEN,
"world",
+ WordpieceTokenizer.BERT_SEP_TOKEN},
+ tokenizer.tokenize("Hello world"));
+ }
+
+ @Test
+ void testSelectsRobertaSpecialTokens() {
+ final BertTokenizer tokenizer = new
TestDL().createTokenizer(robertaVocab(), false);
+
+ assertArrayEquals(new String[] {
+ WordpieceTokenizer.ROBERTA_CLS_TOKEN, "hello",
WordpieceTokenizer.ROBERTA_UNK_TOKEN,
+ WordpieceTokenizer.ROBERTA_SEP_TOKEN},
+ tokenizer.tokenize("hello missing"));
+ }
+
+ @Test
+ void testFallsBackToBertUnknownToken() {
+ final Map<String, Integer> vocab = robertaVocab();
+ vocab.remove(WordpieceTokenizer.ROBERTA_UNK_TOKEN);
+ vocab.put(WordpieceTokenizer.BERT_UNK_TOKEN, 2);
+
+ final BertTokenizer tokenizer = new TestDL().createTokenizer(vocab, false);
+
+ assertArrayEquals(new String[] {
+ WordpieceTokenizer.ROBERTA_CLS_TOKEN, "hello",
WordpieceTokenizer.BERT_UNK_TOKEN,
+ WordpieceTokenizer.ROBERTA_SEP_TOKEN},
+ tokenizer.tokenize("hello missing"));
+ }
+
+ @Test
+ void testRejectsRobertaVocabularyWithoutUnknownToken() {
+ final Map<String, Integer> vocab = robertaVocab();
+ vocab.remove(WordpieceTokenizer.ROBERTA_UNK_TOKEN);
+
+ final TestDL dl = new TestDL();
+ assertThrows(IllegalArgumentException.class, () ->
dl.createTokenizer(vocab, false));
+ assertThrows(IllegalArgumentException.class, () ->
dl.createTokenizer(vocab));
+ }
+
+ @Test
+ void testResolveLowerCaseUsesComponentDefaultWhenUnset() {
+ final InferenceOptions options = new InferenceOptions();
+
+ assertTrue(AbstractDL.resolveLowerCase(options, true));
+ assertFalse(AbstractDL.resolveLowerCase(options, false));
+ }
+
+ @Test
+ void testResolveLowerCaseOverridesComponentDefault() {
+ final InferenceOptions options = new InferenceOptions();
+ options.setLowerCase(false);
+ assertFalse(AbstractDL.resolveLowerCase(options, true));
+
+ options.setLowerCase(true);
+ assertTrue(AbstractDL.resolveLowerCase(options, false));
+ }
+}
diff --git
a/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
b/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
index f1d2b84d8..e045ec90e 100644
---
a/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
+++
b/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
@@ -73,11 +73,11 @@ public class DocumentCategorizerDLEval extends
AbstractEvalTest {
.sorted(Collections.reverseOrder()).mapToDouble(Double::doubleValue).toArray();
final double[] expected = new double[]
- {0.3391093313694,
- 0.2611352801322937,
- 0.24420668184757233,
- 0.11939861625432968,
- 0.03615010157227516};
+ {0.407059907913208,
+ 0.3602477014064789,
+ 0.14488528668880463,
+ 0.07669895142316818,
+ 0.011108151637017727};
logger.debug("Actual: {}", Arrays.toString(sortedResult));
logger.debug("Expected: {}", Arrays.toString(expected));
@@ -114,11 +114,11 @@ public class DocumentCategorizerDLEval extends
AbstractEvalTest {
.sorted(Collections.reverseOrder()).mapToDouble(Double::doubleValue).toArray();
final double[] expected = new double[]
- {0.3391093313694,
- 0.2611352801322937,
- 0.24420668184757233,
- 0.11939861625432968,
- 0.03615010157227516};
+ {0.407059907913208,
+ 0.3602477014064789,
+ 0.14488528668880463,
+ 0.07669895142316818,
+ 0.011108151637017727};
logger.debug("Actual: {}", Arrays.toString(sortedResult));
logger.debug("Expected: {}", Arrays.toString(expected));
@@ -154,11 +154,11 @@ public class DocumentCategorizerDLEval extends
AbstractEvalTest {
logger.debug(Arrays.toString(result));
final double[] expected = new double[]
- {0.007819971069693565,
- 0.006593209225684404,
- 0.04995147883892059,
- 0.3003573715686798,
- 0.6352779865264893};
+ {0.00752239441499114,
+ 0.0074586994014680386,
+ 0.05470007658004761,
+ 0.3344593346118927,
+ 0.5958595275878906};
Assertions.assertArrayEquals(expected, result, 0.000001);
Assertions.assertEquals(5, result.length);
@@ -191,7 +191,7 @@ public class DocumentCategorizerDLEval extends
AbstractEvalTest {
final double[] result = documentCategorizerDL.categorize(new String[]
{"I am angry"});
- final double[] expected = new double[] {0.8851314783096313,
0.11486853659152985};
+ final double[] expected = new double[] {0.9072678089141846,
0.09273219853639603};
Assertions.assertArrayEquals(expected, result, 0.000001);
Assertions.assertEquals(2, result.length);
@@ -216,11 +216,11 @@ public class DocumentCategorizerDLEval extends
AbstractEvalTest {
final Map<String, Double> result = documentCategorizerDL.scoreMap(new
String[] {"I am happy"});
- Assertions.assertEquals(0.6352779865264893, result.get("very good"),
0.000001);
- Assertions.assertEquals(0.3003573715686798, result.get("good"),
0.000001);
- Assertions.assertEquals(0.04995147883892059, result.get("neutral"),
0.000001);
- Assertions.assertEquals(0.006593209225684404, result.get("bad"),
0.000001);
- Assertions.assertEquals(0.007819971069693565, result.get("very bad"),
0.000001);
+ Assertions.assertEquals(0.5958595275878906, result.get("very good"),
0.000001);
+ Assertions.assertEquals(0.3344593346118927, result.get("good"),
0.000001);
+ Assertions.assertEquals(0.05470007658004761, result.get("neutral"),
0.000001);
+ Assertions.assertEquals(0.0074586994014680386, result.get("bad"),
0.000001);
+ Assertions.assertEquals(0.00752239441499114, result.get("very bad"),
0.000001);
}
}
@@ -248,23 +248,23 @@ public class DocumentCategorizerDLEval extends
AbstractEvalTest {
// we assume a sorted map here, so lets check in sorted order (lower
values first).
Map.Entry<Double, Set<String>> e = it.next();
- Assertions.assertEquals(0.006593209225684404, e.getKey(), 0.000001);
+ Assertions.assertEquals(0.0074586994014680386, e.getKey(), 0.000001);
Assertions.assertEquals(e.getValue().size(), 1);
e = it.next();
- Assertions.assertEquals(0.007819971069693565, e.getKey(), 0.000001);
+ Assertions.assertEquals(0.00752239441499114, e.getKey(), 0.000001);
Assertions.assertEquals(e.getValue().size(), 1);
e = it.next();
- Assertions.assertEquals(0.04995147883892059, e.getKey(), 0.000001);
+ Assertions.assertEquals(0.05470007658004761, e.getKey(), 0.000001);
Assertions.assertEquals(e.getValue().size(), 1);
e = it.next();
- Assertions.assertEquals(0.3003573715686798, e.getKey(), 0.000001);
+ Assertions.assertEquals(0.3344593346118927, e.getKey(), 0.000001);
Assertions.assertEquals(e.getValue().size(), 1);
e = it.next();
- Assertions.assertEquals(0.6352779865264893, e.getKey(), 0.000001);
+ Assertions.assertEquals(0.5958595275878906, e.getKey(), 0.000001);
Assertions.assertEquals(e.getValue().size(), 1);
}
diff --git
a/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
b/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
index 47976648e..ed4c20d06 100644
---
a/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
+++
b/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
@@ -61,6 +61,13 @@ public class SentenceVectorsDLEval extends AbstractEvalTest {
Assertions.assertEquals(0.20219636, vectors[1], 0.00001);
Assertions.assertEquals(0.41306049, vectors[2], 0.00001);
Assertions.assertEquals(384, vectors.length);
+
+ // The uncased model lower cases during tokenization, so a capitalized
+ // variant must produce the same vectors. Prior to BERT basic
+ // tokenization, every capitalized word was mapped to [UNK].
+ final float[] capitalized = sv.getVectors("George Washington was
President");
+
+ Assertions.assertArrayEquals(vectors, capitalized, 0.00001f);
}
}