This is an automated email from the ASF dual-hosted git repository.
mawiesne pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/main by this push:
new bd73ea13 OPENNLP-1518: Roberta-based Models - Add support for
utilization via Onnx (#998)
bd73ea13 is described below
commit bd73ea1394dcc5a46cbd71f9b4034bde0e2fe89d
Author: Richard Zowalla <[email protected]>
AuthorDate: Tue Mar 31 08:57:38 2026 +0200
OPENNLP-1518: Roberta-based Models - Add support for utilization via Onnx
(#998)
---
.../opennlp/tools/tokenize/WordpieceTokenizer.java | 58 ++++++++++--
.../src/main/java/opennlp/dl/AbstractDL.java | 89 ++++++++++++++++--
.../opennlp/dl/doccat/DocumentCategorizerDL.java | 23 +++--
.../java/opennlp/dl/namefinder/NameFinderDL.java | 3 +-
.../java/opennlp/dl/vectors/SentenceVectorsDL.java | 4 +-
.../src/test/java/opennlp/dl/LoadVocabTest.java | 102 +++++++++++++++++++++
.../src/test/resources/opennlp/dl/vocab-plain.txt | 6 ++
.../src/test/resources/opennlp/dl/vocab.json | 8 ++
8 files changed, 266 insertions(+), 27 deletions(-)
diff --git
a/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java
b/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java
index 1cf9aa0c..4d92c86b 100644
--- a/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java
+++ b/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java
@@ -45,12 +45,27 @@ import opennlp.tools.util.Span;
*/
public class WordpieceTokenizer implements Tokenizer {
- private static final Pattern PUNCTUATION_PATTERN =
Pattern.compile("\\p{Punct}+");
- private static final String CLASSIFICATION_TOKEN = "[CLS]";
- private static final String SEPARATOR_TOKEN = "[SEP]";
- private static final String UNKNOWN_TOKEN = "[UNK]";
+ /** BERT classification token: {@code [CLS]}. */
+ public static final String BERT_CLS_TOKEN = "[CLS]";
+ /** BERT separator token: {@code [SEP]}. */
+ public static final String BERT_SEP_TOKEN = "[SEP]";
+ /** BERT unknown token: {@code [UNK]}. */
+ public static final String BERT_UNK_TOKEN = "[UNK]";
+
+ /** RoBERTa classification token: {@code <s>}. */
+ public static final String ROBERTA_CLS_TOKEN = "<s>";
+ /** RoBERTa separator token. */
+ public static final String ROBERTA_SEP_TOKEN = "</s>";
+ /** RoBERTa unknown token. */
+ public static final String ROBERTA_UNK_TOKEN = "<unk>";
+
+ private static final Pattern PUNCTUATION_PATTERN =
+ Pattern.compile("\\p{Punct}+");
private final Set<String> vocabulary;
+ private final String classificationToken;
+ private final String separatorToken;
+ private final String unknownToken;
private int maxTokenLength = 50;
/**
@@ -60,7 +75,7 @@ public class WordpieceTokenizer implements Tokenizer {
* @param vocabulary A set of tokens considered the vocabulary.
*/
public WordpieceTokenizer(Set<String> vocabulary) {
- this.vocabulary = vocabulary;
+ this(vocabulary, BERT_CLS_TOKEN, BERT_SEP_TOKEN, BERT_UNK_TOKEN);
}
/**
@@ -75,6 +90,29 @@ public class WordpieceTokenizer implements Tokenizer {
this.maxTokenLength = maxTokenLength;
}
+ /**
+ * Initializes a {@link WordpieceTokenizer} with a
+ * {@code vocabulary} and custom special tokens.
+ * This allows support for models like RoBERTa that
+ * use different special tokens instead of the BERT
+ * defaults.
+ *
+ * @param vocabulary The vocabulary.
+ * @param classificationToken The CLS token.
+ * @param separatorToken The SEP token.
+ * @param unknownToken The UNK token.
+ */
+ public WordpieceTokenizer(
+ final Set<String> vocabulary,
+ final String classificationToken,
+ final String separatorToken,
+ final String unknownToken) {
+ this.vocabulary = vocabulary;
+ this.classificationToken = classificationToken;
+ this.separatorToken = separatorToken;
+ this.unknownToken = unknownToken;
+ }
+
@Override
public Span[] tokenizePos(final String text) {
// TODO: Implement this.
@@ -85,7 +123,7 @@ public class WordpieceTokenizer implements Tokenizer {
public String[] tokenize(final String text) {
final List<String> tokens = new LinkedList<>();
- tokens.add(CLASSIFICATION_TOKEN);
+ tokens.add(classificationToken);
// Put spaces around punctuation.
final String spacedPunctuation =
PUNCTUATION_PATTERN.matcher(text).replaceAll(" $0 ");
@@ -146,7 +184,7 @@ public class WordpieceTokenizer implements Tokenizer {
// If the word can't be represented by vocabulary pieces replace
// it with a specified "unknown" token.
if (!found) {
- tokens.add(UNKNOWN_TOKEN);
+ tokens.add(unknownToken);
break;
}
@@ -157,14 +195,14 @@ public class WordpieceTokenizer implements Tokenizer {
} else {
- // If the token's length is greater than the max length just add [UNK]
instead.
- tokens.add(UNKNOWN_TOKEN);
+ // If the token's length is greater than the max length just add
unknown token instead.
+ tokens.add(unknownToken);
}
}
- tokens.add(SEPARATOR_TOKEN);
+ tokens.add(separatorToken);
return tokens.toArray(new String[0]);
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
index eb8c41ce..d46d68a6 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
@@ -19,11 +19,14 @@ package opennlp.dl;
import java.io.File;
import java.io.IOException;
+import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
import java.util.stream.Stream;
import ai.onnxruntime.OrtEnvironment;
@@ -31,6 +34,7 @@ import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import opennlp.tools.tokenize.Tokenizer;
+import opennlp.tools.tokenize.WordpieceTokenizer;
/**
* Base class for OpenNLP deep-learning classes using ONNX Runtime.
@@ -46,21 +50,92 @@ public abstract class AbstractDL implements AutoCloseable {
protected Tokenizer tokenizer;
protected Map<String, Integer> vocab;
+ private static final Pattern JSON_ENTRY_PATTERN =
+ Pattern.compile("\"((?:[^\"\\\\]|\\\\.)*)\"\\s*:\\s*(\\d+)");
+
/**
* Loads a vocabulary {@link File} from disk.
+ * Supports both plain text files (one token per
+ * line) and JSON files mapping tokens to integer
+ * IDs.
*
* @param vocabFile The vocabulary file.
- * @return A map of vocabulary words to integer IDs.
- * @throws IOException Thrown if the vocabulary file cannot be opened or
read.
+ * @return A map of vocabulary words to IDs.
+ * @throws IOException Thrown if the vocabulary
+ * file cannot be opened or read.
*/
- public Map<String, Integer> loadVocab(final File vocabFile) throws
IOException {
+ public Map<String, Integer> loadVocab(
+ final File vocabFile) throws IOException {
- final Map<String, Integer> vocab = new HashMap<>();
- final AtomicInteger counter = new AtomicInteger(0);
+ final Path vocabPath =
+ Path.of(vocabFile.getPath());
+ final String content = Files.readString(
+ vocabPath, StandardCharsets.UTF_8);
+ final String trimmed = content.trim();
+
+ // Detect JSON format by leading brace
+ if (trimmed.startsWith("{")) {
+ return loadJsonVocab(trimmed);
+ }
+
+ final Map<String, Integer> vocab =
+ new HashMap<>();
+ final AtomicInteger counter =
+ new AtomicInteger(0);
+
+ try (Stream<String> lines = Files.lines(
+ vocabPath, StandardCharsets.UTF_8)) {
+ lines.forEach(line ->
+ vocab.put(line, counter.getAndIncrement())
+ );
+ }
+
+ return vocab;
+ }
- try (Stream<String> lines = Files.lines(Path.of(vocabFile.getPath()))) {
+ /**
+ * Creates a {@link WordpieceTokenizer} that uses the
+ * appropriate special tokens based on the vocabulary.
+ * If the vocabulary contains RoBERTa-style tokens,
+ * those are used. Otherwise, the BERT defaults are
+ * used.
+ *
+ * @param vocab The vocabulary map.
+ * @return A configured {@link WordpieceTokenizer}.
+ */
+ protected WordpieceTokenizer createTokenizer(
+ final Map<String, Integer> vocab) {
+ if (vocab.containsKey(
+ WordpieceTokenizer.ROBERTA_CLS_TOKEN)
+ && vocab.containsKey(
+ WordpieceTokenizer.ROBERTA_SEP_TOKEN)) {
+ final String unk = vocab.containsKey(
+ WordpieceTokenizer.ROBERTA_UNK_TOKEN)
+ ? WordpieceTokenizer.ROBERTA_UNK_TOKEN
+ : WordpieceTokenizer.BERT_UNK_TOKEN;
+ return new WordpieceTokenizer(
+ vocab.keySet(),
+ WordpieceTokenizer.ROBERTA_CLS_TOKEN,
+ WordpieceTokenizer.ROBERTA_SEP_TOKEN,
+ unk);
+ }
+ return new WordpieceTokenizer(vocab.keySet());
+ }
+
+ private Map<String, Integer> loadJsonVocab(final String json) {
+
+ final Map<String, Integer> vocab = new HashMap<>();
+ final Matcher matcher = JSON_ENTRY_PATTERN.matcher(json);
- lines.forEach(line -> vocab.put(line, counter.getAndIncrement()));
+ while (matcher.find()) {
+ final String token = matcher.group(1)
+ .replace("\\\"", "\"")
+ .replace("\\\\", "\\")
+ .replace("\\/", "/")
+ .replace("\\n", "\n")
+ .replace("\\t", "\t");
+ final int id = Integer.parseInt(matcher.group(2));
+ vocab.put(token, id);
}
return vocab;
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
index 9173f30e..a0c9ede7 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
@@ -45,7 +45,7 @@ import opennlp.dl.InferenceOptions;
import opennlp.dl.Tokens;
import opennlp.dl.doccat.scoring.ClassificationScoringStrategy;
import opennlp.tools.doccat.DocumentCategorizer;
-import opennlp.tools.tokenize.WordpieceTokenizer;
+
/**
* An implementation of {@link DocumentCategorizer} that performs document
classification
@@ -90,7 +90,7 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
this.session = env.createSession(model.getPath(), sessionOptions);
this.vocab = loadVocab(vocabulary);
- this.tokenizer = new WordpieceTokenizer(vocab.keySet());
+ this.tokenizer = createTokenizer(vocab);
this.categories = categories;
this.classificationScoringStrategy = classificationScoringStrategy;
this.inferenceOptions = inferenceOptions;
@@ -125,7 +125,7 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
this.session = env.createSession(model.getPath(), sessionOptions);
this.vocab = loadVocab(vocabulary);
- this.tokenizer = new WordpieceTokenizer(vocab.keySet());
+ this.tokenizer = createTokenizer(vocab);
this.categories = readCategoriesFromFile(config);
this.classificationScoringStrategy = classificationScoringStrategy;
this.inferenceOptions = inferenceOptions;
@@ -158,11 +158,22 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
LongBuffer.wrap(t.types()), new long[] {1, t.types().length}));
}
- // The outputs from the model.
- final float[][] v = (float[][]) session.run(inputs).get(0).getValue();
+ // The outputs from the model. Some models return a 2D array (e.g.
BERT),
+ // while others return a 1D array (e.g. RoBERTa).
+ final Object output = session.run(inputs).get(0).getValue();
+
+ final float[] rawScores;
+ if (output instanceof float[][] v) {
+ rawScores = v[0];
+ } else if (output instanceof float[] v) {
+ rawScores = v;
+ } else {
+ throw new IllegalStateException(
+ "Unexpected model output type: " + output.getClass().getName());
+ }
// Keep track of all scores.
- final double[] categoryScoresForTokens = softmax(v[0]);
+ final double[] categoryScoresForTokens = softmax(rawScores);
scores.add(categoryScoresForTokens);
}
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
index 3cbf0e2a..74e5a1aa 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
@@ -39,7 +39,6 @@ import opennlp.dl.SpanEnd;
import opennlp.dl.Tokens;
import opennlp.tools.namefind.TokenNameFinder;
import opennlp.tools.sentdetect.SentenceDetector;
-import opennlp.tools.tokenize.WordpieceTokenizer;
import opennlp.tools.util.Span;
/**
@@ -104,7 +103,7 @@ public class NameFinderDL extends AbstractDL implements
TokenNameFinder {
this.session = env.createSession(model.getPath(), sessionOptions);
this.ids2Labels = ids2Labels;
this.vocab = loadVocab(vocabulary);
- this.tokenizer = new WordpieceTokenizer(vocab.keySet());
+ this.tokenizer = createTokenizer(vocab);
this.inferenceOptions = inferenceOptions;
this.sentenceDetector = sentenceDetector;
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
index 805b4118..85abfbe4 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
@@ -32,7 +32,7 @@ import ai.onnxruntime.OrtSession;
import opennlp.dl.AbstractDL;
import opennlp.dl.Tokens;
import opennlp.tools.tokenize.Tokenizer;
-import opennlp.tools.tokenize.WordpieceTokenizer;
+
/**
* Facilitates the generation of sentence vectors using
@@ -55,7 +55,7 @@ public class SentenceVectorsDL extends AbstractDL {
env = OrtEnvironment.getEnvironment();
session = env.createSession(model.getPath(), new
OrtSession.SessionOptions());
vocab = loadVocab(new File(vocabulary.getPath()));
- tokenizer = new WordpieceTokenizer(vocab.keySet());
+ tokenizer = createTokenizer(vocab);
}
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java
new file mode 100644
index 00000000..d8554c3f
--- /dev/null
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.StandardCopyOption;
+import java.util.Map;
+import java.util.Objects;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+
+public class LoadVocabTest {
+
+ private final AbstractDL dl = new AbstractDL() {
+ @Override
+ public void close() {
+ }
+ };
+
+ private File getResource(String name) throws IOException {
+ try (InputStream is = Objects.requireNonNull(
+ getClass().getResourceAsStream("/opennlp/dl/" + name))) {
+ final File tempFile = File.createTempFile("vocab-test-", "-" + name);
+ tempFile.deleteOnExit();
+ Files.copy(is, tempFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
+ return tempFile;
+ }
+ }
+
+ @Test
+ void testLoadPlainTextVocab() throws IOException {
+ final Map<String, Integer> vocab =
dl.loadVocab(getResource("vocab-plain.txt"));
+
+ assertNotNull(vocab);
+ assertEquals(6, vocab.size());
+ assertEquals(0, vocab.get("[CLS]"));
+ assertEquals(1, vocab.get("[SEP]"));
+ assertEquals(2, vocab.get("[UNK]"));
+ assertEquals(3, vocab.get("hello"));
+ assertEquals(4, vocab.get("world"));
+ assertEquals(5, vocab.get("##ing"));
+ }
+
+ @Test
+ void testLoadJsonVocab() throws IOException {
+ final Map<String, Integer> vocab = dl.loadVocab(getResource("vocab.json"));
+
+ assertNotNull(vocab);
+ assertEquals(6, vocab.size());
+ assertEquals(0, vocab.get("[CLS]"));
+ assertEquals(1, vocab.get("[SEP]"));
+ assertEquals(2, vocab.get("[UNK]"));
+ assertEquals(3, vocab.get("hello"));
+ assertEquals(4, vocab.get("world"));
+ assertEquals(5, vocab.get("##ing"));
+ }
+
+ @Test
+ void testJsonVocabWithEscapedCharacters() throws IOException {
+ final File tempFile = File.createTempFile("vocab-escaped", ".json");
+ tempFile.deleteOnExit();
+
+ Files.writeString(tempFile.toPath(),
+ "{\"hello\\\"world\": 0, \"back\\\\slash\": 1}");
+
+ final Map<String, Integer> vocab = dl.loadVocab(tempFile);
+
+ assertNotNull(vocab);
+ assertEquals(2, vocab.size());
+ assertEquals(0, vocab.get("hello\"world"));
+ assertEquals(1, vocab.get("back\\slash"));
+ }
+
+ @Test
+ void testJsonAndPlainTextVocabProduceSameResult() throws IOException {
+ final Map<String, Integer> plainVocab =
dl.loadVocab(getResource("vocab-plain.txt"));
+ final Map<String, Integer> jsonVocab =
dl.loadVocab(getResource("vocab.json"));
+
+ assertEquals(plainVocab, jsonVocab);
+ }
+}
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt
new file mode 100644
index 00000000..2f458c14
--- /dev/null
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt
@@ -0,0 +1,6 @@
+[CLS]
+[SEP]
+[UNK]
+hello
+world
+##ing
\ No newline at end of file
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab.json
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab.json
new file mode 100644
index 00000000..2e62cee3
--- /dev/null
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab.json
@@ -0,0 +1,8 @@
+{
+ "[CLS]": 0,
+ "[SEP]": 1,
+ "[UNK]": 2,
+ "hello": 3,
+ "world": 4,
+ "##ing": 5
+}
\ No newline at end of file