This is an automated email from the ASF dual-hosted git repository.
krickert pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/main by this push:
new 136633b4b OPENNLP-1844 - Make opennlp-dl components thread-safe (#1084)
136633b4b is described below
commit 136633b4b04a8557918443451d20f6ccd2450b6d
Author: Kristian Rickert <[email protected]>
AuthorDate: Tue Jun 16 08:07:14 2026 -0400
OPENNLP-1844 - Make opennlp-dl components thread-safe (#1084)
Merging then prepping for 1086 -
---
.../src/main/java/opennlp/dl/AbstractDL.java | 249 +++++++++++++++++++--
.../opennlp/dl/doccat/DocumentCategorizerDL.java | 135 ++++++-----
.../java/opennlp/dl/namefinder/NameFinderDL.java | 76 ++++---
.../java/opennlp/dl/vectors/SentenceVectorsDL.java | 13 +-
.../src/test/java/opennlp/dl/ChunkRangesTest.java | 85 +++++++
.../test/java/opennlp/dl/CreateTokenizerTest.java | 17 +-
.../opennlp/dl/InferenceOptionsValidationTest.java | 77 +++++++
.../src/test/java/opennlp/dl/LoadVocabTest.java | 44 +++-
.../dl/doccat/DocumentCategorizerDLTest.java | 50 +----
.../dl/doccat/DocumentCategorizerDLEval.java | 63 ++++++
.../opennlp/dl/namefinder/NameFinderDLEval.java | 161 +++++++++++++
.../opennlp/dl/vectors/SentenceVectorsDLEval.java | 61 +++++
12 files changed, 840 insertions(+), 191 deletions(-)
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
index 38137a057..6e6e54767 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
@@ -22,8 +22,12 @@ import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
+import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -46,19 +50,107 @@ public abstract class AbstractDL implements AutoCloseable {
public static final String ATTENTION_MASK = "attention_mask";
public static final String TOKEN_TYPE_IDS = "token_type_ids";
- protected OrtEnvironment env;
- protected OrtSession session;
- protected Tokenizer tokenizer;
- protected Map<String, Integer> vocab;
+ protected final OrtEnvironment env;
+ protected final OrtSession session;
+ protected final Tokenizer tokenizer;
+ protected final Map<String, Integer> vocab;
+
+ private final AtomicBoolean closed = new AtomicBoolean();
+
+ protected record ChunkRange(int start, int end) {
+ }
private static final Pattern JSON_ENTRY_PATTERN =
Pattern.compile("\"((?:[^\"\\\\]|\\\\.)*)\"\\s*:\\s*(\\d+)");
+ /**
+ * Initializes the shared, immutable inference state: the ONNX environment
and session,
+ * the loaded vocabulary and the configured tokenizer. These fields are
{@code final}
+ * and assigned exactly once here, so a fully constructed instance is safely
published
+ * and can be shared across threads.
+ *
+ * @param model The ONNX model file.
+ * @param vocabulary The vocabulary file matching the model.
+ * @param sessionOptions The session options (e.g. CUDA execution provider);
build with
+ * {@link #sessionOptions(InferenceOptions)} when honoring {@link
InferenceOptions}.
+ * @param lowerCase {@code true} for uncased models (lower casing and accent
stripping
+ * during tokenization), {@code false} for cased models.
+ *
+ * @throws OrtException Thrown if the {@code model} cannot be loaded.
+ * @throws IOException Thrown if the {@code model} or {@code vocabulary}
cannot be read.
+ */
+ protected AbstractDL(final File model, final File vocabulary,
+ final OrtSession.SessionOptions sessionOptions, final
boolean lowerCase)
+ throws IOException, OrtException {
+ Objects.requireNonNull(model, "model");
+ Objects.requireNonNull(vocabulary, "vocabulary");
+ Objects.requireNonNull(sessionOptions, "sessionOptions");
+ this.env = OrtEnvironment.getEnvironment();
+ // try-with-resources closes the session options once the session has
consumed them.
+ try (sessionOptions) {
+ final OrtSession createdSession = env.createSession(model.getPath(),
sessionOptions);
+ try {
+ this.vocab = Map.copyOf(loadVocabFile(vocabulary));
+ this.tokenizer = createBertTokenizer(vocab, lowerCase);
+ } catch (IOException | RuntimeException e) {
+ // Vocabulary/tokenizer init failed after the native session was
created; close it
+ // so a partially constructed instance never leaks the ONNX session.
+ try {
+ createdSession.close();
+ } catch (OrtException suppressed) {
+ e.addSuppressed(suppressed);
+ }
+ throw e;
+ }
+ this.session = createdSession;
+ }
+ }
+
+ /**
+ * Directly assigns the shared inference state. This seam exists for unit
tests that need to
+ * construct a component without loading an ONNX model (e.g. passing a
{@code null}
+ * {@link OrtSession} to exercise inference-failure handling). The fields
remain {@code final}
+ * and are assigned exactly once, so safe publication is preserved.
+ *
+ * @param env The ONNX environment, or {@code null} in tests.
+ * @param session The ONNX session, or {@code null} in tests that do not run
inference.
+ * @param vocab The vocabulary used by the tokenizer.
+ * @param lowerCase {@code true} for uncased models, {@code false} for cased
models.
+ */
+ protected AbstractDL(final OrtEnvironment env, final OrtSession session,
+ final Map<String, Integer> vocab, final boolean
lowerCase) {
+ this.env = env;
+ this.session = session;
+ this.vocab = vocab;
+ this.tokenizer = createBertTokenizer(vocab, lowerCase);
+ }
+
+ /**
+ * Builds ONNX session options from the given {@link InferenceOptions},
enabling the CUDA
+ * execution provider on the configured device when GPU inference is
requested.
+ *
+ * @param inferenceOptions The inference options to read the GPU
configuration from.
+ * @return The configured session options.
+ *
+ * @throws OrtException Thrown if the CUDA execution provider cannot be
added.
+ */
+ protected static OrtSession.SessionOptions sessionOptions(final
InferenceOptions inferenceOptions)
+ throws OrtException {
+ Objects.requireNonNull(inferenceOptions, "inferenceOptions");
+ validateSplitOptions(inferenceOptions);
+ final OrtSession.SessionOptions sessionOptions = new
OrtSession.SessionOptions();
+ if (inferenceOptions.isGpu()) {
+ sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
+ }
+ return sessionOptions;
+ }
+
/**
* Loads a vocabulary {@link File} from disk.
* Supports both plain text files (one token per
- * line) and JSON files mapping tokens to integer
- * IDs.
+ * line) and simple JSON vocabulary files mapping tokens to integer
+ * IDs. JSON support is intentionally limited to the HuggingFace vocabulary
+ * shape; it is not a general-purpose JSON parser.
*
* @param vocabFile The vocabulary file.
* @return A map of vocabulary words to IDs.
@@ -68,6 +160,12 @@ public abstract class AbstractDL implements AutoCloseable {
public Map<String, Integer> loadVocab(
final File vocabFile) throws IOException {
+ return loadVocabFile(vocabFile);
+ }
+
+ static Map<String, Integer> loadVocabFile(
+ final File vocabFile) throws IOException {
+
final Path vocabPath =
Path.of(vocabFile.getPath());
final String content = Files.readString(
@@ -106,6 +204,12 @@ public abstract class AbstractDL implements AutoCloseable {
*/
protected WordpieceTokenizer createTokenizer(
final Map<String, Integer> vocab) {
+
+ return createWordpieceTokenizer(vocab);
+ }
+
+ static WordpieceTokenizer createWordpieceTokenizer(
+ final Map<String, Integer> vocab) {
if (vocab.containsKey(
WordpieceTokenizer.ROBERTA_CLS_TOKEN)
&& vocab.containsKey(
@@ -134,6 +238,12 @@ public abstract class AbstractDL implements AutoCloseable {
*/
protected BertTokenizer createTokenizer(
final Map<String, Integer> vocab, final boolean lowerCase) {
+
+ return createBertTokenizer(vocab, lowerCase);
+ }
+
+ static BertTokenizer createBertTokenizer(
+ final Map<String, Integer> vocab, final boolean lowerCase) {
if (vocab.containsKey(
WordpieceTokenizer.ROBERTA_CLS_TOKEN)
&& vocab.containsKey(
@@ -182,21 +292,76 @@ public abstract class AbstractDL implements AutoCloseable
{
*/
protected static boolean resolveLowerCase(
final InferenceOptions options, final boolean componentDefault) {
+ Objects.requireNonNull(options, "options");
return options.getLowerCase() != null ? options.getLowerCase() :
componentDefault;
}
- private Map<String, Integer> loadJsonVocab(final String json) {
+ /**
+ * Validates the document splitting options used by tokenizers that split
long inputs.
+ *
+ * @param options The inference options to validate.
+ * @throws IllegalArgumentException Thrown if the split settings cannot make
progress.
+ */
+ protected static void validateSplitOptions(final InferenceOptions options) {
+ Objects.requireNonNull(options, "options");
+ validateSplitOptions(options.getDocumentSplitSize(),
options.getSplitOverlapSize());
+ }
+
+ /**
+ * Validates the document splitting values used by tokenizers that split
long inputs.
+ *
+ * @param documentSplitSize The number of tokens per split.
+ * @param splitOverlapSize The number of tokens to overlap between adjacent
splits.
+ * @throws IllegalArgumentException Thrown if the split settings cannot make
progress.
+ */
+ protected static void validateSplitOptions(final int documentSplitSize,
final int splitOverlapSize) {
+ if (documentSplitSize <= 0) {
+ throw new IllegalArgumentException("documentSplitSize must be greater
than zero.");
+ }
+ if (splitOverlapSize < 0) {
+ throw new IllegalArgumentException("splitOverlapSize must not be
negative.");
+ }
+ if (splitOverlapSize >= documentSplitSize) {
+ throw new IllegalArgumentException(
+ "splitOverlapSize must be smaller than documentSplitSize.");
+ }
+ }
+
+ /**
+ * Splits a token sequence into overlapping chunk ranges.
+ *
+ * @param tokenCount The number of tokens to split.
+ * @param documentSplitSize The number of tokens per split.
+ * @param splitOverlapSize The number of tokens to overlap between adjacent
splits.
+ * @return The chunk ranges to process.
+ * @throws IllegalArgumentException Thrown if the token count is negative or
the split settings
+ * cannot make progress.
+ */
+ protected static List<ChunkRange> chunkRanges(final int tokenCount, final
int documentSplitSize,
+ final int splitOverlapSize) {
+ if (tokenCount < 0) {
+ throw new IllegalArgumentException("tokenCount must not be negative.");
+ }
+ validateSplitOptions(documentSplitSize, splitOverlapSize);
+
+ final List<ChunkRange> ranges = new ArrayList<>();
+ int start = 0;
+ while (start < tokenCount) {
+ final int end = Math.min(start + documentSplitSize, tokenCount);
+ ranges.add(new ChunkRange(start, end));
+ start = end == tokenCount ? end : end - splitOverlapSize;
+ }
+ return List.copyOf(ranges);
+ }
+
+ private static Map<String, Integer> loadJsonVocab(final String json) {
final Map<String, Integer> vocab = new HashMap<>();
final Matcher matcher = JSON_ENTRY_PATTERN.matcher(json);
while (matcher.find()) {
final String token = matcher.group(1)
- .replace("\\\"", "\"")
- .replace("\\\\", "\\")
- .replace("\\/", "/")
- .replace("\\n", "\n")
- .replace("\\t", "\t");
+ .transform(AbstractDL::unescapeJsonString);
final int id = Integer.parseInt(matcher.group(2));
vocab.put(token, id);
}
@@ -204,20 +369,64 @@ public abstract class AbstractDL implements AutoCloseable
{
return vocab;
}
+ private static String unescapeJsonString(final String value) {
+ final StringBuilder result = new StringBuilder(value.length());
+ for (int i = 0; i < value.length(); i++) {
+ final char ch = value.charAt(i);
+ if (ch != '\\') {
+ result.append(ch);
+ continue;
+ }
+ if (++i == value.length()) {
+ throw new IllegalArgumentException("Invalid JSON string escape.");
+ }
+ final char escaped = value.charAt(i);
+ switch (escaped) {
+ case '"' -> result.append('"');
+ case '\\' -> result.append('\\');
+ case '/' -> result.append('/');
+ case 'b' -> result.append('\b');
+ case 'f' -> result.append('\f');
+ case 'n' -> result.append('\n');
+ case 'r' -> result.append('\r');
+ case 't' -> result.append('\t');
+ case 'u' -> {
+ if (i + 4 >= value.length()) {
+ throw new IllegalArgumentException("Invalid JSON unicode escape.");
+ }
+ final String hex = value.substring(i + 1, i + 5);
+ try {
+ result.append((char) Integer.parseInt(hex, 16));
+ } catch (NumberFormatException e) {
+ throw new IllegalArgumentException("Invalid JSON unicode escape.",
e);
+ }
+ i += 4;
+ }
+ default -> throw new IllegalArgumentException("Invalid JSON string
escape.");
+ }
+ }
+ return result.toString();
+ }
+
/**
- * Closes this resource, relinquishing any underlying resources.
+ * Closes the ONNX {@link OrtSession} owned by this instance.
*
- * @throws OrtException Thrown if it failed to close Ort resources.
- * @throws IllegalStateException Thrown if the underlying resources were
already closed.
+ * <p>The {@link OrtEnvironment} is deliberately <b>not</b> closed:
+ * {@link OrtEnvironment#getEnvironment()} returns a process-wide singleton
shared by
+ * every deep-learning component, so closing it here would tear down the
environment
+ * other live components still depend on.</p>
+ *
+ * <p>This method is idempotent: calling {@code close()} more than once, or
calling it on
+ * a never-used but successfully constructed instance, is a no-op after the
first successful
+ * close attempt. The underlying {@link OrtSession#close()} is only invoked
once.</p>
+ *
+ * @throws OrtException Thrown if the close attempt fails in the native
layer.
*/
@Override
- public void close() throws OrtException, IllegalStateException {
- if (session != null) {
+ public void close() throws OrtException {
+ if (closed.compareAndSet(false, true)) {
session.close();
}
- if (env != null) {
- env.close();
- }
}
}
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
index e357c48f2..7aa36e494 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java
@@ -22,13 +22,13 @@ import java.io.IOException;
import java.nio.LongBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
-import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
@@ -43,6 +43,7 @@ import opennlp.dl.AbstractDL;
import opennlp.dl.InferenceOptions;
import opennlp.dl.Tokens;
import opennlp.dl.doccat.scoring.ClassificationScoringStrategy;
+import opennlp.tools.commons.ThreadSafe;
import opennlp.tools.doccat.DocumentCategorizer;
@@ -56,10 +57,21 @@ import opennlp.tools.doccat.DocumentCategorizer;
* models commonly used for classification. For cased models, set
* {@link InferenceOptions#setLowerCase(boolean)} to {@code false}.</p>
*
+ * <p>This class is thread-safe and may be shared across threads, provided the
supplied
+ * {@link ClassificationScoringStrategy} is thread-safe (the built-in
+ * {@link opennlp.dl.doccat.scoring.AverageClassificationScoringStrategy} is
stateless).
+ * Inference holds no per-call instance state, the relevant {@link
InferenceOptions} values
+ * are snapshotted into final fields at construction (so mutating the passed
options
+ * afterwards does not affect a shared instance), and the underlying {@link
OrtSession}
+ * supports concurrent execution. This thread-safety guarantee applies until
+ * {@link #close()} is called; callers must not race {@code close()} with
inference
+ * methods.</p>
+ *
* @see DocumentCategorizer
* @see InferenceOptions
* @see ClassificationScoringStrategy
*/
+@ThreadSafe
public class DocumentCategorizerDL extends AbstractDL implements
DocumentCategorizer {
/** Classification models are commonly uncased, so lower casing is the
default. */
@@ -67,19 +79,28 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
private final Map<Integer, String> categories;
private final ClassificationScoringStrategy classificationScoringStrategy;
- private final InferenceOptions inferenceOptions;
+ // Inference options are snapshotted into final fields at construction so a
shared
+ // instance never reads the caller's mutable InferenceOptions during
inference.
+ private final boolean includeAttentionMask;
+ private final boolean includeTokenTypeIds;
+ private final int documentSplitSize;
+ private final int splitOverlapSize;
+ /**
+ * Test-only constructor that injects an already-built {@link OrtSession}
(or {@code null}),
+ * bypassing model loading so inference paths can be exercised in unit tests.
+ */
DocumentCategorizerDL(OrtEnvironment env, OrtSession session, Map<String,
Integer> vocab,
Map<Integer, String> categories,
ClassificationScoringStrategy
classificationScoringStrategy,
InferenceOptions inferenceOptions) {
- this.env = env;
- this.session = session;
- this.vocab = vocab;
- this.tokenizer = createTokenizer(vocab, resolveLowerCase(inferenceOptions,
LOWER_CASE_DEFAULT));
- this.categories = categories;
+ super(env, session, vocab, resolveLowerCase(inferenceOptions,
LOWER_CASE_DEFAULT));
+ this.categories = Map.copyOf(categories);
this.classificationScoringStrategy = classificationScoringStrategy;
- this.inferenceOptions = inferenceOptions;
+ this.includeAttentionMask = inferenceOptions.isIncludeAttentionMask();
+ this.includeTokenTypeIds = inferenceOptions.isIncludeTokenTypeIds();
+ this.documentSplitSize = inferenceOptions.getDocumentSplitSize();
+ this.splitOverlapSize = inferenceOptions.getSplitOverlapSize();
}
/**
@@ -100,19 +121,17 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
InferenceOptions inferenceOptions)
throws IOException, OrtException {
- this.env = OrtEnvironment.getEnvironment();
-
- final OrtSession.SessionOptions sessionOptions = new
OrtSession.SessionOptions();
- if (inferenceOptions.isGpu()) {
- sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
- }
+ super(model, vocabulary,
+ sessionOptions(validateConstructorArguments(
+ inferenceOptions, categories, classificationScoringStrategy)),
+ resolveLowerCase(inferenceOptions, LOWER_CASE_DEFAULT));
- this.session = env.createSession(model.getPath(), sessionOptions);
- this.vocab = loadVocab(vocabulary);
- this.tokenizer = createTokenizer(vocab, resolveLowerCase(inferenceOptions,
LOWER_CASE_DEFAULT));
- this.categories = categories;
+ this.categories = Map.copyOf(categories);
this.classificationScoringStrategy = classificationScoringStrategy;
- this.inferenceOptions = inferenceOptions;
+ this.includeAttentionMask = inferenceOptions.isIncludeAttentionMask();
+ this.includeTokenTypeIds = inferenceOptions.isIncludeTokenTypeIds();
+ this.documentSplitSize = inferenceOptions.getDocumentSplitSize();
+ this.splitOverlapSize = inferenceOptions.getSplitOverlapSize();
}
@@ -135,22 +154,28 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
InferenceOptions inferenceOptions)
throws IOException, OrtException {
- this.env = OrtEnvironment.getEnvironment();
+ super(model, vocabulary,
+ sessionOptions(validateConstructorArguments(
+ inferenceOptions, config, classificationScoringStrategy)),
+ resolveLowerCase(inferenceOptions, LOWER_CASE_DEFAULT));
- final OrtSession.SessionOptions sessionOptions = new
OrtSession.SessionOptions();
- if (inferenceOptions.isGpu()) {
- sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
- }
-
- this.session = env.createSession(model.getPath(), sessionOptions);
- this.vocab = loadVocab(vocabulary);
- this.tokenizer = createTokenizer(vocab, resolveLowerCase(inferenceOptions,
LOWER_CASE_DEFAULT));
- this.categories = readCategoriesFromFile(config);
+ this.categories = Map.copyOf(readCategoriesFromFile(config));
this.classificationScoringStrategy = classificationScoringStrategy;
- this.inferenceOptions = inferenceOptions;
+ this.includeAttentionMask = inferenceOptions.isIncludeAttentionMask();
+ this.includeTokenTypeIds = inferenceOptions.isIncludeTokenTypeIds();
+ this.documentSplitSize = inferenceOptions.getDocumentSplitSize();
+ this.splitOverlapSize = inferenceOptions.getSplitOverlapSize();
}
+ private static InferenceOptions validateConstructorArguments(
+ final InferenceOptions inferenceOptions, final Object categoriesOrConfig,
+ final ClassificationScoringStrategy classificationScoringStrategy) {
+ Objects.requireNonNull(categoriesOrConfig, "categoriesOrConfig");
+ Objects.requireNonNull(classificationScoringStrategy,
"classificationScoringStrategy");
+ return inferenceOptions;
+ }
+
/**
* Categorizes the document, failing loudly rather than returning an invalid
distribution:
* malformed input is rejected with {@link IllegalArgumentException}, and
any failure executing
@@ -191,12 +216,12 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
inputs.put(INPUT_IDS, OnnxTensor.createTensor(env,
LongBuffer.wrap(t.ids()), new long[] {1, t.ids().length}));
- if (inferenceOptions.isIncludeAttentionMask()) {
+ if (includeAttentionMask) {
inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
LongBuffer.wrap(t.mask()), new long[] {1, t.mask().length}));
}
- if (inferenceOptions.isIncludeTokenTypeIds()) {
+ if (includeTokenTypeIds) {
inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
LongBuffer.wrap(t.types()), new long[] {1, t.types().length}));
}
@@ -306,21 +331,19 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
final List<Tokens> t = new LinkedList<>();
- // In this article as the paper suggests, we are going to segment the
input into smaller text and feed
- // each of them into BERT, it means for each row, we will split the text
in order to have some
- // smaller text (200 words long each)
+ // Segment long input text into overlapping chunks configured by
InferenceOptions before
+ // feeding each chunk into BERT.
//
https://medium.com/analytics-vidhya/text-classification-with-bert-using-transformers-for-long-text-inputs-f54833994dfd
-
- // Split the input text into 200 word chunks with 50 overlapping between
chunks.
final String[] whitespaceTokenized = text.split("\\s+");
- for (final int[] range : chunkRanges(whitespaceTokenized.length,
- inferenceOptions.getDocumentSplitSize(),
inferenceOptions.getSplitOverlapSize())) {
+ for (ChunkRange chunkRange : chunkRanges(
+ whitespaceTokenized.length, documentSplitSize, splitOverlapSize)) {
- // The group is that subsection of the input.
- final String group =
- String.join(" ", Arrays.copyOfRange(whitespaceTokenized, range[0],
range[1]));
+ // The group is that subsection of string.
+ final String group = String.join(" ",
+ Arrays.copyOfRange(whitespaceTokenized, chunkRange.start(),
chunkRange.end()));
+ // Now we can tokenize the group and continue.
final String[] tokens = tokenizer.tokenize(group);
final long[] ids = tokenIds(tokens, vocab);
@@ -339,32 +362,6 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
}
- /**
- * Computes the {@code [start, end)} word-index ranges the input is split
into: chunks of
- * {@code splitSize} words overlapping by {@code overlapSize}. The loop
always advances by
- * at least one word, so a misconfigured {@code overlapSize >= splitSize}
can neither stall
- * the loop nor produce negative indices.
- *
- * @param length The number of whitespace-separated words.
- * @param splitSize The chunk size in words.
- * @param overlapSize The overlap between consecutive chunks in words.
- * @return The ordered list of {@code [start, end)} ranges; empty when
{@code length == 0}.
- */
- static List<int[]> chunkRanges(final int length, final int splitSize, final
int overlapSize) {
- final List<int[]> ranges = new ArrayList<>();
- int start = 0;
- while (start < length) {
- final int end = Math.min(start + splitSize, length);
- ranges.add(new int[] {start, end});
- if (end == length) {
- break;
- }
- // Overlap by overlapSize words, but always move forward by at least one.
- start = Math.max(end - overlapSize, start + 1);
- }
- return ranges;
- }
-
/**
* Maps tokens to their vocabulary ids.
*
@@ -410,6 +407,7 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
final double[] t = new double[input.length];
double sum = 0.0;
+
for (int x = 0; x < input.length; x++) {
final double val = Math.exp(input[x] - max);
sum += val;
@@ -417,6 +415,7 @@ public class DocumentCategorizerDL extends AbstractDL
implements DocumentCategor
}
final double[] output = new double[input.length];
+
for (int x = 0; x < output.length; x++) {
output[x] = t[x] / sum;
}
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
index d2adee0b3..3445969e8 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
@@ -25,11 +25,11 @@ import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import ai.onnxruntime.OnnxTensor;
-import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
@@ -37,6 +37,7 @@ import opennlp.dl.AbstractDL;
import opennlp.dl.InferenceOptions;
import opennlp.dl.SpanEnd;
import opennlp.dl.Tokens;
+import opennlp.tools.commons.ThreadSafe;
import opennlp.tools.namefind.TokenNameFinder;
import opennlp.tools.sentdetect.SentenceDetector;
import opennlp.tools.util.Span;
@@ -51,9 +52,19 @@ import opennlp.tools.util.Span;
* boundaries. For uncased models, set
* {@link InferenceOptions#setLowerCase(boolean)} to {@code true}.</p>
*
+ * <p>This class is thread-safe and may be shared across threads, provided the
supplied
+ * {@link SentenceDetector} is itself thread-safe (e.g. {@link
opennlp.tools.sentdetect.SentenceDetectorME},
+ * which is {@code @ThreadSafe}). Inference holds no per-call instance state,
the relevant
+ * {@link InferenceOptions} values are snapshotted into final fields at
construction (so
+ * mutating the passed options afterwards does not affect a shared instance),
and the
+ * underlying {@link OrtSession} supports concurrent execution. This
thread-safety
+ * guarantee applies until {@link #close()} is called; callers must not race
+ * {@code close()} with inference methods.</p>
+ *
* @see TokenNameFinder
* @see InferenceOptions
*/
+@ThreadSafe
public class NameFinderDL extends AbstractDL implements TokenNameFinder {
public static final String I_PER = "I-PER";
@@ -67,7 +78,12 @@ public class NameFinderDL extends AbstractDL implements
TokenNameFinder {
private final SentenceDetector sentenceDetector;
private final Map<Integer, String> ids2Labels;
- private final InferenceOptions inferenceOptions;
+ // Inference options are snapshotted into final fields at construction so a
shared
+ // instance never reads the caller's mutable InferenceOptions during
inference.
+ private final boolean includeAttentionMask;
+ private final boolean includeTokenTypeIds;
+ private final int documentSplitSize;
+ private final int splitOverlapSize;
/**
* Instantiates a {@link TokenNameFinder name finder} using ONNX models.
@@ -103,22 +119,28 @@ public class NameFinderDL extends AbstractDL implements
TokenNameFinder {
InferenceOptions inferenceOptions,
SentenceDetector sentenceDetector) throws IOException,
OrtException {
- this.env = OrtEnvironment.getEnvironment();
-
- final OrtSession.SessionOptions sessionOptions = new
OrtSession.SessionOptions();
- if (inferenceOptions.isGpu()) {
- sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
- }
+ super(model, vocabulary,
+ sessionOptions(validateConstructorArguments(
+ inferenceOptions, ids2Labels, sentenceDetector)),
+ resolveLowerCase(inferenceOptions, LOWER_CASE_DEFAULT));
- this.session = env.createSession(model.getPath(), sessionOptions);
- this.ids2Labels = ids2Labels;
- this.vocab = loadVocab(vocabulary);
- this.tokenizer = createTokenizer(vocab, resolveLowerCase(inferenceOptions,
LOWER_CASE_DEFAULT));
- this.inferenceOptions = inferenceOptions;
+ this.ids2Labels = Map.copyOf(ids2Labels);
+ this.includeAttentionMask = inferenceOptions.isIncludeAttentionMask();
+ this.includeTokenTypeIds = inferenceOptions.isIncludeTokenTypeIds();
+ this.documentSplitSize = inferenceOptions.getDocumentSplitSize();
+ this.splitOverlapSize = inferenceOptions.getSplitOverlapSize();
this.sentenceDetector = sentenceDetector;
}
+ private static InferenceOptions validateConstructorArguments(
+ final InferenceOptions inferenceOptions, final Map<Integer, String>
ids2Labels,
+ final SentenceDetector sentenceDetector) {
+ Objects.requireNonNull(ids2Labels, "ids2Labels");
+ Objects.requireNonNull(sentenceDetector, "sentenceDetector");
+ return inferenceOptions;
+ }
+
@Override
public Span[] find(String[] input) {
@@ -146,12 +168,12 @@ public class NameFinderDL extends AbstractDL implements
TokenNameFinder {
inputs.put(INPUT_IDS, OnnxTensor.createTensor(env,
LongBuffer.wrap(tokens.ids()),
new long[] {1, tokens.ids().length}));
- if (inferenceOptions.isIncludeAttentionMask()) {
+ if (includeAttentionMask) {
inputs.put(ATTENTION_MASK, OnnxTensor.createTensor(env,
LongBuffer.wrap(tokens.mask()), new long[] {1,
tokens.mask().length}));
}
- if (inferenceOptions.isIncludeTokenTypeIds()) {
+ if (includeTokenTypeIds) {
inputs.put(TOKEN_TYPE_IDS, OnnxTensor.createTensor(env,
LongBuffer.wrap(tokens.types()), new long[] {1,
tokens.types().length}));
}
@@ -368,29 +390,17 @@ public class NameFinderDL extends AbstractDL implements
TokenNameFinder {
final List<Tokens> t = new LinkedList<>();
- // In this article as the paper suggests, we are going to segment the
input into smaller text and feed
- // each of them into BERT, it means for each row, we will split the text
in order to have some
- // smaller text (200 words long each)
+ // Segment long input text into overlapping chunks configured by
InferenceOptions before
+ // feeding each chunk into BERT.
//
https://medium.com/analytics-vidhya/text-classification-with-bert-using-transformers-for-long-text-inputs-f54833994dfd
-
- // Split the input text into 200 word chunks with 50 overlapping between
chunks.
final String[] whitespaceTokenized = text.split("\\s+");
- for (int start = 0; start < whitespaceTokenized.length;
- start = start + inferenceOptions.getDocumentSplitSize()) {
-
- // 200 word length chunk
- // Check the end do don't go past and get a
StringIndexOutOfBoundsException
- int end = start + inferenceOptions.getDocumentSplitSize();
- if (end > whitespaceTokenized.length) {
- end = whitespaceTokenized.length;
- }
+ for (ChunkRange chunkRange : chunkRanges(
+ whitespaceTokenized.length, documentSplitSize, splitOverlapSize)) {
// The group is that subsection of string.
- final String group = String.join(" ",
Arrays.copyOfRange(whitespaceTokenized, start, end));
-
- // We want to overlap each chunk by 50 words so scoot back 50 words for
the next iteration.
- start = start - inferenceOptions.getSplitOverlapSize();
+ final String group = String.join(" ",
+ Arrays.copyOfRange(whitespaceTokenized, chunkRange.start(),
chunkRange.end()));
// Now we can tokenize the group and continue.
final String[] tokens = tokenizer.tokenize(group);
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
index c7b7fda86..6bc76ce18 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/vectors/SentenceVectorsDL.java
@@ -25,12 +25,12 @@ import java.util.HashMap;
import java.util.Map;
import ai.onnxruntime.OnnxTensor;
-import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import opennlp.dl.AbstractDL;
import opennlp.dl.Tokens;
+import opennlp.tools.commons.ThreadSafe;
import opennlp.tools.tokenize.Tokenizer;
@@ -51,7 +51,13 @@ import opennlp.tools.tokenize.Tokenizer;
* Output vectors change with the corrected encoding and tokenization;
* any embeddings persisted from the previous behavior are not
* comparable with the corrected output and must be re-embedded.</p>
+ *
+ * <p>This class is thread-safe and may be shared across threads: {@link
#getVectors(String)}
+ * holds no per-call instance state and the underlying {@link OrtSession}
supports
+ * concurrent execution. This thread-safety guarantee applies until {@link
#close()}
+ * is called; callers must not race {@code close()} with inference methods.</p>
*/
+@ThreadSafe
public class SentenceVectorsDL extends AbstractDL {
/**
@@ -87,10 +93,7 @@ public class SentenceVectorsDL extends AbstractDL {
public SentenceVectorsDL(final File model, final File vocabulary, final
boolean lowerCase)
throws OrtException, IOException {
- env = OrtEnvironment.getEnvironment();
- session = env.createSession(model.getPath(), new
OrtSession.SessionOptions());
- vocab = loadVocab(vocabulary);
- tokenizer = createTokenizer(vocab, lowerCase);
+ super(model, vocabulary, new OrtSession.SessionOptions(), lowerCase);
}
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/ChunkRangesTest.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/ChunkRangesTest.java
new file mode 100644
index 000000000..e9dc1cc7f
--- /dev/null
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/ChunkRangesTest.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl;
+
+import java.util.List;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class ChunkRangesTest {
+
+ @Test
+ void testCreatesOverlappingChunkRanges() {
+ final List<AbstractDL.ChunkRange> ranges = AbstractDL.chunkRanges(10, 4,
1);
+
+ assertEquals(List.of(
+ new AbstractDL.ChunkRange(0, 4),
+ new AbstractDL.ChunkRange(3, 7),
+ new AbstractDL.ChunkRange(6, 10)), ranges);
+ }
+
+ @Test
+ void testCreatesTrailingPartialChunkRange() {
+ final List<AbstractDL.ChunkRange> ranges = AbstractDL.chunkRanges(9, 4, 1);
+
+ assertEquals(List.of(
+ new AbstractDL.ChunkRange(0, 4),
+ new AbstractDL.ChunkRange(3, 7),
+ new AbstractDL.ChunkRange(6, 9)), ranges);
+ }
+
+ @Test
+ void testDoesNotCreateRedundantTrailingChunkRange() {
+ final List<AbstractDL.ChunkRange> ranges = AbstractDL.chunkRanges(10, 8,
4);
+
+ assertEquals(List.of(
+ new AbstractDL.ChunkRange(0, 8),
+ new AbstractDL.ChunkRange(4, 10)), ranges);
+ }
+
+ @Test
+ void testCreatesAdjacentChunkRangesWithoutOverlap() {
+ // splitOverlapSize == 0: chunks partition the tokens end-to-end with no
shared tokens.
+ final List<AbstractDL.ChunkRange> ranges = AbstractDL.chunkRanges(10, 5,
0);
+
+ assertEquals(List.of(
+ new AbstractDL.ChunkRange(0, 5),
+ new AbstractDL.ChunkRange(5, 10)), ranges);
+ }
+
+ @Test
+ void testCreatesSingleChunkWhenInputFitsInSplit() {
+ // Fewer tokens than the split size, and exactly the split size, both
yield one chunk.
+ assertEquals(List.of(new AbstractDL.ChunkRange(0, 4)),
AbstractDL.chunkRanges(4, 8, 2));
+ assertEquals(List.of(new AbstractDL.ChunkRange(0, 8)),
AbstractDL.chunkRanges(8, 8, 2));
+ }
+
+ @Test
+ void testCreatesNoRangesForEmptyInput() {
+ assertTrue(AbstractDL.chunkRanges(0, 4, 1).isEmpty());
+ }
+
+ @Test
+ void testRejectsNegativeTokenCount() {
+ assertThrows(IllegalArgumentException.class, () ->
AbstractDL.chunkRanges(-1, 4, 1));
+ }
+}
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/CreateTokenizerTest.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/CreateTokenizerTest.java
index a373cb159..54c4600a8 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/CreateTokenizerTest.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/CreateTokenizerTest.java
@@ -32,10 +32,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class CreateTokenizerTest {
- /** A concrete subclass to access the protected factory methods. */
- private static class TestDL extends AbstractDL {
- }
-
private static Map<String, Integer> bertVocab() {
final Map<String, Integer> vocab = new HashMap<>();
vocab.put(WordpieceTokenizer.BERT_CLS_TOKEN, 0);
@@ -57,7 +53,7 @@ public class CreateTokenizerTest {
@Test
void testCreatesLowerCasingBertTokenizer() {
- final BertTokenizer tokenizer = new TestDL().createTokenizer(bertVocab(),
true);
+ final BertTokenizer tokenizer =
AbstractDL.createBertTokenizer(bertVocab(), true);
// Capitalized input must be lower cased before the wordpiece lookup.
assertArrayEquals(new String[] {
@@ -67,7 +63,7 @@ public class CreateTokenizerTest {
@Test
void testCreatesCasePreservingBertTokenizer() {
- final BertTokenizer tokenizer = new TestDL().createTokenizer(bertVocab(),
false);
+ final BertTokenizer tokenizer =
AbstractDL.createBertTokenizer(bertVocab(), false);
// Without lower casing, capitalized words miss the lowercase-only
vocabulary.
assertArrayEquals(new String[] {
@@ -78,7 +74,7 @@ public class CreateTokenizerTest {
@Test
void testSelectsRobertaSpecialTokens() {
- final BertTokenizer tokenizer = new
TestDL().createTokenizer(robertaVocab(), false);
+ final BertTokenizer tokenizer =
AbstractDL.createBertTokenizer(robertaVocab(), false);
assertArrayEquals(new String[] {
WordpieceTokenizer.ROBERTA_CLS_TOKEN, "hello",
WordpieceTokenizer.ROBERTA_UNK_TOKEN,
@@ -92,7 +88,7 @@ public class CreateTokenizerTest {
vocab.remove(WordpieceTokenizer.ROBERTA_UNK_TOKEN);
vocab.put(WordpieceTokenizer.BERT_UNK_TOKEN, 2);
- final BertTokenizer tokenizer = new TestDL().createTokenizer(vocab, false);
+ final BertTokenizer tokenizer = AbstractDL.createBertTokenizer(vocab,
false);
assertArrayEquals(new String[] {
WordpieceTokenizer.ROBERTA_CLS_TOKEN, "hello",
WordpieceTokenizer.BERT_UNK_TOKEN,
@@ -105,9 +101,8 @@ public class CreateTokenizerTest {
final Map<String, Integer> vocab = robertaVocab();
vocab.remove(WordpieceTokenizer.ROBERTA_UNK_TOKEN);
- final TestDL dl = new TestDL();
- assertThrows(IllegalArgumentException.class, () ->
dl.createTokenizer(vocab, false));
- assertThrows(IllegalArgumentException.class, () ->
dl.createTokenizer(vocab));
+ assertThrows(IllegalArgumentException.class, () ->
AbstractDL.createBertTokenizer(vocab, false));
+ assertThrows(IllegalArgumentException.class, () ->
AbstractDL.createWordpieceTokenizer(vocab));
}
@Test
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/InferenceOptionsValidationTest.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/InferenceOptionsValidationTest.java
new file mode 100644
index 000000000..d54597793
--- /dev/null
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/InferenceOptionsValidationTest.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.dl;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public class InferenceOptionsValidationTest {
+
+ @Test
+ void testValidSplitOptions() {
+ final InferenceOptions options = new InferenceOptions();
+ options.setDocumentSplitSize(2);
+ options.setSplitOverlapSize(1);
+
+ assertDoesNotThrow(() -> AbstractDL.validateSplitOptions(options));
+ }
+
+ @Test
+ void testRejectsZeroDocumentSplitSize() {
+ final InferenceOptions options = new InferenceOptions();
+ options.setDocumentSplitSize(0);
+
+ assertThrows(IllegalArgumentException.class, () ->
AbstractDL.validateSplitOptions(options));
+ }
+
+ @Test
+ void testRejectsNegativeDocumentSplitSize() {
+ final InferenceOptions options = new InferenceOptions();
+ options.setDocumentSplitSize(-1);
+
+ assertThrows(IllegalArgumentException.class, () ->
AbstractDL.validateSplitOptions(options));
+ }
+
+ @Test
+ void testRejectsNegativeSplitOverlapSize() {
+ final InferenceOptions options = new InferenceOptions();
+ options.setSplitOverlapSize(-1);
+
+ assertThrows(IllegalArgumentException.class, () ->
AbstractDL.validateSplitOptions(options));
+ }
+
+ @Test
+ void testRejectsSplitOverlapSizeEqualToDocumentSplitSize() {
+ final InferenceOptions options = new InferenceOptions();
+ options.setDocumentSplitSize(2);
+ options.setSplitOverlapSize(2);
+
+ assertThrows(IllegalArgumentException.class, () ->
AbstractDL.validateSplitOptions(options));
+ }
+
+ @Test
+ void testRejectsSplitOverlapSizeGreaterThanDocumentSplitSize() {
+ final InferenceOptions options = new InferenceOptions();
+ options.setDocumentSplitSize(2);
+ options.setSplitOverlapSize(3);
+
+ assertThrows(IllegalArgumentException.class, () ->
AbstractDL.validateSplitOptions(options));
+ }
+}
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java
index d8554c3fb..8b3961e78 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/LoadVocabTest.java
@@ -29,15 +29,10 @@ import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
public class LoadVocabTest {
- private final AbstractDL dl = new AbstractDL() {
- @Override
- public void close() {
- }
- };
-
private File getResource(String name) throws IOException {
try (InputStream is = Objects.requireNonNull(
getClass().getResourceAsStream("/opennlp/dl/" + name))) {
@@ -50,7 +45,7 @@ public class LoadVocabTest {
@Test
void testLoadPlainTextVocab() throws IOException {
- final Map<String, Integer> vocab =
dl.loadVocab(getResource("vocab-plain.txt"));
+ final Map<String, Integer> vocab =
AbstractDL.loadVocabFile(getResource("vocab-plain.txt"));
assertNotNull(vocab);
assertEquals(6, vocab.size());
@@ -64,7 +59,7 @@ public class LoadVocabTest {
@Test
void testLoadJsonVocab() throws IOException {
- final Map<String, Integer> vocab = dl.loadVocab(getResource("vocab.json"));
+ final Map<String, Integer> vocab =
AbstractDL.loadVocabFile(getResource("vocab.json"));
assertNotNull(vocab);
assertEquals(6, vocab.size());
@@ -84,7 +79,7 @@ public class LoadVocabTest {
Files.writeString(tempFile.toPath(),
"{\"hello\\\"world\": 0, \"back\\\\slash\": 1}");
- final Map<String, Integer> vocab = dl.loadVocab(tempFile);
+ final Map<String, Integer> vocab = AbstractDL.loadVocabFile(tempFile);
assertNotNull(vocab);
assertEquals(2, vocab.size());
@@ -92,10 +87,37 @@ public class LoadVocabTest {
assertEquals(1, vocab.get("back\\slash"));
}
+ @Test
+ void testJsonVocabWithUnicodeEscapedCharacters() throws IOException {
+ final File tempFile = File.createTempFile("vocab-unicode", ".json");
+ tempFile.deleteOnExit();
+
+ Files.writeString(tempFile.toPath(),
+ "{\"\\u0120token\": 0, \"line\\rbreak\": 1, \"form\\ffeed\": 2}");
+
+ final Map<String, Integer> vocab = AbstractDL.loadVocabFile(tempFile);
+
+ assertNotNull(vocab);
+ assertEquals(3, vocab.size());
+ assertEquals(0, vocab.get("Ġtoken"));
+ assertEquals(1, vocab.get("line\rbreak"));
+ assertEquals(2, vocab.get("form\ffeed"));
+ }
+
+ @Test
+ void testJsonVocabRejectsInvalidEscapedCharacters() throws IOException {
+ final File tempFile = File.createTempFile("vocab-invalid-escape", ".json");
+ tempFile.deleteOnExit();
+
+ Files.writeString(tempFile.toPath(), "{\"bad\\xescape\": 0}");
+
+ assertThrows(IllegalArgumentException.class, () ->
AbstractDL.loadVocabFile(tempFile));
+ }
+
@Test
void testJsonAndPlainTextVocabProduceSameResult() throws IOException {
- final Map<String, Integer> plainVocab =
dl.loadVocab(getResource("vocab-plain.txt"));
- final Map<String, Integer> jsonVocab =
dl.loadVocab(getResource("vocab.json"));
+ final Map<String, Integer> plainVocab =
AbstractDL.loadVocabFile(getResource("vocab-plain.txt"));
+ final Map<String, Integer> jsonVocab =
AbstractDL.loadVocabFile(getResource("vocab.json"));
assertEquals(plainVocab, jsonVocab);
}
diff --git
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
index e962e4bee..80087d97c 100644
---
a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
+++
b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
@@ -18,7 +18,6 @@
package opennlp.dl.doccat;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;
@@ -131,6 +130,13 @@ public class DocumentCategorizerDLTest {
assertEquals(1.0, sum, 1e-12);
}
+ @Test
+ void testSoftmaxIsNumericallyStable() {
+ final double[] scores = DocumentCategorizerDL.softmax(new float[]
{1000.0f, 1001.0f});
+
+ assertArrayEquals(new double[] {0.2689414213699951, 0.7310585786300049},
scores, 1e-15);
+ }
+
@Test
void testSoftmaxMatchesReferenceDistribution() {
// Reference (numpy): softmax([1,2,3]) = [0.09003057, 0.24472847,
0.66524096].
@@ -140,46 +146,4 @@ public class DocumentCategorizerDLTest {
assertEquals(0.24472847, out[1], 1e-6);
assertEquals(0.66524096, out[2], 1e-6);
}
-
- @Test
- void testChunkRangesSplitsWithOverlap() {
- // 210 words, 200-word chunks overlapping by 50 -> [0,200), [150,210).
- final List<int[]> ranges = DocumentCategorizerDL.chunkRanges(210, 200, 50);
-
- assertEquals(2, ranges.size());
- assertArrayEquals(new int[] {0, 200}, ranges.get(0));
- assertArrayEquals(new int[] {150, 210}, ranges.get(1));
- }
-
- @Test
- void testChunkRangesSingleChunkWhenShorterThanSplit() {
- final List<int[]> ranges = DocumentCategorizerDL.chunkRanges(30, 200, 50);
-
- assertEquals(1, ranges.size());
- assertArrayEquals(new int[] {0, 30}, ranges.get(0));
- }
-
- @Test
- void testChunkRangesEmptyForZeroLength() {
- assertTrue(DocumentCategorizerDL.chunkRanges(0, 200, 50).isEmpty());
- }
-
- @Test
- void testChunkRangesAlwaysProgressesForInvalidOverlap() {
- // overlap == split would stall forever, and overlap > split would make
the start index
- // negative, without the forward-progress guard.
- for (final int[] cfg : new int[][] {{10, 5, 5}, {8, 3, 10}, {7, 4, 100}}) {
- final int length = cfg[0];
- final List<int[]> ranges = DocumentCategorizerDL.chunkRanges(length,
cfg[1], cfg[2]);
-
- int previousStart = -1;
- for (final int[] range : ranges) {
- assertTrue(range[0] >= 0, "start must never be negative: " + range[0]);
- assertTrue(range[1] >= range[0], "end must be >= start");
- assertTrue(range[0] > previousStart, "each chunk must advance the
start index");
- previousStart = range[0];
- }
- assertEquals(length, ranges.get(ranges.size() - 1)[1], "last chunk must
reach the end");
- }
- }
}
diff --git
a/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
b/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
index d19443b5b..120e331c4 100644
---
a/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
+++
b/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
@@ -18,12 +18,18 @@
package opennlp.dl.doccat;
import java.io.File;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
+import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled;
@@ -92,6 +98,63 @@ public class DocumentCategorizerDLEval extends
AbstractEvalTest {
}
+ /**
+ * Verifies that a single {@link DocumentCategorizerDL} instance is safe to
share across
+ * threads: concurrent {@link DocumentCategorizerDL#categorize(String[])}
calls on one
+ * instance must all return the same scores as the single-threaded baseline.
+ */
+ @Test
+ public void categorizeConcurrentTest() throws Exception {
+
+ final File model = new File(getOpennlpDataDir(),
+ "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx");
+ final File vocab = new File(getOpennlpDataDir(),
+ "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab");
+
+ final int threads = 8;
+ final int iterationsPerThread = 10;
+
+ try (final DocumentCategorizerDL documentCategorizerDL =
+ new DocumentCategorizerDL(model, vocab, getCategories(),
+ new AverageClassificationScoringStrategy(), new
InferenceOptions())) {
+
+ final double[] baseline = documentCategorizerDL.categorize(new String[]
{text});
+
+ final ExecutorService executor = Executors.newFixedThreadPool(threads);
+ try {
+ final CountDownLatch startGate = new CountDownLatch(1);
+ final List<Future<Boolean>> futures = new ArrayList<>();
+
+ for (int t = 0; t < threads; t++) {
+ futures.add(executor.submit(() -> {
+ startGate.await();
+ for (int i = 0; i < iterationsPerThread; i++) {
+ final double[] result = documentCategorizerDL.categorize(new
String[] {text});
+ if (result.length != baseline.length) {
+ return false;
+ }
+ for (int c = 0; c < baseline.length; c++) {
+ if (Math.abs(result[c] - baseline[c]) > 0.000001) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }));
+ }
+
+ startGate.countDown();
+ for (Future<Boolean> future : futures) {
+ Assertions.assertTrue(future.get(),
+ "a concurrent categorize() returned scores inconsistent with the
single-threaded case");
+ }
+ } finally {
+ executor.shutdownNow();
+ }
+ }
+
+ }
+
@Test
public void categorizeFailsLoudlyOnFailure() throws Exception {
diff --git
a/opennlp-eval-tests/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
b/opennlp-eval-tests/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
index 79b1bcf7c..553c31590 100644
---
a/opennlp-eval-tests/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
+++
b/opennlp-eval-tests/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
@@ -19,8 +19,14 @@ package opennlp.dl.namefinder;
import java.io.File;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
import ai.onnxruntime.OrtException;
import org.junit.jupiter.api.Assertions;
@@ -28,6 +34,7 @@ import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import opennlp.dl.InferenceOptions;
import opennlp.tools.eval.AbstractEvalTest;
import opennlp.tools.sentdetect.SentenceDetector;
import opennlp.tools.sentdetect.SentenceDetectorME;
@@ -72,6 +79,160 @@ public class NameFinderDLEval extends AbstractEvalTest {
}
+ /**
+ * Verifies that a single {@link NameFinderDL} instance is safe to share
across threads:
+ * many threads call {@link NameFinderDL#find(String[])} concurrently on one
instance and
+ * every call must return the same correct result as the single-threaded
case. A data
+ * race on the shared inference state would surface here as a wrong span, an
exception or
+ * a non-deterministic result.
+ */
+ @Test
+ public void tokenNameFinderConcurrentTest() throws Exception {
+
+ final File model = new File(getOpennlpDataDir(),
"onnx/namefinder/model.onnx");
+ final File vocab = new File(getOpennlpDataDir(),
"onnx/namefinder/vocab.txt");
+
+ final String[] tokens = new String[]
+ {"George", "Washington", "was", "president", "of", "the", "United",
"States", "."};
+ final String text = String.join(" ", tokens);
+
+ final int threads = 8;
+ final int iterationsPerThread = 25;
+
+ try (final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels(),
+ sentenceDetector)) {
+
+ final ExecutorService executor = Executors.newFixedThreadPool(threads);
+ try {
+ final CountDownLatch startGate = new CountDownLatch(1);
+ final List<Future<Boolean>> futures = new ArrayList<>();
+
+ for (int t = 0; t < threads; t++) {
+ futures.add(executor.submit(() -> {
+ // Release all threads at once to maximize contention on the
shared instance.
+ startGate.await();
+ for (int i = 0; i < iterationsPerThread; i++) {
+ final Span[] spans = nameFinderDL.find(tokens);
+ if (spans.length != 1
+ || spans[0].getStart() != 0
+ || spans[0].getEnd() != 17
+ || !"George
Washington".equals(spans[0].getCoveredText(text))) {
+ return false;
+ }
+ }
+ return true;
+ }));
+ }
+
+ startGate.countDown();
+ for (Future<Boolean> future : futures) {
+ Assertions.assertTrue(future.get(),
+ "a concurrent find() returned a result inconsistent with the
single-threaded case");
+ }
+ } finally {
+ // Shut down on every path so a failed assertion can never leave the
pool running.
+ executor.shutdownNow();
+ }
+ }
+
+ }
+
+ /**
+ * Concurrent test that explicitly pairs {@link NameFinderDL} with {@link
SentenceDetectorME}
+ * to validate the documented {@code @ThreadSafe} precondition: {@code
NameFinderDL} may be
+ * shared across threads only when the injected {@link SentenceDetector} is
itself thread-safe.
+ * {@code SentenceDetectorME} is annotated {@code @ThreadSafe}, satisfying
the contract.
+ */
+ @Test
+ public void nameFinderDlConcurrentWithSentenceDetectorMe() throws Exception {
+
+ final File model = new File(getOpennlpDataDir(),
"onnx/namefinder/model.onnx");
+ final File vocab = new File(getOpennlpDataDir(),
"onnx/namefinder/vocab.txt");
+
+ final String[] tokens = new String[]
+ {"George", "Washington", "was", "president", "of", "the", "United",
"States", "."};
+
+ // Explicitly construct the detector inside the test to make the
precondition visible.
+ final SentenceDetectorME detector = new SentenceDetectorME("en");
+
+ final int threads = 8;
+ final int iterationsPerThread = 25;
+
+ try (final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels(),
+ detector)) {
+
+ final ExecutorService executor = Executors.newFixedThreadPool(threads);
+ try {
+ final CountDownLatch startGate = new CountDownLatch(1);
+ final List<Future<Boolean>> futures = new ArrayList<>();
+
+ for (int t = 0; t < threads; t++) {
+ futures.add(executor.submit(() -> {
+ startGate.await();
+ for (int i = 0; i < iterationsPerThread; i++) {
+ final Span[] spans = nameFinderDL.find(tokens);
+ if (spans.length != 1
+ || spans[0].getStart() != 0
+ || spans[0].getEnd() != 17) {
+ return false;
+ }
+ }
+ return true;
+ }));
+ }
+
+ startGate.countDown();
+ for (Future<Boolean> future : futures) {
+ Assertions.assertTrue(future.get(),
+ "concurrent find() with SentenceDetectorME returned inconsistent
results");
+ }
+ } finally {
+ executor.shutdownNow();
+ }
+ }
+ }
+
+ /**
+ * Verifies that {@link InferenceOptions} are snapshotted at construction:
mutating the
+ * options object after the {@link NameFinderDL} is built must not change
its inference,
+ * which is what makes a shared instance safe against callers that hold the
same options.
+ */
+ @Test
+ public void tokenNameFinderSnapshotsInferenceOptionsTest() throws Exception {
+
+ final File model = new File(getOpennlpDataDir(),
"onnx/namefinder/model.onnx");
+ final File vocab = new File(getOpennlpDataDir(),
"onnx/namefinder/vocab.txt");
+
+ final String[] tokens = new String[]
+ {"George", "Washington", "was", "president", "of", "the", "United",
"States", "."};
+ final String text = String.join(" ", tokens);
+
+ final InferenceOptions options = new InferenceOptions();
+
+ try (final NameFinderDL nameFinderDL = new NameFinderDL(model, vocab,
getIds2Labels(),
+ options, sentenceDetector)) {
+
+ final Span[] baseline = nameFinderDL.find(tokens);
+ Assertions.assertEquals(1, baseline.length);
+ Assertions.assertEquals("George Washington",
baseline[0].getCoveredText(text));
+
+ // Mutate the options in ways that would change inference if they were
read live:
+ // a split size of 1 would chunk the input one word at a time.
+ options.setIncludeAttentionMask(!options.isIncludeAttentionMask());
+ options.setIncludeTokenTypeIds(!options.isIncludeTokenTypeIds());
+ options.setDocumentSplitSize(1);
+ options.setSplitOverlapSize(0);
+
+ final Span[] afterMutation = nameFinderDL.find(tokens);
+ Assertions.assertEquals(1, afterMutation.length,
+ "mutating InferenceOptions after construction must not affect a
built instance");
+ Assertions.assertEquals(0, afterMutation[0].getStart());
+ Assertions.assertEquals(17, afterMutation[0].getEnd());
+ Assertions.assertEquals("George Washington",
afterMutation[0].getCoveredText(text));
+ }
+
+ }
+
@Test
public void tokenNameFinder2Test() throws Exception {
diff --git
a/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
b/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
index ed4c20d06..b3fb5e528 100644
---
a/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
+++
b/opennlp-eval-tests/src/test/java/opennlp/dl/vectors/SentenceVectorsDLEval.java
@@ -18,6 +18,12 @@
package opennlp.dl.vectors;
import java.io.File;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@@ -72,4 +78,59 @@ public class SentenceVectorsDLEval extends AbstractEvalTest {
}
+ /**
+ * Verifies that a single {@link SentenceVectorsDL} instance is safe to
share across
+ * threads: concurrent {@link SentenceVectorsDL#getVectors(String)} calls on
one instance
+ * must all return the same vector as the single-threaded baseline.
+ */
+ @Test
+ public void generateVectorsConcurrentTest() throws Exception {
+
+ final File model = new File(getOpennlpDataDir(),
"onnx/sentence-transformers/model.onnx");
+ final File vocab = new File(getOpennlpDataDir(),
"onnx/sentence-transformers/vocab.txt");
+
+ final String sentence = "george washington was president";
+
+ final int threads = 8;
+ final int iterationsPerThread = 10;
+
+ try (final SentenceVectorsDL sv = new SentenceVectorsDL(model, vocab)) {
+
+ final float[] baseline = sv.getVectors(sentence);
+
+ final ExecutorService executor = Executors.newFixedThreadPool(threads);
+ try {
+ final CountDownLatch startGate = new CountDownLatch(1);
+ final List<Future<Boolean>> futures = new ArrayList<>();
+
+ for (int t = 0; t < threads; t++) {
+ futures.add(executor.submit(() -> {
+ startGate.await();
+ for (int i = 0; i < iterationsPerThread; i++) {
+ final float[] vectors = sv.getVectors(sentence);
+ if (vectors.length != baseline.length) {
+ return false;
+ }
+ for (int c = 0; c < baseline.length; c++) {
+ if (Math.abs(vectors[c] - baseline[c]) > 0.00001f) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }));
+ }
+
+ startGate.countDown();
+ for (Future<Boolean> future : futures) {
+ Assertions.assertTrue(future.get(),
+ "a concurrent getVectors() returned a vector inconsistent with
the single-threaded case");
+ }
+ } finally {
+ executor.shutdownNow();
+ }
+ }
+
+ }
+
}