This is an automated email from the ASF dual-hosted git repository.

rzo1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opennlp.git


The following commit(s) were added to refs/heads/main by this push:
     new 3cf42d4a OPENNLP-1823: Harden SvmDoccatModel.deserialize() with 
ObjectInputFilter and resource limits (#1029)
3cf42d4a is described below

commit 3cf42d4a0145cefd9dd06138873a91312052c767
Author: Richard Zowalla <[email protected]>
AuthorDate: Sat May 2 16:58:43 2026 +0200

    OPENNLP-1823: Harden SvmDoccatModel.deserialize() with ObjectInputFilter 
and resource limits (#1029)
    
    * OPENNLP-1823: Harden SvmDoccatModel.deserialize() with ObjectInputFilter 
and resource limits
    
    Apply a JEP 290 ObjectInputFilter to SvmDoccatModel.deserialize() that
    allow-lists only the classes reachable from a legitimate model graph and
    bounds graph depth, references, and array length. Foreign payloads are
    now rejected with InvalidClassException before readObject() returns,
    rather than after the cast.
    
    Add a public DeserializationLimits record and a
    deserialize(InputStream, DeserializationLimits) overload so callers with
    unusually large models can raise the resource limits without touching
    the class allow-list. The original deserialize(InputStream) signature is
    preserved and now delegates to DeserializationLimits.DEFAULT.
    
    * OPENNLP-1823: Add null-stream tests for SvmDoccatModel.deserialize() 
overloads
    
    * OPENNLP-1823: Reject null streams with IllegalArgumentException and 
clarify allow-list comment
    
    Validate stream and limits arguments to SvmDoccatModel.serialize() and
    SvmDoccatModel.deserialize() up front and throw IllegalArgumentException
    for null inputs instead of letting a NullPointerException surface from
    inside the JDK stream constructors. Document the contract on each
    method's Javadoc.
    
    Update the comment on the JDK section of the deserialization filter
    allow-list to explain why the abstract supertypes java.lang.Number and
    java.lang.Enum must remain on the list — ObjectInputStream invokes the
    filter for every class descriptor in the inheritance chain, not only
    for the runtime class.
---
 .../tools/ml/libsvm/doccat/SvmDoccatModel.java     | 195 ++++++++++++++++++++-
 .../tools/ml/libsvm/doccat/SvmDoccatModelTest.java | 144 +++++++++++++++
 2 files changed, 336 insertions(+), 3 deletions(-)

diff --git 
a/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/main/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModel.java
 
b/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/main/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModel.java
index b1ed01d0..5702a8c0 100644
--- 
a/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/main/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModel.java
+++ 
b/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/main/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModel.java
@@ -19,6 +19,7 @@ package opennlp.tools.ml.libsvm.doccat;
 
 import java.io.IOException;
 import java.io.InputStream;
+import java.io.ObjectInputFilter;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.io.OutputStream;
@@ -28,6 +29,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Set;
 
 import de.hhn.mi.domain.SvmModel;
 
@@ -35,6 +37,17 @@ import de.hhn.mi.domain.SvmModel;
  * A model for SVM-based document categorization. This model wraps a zlibsvm
  * {@link SvmModel} together with the feature vocabulary, category label
  * mappings, corpus statistics, and configuration required for classification.
+ * <p>
+ * Persistence uses Java object serialization via {@link 
#serialize(OutputStream)}
+ * and {@link #deserialize(InputStream)}. Reads are guarded by an
+ * {@link java.io.ObjectInputFilter ObjectInputFilter} that allow-lists only 
the
+ * classes reachable from a legitimate {@code SvmDoccatModel} graph and bounds
+ * graph depth, references, and array length. Foreign payloads are rejected
+ * with {@link java.io.InvalidClassException} before being materialised.
+ * <p>
+ * Treat the filter as defense-in-depth, not as a license to deserialize from
+ * untrusted sources: callers should still ensure the input stream originates
+ * from a location they trust.
  *
  * @see DocumentCategorizerSVM
  */
@@ -72,6 +85,8 @@ public class SvmDoccatModel implements Serializable {
    * @param configuration      The {@link SvmDoccatConfiguration} used for 
training.
    *                           Must not be {@code null}.
    * @param languageCode      An ISO conform language code.
+   * @throws NullPointerException if any argument other than {@code 
languageCode}
+   *                              is {@code null}.
    */
   SvmDoccatModel(SvmModel svmModel,
                   Map<String, Integer> featureVocabulary,
@@ -172,28 +187,202 @@ public class SvmDoccatModel implements Serializable {
   }
 
   /**
-   * Serializes this model to the given {@link OutputStream}.
+   * Serializes this model to the given {@link OutputStream} using Java object
+   * serialization. The resulting stream can be read back with
+   * {@link #deserialize(InputStream)}.
    *
    * @param out The {@link OutputStream} to write to. Must not be {@code null}.
    * @throws IOException Thrown if IO errors occurred during serialization.
+   * @throws IllegalArgumentException if {@code out} is {@code null}.
    */
   public void serialize(OutputStream out) throws IOException {
+    if (out == null) {
+      throw new IllegalArgumentException("out must not be null");
+    }
     try (ObjectOutputStream oos = new ObjectOutputStream(out)) {
       oos.writeObject(this);
     }
   }
 
   /**
-   * Deserializes a {@link SvmDoccatModel} from the given {@link InputStream}.
+   * Deserializes a {@link SvmDoccatModel} from the given {@link InputStream}
+   * using {@link DeserializationLimits#DEFAULT default} resource limits.
+   * <p>
+   * The stream is filtered via an {@link ObjectInputFilter} that allow-lists
+   * only the classes required to reconstruct an {@link SvmDoccatModel}, plus
+   * resource limits on graph depth, references, and array length. Foreign
+   * payloads are rejected with {@link java.io.InvalidClassException} before
+   * {@link ObjectInputStream#readObject()} returns.
+   * <p>
+   * Callers should still treat this method as defense-in-depth: only invoke
+   * it on streams from trusted sources.
+   * <p>
+   * If the default limits are too tight for an unusually large model — for
+   * example, a model with more than {@value #MAX_ARRAY_DEFAULT} entries in a
+   * single map — use {@link #deserialize(InputStream, DeserializationLimits)}
+   * to supply higher limits. The class allow-list is intentionally not
+   * configurable; loosening it would defeat the purpose of the filter.
    *
    * @param in The {@link InputStream} to read from. Must not be {@code null}.
    * @return A valid {@link SvmDoccatModel} instance.
-   * @throws IOException Thrown if IO errors occurred during deserialization.
+   * @throws IOException Thrown if IO errors occurred during deserialization,
+   *                     including {@link java.io.InvalidClassException} when
+   *                     the stream contains a class outside the allow-list or
+   *                     exceeds a resource limit.
    * @throws ClassNotFoundException Thrown if required classes are not found.
+   * @throws IllegalArgumentException if {@code in} is {@code null}.
    */
   public static SvmDoccatModel deserialize(InputStream in) throws IOException, 
ClassNotFoundException {
+    return deserialize(in, DeserializationLimits.DEFAULT);
+  }
+
+  /**
+   * Deserializes a {@link SvmDoccatModel} from the given {@link InputStream}
+   * using the supplied {@link DeserializationLimits resource limits}.
+   * <p>
+   * Use this overload when the {@link DeserializationLimits#DEFAULT default
+   * limits} reject a legitimate model — for example, models trained with very
+   * large feature vocabularies. The class allow-list applied to the stream is
+   * the same as for {@link #deserialize(InputStream)}; only the numeric
+   * limits change.
+   *
+   * @param in     The {@link InputStream} to read from. Must not be {@code 
null}.
+   * @param limits The {@link DeserializationLimits} to apply. Must not be
+   *               {@code null}.
+   * @return A valid {@link SvmDoccatModel} instance.
+   * @throws IOException Thrown if IO errors occurred during deserialization,
+   *                     including {@link java.io.InvalidClassException} when
+   *                     the stream contains a class outside the allow-list or
+   *                     exceeds one of the supplied limits.
+   * @throws ClassNotFoundException Thrown if required classes are not found.
+   * @throws IllegalArgumentException if {@code in} or {@code limits} is
+   *                                  {@code null}.
+   */
+  public static SvmDoccatModel deserialize(InputStream in, 
DeserializationLimits limits)
+      throws IOException, ClassNotFoundException {
+    if (in == null) {
+      throw new IllegalArgumentException("in must not be null");
+    }
+    if (limits == null) {
+      throw new IllegalArgumentException("limits must not be null");
+    }
     try (ObjectInputStream ois = new ObjectInputStream(in)) {
+      ois.setObjectInputFilter(buildFilter(limits));
       return (SvmDoccatModel) ois.readObject();
     }
   }
+
+  /**
+   * Resource limits applied to the {@link ObjectInputFilter} used by
+   * {@link SvmDoccatModel#deserialize(InputStream, DeserializationLimits)}.
+   * <p>
+   * The limits bound graph traversal regardless of the class allow-list and
+   * provide defense-in-depth against pathological streams. The
+   * {@linkplain #DEFAULT default values} are generous enough for typical
+   * production models; raise them only if a legitimate model is rejected.
+   *
+   * @param maxDepth       Maximum object-graph nesting depth. Must be {@code 
> 0}.
+   * @param maxRefs        Maximum number of internal references the stream may
+   *                       create. Must be {@code > 0}.
+   * @param maxArrayLength Maximum length of any single array allocation
+   *                       requested by the stream. Must be {@code > 0}.
+   * @throws IllegalArgumentException if any of {@code maxDepth},
+   *                                  {@code maxRefs}, or {@code 
maxArrayLength}
+   *                                  is {@code <= 0}.
+   */
+  public record DeserializationLimits(long maxDepth, long maxRefs, long 
maxArrayLength) {
+
+    /**
+     * Default limits. Sized to allow typical production-scale models to
+     * round-trip while still bounding pathological streams.
+     */
+    public static final DeserializationLimits DEFAULT =
+        new DeserializationLimits(MAX_DEPTH_DEFAULT, MAX_REFS_DEFAULT, 
MAX_ARRAY_DEFAULT);
+
+    public DeserializationLimits {
+      if (maxDepth <= 0) {
+        throw new IllegalArgumentException("maxDepth must be > 0");
+      }
+      if (maxRefs <= 0) {
+        throw new IllegalArgumentException("maxRefs must be > 0");
+      }
+      if (maxArrayLength <= 0) {
+        throw new IllegalArgumentException("maxArrayLength must be > 0");
+      }
+    }
+  }
+
+  // Default resource limits applied during deserialization. These are
+  // intentionally generous so that legitimate models — which may contain
+  // millions of support vectors — round-trip without issue, while still
+  // bounding pathological inputs.
+  private static final long MAX_DEPTH_DEFAULT = 64;
+  private static final long MAX_REFS_DEFAULT = 5_000_000;
+  private static final long MAX_ARRAY_DEFAULT = 10_000_000;
+
+  // Allow-list of fully qualified class names that may appear in the
+  // serialized graph of a SvmDoccatModel. Anything else is rejected.
+  private static final Set<String> ALLOWED_CLASSES = Set.of(
+      // OpenNLP types persisted by this model
+      "opennlp.tools.ml.libsvm.doccat.SvmDoccatModel",
+      "opennlp.tools.ml.libsvm.doccat.SvmDoccatConfiguration",
+      "opennlp.tools.ml.libsvm.doccat.TermWeightingStrategy",
+      "opennlp.tools.ml.libsvm.doccat.FeatureSelectionStrategy",
+      // zlibsvm domain + configuration types reachable from SvmModel
+      "de.hhn.mi.configuration.SvmConfigurationImpl",
+      "de.hhn.mi.configuration.SvmType",
+      "de.hhn.mi.configuration.KernelType",
+      "de.hhn.mi.domain.SvmModelImpl",
+      "de.hhn.mi.domain.SvmMetaInformationImpl",
+      "de.hhn.mi.domain.SvmFeatureImpl",
+      "de.hhn.mi.domain.SvmClassLabelImpl",
+      // Native libsvm structures embedded in SvmModelImpl
+      "libsvm.svm_model",
+      "libsvm.svm_node",
+      "libsvm.svm_parameter",
+      // JDK types used in field declarations. Note: ObjectInputStream
+      // invokes the filter for every class descriptor in the inheritance
+      // chain, not only for the runtime class — so the abstract superclasses
+      // java.lang.Number (super of Integer/Double) and java.lang.Enum (super
+      // of every concrete enum value in the graph) must be allow-listed even
+      // though no instance of either appears in the stream.
+      "java.lang.String",
+      "java.lang.Number",
+      "java.lang.Integer",
+      "java.lang.Double",
+      "java.lang.Boolean",
+      "java.lang.Enum",
+      "java.util.HashMap",
+      // HashMap.readObject() requests permission to allocate a Map.Entry[]
+      // before reading entries; the array type itself never appears as a
+      // value in the stream.
+      "java.util.Map$Entry"
+  );
+
+  private static ObjectInputFilter buildFilter(DeserializationLimits limits) {
+    return info -> {
+      if (info.depth() > limits.maxDepth()
+          || info.references() > limits.maxRefs()
+          || info.arrayLength() > limits.maxArrayLength()) {
+        return ObjectInputFilter.Status.REJECTED;
+      }
+
+      Class<?> serialClass = info.serialClass();
+      if (serialClass == null) {
+        return ObjectInputFilter.Status.UNDECIDED;
+      }
+
+      Class<?> componentType = serialClass;
+      while (componentType.isArray()) {
+        componentType = componentType.getComponentType();
+      }
+      if (componentType.isPrimitive()) {
+        return ObjectInputFilter.Status.ALLOWED;
+      }
+
+      return ALLOWED_CLASSES.contains(componentType.getName())
+          ? ObjectInputFilter.Status.ALLOWED
+          : ObjectInputFilter.Status.REJECTED;
+    };
+  }
 }
diff --git 
a/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/test/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModelTest.java
 
b/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/test/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModelTest.java
index b7108a8b..b1978063 100644
--- 
a/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/test/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModelTest.java
+++ 
b/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/test/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModelTest.java
@@ -19,8 +19,12 @@ package opennlp.tools.ml.libsvm.doccat;
 
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
+import java.io.File;
 import java.io.IOException;
+import java.io.InvalidClassException;
+import java.io.ObjectOutputStream;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
@@ -132,6 +136,146 @@ class SvmDoccatModelTest {
         SvmDoccatModel.deserialize(new ByteArrayInputStream(garbage)));
   }
 
+  @Test
+  void testDeserializeRejectsForeignClass() throws IOException {
+    // A well-formed Java serialization stream containing a class that is NOT
+    // part of the SvmDoccatModel graph. Without ObjectInputFilter, 
readObject()
+    // would fully deserialize the foreign object before the cast threw
+    // ClassCastException — the gadget-chain attack window. With the filter,
+    // the read fails at the class check.
+    byte[] payload = serialize(new File("/tmp/poc"));
+    InvalidClassException ex = assertThrows(InvalidClassException.class, () ->
+        SvmDoccatModel.deserialize(new ByteArrayInputStream(payload)));
+    assertTrue(ex.getMessage().contains("filter"),
+        "expected filter rejection, got: " + ex.getMessage());
+  }
+
+  @Test
+  void testDeserializeRejectsForeignCollectionType() throws IOException {
+    // java.util.HashMap is on the allow-list, but java.util.Hashtable is not.
+    // Confirms the filter does not blanket-allow java.util collections.
+    byte[] payload = serialize(new java.util.Hashtable<>(Map.of("k", "v")));
+    assertThrows(InvalidClassException.class, () ->
+        SvmDoccatModel.deserialize(new ByteArrayInputStream(payload)));
+  }
+
+  @Test
+  void testDeserializeRejectsUnrelatedSerializable() throws IOException {
+    // A user-defined Serializable that is not in the allow-list must be
+    // rejected even though it is structurally valid.
+    byte[] payload = serialize(new ForeignPayload("yolo"));
+    assertThrows(InvalidClassException.class, () ->
+        SvmDoccatModel.deserialize(new ByteArrayInputStream(payload)));
+  }
+
+  @Test
+  void testDeserializeRejectsDeeplyNestedGraph() throws IOException {
+    // Build a chain of HashMaps deeper than the configured maxdepth (64).
+    // HashMap is on the allow-list, so this exercises the resource limits.
+    HashMap<String, Object> root = new HashMap<>();
+    HashMap<String, Object> cursor = root;
+    for (int i = 0; i < 200; i++) {
+      HashMap<String, Object> next = new HashMap<>();
+      cursor.put("n", next);
+      cursor = next;
+    }
+    byte[] payload = serialize(root);
+    assertThrows(InvalidClassException.class, () ->
+        SvmDoccatModel.deserialize(new ByteArrayInputStream(payload)));
+  }
+
+  @Test
+  void testDeserializeWithCustomLimitsRoundTrips() throws IOException, 
ClassNotFoundException {
+    // The configurable overload accepts a real model with custom limits and
+    // round-trips it identically to the default-limits path.
+    SvmDoccatModel model = trainSimpleModel();
+
+    ByteArrayOutputStream baos = new ByteArrayOutputStream();
+    model.serialize(baos);
+    byte[] bytes = baos.toByteArray();
+
+    SvmDoccatModel.DeserializationLimits raised =
+        new SvmDoccatModel.DeserializationLimits(128, 50_000_000L, 
50_000_000L);
+    SvmDoccatModel restored = SvmDoccatModel.deserialize(
+        new ByteArrayInputStream(bytes), raised);
+
+    assertEquals(model.getLanguageCode(), restored.getLanguageCode());
+    assertEquals(model.getNumberOfCategories(), 
restored.getNumberOfCategories());
+    assertEquals(model.getFeatureVocabulary(), 
restored.getFeatureVocabulary());
+  }
+
+  @Test
+  void testDeserializeWithCustomLimitsAcceptsDeeperGraphs() throws 
IOException, ClassNotFoundException {
+    // A 100-deep HashMap chain exceeds the default maxDepth (64) but fits
+    // within a raised limit. Confirms the configuration knob actually moves.
+    HashMap<String, Object> root = new HashMap<>();
+    HashMap<String, Object> cursor = root;
+    for (int i = 0; i < 100; i++) {
+      HashMap<String, Object> next = new HashMap<>();
+      cursor.put("n", next);
+      cursor = next;
+    }
+    byte[] payload = serialize(root);
+
+    // Default limits reject this graph...
+    assertThrows(InvalidClassException.class, () ->
+        SvmDoccatModel.deserialize(new ByteArrayInputStream(payload)));
+
+    // ...but a wider depth limit lets it through far enough to fail at the
+    // final cast (HashMap is not an SvmDoccatModel) — proving the depth check
+    // is no longer the gate.
+    SvmDoccatModel.DeserializationLimits raised =
+        new SvmDoccatModel.DeserializationLimits(256, 5_000_000L, 10_000_000L);
+    assertThrows(ClassCastException.class, () ->
+        SvmDoccatModel.deserialize(new ByteArrayInputStream(payload), raised));
+  }
+
+  @Test
+  void testDeserializationLimitsValidatesArguments() {
+    assertThrows(IllegalArgumentException.class, () ->
+        new SvmDoccatModel.DeserializationLimits(0, 1, 1));
+    assertThrows(IllegalArgumentException.class, () ->
+        new SvmDoccatModel.DeserializationLimits(1, 0, 1));
+    assertThrows(IllegalArgumentException.class, () ->
+        new SvmDoccatModel.DeserializationLimits(1, 1, 0));
+  }
+
+  @Test
+  void testDeserializeRejectsNullLimits() {
+    byte[] empty = {};
+    assertThrows(IllegalArgumentException.class, () ->
+        SvmDoccatModel.deserialize(new ByteArrayInputStream(empty), null));
+  }
+
+  @Test
+  void testDeserializeRejectsNullStream() {
+    assertThrows(IllegalArgumentException.class, () ->
+        SvmDoccatModel.deserialize(null));
+    assertThrows(IllegalArgumentException.class, () ->
+        SvmDoccatModel.deserialize(null, 
SvmDoccatModel.DeserializationLimits.DEFAULT));
+  }
+
+  @Test
+  void testSerializeRejectsNullStream() throws IOException {
+    SvmDoccatModel model = trainSimpleModel();
+    assertThrows(IllegalArgumentException.class, () -> model.serialize(null));
+  }
+
+  private static byte[] serialize(Object obj) throws IOException {
+    ByteArrayOutputStream baos = new ByteArrayOutputStream();
+    try (ObjectOutputStream oos = new ObjectOutputStream(baos)) {
+      oos.writeObject(obj);
+    }
+    return baos.toByteArray();
+  }
+
+  private static class ForeignPayload implements java.io.Serializable {
+    private static final long serialVersionUID = 1L;
+    @SuppressWarnings("unused")
+    private final String name;
+    ForeignPayload(String name) { this.name = name; }
+  }
+
   private SvmDoccatModel trainSimpleModel() throws IOException {
     List<DocumentSample> samples = new ArrayList<>();
     samples.add(new DocumentSample("a", new String[]{"x", "y"}));

Reply via email to