This is an automated email from the ASF dual-hosted git repository. rzo1 pushed a commit to branch OPENNLP-1823 in repository https://gitbox.apache.org/repos/asf/opennlp.git
commit 547cea3c85969e21c2382f0fbf99fd32f775e60a Author: Richard Zowalla <[email protected]> AuthorDate: Sat May 2 10:52:30 2026 +0200 OPENNLP-1823: Harden SvmDoccatModel.deserialize() with ObjectInputFilter and resource limits Apply a JEP 290 ObjectInputFilter to SvmDoccatModel.deserialize() that allow-lists only the classes reachable from a legitimate model graph and bounds graph depth, references, and array length. Foreign payloads are now rejected with InvalidClassException before readObject() returns, rather than after the cast. Add a public DeserializationLimits record and a deserialize(InputStream, DeserializationLimits) overload so callers with unusually large models can raise the resource limits without touching the class allow-list. The original deserialize(InputStream) signature is preserved and now delegates to DeserializationLimits.DEFAULT. --- .../tools/ml/libsvm/doccat/SvmDoccatModel.java | 176 ++++++++++++++++++++- .../tools/ml/libsvm/doccat/SvmDoccatModelTest.java | 130 +++++++++++++++ 2 files changed, 303 insertions(+), 3 deletions(-) diff --git a/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/main/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModel.java b/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/main/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModel.java index b1ed01d0..0368d632 100644 --- a/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/main/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModel.java +++ b/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/main/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModel.java @@ -19,6 +19,7 @@ package opennlp.tools.ml.libsvm.doccat; import java.io.IOException; import java.io.InputStream; +import java.io.ObjectInputFilter; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.OutputStream; @@ -28,6 +29,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.Set; import de.hhn.mi.domain.SvmModel; @@ -35,6 +37,17 @@ import de.hhn.mi.domain.SvmModel; * A model for SVM-based document categorization. This model wraps a zlibsvm * {@link SvmModel} together with the feature vocabulary, category label * mappings, corpus statistics, and configuration required for classification. + * <p> + * Persistence uses Java object serialization via {@link #serialize(OutputStream)} + * and {@link #deserialize(InputStream)}. Reads are guarded by an + * {@link java.io.ObjectInputFilter ObjectInputFilter} that allow-lists only the + * classes reachable from a legitimate {@code SvmDoccatModel} graph and bounds + * graph depth, references, and array length. Foreign payloads are rejected + * with {@link java.io.InvalidClassException} before being materialised. + * <p> + * Treat the filter as defense-in-depth, not as a license to deserialize from + * untrusted sources: callers should still ensure the input stream originates + * from a location they trust. * * @see DocumentCategorizerSVM */ @@ -172,7 +185,9 @@ public class SvmDoccatModel implements Serializable { } /** - * Serializes this model to the given {@link OutputStream}. + * Serializes this model to the given {@link OutputStream} using Java object + * serialization. The resulting stream can be read back with + * {@link #deserialize(InputStream)}. * * @param out The {@link OutputStream} to write to. Must not be {@code null}. * @throws IOException Thrown if IO errors occurred during serialization. @@ -184,16 +199,171 @@ public class SvmDoccatModel implements Serializable { } /** - * Deserializes a {@link SvmDoccatModel} from the given {@link InputStream}. + * Deserializes a {@link SvmDoccatModel} from the given {@link InputStream} + * using {@link DeserializationLimits#DEFAULT default} resource limits. + * <p> + * The stream is filtered via an {@link ObjectInputFilter} that allow-lists + * only the classes required to reconstruct an {@link SvmDoccatModel}, plus + * resource limits on graph depth, references, and array length. Foreign + * payloads are rejected with {@link java.io.InvalidClassException} before + * {@link ObjectInputStream#readObject()} returns. + * <p> + * Callers should still treat this method as defense-in-depth: only invoke + * it on streams from trusted sources. + * <p> + * If the default limits are too tight for an unusually large model — for + * example, a model with more than {@value #MAX_ARRAY_DEFAULT} entries in a + * single map — use {@link #deserialize(InputStream, DeserializationLimits)} + * to supply higher limits. The class allow-list is intentionally not + * configurable; loosening it would defeat the purpose of the filter. * * @param in The {@link InputStream} to read from. Must not be {@code null}. * @return A valid {@link SvmDoccatModel} instance. - * @throws IOException Thrown if IO errors occurred during deserialization. + * @throws IOException Thrown if IO errors occurred during deserialization, + * including {@link java.io.InvalidClassException} when + * the stream contains a class outside the allow-list or + * exceeds a resource limit. * @throws ClassNotFoundException Thrown if required classes are not found. */ public static SvmDoccatModel deserialize(InputStream in) throws IOException, ClassNotFoundException { + return deserialize(in, DeserializationLimits.DEFAULT); + } + + /** + * Deserializes a {@link SvmDoccatModel} from the given {@link InputStream} + * using the supplied {@link DeserializationLimits resource limits}. + * <p> + * Use this overload when the {@link DeserializationLimits#DEFAULT default + * limits} reject a legitimate model — for example, models trained with very + * large feature vocabularies. The class allow-list applied to the stream is + * the same as for {@link #deserialize(InputStream)}; only the numeric + * limits change. + * + * @param in The {@link InputStream} to read from. Must not be {@code null}. + * @param limits The {@link DeserializationLimits} to apply. Must not be + * {@code null}. + * @return A valid {@link SvmDoccatModel} instance. + * @throws IOException Thrown if IO errors occurred during deserialization, + * including {@link java.io.InvalidClassException} when + * the stream contains a class outside the allow-list or + * exceeds one of the supplied limits. + * @throws ClassNotFoundException Thrown if required classes are not found. + */ + public static SvmDoccatModel deserialize(InputStream in, DeserializationLimits limits) + throws IOException, ClassNotFoundException { + Objects.requireNonNull(limits, "limits must not be null"); try (ObjectInputStream ois = new ObjectInputStream(in)) { + ois.setObjectInputFilter(buildFilter(limits)); return (SvmDoccatModel) ois.readObject(); } } + + /** + * Resource limits applied to the {@link ObjectInputFilter} used by + * {@link SvmDoccatModel#deserialize(InputStream, DeserializationLimits)}. + * <p> + * The limits bound graph traversal regardless of the class allow-list and + * provide defense-in-depth against pathological streams. The + * {@linkplain #DEFAULT default values} are generous enough for typical + * production models; raise them only if a legitimate model is rejected. + * + * @param maxDepth Maximum object-graph nesting depth. Must be {@code > 0}. + * @param maxRefs Maximum number of internal references the stream may + * create. Must be {@code > 0}. + * @param maxArrayLength Maximum length of any single array allocation + * requested by the stream. Must be {@code > 0}. + * @throws IllegalArgumentException if any of {@code maxDepth}, + * {@code maxRefs}, or {@code maxArrayLength} + * is {@code <= 0}. + */ + public record DeserializationLimits(long maxDepth, long maxRefs, long maxArrayLength) { + + /** + * Default limits. Sized to allow typical production-scale models to + * round-trip while still bounding pathological streams. + */ + public static final DeserializationLimits DEFAULT = + new DeserializationLimits(MAX_DEPTH_DEFAULT, MAX_REFS_DEFAULT, MAX_ARRAY_DEFAULT); + + public DeserializationLimits { + if (maxDepth <= 0) { + throw new IllegalArgumentException("maxDepth must be > 0"); + } + if (maxRefs <= 0) { + throw new IllegalArgumentException("maxRefs must be > 0"); + } + if (maxArrayLength <= 0) { + throw new IllegalArgumentException("maxArrayLength must be > 0"); + } + } + } + + // Default resource limits applied during deserialization. These are + // intentionally generous so that legitimate models — which may contain + // millions of support vectors — round-trip without issue, while still + // bounding pathological inputs. + private static final long MAX_DEPTH_DEFAULT = 64; + private static final long MAX_REFS_DEFAULT = 5_000_000; + private static final long MAX_ARRAY_DEFAULT = 10_000_000; + + // Allow-list of fully qualified class names that may appear in the + // serialized graph of a SvmDoccatModel. Anything else is rejected. + private static final Set<String> ALLOWED_CLASSES = Set.of( + // OpenNLP types persisted by this model + "opennlp.tools.ml.libsvm.doccat.SvmDoccatModel", + "opennlp.tools.ml.libsvm.doccat.SvmDoccatConfiguration", + "opennlp.tools.ml.libsvm.doccat.TermWeightingStrategy", + "opennlp.tools.ml.libsvm.doccat.FeatureSelectionStrategy", + // zlibsvm domain + configuration types reachable from SvmModel + "de.hhn.mi.configuration.SvmConfigurationImpl", + "de.hhn.mi.configuration.SvmType", + "de.hhn.mi.configuration.KernelType", + "de.hhn.mi.domain.SvmModelImpl", + "de.hhn.mi.domain.SvmMetaInformationImpl", + "de.hhn.mi.domain.SvmFeatureImpl", + "de.hhn.mi.domain.SvmClassLabelImpl", + // Native libsvm structures embedded in SvmModelImpl + "libsvm.svm_model", + "libsvm.svm_node", + "libsvm.svm_parameter", + // JDK types used in field declarations + "java.lang.String", + "java.lang.Number", + "java.lang.Integer", + "java.lang.Double", + "java.lang.Boolean", + "java.lang.Enum", + "java.util.HashMap", + // HashMap.readObject() requests permission to allocate a Map.Entry[] + // before reading entries; the array type itself never appears as a + // value in the stream. + "java.util.Map$Entry" + ); + + private static ObjectInputFilter buildFilter(DeserializationLimits limits) { + return info -> { + if (info.depth() > limits.maxDepth() + || info.references() > limits.maxRefs() + || info.arrayLength() > limits.maxArrayLength()) { + return ObjectInputFilter.Status.REJECTED; + } + + Class<?> serialClass = info.serialClass(); + if (serialClass == null) { + return ObjectInputFilter.Status.UNDECIDED; + } + + Class<?> componentType = serialClass; + while (componentType.isArray()) { + componentType = componentType.getComponentType(); + } + if (componentType.isPrimitive()) { + return ObjectInputFilter.Status.ALLOWED; + } + + return ALLOWED_CLASSES.contains(componentType.getName()) + ? ObjectInputFilter.Status.ALLOWED + : ObjectInputFilter.Status.REJECTED; + }; + } } diff --git a/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/test/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModelTest.java b/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/test/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModelTest.java index b7108a8b..d3b1087b 100644 --- a/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/test/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModelTest.java +++ b/opennlp-core/opennlp-ml/opennlp-ml-libsvm/src/test/java/opennlp/tools/ml/libsvm/doccat/SvmDoccatModelTest.java @@ -19,8 +19,12 @@ package opennlp.tools.ml.libsvm.doccat; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.File; import java.io.IOException; +import java.io.InvalidClassException; +import java.io.ObjectOutputStream; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -132,6 +136,132 @@ class SvmDoccatModelTest { SvmDoccatModel.deserialize(new ByteArrayInputStream(garbage))); } + @Test + void testDeserializeRejectsForeignClass() throws IOException { + // A well-formed Java serialization stream containing a class that is NOT + // part of the SvmDoccatModel graph. Without ObjectInputFilter, readObject() + // would fully deserialize the foreign object before the cast threw + // ClassCastException — the gadget-chain attack window. With the filter, + // the read fails at the class check. + byte[] payload = serialize(new File("/tmp/poc")); + InvalidClassException ex = assertThrows(InvalidClassException.class, () -> + SvmDoccatModel.deserialize(new ByteArrayInputStream(payload))); + assertTrue(ex.getMessage().contains("filter"), + "expected filter rejection, got: " + ex.getMessage()); + } + + @Test + void testDeserializeRejectsForeignCollectionType() throws IOException { + // java.util.HashMap is on the allow-list, but java.util.Hashtable is not. + // Confirms the filter does not blanket-allow java.util collections. + byte[] payload = serialize(new java.util.Hashtable<>(Map.of("k", "v"))); + assertThrows(InvalidClassException.class, () -> + SvmDoccatModel.deserialize(new ByteArrayInputStream(payload))); + } + + @Test + void testDeserializeRejectsUnrelatedSerializable() throws IOException { + // A user-defined Serializable that is not in the allow-list must be + // rejected even though it is structurally valid. + byte[] payload = serialize(new ForeignPayload("yolo")); + assertThrows(InvalidClassException.class, () -> + SvmDoccatModel.deserialize(new ByteArrayInputStream(payload))); + } + + @Test + void testDeserializeRejectsDeeplyNestedGraph() throws IOException { + // Build a chain of HashMaps deeper than the configured maxdepth (64). + // HashMap is on the allow-list, so this exercises the resource limits. + HashMap<String, Object> root = new HashMap<>(); + HashMap<String, Object> cursor = root; + for (int i = 0; i < 200; i++) { + HashMap<String, Object> next = new HashMap<>(); + cursor.put("n", next); + cursor = next; + } + byte[] payload = serialize(root); + assertThrows(InvalidClassException.class, () -> + SvmDoccatModel.deserialize(new ByteArrayInputStream(payload))); + } + + @Test + void testDeserializeWithCustomLimitsRoundTrips() throws IOException, ClassNotFoundException { + // The configurable overload accepts a real model with custom limits and + // round-trips it identically to the default-limits path. + SvmDoccatModel model = trainSimpleModel(); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + model.serialize(baos); + byte[] bytes = baos.toByteArray(); + + SvmDoccatModel.DeserializationLimits raised = + new SvmDoccatModel.DeserializationLimits(128, 50_000_000L, 50_000_000L); + SvmDoccatModel restored = SvmDoccatModel.deserialize( + new ByteArrayInputStream(bytes), raised); + + assertEquals(model.getLanguageCode(), restored.getLanguageCode()); + assertEquals(model.getNumberOfCategories(), restored.getNumberOfCategories()); + assertEquals(model.getFeatureVocabulary(), restored.getFeatureVocabulary()); + } + + @Test + void testDeserializeWithCustomLimitsAcceptsDeeperGraphs() throws IOException, ClassNotFoundException { + // A 100-deep HashMap chain exceeds the default maxDepth (64) but fits + // within a raised limit. Confirms the configuration knob actually moves. + HashMap<String, Object> root = new HashMap<>(); + HashMap<String, Object> cursor = root; + for (int i = 0; i < 100; i++) { + HashMap<String, Object> next = new HashMap<>(); + cursor.put("n", next); + cursor = next; + } + byte[] payload = serialize(root); + + // Default limits reject this graph... + assertThrows(InvalidClassException.class, () -> + SvmDoccatModel.deserialize(new ByteArrayInputStream(payload))); + + // ...but a wider depth limit lets it through far enough to fail at the + // final cast (HashMap is not an SvmDoccatModel) — proving the depth check + // is no longer the gate. + SvmDoccatModel.DeserializationLimits raised = + new SvmDoccatModel.DeserializationLimits(256, 5_000_000L, 10_000_000L); + assertThrows(ClassCastException.class, () -> + SvmDoccatModel.deserialize(new ByteArrayInputStream(payload), raised)); + } + + @Test + void testDeserializationLimitsValidatesArguments() { + assertThrows(IllegalArgumentException.class, () -> + new SvmDoccatModel.DeserializationLimits(0, 1, 1)); + assertThrows(IllegalArgumentException.class, () -> + new SvmDoccatModel.DeserializationLimits(1, 0, 1)); + assertThrows(IllegalArgumentException.class, () -> + new SvmDoccatModel.DeserializationLimits(1, 1, 0)); + } + + @Test + void testDeserializeRejectsNullLimits() { + byte[] empty = {}; + assertThrows(NullPointerException.class, () -> + SvmDoccatModel.deserialize(new ByteArrayInputStream(empty), null)); + } + + private static byte[] serialize(Object obj) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (ObjectOutputStream oos = new ObjectOutputStream(baos)) { + oos.writeObject(obj); + } + return baos.toByteArray(); + } + + private static class ForeignPayload implements java.io.Serializable { + private static final long serialVersionUID = 1L; + @SuppressWarnings("unused") + private final String name; + ForeignPayload(String name) { this.name = name; } + } + private SvmDoccatModel trainSimpleModel() throws IOException { List<DocumentSample> samples = new ArrayList<>(); samples.add(new DocumentSample("a", new String[]{"x", "y"}));
