This is an automated email from the ASF dual-hosted git repository. rzo1 pushed a commit to branch opennlp-2.x in repository https://gitbox.apache.org/repos/asf/opennlp.git
commit 49ba4ef3a3c12541d9d7a910993de1dc540683f1 Author: subbudvk <[email protected]> AuthorDate: Mon Apr 27 18:03:07 2026 +0530 OPENNLP-1821: Prevent OutOfMemory Due To Huge Array Allocation (#1022) * Fix : Prevent OOM/DoS from Crafted Inputs * Customizable entry code in OpenNLP * Use Max_Entries Declared to prevent OOM * Use correct exception in fix for OOM (cherry picked from commit 96a073f693f3a0ded808a475f7d7773c072bb8a1) --- .../tools/ml/model/AbstractModelReader.java | 44 ++++++++ .../tools/ml/model/ModelParameterChunker.java | 1 + .../tools/ml/model/AbstractModelReaderOomTest.java | 119 +++++++++++++++++++++ 3 files changed, 164 insertions(+) diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/model/AbstractModelReader.java b/opennlp-tools/src/main/java/opennlp/tools/ml/model/AbstractModelReader.java index 55164614..d8ae3cb6 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/model/AbstractModelReader.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/model/AbstractModelReader.java @@ -29,6 +29,32 @@ import java.util.zip.GZIPInputStream; */ public abstract class AbstractModelReader { + /** + * System property for overriding the maximum number of entries (outcomes, predicates, + * outcome patterns, chunk counts) that may be read from a model file or training data. + * Set at JVM startup, e.g. {@code -DOPENNLP_MAX_ENTRIES=5000000}. + * Falls back to {@code 10_000_000} if absent or invalid. + */ + public static final String MAX_ENTRIES_PROPERTY = "OPENNLP_MAX_ENTRIES"; + + /** + * Upper bound on count fields read from a model file. + * Prevents OOM on crafted inputs with oversized array size declarations. + * Configurable via the {@link #MAX_ENTRIES_PROPERTY} system property. + */ + static final int MAX_ENTRIES = initMaxEntries(); + + private static int initMaxEntries() { + String prop = System.getProperty(MAX_ENTRIES_PROPERTY, "").trim(); + if (!prop.isEmpty()) { + try { + int val = Integer.parseInt(prop); + if (val > 0) return val; + } catch (NumberFormatException ignore) { } + } + return 10_000_000; + } + /** * The number of predicates contained in a model. */ @@ -128,9 +154,15 @@ public abstract class AbstractModelReader { /** * @return Reads and retrieves the {@code outcome labels} from the model. * @throws IOException Thrown if IO errors occurred. + * @throws IllegalArgumentException Thrown if the outcome count is negative or + * exceeds {@link #MAX_ENTRIES}. */ protected String[] getOutcomes() throws IOException { int numOutcomes = readInt(); + if (numOutcomes < 0 || numOutcomes > MAX_ENTRIES) { + throw new IllegalArgumentException( + "Outcome count " + numOutcomes + " exceeds safe limit of " + MAX_ENTRIES); + } String[] outcomeLabels = new String[numOutcomes]; for (int i = 0; i < numOutcomes; i++) outcomeLabels[i] = readUTF(); return outcomeLabels; @@ -139,9 +171,15 @@ public abstract class AbstractModelReader { /** * @return Reads and retrieves the {@code outcome patterns} from the model. * @throws IOException Thrown if IO errors occurred. + * @throws IllegalArgumentException Thrown if the outcome pattern count is negative or + * exceeds {@link #MAX_ENTRIES}. */ protected int[][] getOutcomePatterns() throws IOException { int numOCTypes = readInt(); + if (numOCTypes < 0 || numOCTypes > MAX_ENTRIES) { + throw new IllegalArgumentException( + "Outcome pattern count " + numOCTypes + " exceeds safe limit of " + MAX_ENTRIES); + } int[][] outcomePatterns = new int[numOCTypes][]; for (int i = 0; i < numOCTypes; i++) { StringTokenizer tok = new StringTokenizer(readUTF(), " "); @@ -157,9 +195,15 @@ public abstract class AbstractModelReader { /** * @return Reads and retrieves the {@code predicates} from the model. * @throws IOException Thrown if IO errors occurred. + * @throws IllegalArgumentException Thrown if the predicate count is negative or + * exceeds {@link #MAX_ENTRIES}. */ protected String[] getPredicates() throws IOException { NUM_PREDS = readInt(); + if (NUM_PREDS < 0 || NUM_PREDS > MAX_ENTRIES) { + throw new IllegalArgumentException( + "Predicate count " + NUM_PREDS + " exceeds safe limit of " + MAX_ENTRIES); + } String[] predLabels = new String[NUM_PREDS]; for (int i = 0; i < NUM_PREDS; i++) predLabels[i] = readUTF(); diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/model/ModelParameterChunker.java b/opennlp-tools/src/main/java/opennlp/tools/ml/model/ModelParameterChunker.java index 98d74be1..eab2b853 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/model/ModelParameterChunker.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/model/ModelParameterChunker.java @@ -29,6 +29,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; + /** * A helper class that handles Strings with more than 64k (65535 bytes) in length. * This is achieved via the signature {@link #SIGNATURE_CHUNKED_PARAMS} at the beginning of diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/model/AbstractModelReaderOomTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/model/AbstractModelReaderOomTest.java new file mode 100644 index 00000000..242555c7 --- /dev/null +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/model/AbstractModelReaderOomTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.ml.model; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Verifies that crafted model files with oversized count fields are rejected before array + * allocation occurs, preventing OOM DoS. See OPENNLP-1821. + */ +class AbstractModelReaderOomTest { + + /** + * Minimal concrete subclass that exposes the three protected methods under test. + */ + static class TestableReader extends AbstractModelReader { + TestableReader(DataReader dr) { super(dr); } + + @Override public void checkModelType() {} + @Override public AbstractModel constructModel() { return null; } + + String[] outcomes() throws IOException { return getOutcomes(); } + int[][] outcomePatterns() throws IOException { return getOutcomePatterns(); } + String[] predicates() throws IOException { return getPredicates(); } + } + + /** Reader whose stream starts with a single int (the count field). */ + private static TestableReader readerFor(int countValue) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + dos.writeInt(countValue); + dos.flush(); + DataInputStream dis = new DataInputStream(new ByteArrayInputStream(baos.toByteArray())); + return new TestableReader(new BinaryFileDataReader(dis)); + } + + @Test + void testGetOutcomes_RejectsMaxValue() throws IOException { + assertThrows(IllegalArgumentException.class, readerFor(Integer.MAX_VALUE)::outcomes); + } + + @Test + void testGetOutcomePatterns_RejectsMaxValue() throws IOException { + assertThrows(IllegalArgumentException.class, readerFor(Integer.MAX_VALUE)::outcomePatterns); + } + + @Test + void testGetPredicates_RejectsMaxValue() throws IOException { + assertThrows(IllegalArgumentException.class, readerFor(Integer.MAX_VALUE)::predicates); + } + + @Test + void testGetOutcomes_RejectsNegativeCount() throws IOException { + assertThrows(IllegalArgumentException.class, readerFor(-1)::outcomes); + } + + @Test + void testGetOutcomePatterns_RejectsNegativeCount() throws IOException { + assertThrows(IllegalArgumentException.class, readerFor(-1)::outcomePatterns); + } + + @Test + void testGetPredicates_RejectsNegativeCount() throws IOException { + assertThrows(IllegalArgumentException.class, readerFor(-1)::predicates); + } + + @Test + void testGetOutcomes_ValidCountReturnsLabels() throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + dos.writeInt(2); + dos.writeUTF("label-A"); + dos.writeUTF("label-B"); + dos.flush(); + + TestableReader reader = new TestableReader( + new BinaryFileDataReader(new DataInputStream(new ByteArrayInputStream(baos.toByteArray())))); + assertArrayEquals(new String[]{"label-A", "label-B"}, reader.outcomes()); + } + + @Test + void testGetPredicates_ValidCountReturnsLabels() throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + dos.writeInt(3); + dos.writeUTF("pred-X"); + dos.writeUTF("pred-Y"); + dos.writeUTF("pred-Z"); + dos.flush(); + + TestableReader reader = new TestableReader( + new BinaryFileDataReader(new DataInputStream(new ByteArrayInputStream(baos.toByteArray())))); + assertArrayEquals(new String[]{"pred-X", "pred-Y", "pred-Z"}, reader.predicates()); + } +}
