This is an automated email from the ASF dual-hosted git repository.

mawiesne pushed a commit to branch opennlp-2.x
in repository https://gitbox.apache.org/repos/asf/opennlp.git

commit 717fb2982df64ad3656c3753bdb4b8a5f2c75a13
Author: Richard Zowalla <[email protected]>
AuthorDate: Tue Mar 31 08:57:38 2026 +0200

    OPENNLP-1518: Roberta-based Models - Add support for utilization via Onnx 
(#998)
    
    (cherry picked from commit bd73ea1394dcc5a46cbd71f9b4034bde0e2fe89d)
---
 .../src/main/java/opennlp/dl/AbstractDL.java       |  89 ++++++++++++++++--
 .../opennlp/dl/doccat/DocumentCategorizerDL.java   |  23 +++--
 .../java/opennlp/dl/namefinder/NameFinderDL.java   |   3 +-
 .../java/opennlp/dl/vectors/SentenceVectorsDL.java |   4 +-
 .../src/test/java/opennlp/dl/LoadVocabTest.java    | 102 +++++++++++++++++++++
 .../src/test/resources/opennlp/dl/vocab-plain.txt  |   6 ++
 .../src/test/resources/opennlp/dl/vocab.json       |   8 ++
 .../opennlp/tools/tokenize/WordpieceTokenizer.java |  58 ++++++++++--
 8 files changed, 266 insertions(+), 27 deletions(-)

diff --git a/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java 
b/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
index eb8c41cef..d46d68a6f 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
@@ -19,11 +19,14 @@ package opennlp.dl;
 
 import java.io.File;
 import java.io.IOException;
+import java.nio.charset.StandardCharsets;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
 import java.util.stream.Stream;
 
 import ai.onnxruntime.OrtEnvironment;
@@ -31,6 +34,7 @@ import ai.onnxruntime.OrtException;
 import ai.onnxruntime.OrtSession;
 
 import opennlp.tools.tokenize.Tokenizer;
+import opennlp.tools.tokenize.WordpieceTokenizer;
 
 /**
  * Base class for OpenNLP deep-learning classes using ONNX Runtime.
@@ -46,21 +50,92 @@ public abstract class AbstractDL implements AutoCloseable {
   protected Tokenizer tokenizer;
   protected Map<String, Integer> vocab;
 
+  private static final Pattern JSON_ENTRY_PATTERN =
+      Pattern.compile("\"((?:[^\"\\\\]|\\\\.)*)\"\\s*:\\s*(\\d+)");
+
   /**
    * Loads a vocabulary {@link File} from disk.
+   * Supports both plain text files (one token per
+   * line) and JSON files mapping tokens to integer
+   * IDs.
    *
    * @param vocabFile The vocabulary file.
-   * @return A map of vocabulary words to integer IDs.
-   * @throws IOException Thrown if the vocabulary file cannot be opened or 
read.
+   * @return A map of vocabulary words to IDs.
+   * @throws IOException Thrown if the vocabulary
+   *     file cannot be opened or read.
    */
-  public Map<String, Integer> loadVocab(final File vocabFile) throws 
IOException {
+  public Map<String, Integer> loadVocab(
+      final File vocabFile) throws IOException {
 
-    final Map<String, Integer> vocab = new HashMap<>();
-    final AtomicInteger counter = new AtomicInteger(0);
+    final Path vocabPath =
+        Path.of(vocabFile.getPath());
+    final String content = Files.readString(
+        vocabPath, StandardCharsets.UTF_8);
+    final String trimmed = content.trim();
+
+    // Detect JSON format by leading brace
+    if (trimmed.startsWith("{")) {
+      return loadJsonVocab(trimmed);
+    }
+
+    final Map<String, Integer> vocab =
+        new HashMap<>();
+    final AtomicInteger counter =
+        new AtomicInteger(0);
+
+    try (Stream<String> lines = Files.lines(
+        vocabPath, StandardCharsets.UTF_8)) {
+      lines.forEach(line ->
+          vocab.put(line, counter.getAndIncrement())
+      );
+    }
+
+    return vocab;
+  }
 
-    try (Stream<String> lines = Files.lines(Path.of(vocabFile.getPath()))) {
+  /**
+   * Creates a {@link WordpieceTokenizer} that uses the
+   * appropriate special tokens based on the vocabulary.
+   * If the vocabulary contains RoBERTa-style tokens,
+   * those are used. Otherwise, the BERT defaults are
+   * used.
+   *
+   * @param vocab The vocabulary map.
+   * @return A configured {@link WordpieceTokenizer}.
+   */
+  protected WordpieceTokenizer createTokenizer(
+      final Map<String, Integer> vocab) {
+    if (vocab.containsKey(
+            WordpieceTokenizer.ROBERTA_CLS_TOKEN)
+        && vocab.containsKey(
+            WordpieceTokenizer.ROBERTA_SEP_TOKEN)) {
+      final String unk = vocab.containsKey(
+          WordpieceTokenizer.ROBERTA_UNK_TOKEN)
+          ? WordpieceTokenizer.ROBERTA_UNK_TOKEN
+          : WordpieceTokenizer.BERT_UNK_TOKEN;
+      return new WordpieceTokenizer(
+          vocab.keySet(),
+          WordpieceTokenizer.ROBERTA_CLS_TOKEN,
+          WordpieceTokenizer.ROBERTA_SEP_TOKEN,
+          unk);
+    }
+    return new WordpieceTokenizer(vocab.keySet());
+  }
+
+  private Map<String, Integer> loadJsonVocab(final String json) {
+
+    final Map<String, Integer> vocab = new HashMap<>();
+    final Matcher matcher = JSON_ENTRY_PATTERN.matcher(json);
 
-      lines.forEach(line -> vocab.put(line, counter.getAndIncrement()));
+    while (matcher.find()) {
+      final String token = matcher.group(1)
+          .replace("\\\"", "\"")
+          .replace("\\\\", "\\")
+          .replace("\\/", "/")
+          .replace("\\n", "\n")
+          .replace("\\t", "\t");
+      final int id = Integer.parseInt(matcher.group(2));
+      vocab.put(token, id);
     }
 
     return vocab;
diff --git 
a/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java 
b/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
index 9173f30e4..a0c9ede77 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
@@ -45,7 +45,7 @@ import opennlp.dl.InferenceOptions;
 import opennlp.dl.Tokens;
 import opennlp.dl.doccat.scoring.ClassificationScoringStrategy;
 import opennlp.tools.doccat.DocumentCategorizer;
-import opennlp.tools.tokenize.WordpieceTokenizer;
+
 
 /**
  * An implementation of {@link DocumentCategorizer} that performs document 
classification
@@ -90,7 +90,7 @@ public class DocumentCategorizerDL extends AbstractDL 
implements DocumentCategor
 
     this.session = env.createSession(model.getPath(), sessionOptions);
     this.vocab = loadVocab(vocabulary);
-    this.tokenizer = new WordpieceTokenizer(vocab.keySet());
+    this.tokenizer = createTokenizer(vocab);
     this.categories = categories;
     this.classificationScoringStrategy = classificationScoringStrategy;
     this.inferenceOptions = inferenceOptions;
@@ -125,7 +125,7 @@ public class DocumentCategorizerDL extends AbstractDL 
implements DocumentCategor
 
     this.session = env.createSession(model.getPath(), sessionOptions);
     this.vocab = loadVocab(vocabulary);
-    this.tokenizer = new WordpieceTokenizer(vocab.keySet());
+    this.tokenizer = createTokenizer(vocab);
     this.categories = readCategoriesFromFile(config);
     this.classificationScoringStrategy = classificationScoringStrategy;
     this.inferenceOptions = inferenceOptions;
@@ -158,11 +158,22 @@ public class DocumentCategorizerDL extends AbstractDL 
implements DocumentCategor
               LongBuffer.wrap(t.types()), new long[] {1, t.types().length}));
         }
 
-        // The outputs from the model.
-        final float[][] v = (float[][]) session.run(inputs).get(0).getValue();
+        // The outputs from the model. Some models return a 2D array (e.g. 
BERT),
+        // while others return a 1D array (e.g. RoBERTa).
+        final Object output = session.run(inputs).get(0).getValue();
+
+        final float[] rawScores;
+        if (output instanceof float[][] v) {
+          rawScores = v[0];
+        } else if (output instanceof float[] v) {
+          rawScores = v;
+        } else {
+          throw new IllegalStateException(
+              "Unexpected model output type: " + output.getClass().getName());
+        }
 
         // Keep track of all scores.
-        final double[] categoryScoresForTokens = softmax(v[0]);
+        final double[] categoryScoresForTokens = softmax(rawScores);
         scores.add(categoryScoresForTokens);
 
       }
diff --git a/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java 
b/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
index 3cbf0e2a0..74e5a1aac 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
@@ -39,7 +39,6 @@ import opennlp.dl.SpanEnd;
 import opennlp.dl.Tokens;
 import opennlp.tools.namefind.TokenNameFinder;
 import opennlp.tools.sentdetect.SentenceDetector;
-import opennlp.tools.tokenize.WordpieceTokenizer;
 import opennlp.tools.util.Span;
 
 /**
@@ -104,7 +103,7 @@ public class NameFinderDL extends AbstractDL implements 
TokenNameFinder {
     this.session = env.createSession(model.getPath(), sessionOptions);
     this.ids2Labels = ids2Labels;
     this.vocab = loadVocab(vocabulary);
-    this.tokenizer = new WordpieceTokenizer(vocab.keySet());
+    this.tokenizer = createTokenizer(vocab);
     this.inferenceOptions = inferenceOptions;
     this.sentenceDetector = sentenceDetector;
 
diff --git a/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java 
b/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
index 805b41188..85abfbe4f 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
@@ -32,7 +32,7 @@ import ai.onnxruntime.OrtSession;
 import opennlp.dl.AbstractDL;
 import opennlp.dl.Tokens;
 import opennlp.tools.tokenize.Tokenizer;
-import opennlp.tools.tokenize.WordpieceTokenizer;
+
 
 /**
  * Facilitates the generation of sentence vectors using
@@ -55,7 +55,7 @@ public class SentenceVectorsDL extends AbstractDL {
     env = OrtEnvironment.getEnvironment();
     session = env.createSession(model.getPath(), new 
OrtSession.SessionOptions());
     vocab = loadVocab(new File(vocabulary.getPath()));
-    tokenizer = new WordpieceTokenizer(vocab.keySet());
+    tokenizer = createTokenizer(vocab);
 
   }
 
diff --git a/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java 
b/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java
new file mode 100644
index 000000000..d8554c3fb
--- /dev/null
+++ b/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.StandardCopyOption;
+import java.util.Map;
+import java.util.Objects;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+
+public class LoadVocabTest {
+
+  private final AbstractDL dl = new AbstractDL() {
+    @Override
+    public void close() {
+    }
+  };
+
+  private File getResource(String name) throws IOException {
+    try (InputStream is = Objects.requireNonNull(
+        getClass().getResourceAsStream("/opennlp/dl/" + name))) {
+      final File tempFile = File.createTempFile("vocab-test-", "-" + name);
+      tempFile.deleteOnExit();
+      Files.copy(is, tempFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
+      return tempFile;
+    }
+  }
+
+  @Test
+  void testLoadPlainTextVocab() throws IOException {
+    final Map<String, Integer> vocab = 
dl.loadVocab(getResource("vocab-plain.txt"));
+
+    assertNotNull(vocab);
+    assertEquals(6, vocab.size());
+    assertEquals(0, vocab.get("[CLS]"));
+    assertEquals(1, vocab.get("[SEP]"));
+    assertEquals(2, vocab.get("[UNK]"));
+    assertEquals(3, vocab.get("hello"));
+    assertEquals(4, vocab.get("world"));
+    assertEquals(5, vocab.get("##ing"));
+  }
+
+  @Test
+  void testLoadJsonVocab() throws IOException {
+    final Map<String, Integer> vocab = dl.loadVocab(getResource("vocab.json"));
+
+    assertNotNull(vocab);
+    assertEquals(6, vocab.size());
+    assertEquals(0, vocab.get("[CLS]"));
+    assertEquals(1, vocab.get("[SEP]"));
+    assertEquals(2, vocab.get("[UNK]"));
+    assertEquals(3, vocab.get("hello"));
+    assertEquals(4, vocab.get("world"));
+    assertEquals(5, vocab.get("##ing"));
+  }
+
+  @Test
+  void testJsonVocabWithEscapedCharacters() throws IOException {
+    final File tempFile = File.createTempFile("vocab-escaped", ".json");
+    tempFile.deleteOnExit();
+
+    Files.writeString(tempFile.toPath(),
+        "{\"hello\\\"world\": 0, \"back\\\\slash\": 1}");
+
+    final Map<String, Integer> vocab = dl.loadVocab(tempFile);
+
+    assertNotNull(vocab);
+    assertEquals(2, vocab.size());
+    assertEquals(0, vocab.get("hello\"world"));
+    assertEquals(1, vocab.get("back\\slash"));
+  }
+
+  @Test
+  void testJsonAndPlainTextVocabProduceSameResult() throws IOException {
+    final Map<String, Integer> plainVocab = 
dl.loadVocab(getResource("vocab-plain.txt"));
+    final Map<String, Integer> jsonVocab = 
dl.loadVocab(getResource("vocab.json"));
+
+    assertEquals(plainVocab, jsonVocab);
+  }
+}
diff --git a/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt 
b/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt
new file mode 100644
index 000000000..2f458c144
--- /dev/null
+++ b/opennlp-dl/src/test/resources/opennlp/dl/vocab-plain.txt
@@ -0,0 +1,6 @@
+[CLS]
+[SEP]
+[UNK]
+hello
+world
+##ing
\ No newline at end of file
diff --git a/opennlp-dl/src/test/resources/opennlp/dl/vocab.json 
b/opennlp-dl/src/test/resources/opennlp/dl/vocab.json
new file mode 100644
index 000000000..2e62cee38
--- /dev/null
+++ b/opennlp-dl/src/test/resources/opennlp/dl/vocab.json
@@ -0,0 +1,8 @@
+{
+  "[CLS]": 0,
+  "[SEP]": 1,
+  "[UNK]": 2,
+  "hello": 3,
+  "world": 4,
+  "##ing": 5
+}
\ No newline at end of file
diff --git 
a/opennlp-tools/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java 
b/opennlp-tools/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java
index 1cf9aa0c5..4d92c86bd 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/tokenize/WordpieceTokenizer.java
@@ -45,12 +45,27 @@ import opennlp.tools.util.Span;
  */
 public class WordpieceTokenizer implements Tokenizer {
 
-  private static final Pattern PUNCTUATION_PATTERN = 
Pattern.compile("\\p{Punct}+");
-  private static final String CLASSIFICATION_TOKEN = "[CLS]";
-  private static final String SEPARATOR_TOKEN = "[SEP]";
-  private static final String UNKNOWN_TOKEN = "[UNK]";
+  /** BERT classification token: {@code [CLS]}. */
+  public static final String BERT_CLS_TOKEN = "[CLS]";
+  /** BERT separator token: {@code [SEP]}. */
+  public static final String BERT_SEP_TOKEN = "[SEP]";
+  /** BERT unknown token: {@code [UNK]}. */
+  public static final String BERT_UNK_TOKEN = "[UNK]";
+
+  /** RoBERTa classification token: {@code <s>}. */
+  public static final String ROBERTA_CLS_TOKEN = "<s>";
+  /** RoBERTa separator token. */
+  public static final String ROBERTA_SEP_TOKEN = "</s>";
+  /** RoBERTa unknown token. */
+  public static final String ROBERTA_UNK_TOKEN = "<unk>";
+
+  private static final Pattern PUNCTUATION_PATTERN =
+      Pattern.compile("\\p{Punct}+");
 
   private final Set<String> vocabulary;
+  private final String classificationToken;
+  private final String separatorToken;
+  private final String unknownToken;
   private int maxTokenLength = 50;
 
   /**
@@ -60,7 +75,7 @@ public class WordpieceTokenizer implements Tokenizer {
    * @param vocabulary  A set of tokens considered the vocabulary.
    */
   public WordpieceTokenizer(Set<String> vocabulary) {
-    this.vocabulary = vocabulary;
+    this(vocabulary, BERT_CLS_TOKEN, BERT_SEP_TOKEN, BERT_UNK_TOKEN);
   }
 
   /**
@@ -75,6 +90,29 @@ public class WordpieceTokenizer implements Tokenizer {
     this.maxTokenLength = maxTokenLength;
   }
 
+  /**
+   * Initializes a {@link WordpieceTokenizer} with a
+   * {@code vocabulary} and custom special tokens.
+   * This allows support for models like RoBERTa that
+   * use different special tokens instead of the BERT
+   * defaults.
+   *
+   * @param vocabulary          The vocabulary.
+   * @param classificationToken The CLS token.
+   * @param separatorToken      The SEP token.
+   * @param unknownToken        The UNK token.
+   */
+  public WordpieceTokenizer(
+      final Set<String> vocabulary,
+      final String classificationToken,
+      final String separatorToken,
+      final String unknownToken) {
+    this.vocabulary = vocabulary;
+    this.classificationToken = classificationToken;
+    this.separatorToken = separatorToken;
+    this.unknownToken = unknownToken;
+  }
+
   @Override
   public Span[] tokenizePos(final String text) {
     // TODO: Implement this.
@@ -85,7 +123,7 @@ public class WordpieceTokenizer implements Tokenizer {
   public String[] tokenize(final String text) {
 
     final List<String> tokens = new LinkedList<>();
-    tokens.add(CLASSIFICATION_TOKEN);
+    tokens.add(classificationToken);
 
     // Put spaces around punctuation.
     final String spacedPunctuation = 
PUNCTUATION_PATTERN.matcher(text).replaceAll(" $0 ");
@@ -146,7 +184,7 @@ public class WordpieceTokenizer implements Tokenizer {
           // If the word can't be represented by vocabulary pieces replace
           // it with a specified "unknown" token.
           if (!found) {
-            tokens.add(UNKNOWN_TOKEN);
+            tokens.add(unknownToken);
             break;
           }
 
@@ -157,14 +195,14 @@ public class WordpieceTokenizer implements Tokenizer {
 
       } else {
 
-        // If the token's length is greater than the max length just add [UNK] 
instead.
-        tokens.add(UNKNOWN_TOKEN);
+        // If the token's length is greater than the max length just add 
unknown token instead.
+        tokens.add(unknownToken);
 
       }
 
     }
 
-    tokens.add(SEPARATOR_TOKEN);
+    tokens.add(separatorToken);
 
     return tokens.toArray(new String[0]);
 

Reply via email to