This is an automated email from the ASF dual-hosted git repository. rzo1 pushed a commit to branch OPENNLP-1518 in repository https://gitbox.apache.org/repos/asf/opennlp.git
commit 2d95220ed2f6a482f170234a49b6f28e2c86bd0a Author: Richard Zowalla <[email protected]> AuthorDate: Thu Mar 26 20:49:49 2026 +0100 OPENNLP-1518 - Roberta-based Models - Add support for utilization via Onnx --- .../opennlp/tools/tokenize/WordpieceTokenizer.java | 46 +++++++++-- .../src/main/java/opennlp/dl/AbstractDL.java | 60 +++++++++++++- .../opennlp/dl/doccat/DocumentCategorizerDL.java | 23 ++++-- .../java/opennlp/dl/namefinder/NameFinderDL.java | 4 +- .../java/opennlp/dl/vectors/SentenceVectorsDL.java | 4 +- .../src/test/java/opennlp/dl/LoadVocabTest.java | 95 ++++++++++++++++++++++ .../src/test/resources/opennlp/dl/vocab-plain.txt | 6 ++ .../src/test/resources/opennlp/dl/vocab.json | 8 ++ 8 files changed, 225 insertions(+), 21 deletions(-) diff --git a/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java b/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java index 1cf9aa0c..3e59f830 100644 --- a/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java +++ b/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java @@ -46,11 +46,21 @@ import opennlp.tools.util.Span; public class WordpieceTokenizer implements Tokenizer { private static final Pattern PUNCTUATION_PATTERN = Pattern.compile("\\p{Punct}+"); - private static final String CLASSIFICATION_TOKEN = "[CLS]"; - private static final String SEPARATOR_TOKEN = "[SEP]"; - private static final String UNKNOWN_TOKEN = "[UNK]"; + + // BERT special tokens + public static final String BERT_CLS_TOKEN = "[CLS]"; + public static final String BERT_SEP_TOKEN = "[SEP]"; + public static final String BERT_UNK_TOKEN = "[UNK]"; + + // RoBERTa special tokens + public static final String ROBERTA_CLS_TOKEN = "<s>"; + public static final String ROBERTA_SEP_TOKEN = "</s>"; + public static final String ROBERTA_UNK_TOKEN = "<unk>"; private final Set<String> vocabulary; + private final String classificationToken; + private final String separatorToken; + private final String unknownToken; private int maxTokenLength = 50; /** @@ -60,7 +70,7 @@ public class WordpieceTokenizer implements Tokenizer { * @param vocabulary A set of tokens considered the vocabulary. */ public WordpieceTokenizer(Set<String> vocabulary) { - this.vocabulary = vocabulary; + this(vocabulary, BERT_CLS_TOKEN, BERT_SEP_TOKEN, BERT_UNK_TOKEN); } /** @@ -75,6 +85,24 @@ public class WordpieceTokenizer implements Tokenizer { this.maxTokenLength = maxTokenLength; } + /** + * Initializes a {@link WordpieceTokenizer} with a {@code vocabulary} and custom special tokens. + * This allows support for models like RoBERTa that use different special tokens + * (e.g. {@code <s>}, {@code </s>}, {@code <unk>}) instead of the BERT defaults. + * + * @param vocabulary A set of tokens considered the vocabulary. + * @param classificationToken The token to prepend (e.g. {@code [CLS]} or {@code <s>}). + * @param separatorToken The token to append (e.g. {@code [SEP]} or {@code </s>}). + * @param unknownToken The token for unknown words (e.g. {@code [UNK]} or {@code <unk>}). + */ + public WordpieceTokenizer(Set<String> vocabulary, String classificationToken, + String separatorToken, String unknownToken) { + this.vocabulary = vocabulary; + this.classificationToken = classificationToken; + this.separatorToken = separatorToken; + this.unknownToken = unknownToken; + } + @Override public Span[] tokenizePos(final String text) { // TODO: Implement this. @@ -85,7 +113,7 @@ public class WordpieceTokenizer implements Tokenizer { public String[] tokenize(final String text) { final List<String> tokens = new LinkedList<>(); - tokens.add(CLASSIFICATION_TOKEN); + tokens.add(classificationToken); // Put spaces around punctuation. final String spacedPunctuation = PUNCTUATION_PATTERN.matcher(text).replaceAll(" $0 "); @@ -146,7 +174,7 @@ public class WordpieceTokenizer implements Tokenizer { // If the word can't be represented by vocabulary pieces replace // it with a specified "unknown" token. if (!found) { - tokens.add(UNKNOWN_TOKEN); + tokens.add(unknownToken); break; } @@ -157,14 +185,14 @@ public class WordpieceTokenizer implements Tokenizer { } else { - // If the token's length is greater than the max length just add [UNK] instead. - tokens.add(UNKNOWN_TOKEN); + // If the token's length is greater than the max length just add unknown token instead. + tokens.add(unknownToken); } } - tokens.add(SEPARATOR_TOKEN); + tokens.add(separatorToken); return tokens.toArray(new String[0]); diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java index eb8c41ce..b95c034d 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java @@ -19,11 +19,14 @@ package opennlp.dl; import java.io.File; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Stream; import ai.onnxruntime.OrtEnvironment; @@ -31,6 +34,7 @@ import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import opennlp.tools.tokenize.Tokenizer; +import opennlp.tools.tokenize.WordpieceTokenizer; /** * Base class for OpenNLP deep-learning classes using ONNX Runtime. @@ -46,8 +50,13 @@ public abstract class AbstractDL implements AutoCloseable { protected Tokenizer tokenizer; protected Map<String, Integer> vocab; + private static final Pattern JSON_ENTRY_PATTERN = + Pattern.compile("\"((?:[^\"\\\\]|\\\\.)*)\"\\s*:\\s*(\\d+)"); + /** * Loads a vocabulary {@link File} from disk. + * Supports both plain text files (one token per line) and JSON files + * mapping tokens to integer IDs (e.g. {@code {"token": 0, ...}}). * * @param vocabFile The vocabulary file. * @return A map of vocabulary words to integer IDs. @@ -55,17 +64,64 @@ public abstract class AbstractDL implements AutoCloseable { */ public Map<String, Integer> loadVocab(final File vocabFile) throws IOException { + final Path vocabPath = Path.of(vocabFile.getPath()); + final String content = Files.readString(vocabPath, StandardCharsets.UTF_8); + final String trimmed = content.trim(); + + if (trimmed.startsWith("{")) { //we avoid using a JSON parser lib, so we do a quick'n-dirty check for JSON format and try to parse it with a regex if it looks like JSON + return loadJsonVocab(trimmed); + } + final Map<String, Integer> vocab = new HashMap<>(); final AtomicInteger counter = new AtomicInteger(0); - try (Stream<String> lines = Files.lines(Path.of(vocabFile.getPath()))) { - + try (Stream<String> lines = Files.lines(vocabPath, StandardCharsets.UTF_8)) { lines.forEach(line -> vocab.put(line, counter.getAndIncrement())); } return vocab; } + /** + * Creates a {@link WordpieceTokenizer} that uses the appropriate special tokens + * based on the vocabulary. If the vocabulary contains RoBERTa-style tokens + * ({@code <s>}, {@code </s>}, {@code <unk>}), those are used. Otherwise, + * the BERT defaults ({@code [CLS]}, {@code [SEP]}, {@code [UNK]}) are used. + * + * @param vocab The vocabulary map. + * @return A configured {@link WordpieceTokenizer}. + */ + protected WordpieceTokenizer createTokenizer(final Map<String, Integer> vocab) { + if (vocab.containsKey(WordpieceTokenizer.ROBERTA_CLS_TOKEN) + && vocab.containsKey(WordpieceTokenizer.ROBERTA_SEP_TOKEN)) { + return new WordpieceTokenizer(vocab.keySet(), + WordpieceTokenizer.ROBERTA_CLS_TOKEN, + WordpieceTokenizer.ROBERTA_SEP_TOKEN, + vocab.containsKey(WordpieceTokenizer.ROBERTA_UNK_TOKEN) + ? WordpieceTokenizer.ROBERTA_UNK_TOKEN : WordpieceTokenizer.BERT_UNK_TOKEN); + } + return new WordpieceTokenizer(vocab.keySet()); + } + + private Map<String, Integer> loadJsonVocab(final String json) { + + final Map<String, Integer> vocab = new HashMap<>(); + final Matcher matcher = JSON_ENTRY_PATTERN.matcher(json); + + while (matcher.find()) { + final String token = matcher.group(1) + .replace("\\\"", "\"") + .replace("\\\\", "\\") + .replace("\\/", "/") + .replace("\\n", "\n") + .replace("\\t", "\t"); + final int id = Integer.parseInt(matcher.group(2)); + vocab.put(token, id); + } + + return vocab; + } + /** * Closes this resource, relinquishing any underlying resources. * diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java index 9173f30e..a0c9ede7 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java @@ -45,7 +45,7 @@ import opennlp.dl.InferenceOptions; import opennlp.dl.Tokens; import opennlp.dl.doccat.scoring.ClassificationScoringStrategy; import opennlp.tools.doccat.DocumentCategorizer; -import opennlp.tools.tokenize.WordpieceTokenizer; + /** * An implementation of {@link DocumentCategorizer} that performs document classification @@ -90,7 +90,7 @@ public class DocumentCategorizerDL extends AbstractDL implements DocumentCategor this.session = env.createSession(model.getPath(), sessionOptions); this.vocab = loadVocab(vocabulary); - this.tokenizer = new WordpieceTokenizer(vocab.keySet()); + this.tokenizer = createTokenizer(vocab); this.categories = categories; this.classificationScoringStrategy = classificationScoringStrategy; this.inferenceOptions = inferenceOptions; @@ -125,7 +125,7 @@ public class DocumentCategorizerDL extends AbstractDL implements DocumentCategor this.session = env.createSession(model.getPath(), sessionOptions); this.vocab = loadVocab(vocabulary); - this.tokenizer = new WordpieceTokenizer(vocab.keySet()); + this.tokenizer = createTokenizer(vocab); this.categories = readCategoriesFromFile(config); this.classificationScoringStrategy = classificationScoringStrategy; this.inferenceOptions = inferenceOptions; @@ -158,11 +158,22 @@ public class DocumentCategorizerDL extends AbstractDL implements DocumentCategor LongBuffer.wrap(t.types()), new long[] {1, t.types().length})); } - // The outputs from the model. - final float[][] v = (float[][]) session.run(inputs).get(0).getValue(); + // The outputs from the model. Some models return a 2D array (e.g. BERT), + // while others return a 1D array (e.g. RoBERTa). + final Object output = session.run(inputs).get(0).getValue(); + + final float[] rawScores; + if (output instanceof float[][] v) { + rawScores = v[0]; + } else if (output instanceof float[] v) { + rawScores = v; + } else { + throw new IllegalStateException( + "Unexpected model output type: " + output.getClass().getName()); + } // Keep track of all scores. - final double[] categoryScoresForTokens = softmax(v[0]); + final double[] categoryScoresForTokens = softmax(rawScores); scores.add(categoryScoresForTokens); } diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java index 3cbf0e2a..9f4e1030 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java @@ -39,7 +39,7 @@ import opennlp.dl.SpanEnd; import opennlp.dl.Tokens; import opennlp.tools.namefind.TokenNameFinder; import opennlp.tools.sentdetect.SentenceDetector; -import opennlp.tools.tokenize.WordpieceTokenizer; + import opennlp.tools.util.Span; /** @@ -104,7 +104,7 @@ public class NameFinderDL extends AbstractDL implements TokenNameFinder { this.session = env.createSession(model.getPath(), sessionOptions); this.ids2Labels = ids2Labels; this.vocab = loadVocab(vocabulary); - this.tokenizer = new WordpieceTokenizer(vocab.keySet()); + this.tokenizer = createTokenizer(vocab); this.inferenceOptions = inferenceOptions; this.sentenceDetector = sentenceDetector; diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java index 805b4118..85abfbe4 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java @@ -32,7 +32,7 @@ import ai.onnxruntime.OrtSession; import opennlp.dl.AbstractDL; import opennlp.dl.Tokens; import opennlp.tools.tokenize.Tokenizer; -import opennlp.tools.tokenize.WordpieceTokenizer; + /** * Facilitates the generation of sentence vectors using @@ -55,7 +55,7 @@ public class SentenceVectorsDL extends AbstractDL { env = OrtEnvironment.getEnvironment(); session = env.createSession(model.getPath(), new OrtSession.SessionOptions()); vocab = loadVocab(new File(vocabulary.getPath())); - tokenizer = new WordpieceTokenizer(vocab.keySet()); + tokenizer = createTokenizer(vocab); } diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java new file mode 100644 index 00000000..04e736d4 --- /dev/null +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.dl; + +import java.io.File; +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class LoadVocabTest { + + private final AbstractDL dl = new AbstractDL() { + @Override + public void close() { + } + }; + + private File getResource(String name) { + return new File(Objects.requireNonNull( + getClass().getResource("/opennlp/dl/" + name)).getFile()); + } + + @Test + void testLoadPlainTextVocab() throws IOException { + final Map<String, Integer> vocab = dl.loadVocab(getResource("vocab-plain.txt")); + + assertNotNull(vocab); + assertEquals(6, vocab.size()); + assertEquals(0, vocab.get("[CLS]")); + assertEquals(1, vocab.get("[SEP]")); + assertEquals(2, vocab.get("[UNK]")); + assertEquals(3, vocab.get("hello")); + assertEquals(4, vocab.get("world")); + assertEquals(5, vocab.get("##ing")); + } + + @Test + void testLoadJsonVocab() throws IOException { + final Map<String, Integer> vocab = dl.loadVocab(getResource("vocab.json")); + + assertNotNull(vocab); + assertEquals(6, vocab.size()); + assertEquals(0, vocab.get("[CLS]")); + assertEquals(1, vocab.get("[SEP]")); + assertEquals(2, vocab.get("[UNK]")); + assertEquals(3, vocab.get("hello")); + assertEquals(4, vocab.get("world")); + assertEquals(5, vocab.get("##ing")); + } + + @Test + void testJsonVocabWithEscapedCharacters() throws IOException { + // Write a temp file with escaped characters + final File tempFile = File.createTempFile("vocab-escaped", ".json"); + tempFile.deleteOnExit(); + + java.nio.file.Files.writeString(tempFile.toPath(), + "{\"hello\\\"world\": 0, \"back\\\\slash\": 1}"); + + final Map<String, Integer> vocab = dl.loadVocab(tempFile); + + assertNotNull(vocab); + assertEquals(2, vocab.size()); + assertEquals(0, vocab.get("hello\"world")); + assertEquals(1, vocab.get("back\\slash")); + } + + @Test + void testJsonAndPlainTextVocabProduceSameResult() throws IOException { + final Map<String, Integer> plainVocab = dl.loadVocab(getResource("vocab-plain.txt")); + final Map<String, Integer> jsonVocab = dl.loadVocab(getResource("vocab.json")); + + assertEquals(plainVocab, jsonVocab); + } +} diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt new file mode 100644 index 00000000..2f458c14 --- /dev/null +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt @@ -0,0 +1,6 @@ +[CLS] +[SEP] +[UNK] +hello +world +##ing \ No newline at end of file diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab.json b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab.json new file mode 100644 index 00000000..2e62cee3 --- /dev/null +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab.json @@ -0,0 +1,8 @@ +{ + "[CLS]": 0, + "[SEP]": 1, + "[UNK]": 2, + "hello": 3, + "world": 4, + "##ing": 5 +} \ No newline at end of file
