This is an automated email from the ASF dual-hosted git repository.

rzo1 pushed a commit to branch OPENNLP-1518
in repository https://gitbox.apache.org/repos/asf/opennlp.git

commit 2d95220ed2f6a482f170234a49b6f28e2c86bd0a
Author: Richard Zowalla <[email protected]>
AuthorDate: Thu Mar 26 20:49:49 2026 +0100

    OPENNLP-1518 - Roberta-based Models - Add support for utilization via Onnx
---
 .../opennlp/tools/tokenize/WordpieceTokenizer.java | 46 +++++++++--
 .../src/main/java/opennlp/dl/AbstractDL.java       | 60 +++++++++++++-
 .../opennlp/dl/doccat/DocumentCategorizerDL.java   | 23 ++++--
 .../java/opennlp/dl/namefinder/NameFinderDL.java   |  4 +-
 .../java/opennlp/dl/vectors/SentenceVectorsDL.java |  4 +-
 .../src/test/java/opennlp/dl/LoadVocabTest.java    | 95 ++++++++++++++++++++++
 .../src/test/resources/opennlp/dl/vocab-plain.txt  |  6 ++
 .../src/test/resources/opennlp/dl/vocab.json       |  8 ++
 8 files changed, 225 insertions(+), 21 deletions(-)

