This is an automated email from the ASF dual-hosted git repository.
rzo1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/main by this push:
new 3cf42d4a OPENNLP-1823: Harden SvmDoccatModel.deserialize() with
ObjectInputFilter and resource limits (#1029)
3cf42d4a is described below
commit 3cf42d4a0145cefd9dd06138873a91312052c767
Author: Richard Zowalla <[email protected]>
AuthorDate: Sat May 2 16:58:43 2026 +0200
OPENNLP-1823: Harden SvmDoccatModel.deserialize() with ObjectInputFilter
and resource limits (#1029)
* OPENNLP-1823: Harden SvmDoccatModel.deserialize() with ObjectInputFilter
and resource limits
Apply a JEP 290 ObjectInputFilter to SvmDoccatModel.deserialize() that
allow-lists only the classes reachable from a legitimate model graph and
bounds graph depth, references, and array length. Foreign payloads are
now rejected with InvalidClassException before readObject() returns,
rather than after the cast.
Add a public DeserializationLimits record and a
deserialize(InputStream, DeserializationLimits) overload so callers with
unusually large models can raise the resource limits without touching
the class allow-list. The original deserialize(InputStream) signature is
preserved and now delegates to DeserializationLimits.DEFAULT.
* OPENNLP-1823: Add null-stream tests for SvmDoccatModel.deserialize()
overloads
* OPENNLP-1823: Reject null streams with IllegalArgumentException and
clarify allow-list comment
Validate stream and limits arguments to SvmDoccatModel.serialize() and
SvmDoccatModel.deserialize() up front and throw IllegalArgumentException
for null inputs instead of letting a NullPointerException surface from
inside the JDK stream constructors. Document the contract on each
method's Javadoc.
Update the comment on the JDK section of the deserialization filter
allow-list to explain why the abstract supertypes java.lang.Number and
java.lang.Enum must remain on the list — ObjectInputStream invokes the
filter for every class descriptor in the inheritance chain, not only
for the runtime class.
---
.../tools/ml/libsvm/doccat/SvmDoccatModel.java | 195 ++++++++++++++++++++-
.../tools/ml/libsvm/doccat/SvmDoccatModelTest.java | 144 +++++++++++++++
2 files changed, 336 insertions(+), 3 deletions(-)
diff --git
a/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/main/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModel.java
b/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/main/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModel.java
index b1ed01d0..5702a8c0 100644
---
a/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/main/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModel.java
+++
b/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/main/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModel.java
@@ -19,6 +19,7 @@ package opennlp.tools.ml.libsvm.doccat;
import java.io.IOException;
import java.io.InputStream;
+import java.io.ObjectInputFilter;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
@@ -28,6 +29,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
+import java.util.Set;
import de.hhn.mi.domain.SvmModel;
@@ -35,6 +37,17 @@ import de.hhn.mi.domain.SvmModel;
* A model for SVM-based document categorization. This model wraps a zlibsvm
* {@link SvmModel} together with the feature vocabulary, category label
* mappings, corpus statistics, and configuration required for classification.
+ * <p>
+ * Persistence uses Java object serialization via {@link
#serialize(OutputStream)}
+ * and {@link #deserialize(InputStream)}. Reads are guarded by an
+ * {@link java.io.ObjectInputFilter ObjectInputFilter} that allow-lists only
the
+ * classes reachable from a legitimate {@code SvmDoccatModel} graph and bounds
+ * graph depth, references, and array length. Foreign payloads are rejected
+ * with {@link java.io.InvalidClassException} before being materialised.
+ * <p>
+ * Treat the filter as defense-in-depth, not as a license to deserialize from
+ * untrusted sources: callers should still ensure the input stream originates
+ * from a location they trust.
*
* @see DocumentCategorizerSVM
*/
@@ -72,6 +85,8 @@ public class SvmDoccatModel implements Serializable {
* @param configuration The {@link SvmDoccatConfiguration} used for
training.
* Must not be {@code null}.
* @param languageCode An ISO conform language code.
+ * @throws NullPointerException if any argument other than {@code
languageCode}
+ * is {@code null}.
*/
SvmDoccatModel(SvmModel svmModel,
Map<String, Integer> featureVocabulary,
@@ -172,28 +187,202 @@ public class SvmDoccatModel implements Serializable {
}
/**
- * Serializes this model to the given {@link OutputStream}.
+ * Serializes this model to the given {@link OutputStream} using Java object
+ * serialization. The resulting stream can be read back with
+ * {@link #deserialize(InputStream)}.
*
* @param out The {@link OutputStream} to write to. Must not be {@code null}.
* @throws IOException Thrown if IO errors occurred during serialization.
+ * @throws IllegalArgumentException if {@code out} is {@code null}.
*/
public void serialize(OutputStream out) throws IOException {
+ if (out == null) {
+ throw new IllegalArgumentException("out must not be null");
+ }
try (ObjectOutputStream oos = new ObjectOutputStream(out)) {
oos.writeObject(this);
}
}
/**
- * Deserializes a {@link SvmDoccatModel} from the given {@link InputStream}.
+ * Deserializes a {@link SvmDoccatModel} from the given {@link InputStream}
+ * using {@link DeserializationLimits#DEFAULT default} resource limits.
+ * <p>
+ * The stream is filtered via an {@link ObjectInputFilter} that allow-lists
+ * only the classes required to reconstruct an {@link SvmDoccatModel}, plus
+ * resource limits on graph depth, references, and array length. Foreign
+ * payloads are rejected with {@link java.io.InvalidClassException} before
+ * {@link ObjectInputStream#readObject()} returns.
+ * <p>
+ * Callers should still treat this method as defense-in-depth: only invoke
+ * it on streams from trusted sources.
+ * <p>
+ * If the default limits are too tight for an unusually large model — for
+ * example, a model with more than {@value #MAX_ARRAY_DEFAULT} entries in a
+ * single map — use {@link #deserialize(InputStream, DeserializationLimits)}
+ * to supply higher limits. The class allow-list is intentionally not
+ * configurable; loosening it would defeat the purpose of the filter.
*
* @param in The {@link InputStream} to read from. Must not be {@code null}.
* @return A valid {@link SvmDoccatModel} instance.
- * @throws IOException Thrown if IO errors occurred during deserialization.
+ * @throws IOException Thrown if IO errors occurred during deserialization,
+ * including {@link java.io.InvalidClassException} when
+ * the stream contains a class outside the allow-list or
+ * exceeds a resource limit.
* @throws ClassNotFoundException Thrown if required classes are not found.
+ * @throws IllegalArgumentException if {@code in} is {@code null}.
*/
public static SvmDoccatModel deserialize(InputStream in) throws IOException,
ClassNotFoundException {
+ return deserialize(in, DeserializationLimits.DEFAULT);
+ }
+
+ /**
+ * Deserializes a {@link SvmDoccatModel} from the given {@link InputStream}
+ * using the supplied {@link DeserializationLimits resource limits}.
+ * <p>
+ * Use this overload when the {@link DeserializationLimits#DEFAULT default
+ * limits} reject a legitimate model — for example, models trained with very
+ * large feature vocabularies. The class allow-list applied to the stream is
+ * the same as for {@link #deserialize(InputStream)}; only the numeric
+ * limits change.
+ *
+ * @param in The {@link InputStream} to read from. Must not be {@code
null}.
+ * @param limits The {@link DeserializationLimits} to apply. Must not be
+ * {@code null}.
+ * @return A valid {@link SvmDoccatModel} instance.
+ * @throws IOException Thrown if IO errors occurred during deserialization,
+ * including {@link java.io.InvalidClassException} when
+ * the stream contains a class outside the allow-list or
+ * exceeds one of the supplied limits.
+ * @throws ClassNotFoundException Thrown if required classes are not found.
+ * @throws IllegalArgumentException if {@code in} or {@code limits} is
+ * {@code null}.
+ */
+ public static SvmDoccatModel deserialize(InputStream in,
DeserializationLimits limits)
+ throws IOException, ClassNotFoundException {
+ if (in == null) {
+ throw new IllegalArgumentException("in must not be null");
+ }
+ if (limits == null) {
+ throw new IllegalArgumentException("limits must not be null");
+ }
try (ObjectInputStream ois = new ObjectInputStream(in)) {
+ ois.setObjectInputFilter(buildFilter(limits));
return (SvmDoccatModel) ois.readObject();
}
}
+
+ /**
+ * Resource limits applied to the {@link ObjectInputFilter} used by
+ * {@link SvmDoccatModel#deserialize(InputStream, DeserializationLimits)}.
+ * <p>
+ * The limits bound graph traversal regardless of the class allow-list and
+ * provide defense-in-depth against pathological streams. The
+ * {@linkplain #DEFAULT default values} are generous enough for typical
+ * production models; raise them only if a legitimate model is rejected.
+ *
+ * @param maxDepth Maximum object-graph nesting depth. Must be {@code
> 0}.
+ * @param maxRefs Maximum number of internal references the stream may
+ * create. Must be {@code > 0}.
+ * @param maxArrayLength Maximum length of any single array allocation
+ * requested by the stream. Must be {@code > 0}.
+ * @throws IllegalArgumentException if any of {@code maxDepth},
+ * {@code maxRefs}, or {@code
maxArrayLength}
+ * is {@code <= 0}.
+ */
+ public record DeserializationLimits(long maxDepth, long maxRefs, long
maxArrayLength) {
+
+ /**
+ * Default limits. Sized to allow typical production-scale models to
+ * round-trip while still bounding pathological streams.
+ */
+ public static final DeserializationLimits DEFAULT =
+ new DeserializationLimits(MAX_DEPTH_DEFAULT, MAX_REFS_DEFAULT,
MAX_ARRAY_DEFAULT);
+
+ public DeserializationLimits {
+ if (maxDepth <= 0) {
+ throw new IllegalArgumentException("maxDepth must be > 0");
+ }
+ if (maxRefs <= 0) {
+ throw new IllegalArgumentException("maxRefs must be > 0");
+ }
+ if (maxArrayLength <= 0) {
+ throw new IllegalArgumentException("maxArrayLength must be > 0");
+ }
+ }
+ }
+
+ // Default resource limits applied during deserialization. These are
+ // intentionally generous so that legitimate models — which may contain
+ // millions of support vectors — round-trip without issue, while still
+ // bounding pathological inputs.
+ private static final long MAX_DEPTH_DEFAULT = 64;
+ private static final long MAX_REFS_DEFAULT = 5_000_000;
+ private static final long MAX_ARRAY_DEFAULT = 10_000_000;
+
+ // Allow-list of fully qualified class names that may appear in the
+ // serialized graph of a SvmDoccatModel. Anything else is rejected.
+ private static final Set<String> ALLOWED_CLASSES = Set.of(
+ // OpenNLP types persisted by this model
+ "opennlp.tools.ml.libsvm.doccat.SvmDoccatModel",
+ "opennlp.tools.ml.libsvm.doccat.SvmDoccatConfiguration",
+ "opennlp.tools.ml.libsvm.doccat.TermWeightingStrategy",
+ "opennlp.tools.ml.libsvm.doccat.FeatureSelectionStrategy",
+ // zlibsvm domain + configuration types reachable from SvmModel
+ "de.hhn.mi.configuration.SvmConfigurationImpl",
+ "de.hhn.mi.configuration.SvmType",
+ "de.hhn.mi.configuration.KernelType",
+ "de.hhn.mi.domain.SvmModelImpl",
+ "de.hhn.mi.domain.SvmMetaInformationImpl",
+ "de.hhn.mi.domain.SvmFeatureImpl",
+ "de.hhn.mi.domain.SvmClassLabelImpl",
+ // Native libsvm structures embedded in SvmModelImpl
+ "libsvm.svm_model",
+ "libsvm.svm_node",
+ "libsvm.svm_parameter",
+ // JDK types used in field declarations. Note: ObjectInputStream
+ // invokes the filter for every class descriptor in the inheritance
+ // chain, not only for the runtime class — so the abstract superclasses
+ // java.lang.Number (super of Integer/Double) and java.lang.Enum (super
+ // of every concrete enum value in the graph) must be allow-listed even
+ // though no instance of either appears in the stream.
+ "java.lang.String",
+ "java.lang.Number",
+ "java.lang.Integer",
+ "java.lang.Double",
+ "java.lang.Boolean",
+ "java.lang.Enum",
+ "java.util.HashMap",
+ // HashMap.readObject() requests permission to allocate a Map.Entry[]
+ // before reading entries; the array type itself never appears as a
+ // value in the stream.
+ "java.util.Map$Entry"
+ );
+
+ private static ObjectInputFilter buildFilter(DeserializationLimits limits) {
+ return info -> {
+ if (info.depth() > limits.maxDepth()
+ || info.references() > limits.maxRefs()
+ || info.arrayLength() > limits.maxArrayLength()) {
+ return ObjectInputFilter.Status.REJECTED;
+ }
+
+ Class<?> serialClass = info.serialClass();
+ if (serialClass == null) {
+ return ObjectInputFilter.Status.UNDECIDED;
+ }
+
+ Class<?> componentType = serialClass;
+ while (componentType.isArray()) {
+ componentType = componentType.getComponentType();
+ }
+ if (componentType.isPrimitive()) {
+ return ObjectInputFilter.Status.ALLOWED;
+ }
+
+ return ALLOWED_CLASSES.contains(componentType.getName())
+ ? ObjectInputFilter.Status.ALLOWED
+ : ObjectInputFilter.Status.REJECTED;
+ };
+ }
}
diff --git
a/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/test/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModelTest.java
b/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/test/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModelTest.java
index b7108a8b..b1978063 100644
---
a/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/test/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModelTest.java
+++
b/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/test/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModelTest.java
@@ -19,8 +19,12 @@ package opennlp.tools.ml.libsvm.doccat;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
+import java.io.File;
import java.io.IOException;
+import java.io.InvalidClassException;
+import java.io.ObjectOutputStream;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -132,6 +136,146 @@ class SvmDoccatModelTest {
SvmDoccatModel.deserialize(new ByteArrayInputStream(garbage)));
}
+ @Test
+ void testDeserializeRejectsForeignClass() throws IOException {
+ // A well-formed Java serialization stream containing a class that is NOT
+ // part of the SvmDoccatModel graph. Without ObjectInputFilter,
readObject()
+ // would fully deserialize the foreign object before the cast threw
+ // ClassCastException — the gadget-chain attack window. With the filter,
+ // the read fails at the class check.
+ byte[] payload = serialize(new File("/tmp/poc"));
+ InvalidClassException ex = assertThrows(InvalidClassException.class, () ->
+ SvmDoccatModel.deserialize(new ByteArrayInputStream(payload)));
+ assertTrue(ex.getMessage().contains("filter"),
+ "expected filter rejection, got: " + ex.getMessage());
+ }
+
+ @Test
+ void testDeserializeRejectsForeignCollectionType() throws IOException {
+ // java.util.HashMap is on the allow-list, but java.util.Hashtable is not.
+ // Confirms the filter does not blanket-allow java.util collections.
+ byte[] payload = serialize(new java.util.Hashtable<>(Map.of("k", "v")));
+ assertThrows(InvalidClassException.class, () ->
+ SvmDoccatModel.deserialize(new ByteArrayInputStream(payload)));
+ }
+
+ @Test
+ void testDeserializeRejectsUnrelatedSerializable() throws IOException {
+ // A user-defined Serializable that is not in the allow-list must be
+ // rejected even though it is structurally valid.
+ byte[] payload = serialize(new ForeignPayload("yolo"));
+ assertThrows(InvalidClassException.class, () ->
+ SvmDoccatModel.deserialize(new ByteArrayInputStream(payload)));
+ }
+
+ @Test
+ void testDeserializeRejectsDeeplyNestedGraph() throws IOException {
+ // Build a chain of HashMaps deeper than the configured maxdepth (64).
+ // HashMap is on the allow-list, so this exercises the resource limits.
+ HashMap<String, Object> root = new HashMap<>();
+ HashMap<String, Object> cursor = root;
+ for (int i = 0; i < 200; i++) {
+ HashMap<String, Object> next = new HashMap<>();
+ cursor.put("n", next);
+ cursor = next;
+ }
+ byte[] payload = serialize(root);
+ assertThrows(InvalidClassException.class, () ->
+ SvmDoccatModel.deserialize(new ByteArrayInputStream(payload)));
+ }
+
+ @Test
+ void testDeserializeWithCustomLimitsRoundTrips() throws IOException,
ClassNotFoundException {
+ // The configurable overload accepts a real model with custom limits and
+ // round-trips it identically to the default-limits path.
+ SvmDoccatModel model = trainSimpleModel();
+
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ model.serialize(baos);
+ byte[] bytes = baos.toByteArray();
+
+ SvmDoccatModel.DeserializationLimits raised =
+ new SvmDoccatModel.DeserializationLimits(128, 50_000_000L,
50_000_000L);
+ SvmDoccatModel restored = SvmDoccatModel.deserialize(
+ new ByteArrayInputStream(bytes), raised);
+
+ assertEquals(model.getLanguageCode(), restored.getLanguageCode());
+ assertEquals(model.getNumberOfCategories(),
restored.getNumberOfCategories());
+ assertEquals(model.getFeatureVocabulary(),
restored.getFeatureVocabulary());
+ }
+
+ @Test
+ void testDeserializeWithCustomLimitsAcceptsDeeperGraphs() throws
IOException, ClassNotFoundException {
+ // A 100-deep HashMap chain exceeds the default maxDepth (64) but fits
+ // within a raised limit. Confirms the configuration knob actually moves.
+ HashMap<String, Object> root = new HashMap<>();
+ HashMap<String, Object> cursor = root;
+ for (int i = 0; i < 100; i++) {
+ HashMap<String, Object> next = new HashMap<>();
+ cursor.put("n", next);
+ cursor = next;
+ }
+ byte[] payload = serialize(root);
+
+ // Default limits reject this graph...
+ assertThrows(InvalidClassException.class, () ->
+ SvmDoccatModel.deserialize(new ByteArrayInputStream(payload)));
+
+ // ...but a wider depth limit lets it through far enough to fail at the
+ // final cast (HashMap is not an SvmDoccatModel) — proving the depth check
+ // is no longer the gate.
+ SvmDoccatModel.DeserializationLimits raised =
+ new SvmDoccatModel.DeserializationLimits(256, 5_000_000L, 10_000_000L);
+ assertThrows(ClassCastException.class, () ->
+ SvmDoccatModel.deserialize(new ByteArrayInputStream(payload), raised));
+ }
+
+ @Test
+ void testDeserializationLimitsValidatesArguments() {
+ assertThrows(IllegalArgumentException.class, () ->
+ new SvmDoccatModel.DeserializationLimits(0, 1, 1));
+ assertThrows(IllegalArgumentException.class, () ->
+ new SvmDoccatModel.DeserializationLimits(1, 0, 1));
+ assertThrows(IllegalArgumentException.class, () ->
+ new SvmDoccatModel.DeserializationLimits(1, 1, 0));
+ }
+
+ @Test
+ void testDeserializeRejectsNullLimits() {
+ byte[] empty = {};
+ assertThrows(IllegalArgumentException.class, () ->
+ SvmDoccatModel.deserialize(new ByteArrayInputStream(empty), null));
+ }
+
+ @Test
+ void testDeserializeRejectsNullStream() {
+ assertThrows(IllegalArgumentException.class, () ->
+ SvmDoccatModel.deserialize(null));
+ assertThrows(IllegalArgumentException.class, () ->
+ SvmDoccatModel.deserialize(null,
SvmDoccatModel.DeserializationLimits.DEFAULT));
+ }
+
+ @Test
+ void testSerializeRejectsNullStream() throws IOException {
+ SvmDoccatModel model = trainSimpleModel();
+ assertThrows(IllegalArgumentException.class, () -> model.serialize(null));
+ }
+
+ private static byte[] serialize(Object obj) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ try (ObjectOutputStream oos = new ObjectOutputStream(baos)) {
+ oos.writeObject(obj);
+ }
+ return baos.toByteArray();
+ }
+
+ private static class ForeignPayload implements java.io.Serializable {
+ private static final long serialVersionUID = 1L;
+ @SuppressWarnings("unused")
+ private final String name;
+ ForeignPayload(String name) { this.name = name; }
+ }
+
private SvmDoccatModel trainSimpleModel() throws IOException {
List<DocumentSample> samples = new ArrayList<>();
samples.add(new DocumentSample("a", new String[]{"x", "y"}));