diff --git 
a/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java 
b/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java
index 1cf9aa0c..3e59f830 100644
--- a/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java
+++ b/opennlp-api/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java
@@ -46,11 +46,21 @@ import opennlp.tools.util.Span;
 public class WordpieceTokenizer implements Tokenizer {
 
   private static final Pattern PUNCTUATION_PATTERN = 
Pattern.compile("\\p{Punct}+");
-  private static final String CLASSIFICATION_TOKEN = "[CLS]";
-  private static final String SEPARATOR_TOKEN = "[SEP]";
-  private static final String UNKNOWN_TOKEN = "[UNK]";
+
+  // BERT special tokens
+  public static final String BERT_CLS_TOKEN = "[CLS]";
+  public static final String BERT_SEP_TOKEN = "[SEP]";
+  public static final String BERT_UNK_TOKEN = "[UNK]";
+
+  // RoBERTa special tokens
+  public static final String ROBERTA_CLS_TOKEN = "<s>";
+  public static final String ROBERTA_SEP_TOKEN = "</s>";
+  public static final String ROBERTA_UNK_TOKEN = "<unk>";
 
   private final Set<String> vocabulary;
+  private final String classificationToken;
+  private final String separatorToken;
+  private final String unknownToken;
   private int maxTokenLength = 50;
 
   /**
@@ -60,7 +70,7 @@ public class WordpieceTokenizer implements Tokenizer {
    * @param vocabulary  A set of tokens considered the vocabulary.
    */
   public WordpieceTokenizer(Set<String> vocabulary) {
-    this.vocabulary = vocabulary;
+    this(vocabulary, BERT_CLS_TOKEN, BERT_SEP_TOKEN, BERT_UNK_TOKEN);
   }
 
   /**
@@ -75,6 +85,24 @@ public class WordpieceTokenizer implements Tokenizer {
     this.maxTokenLength = maxTokenLength;
   }
 
+  /**
+   * Initializes a {@link WordpieceTokenizer} with a {@code vocabulary} and 
custom special tokens.
+   * This allows support for models like RoBERTa that use different special 
tokens
+   * (e.g. {@code <s>}, {@code </s>}, {@code <unk>}) instead of the BERT 
defaults.
+   *
+   * @param vocabulary          A set of tokens considered the vocabulary.
+   * @param classificationToken The token to prepend (e.g. {@code [CLS]} or 
{@code <s>}).
+   * @param separatorToken      The token to append (e.g. {@code [SEP]} or 
{@code </s>}).
+   * @param unknownToken        The token for unknown words (e.g. {@code 
[UNK]} or {@code <unk>}).
+   */
+  public WordpieceTokenizer(Set<String> vocabulary, String classificationToken,
+                             String separatorToken, String unknownToken) {
+    this.vocabulary = vocabulary;
+    this.classificationToken = classificationToken;
+    this.separatorToken = separatorToken;
+    this.unknownToken = unknownToken;
+  }
+
   @Override
   public Span[] tokenizePos(final String text) {
     // TODO: Implement this.
@@ -85,7 +113,7 @@ public class WordpieceTokenizer implements Tokenizer {
   public String[] tokenize(final String text) {
 
     final List<String> tokens = new LinkedList<>();
-    tokens.add(CLASSIFICATION_TOKEN);
+    tokens.add(classificationToken);
 
     // Put spaces around punctuation.
     final String spacedPunctuation = 
PUNCTUATION_PATTERN.matcher(text).replaceAll(" $0 ");
@@ -146,7 +174,7 @@ public class WordpieceTokenizer implements Tokenizer {
           // If the word can't be represented by vocabulary pieces replace
           // it with a specified "unknown" token.
           if (!found) {
-            tokens.add(UNKNOWN_TOKEN);
+            tokens.add(unknownToken);
             break;
           }
 
@@ -157,14 +185,14 @@ public class WordpieceTokenizer implements Tokenizer {
 
       } else {
 
-        // If the token's length is greater than the max length just add [UNK] 
instead.
-        tokens.add(UNKNOWN_TOKEN);
+        // If the token's length is greater than the max length just add 
unknown token instead.
+        tokens.add(unknownToken);
 
       }
 
     }
 
-    tokens.add(SEPARATOR_TOKEN);
+    tokens.add(separatorToken);
 
     return tokens.toArray(new String[0]);
 
diff --git 
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java 
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
index eb8c41ce..b95c034d 100644
--- 
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
+++ 
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
@@ -19,11 +19,14 @@ package opennlp.dl;
 
 import java.io.File;
 import java.io.IOException;
+import java.nio.charset.StandardCharsets;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
 import java.util.stream.Stream;
 
 import ai.onnxruntime.OrtEnvironment;
@@ -31,6 +34,7 @@ import ai.onnxruntime.OrtException;
 import ai.onnxruntime.OrtSession;
 
 import opennlp.tools.tokenize.Tokenizer;
+import opennlp.tools.tokenize.WordpieceTokenizer;
 
 /**
  * Base class for OpenNLP deep-learning classes using ONNX Runtime.
@@ -46,8 +50,13 @@ public abstract class AbstractDL implements AutoCloseable {
   protected Tokenizer tokenizer;
   protected Map<String, Integer> vocab;
 
+  private static final Pattern JSON_ENTRY_PATTERN =
+      Pattern.compile("\"((?:[^\"\\\\]|\\\\.)*)\"\\s*:\\s*(\\d+)");
+
   /**
    * Loads a vocabulary {@link File} from disk.
+   * Supports both plain text files (one token per line) and JSON files
+   * mapping tokens to integer IDs (e.g. {@code {"token": 0, ...}}).
    *
    * @param vocabFile The vocabulary file.
    * @return A map of vocabulary words to integer IDs.
@@ -55,17 +64,64 @@ public abstract class AbstractDL implements AutoCloseable {
    */
   public Map<String, Integer> loadVocab(final File vocabFile) throws 
IOException {
 
+    final Path vocabPath = Path.of(vocabFile.getPath());
+    final String content = Files.readString(vocabPath, StandardCharsets.UTF_8);
+    final String trimmed = content.trim();
+
+    if (trimmed.startsWith("{")) { //we avoid using a JSON parser lib, so we 
do a quick'n-dirty check for JSON format and try to parse it with a regex if it 
looks like JSON
+      return loadJsonVocab(trimmed);
+    }
+
     final Map<String, Integer> vocab = new HashMap<>();
     final AtomicInteger counter = new AtomicInteger(0);
 
-    try (Stream<String> lines = Files.lines(Path.of(vocabFile.getPath()))) {
-
+    try (Stream<String> lines = Files.lines(vocabPath, 
StandardCharsets.UTF_8)) {
       lines.forEach(line -> vocab.put(line, counter.getAndIncrement()));
     }
 
     return vocab;
   }
 
+  /**
+   * Creates a {@link WordpieceTokenizer} that uses the appropriate special 
tokens
+   * based on the vocabulary. If the vocabulary contains RoBERTa-style tokens
+   * ({@code <s>}, {@code </s>}, {@code <unk>}), those are used. Otherwise,
+   * the BERT defaults ({@code [CLS]}, {@code [SEP]}, {@code [UNK]}) are used.
+   *
+   * @param vocab The vocabulary map.
+   * @return A configured {@link WordpieceTokenizer}.
+   */
+  protected WordpieceTokenizer createTokenizer(final Map<String, Integer> 
vocab) {
+    if (vocab.containsKey(WordpieceTokenizer.ROBERTA_CLS_TOKEN)
+        && vocab.containsKey(WordpieceTokenizer.ROBERTA_SEP_TOKEN)) {
+      return new WordpieceTokenizer(vocab.keySet(),
+          WordpieceTokenizer.ROBERTA_CLS_TOKEN,
+          WordpieceTokenizer.ROBERTA_SEP_TOKEN,
+          vocab.containsKey(WordpieceTokenizer.ROBERTA_UNK_TOKEN)
+              ? WordpieceTokenizer.ROBERTA_UNK_TOKEN : 
WordpieceTokenizer.BERT_UNK_TOKEN);
+    }
+    return new WordpieceTokenizer(vocab.keySet());
+  }
+
+  private Map<String, Integer> loadJsonVocab(final String json) {
+
+    final Map<String, Integer> vocab = new HashMap<>();
+    final Matcher matcher = JSON_ENTRY_PATTERN.matcher(json);
+
+    while (matcher.find()) {
+      final String token = matcher.group(1)
+          .replace("\\\"", "\"")
+          .replace("\\\\", "\\")
+          .replace("\\/", "/")
+          .replace("\\n", "\n")
+          .replace("\\t", "\t");
+      final int id = Integer.parseInt(matcher.group(2));
+      vocab.put(token, id);
+    }
+
+    return vocab;
+  }
+
   /**
    * Closes this resource, relinquishing any underlying resources.
    *
diff --git 
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
 
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
index 9173f30e..a0c9ede7 100644
--- 
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
+++ 
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
@@ -45,7 +45,7 @@ import opennlp.dl.InferenceOptions;
 import opennlp.dl.Tokens;
 import opennlp.dl.doccat.scoring.ClassificationScoringStrategy;
 import opennlp.tools.doccat.DocumentCategorizer;
-import opennlp.tools.tokenize.WordpieceTokenizer;
+
 
 /**
  * An implementation of {@link DocumentCategorizer} that performs document 
classification
@@ -90,7 +90,7 @@ public class DocumentCategorizerDL extends AbstractDL 
implements DocumentCategor
 
     this.session = env.createSession(model.getPath(), sessionOptions);
     this.vocab = loadVocab(vocabulary);
-    this.tokenizer = new WordpieceTokenizer(vocab.keySet());
+    this.tokenizer = createTokenizer(vocab);
     this.categories = categories;
     this.classificationScoringStrategy = classificationScoringStrategy;
     this.inferenceOptions = inferenceOptions;
@@ -125,7 +125,7 @@ public class DocumentCategorizerDL extends AbstractDL 
implements DocumentCategor
 
     this.session = env.createSession(model.getPath(), sessionOptions);
     this.vocab = loadVocab(vocabulary);
-    this.tokenizer = new WordpieceTokenizer(vocab.keySet());
+    this.tokenizer = createTokenizer(vocab);
     this.categories = readCategoriesFromFile(config);
     this.classificationScoringStrategy = classificationScoringStrategy;
     this.inferenceOptions = inferenceOptions;
@@ -158,11 +158,22 @@ public class DocumentCategorizerDL extends AbstractDL 
implements DocumentCategor
               LongBuffer.wrap(t.types()), new long[] {1, t.types().length}));
         }
 
-        // The outputs from the model.
-        final float[][] v = (float[][]) session.run(inputs).get(0).getValue();
+        // The outputs from the model. Some models return a 2D array (e.g. 
BERT),
+        // while others return a 1D array (e.g. RoBERTa).
+        final Object output = session.run(inputs).get(0).getValue();
+
+        final float[] rawScores;
+        if (output instanceof float[][] v) {
+          rawScores = v[0];
+        } else if (output instanceof float[] v) {
+          rawScores = v;
+        } else {
+          throw new IllegalStateException(
+              "Unexpected model output type: " + output.getClass().getName());
+        }
 
         // Keep track of all scores.
-        final double[] categoryScoresForTokens = softmax(v[0]);
+        final double[] categoryScoresForTokens = softmax(rawScores);
         scores.add(categoryScoresForTokens);
 
       }
diff --git 
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
 
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
index 3cbf0e2a..9f4e1030 100644
--- 
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
+++ 
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
@@ -39,7 +39,7 @@ import opennlp.dl.SpanEnd;
 import opennlp.dl.Tokens;
 import opennlp.tools.namefind.TokenNameFinder;
 import opennlp.tools.sentdetect.SentenceDetector;
-import opennlp.tools.tokenize.WordpieceTokenizer;
+
 import opennlp.tools.util.Span;
 
 /**
@@ -104,7 +104,7 @@ public class NameFinderDL extends AbstractDL implements 
TokenNameFinder {
     this.session = env.createSession(model.getPath(), sessionOptions);
     this.ids2Labels = ids2Labels;
     this.vocab = loadVocab(vocabulary);
-    this.tokenizer = new WordpieceTokenizer(vocab.keySet());
+    this.tokenizer = createTokenizer(vocab);
     this.inferenceOptions = inferenceOptions;
     this.sentenceDetector = sentenceDetector;
 
diff --git 
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
 
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
index 805b4118..85abfbe4 100644
--- 
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
+++ 
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
@@ -32,7 +32,7 @@ import ai.onnxruntime.OrtSession;
 import opennlp.dl.AbstractDL;
 import opennlp.dl.Tokens;
 import opennlp.tools.tokenize.Tokenizer;
-import opennlp.tools.tokenize.WordpieceTokenizer;
+
 
 /**
  * Facilitates the generation of sentence vectors using
@@ -55,7 +55,7 @@ public class SentenceVectorsDL extends AbstractDL {
     env = OrtEnvironment.getEnvironment();
     session = env.createSession(model.getPath(), new 
OrtSession.SessionOptions());
     vocab = loadVocab(new File(vocabulary.getPath()));
-    tokenizer = new WordpieceTokenizer(vocab.keySet());
+    tokenizer = createTokenizer(vocab);
 
   }
 
diff --git 
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java
 
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java
new file mode 100644
index 00000000..04e736d4
--- /dev/null
+++ 
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+
+public class LoadVocabTest {
+
+  private final AbstractDL dl = new AbstractDL() {
+    @Override
+    public void close() {
+    }
+  };
+
+  private File getResource(String name) {
+    return new File(Objects.requireNonNull(
+        getClass().getResource("/opennlp/dl/" + name)).getFile());
+  }
+
+  @Test
+  void testLoadPlainTextVocab() throws IOException {
+    final Map<String, Integer> vocab = 
dl.loadVocab(getResource("vocab-plain.txt"));
+
+    assertNotNull(vocab);
+    assertEquals(6, vocab.size());
+    assertEquals(0, vocab.get("[CLS]"));
+    assertEquals(1, vocab.get("[SEP]"));
+    assertEquals(2, vocab.get("[UNK]"));
+    assertEquals(3, vocab.get("hello"));
+    assertEquals(4, vocab.get("world"));
+    assertEquals(5, vocab.get("##ing"));
+  }
+
+  @Test
+  void testLoadJsonVocab() throws IOException {
+    final Map<String, Integer> vocab = dl.loadVocab(getResource("vocab.json"));
+
+    assertNotNull(vocab);
+    assertEquals(6, vocab.size());
+    assertEquals(0, vocab.get("[CLS]"));
+    assertEquals(1, vocab.get("[SEP]"));
+    assertEquals(2, vocab.get("[UNK]"));
+    assertEquals(3, vocab.get("hello"));
+    assertEquals(4, vocab.get("world"));
+    assertEquals(5, vocab.get("##ing"));
+  }
+
+  @Test
+  void testJsonVocabWithEscapedCharacters() throws IOException {
+    // Write a temp file with escaped characters
+    final File tempFile = File.createTempFile("vocab-escaped", ".json");
+    tempFile.deleteOnExit();
+
+    java.nio.file.Files.writeString(tempFile.toPath(),
+        "{\"hello\\\"world\": 0, \"back\\\\slash\": 1}");
+
+    final Map<String, Integer> vocab = dl.loadVocab(tempFile);
+
+    assertNotNull(vocab);
+    assertEquals(2, vocab.size());
+    assertEquals(0, vocab.get("hello\"world"));
+    assertEquals(1, vocab.get("back\\slash"));
+  }
+
+  @Test
+  void testJsonAndPlainTextVocabProduceSameResult() throws IOException {
+    final Map<String, Integer> plainVocab = 
dl.loadVocab(getResource("vocab-plain.txt"));
+    final Map<String, Integer> jsonVocab = 
dl.loadVocab(getResource("vocab.json"));
+
+    assertEquals(plainVocab, jsonVocab);
+  }
+}
diff --git 
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt
 
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt
new file mode 100644
index 00000000..2f458c14
--- /dev/null
+++ 
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt
@@ -0,0 +1,6 @@
+[CLS]
+[SEP]
+[UNK]
+hello
+world
+##ing
\ No newline at end of file
diff --git 
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab.json 
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab.json
new file mode 100644
index 00000000..2e62cee3
--- /dev/null
+++ 
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/resources/opennlp/dl/vocab.json
@@ -0,0 +1,8 @@
+{
+  "[CLS]": 0,
+  "[SEP]": 1,
+  "[UNK]": 2,
+  "hello": 3,
+  "world": 4,
+  "##ing": 5
+}
\ No newline at end of file

Reply via email to