This is an automated email from the ASF dual-hosted git repository. tballison pushed a commit to branch charset-ship-today in repository https://gitbox.apache.org/repos/asf/tika.git
commit 4d0c37bf90716c9eea0f86e6d6cd89c7a13e1041 Author: tallison <[email protected]> AuthorDate: Fri Apr 17 08:26:58 2026 -0400 ship today checkpoint --- .../ml/chardetect/MojibusterEncodingDetector.java | 49 ++- .../apache/tika/ml/chardetect/ScoredCandidate.java | 56 +++ .../tika/ml/chardetect/SpecialistOutput.java | 67 ++++ .../tika/ml/chardetect/StatisticalSpecialist.java | 28 +- .../ml/chardetect/StructuralEncodingRules.java | 34 +- .../ml/chardetect/Utf16ColumnFeatureExtractor.java | 240 ++++++++++++ .../Utf16SpecialistEncodingDetector.java | 344 +++++++++++++++++ .../tika/ml/chardetect/WideUnicodeDetector.java | 152 +------- ...apache.tika.ml.chardetect.StatisticalSpecialist | 4 + .../apache/tika/ml/chardetect/utf16-specialist.bin | Bin 0 -> 95 bytes .../ml/chardetect/ModelResourceUniquenessTest.java | 91 +++++ .../Utf16ColumnFeatureExtractorTest.java | 412 +++++++++++++++++++++ .../Utf16SpecialistEncodingDetectorTest.java | 369 ++++++++++++++++++ ...tf16SpecialistEncodingDetectorTestFixtures.java | 69 ++++ .../java/org/apache/tika/ml/FeatureExtractor.java | 23 ++ .../main/java/org/apache/tika/ml/LinearModel.java | 146 +++++++- .../apache/tika/ml/LinearModelCalibrationTest.java | 145 ++++++++ 17 files changed, 2066 insertions(+), 163 deletions(-) diff --git a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/MojibusterEncodingDetector.java b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/MojibusterEncodingDetector.java index e14d4c6b84..c650284f53 100644 --- a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/MojibusterEncodingDetector.java +++ b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/MojibusterEncodingDetector.java @@ -191,6 +191,16 @@ public class MojibusterEncodingDetector implements EncodingDetector { private final ByteNgramFeatureExtractor extractor; private final EnumSet<Rule> enabledRules; private final int maxProbeBytes; + /** + * UTF-16 specialist. Replaces the legacy structural UTF-16 detection + * in {@link WideUnicodeDetector}: correctly distinguishes LE from BE + * for Latin, Cyrillic, Arabic, Hebrew, Indic, Thai, CJK Unified and + * Hangul alike — the last two of which the structural detector + * explicitly could not handle. Loaded eagerly at construction; the + * detector refuses to start if the specialist model is not on the + * classpath. + */ + private final Utf16SpecialistEncodingDetector utf16Specialist; /** * Load the model from its default classpath location with all rules enabled @@ -246,6 +256,15 @@ public class MojibusterEncodingDetector implements EncodingDetector { this.extractor = new ByteNgramFeatureExtractor(model.getNumBuckets()); this.enabledRules = rules.isEmpty() ? EnumSet.noneOf(Rule.class) : EnumSet.copyOf(rules); this.maxProbeBytes = maxProbeBytes; + try { + this.utf16Specialist = new Utf16SpecialistEncodingDetector(); + } catch (IOException e) { + throw new IllegalStateException( + "UTF-16 specialist model could not be loaded. Mojibuster " + + "refuses to run without it — silent no-op produces " + + "wrong answers. Ensure utf16-specialist.bin is on " + + "the classpath.", e); + } } /** @@ -384,9 +403,12 @@ public class MojibusterEncodingDetector implements EncodingDetector { public List<EncodingResult> detectAll(byte[] probe, int topN) { boolean gates = enabledRules.contains(Rule.STRUCTURAL_GATES); - // Wide-Unicode analysis: positive detection and/or invalidity flags. - // Must run BEFORE isPureAscii: scripts like Cyrillic in UTF-16-LE have - // all bytes < 0x80 with no nulls, so isPureAscii would misclassify them. + // Wide-Unicode analysis: UTF-32 positive detection + UTF-16 surrogate + // invalidity flags. UTF-16 positive detection is delegated to the + // trained Utf16 specialist below (which handles CJK/Hangul that the + // structural detector cannot). Must run BEFORE isPureAscii: scripts + // like Cyrillic in UTF-16-LE have all bytes < 0x80 with no nulls, so + // isPureAscii would misclassify them. WideUnicodeDetector.Result wideResult = gates ? WideUnicodeDetector.analyze(probe) : WideUnicodeDetector.Result.EMPTY; @@ -395,6 +417,27 @@ public class MojibusterEncodingDetector implements EncodingDetector { EncodingResult.ResultType.STRUCTURAL, topN); } + // UTF-16 specialist: evidence-based column-asymmetry prefilter (the + // conservative "true-on-short-probe" default used for the main SBCS + // model's negative gate is wrong here — absence of evidence must + // mean "not UTF-16"), then a trained maxent over per-column + // byte-range counts decides LE vs BE. Refuses if the chosen + // endianness is surrogate-invalid. + if (gates && StructuralEncodingRules.has2ByteColumnAsymmetryEvidence(probe)) { + List<EncodingResult> utf16 = utf16Specialist.detect(probe); + if (!utf16.isEmpty()) { + EncodingResult er = utf16.get(0); + String name = er.getCharset().name(); + boolean invalid = + ("UTF-16LE".equals(name) && wideResult.invalidUtf16Le) + || ("UTF-16BE".equals(name) && wideResult.invalidUtf16Be); + if (!invalid) { + return singleResult(name, 1.0f, + EncodingResult.ResultType.STRUCTURAL, topN); + } + } + } + if (gates) { // Structural rules: byte-grammar proof (ISO-2022, sparse UTF-8). Charset structural = applyStructuralRules(probe); diff --git a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/ScoredCandidate.java b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/ScoredCandidate.java new file mode 100644 index 0000000000..60564bf884 --- /dev/null +++ b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/ScoredCandidate.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.chardetect; + +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; + +/** + * Pooled candidate from {@link LogLinearCombiner}: label, raw summed score + * (larger is better, not normalized), and the specialists that contributed. + */ +public final class ScoredCandidate { + + private final String label; + private final float score; + private final Set<String> contributingSpecialists; + + public ScoredCandidate(String label, float score, Set<String> contributingSpecialists) { + this.label = label; + this.score = score; + this.contributingSpecialists = + Collections.unmodifiableSet(new LinkedHashSet<>(contributingSpecialists)); + } + + public String getLabel() { + return label; + } + + public float getScore() { + return score; + } + + public Set<String> getContributingSpecialists() { + return contributingSpecialists; + } + + @Override + public String toString() { + return "ScoredCandidate{" + label + "=" + score + " from " + contributingSpecialists + "}"; + } +} diff --git a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/SpecialistOutput.java b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/SpecialistOutput.java new file mode 100644 index 0000000000..debb56cbad --- /dev/null +++ b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/SpecialistOutput.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.chardetect; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Raw per-class logits from a single MoE specialist. Labels the specialist + * doesn't cover are absent from the map (no OTHER class). Logits are raw + * (pre-softmax); pooling happens in the combiner. + */ +public final class SpecialistOutput { + + private final String specialistName; + private final Map<String, Float> classLogits; + + public SpecialistOutput(String specialistName, Map<String, Float> classLogits) { + if (specialistName == null) { + throw new IllegalArgumentException("specialistName is required"); + } + if (classLogits == null) { + throw new IllegalArgumentException("classLogits is required"); + } + this.specialistName = specialistName; + this.classLogits = Collections.unmodifiableMap(new LinkedHashMap<>(classLogits)); + } + + public String getSpecialistName() { + return specialistName; + } + + public Map<String, Float> getClassLogits() { + return classLogits; + } + + public Iterable<String> getCoveredLabels() { + return classLogits.keySet(); + } + + /** + * Raw logit for {@code label}, or {@code null} if not covered. + */ + public Float getLogit(String label) { + return classLogits.get(label); + } + + @Override + public String toString() { + return "SpecialistOutput{" + specialistName + "=" + classLogits + "}"; + } +} diff --git a/tika-ml/tika-ml-core/src/main/java/org/apache/tika/ml/FeatureExtractor.java b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/StatisticalSpecialist.java similarity index 56% copy from tika-ml/tika-ml-core/src/main/java/org/apache/tika/ml/FeatureExtractor.java copy to tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/StatisticalSpecialist.java index 33aff831b5..39594da81f 100644 --- a/tika-ml/tika-ml-core/src/main/java/org/apache/tika/ml/FeatureExtractor.java +++ b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/StatisticalSpecialist.java @@ -14,27 +14,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.tika.ml; +package org.apache.tika.ml.chardetect; /** - * Generic feature extractor that maps an input of type {@code T} to a - * fixed-length integer feature vector suitable for a {@link LinearModel}. - * - * @param <T> the raw input type (e.g. {@code String} for text, {@code byte[]} - * for raw bytes) + * SPI contract for an MoE charset-detection specialist. Discovered via + * {@link java.util.ServiceLoader} at + * {@code META-INF/services/org.apache.tika.ml.chardetect.StatisticalSpecialist}. + * Implementations must be thread-safe. */ -public interface FeatureExtractor<T> { +public interface StatisticalSpecialist { /** - * Extract features from the given input. - * - * @param input raw input (may be {@code null}) - * @return int array of length {@link #getNumBuckets()} with feature counts + * Short name: {@code "utf16"}, {@code "sbcs"}, etc. */ - int[] extract(T input); + String getName(); - /** - * @return number of hash buckets (feature-vector dimension) - */ - int getNumBuckets(); + /** Per-class logits for the probe, or {@code null} to decline + * (probe too short, hard-gated, etc.). Declining contributes nothing; + * a low-scoring result contributes weak signal. */ + SpecialistOutput score(byte[] probe); } diff --git a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/StructuralEncodingRules.java b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/StructuralEncodingRules.java index beaffc7475..e7de52ad82 100644 --- a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/StructuralEncodingRules.java +++ b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/StructuralEncodingRules.java @@ -309,6 +309,38 @@ public final class StructuralEncodingRules { if (bytes == null || bytes.length < MIN_COLUMN_ASYMMETRY_PROBE) { return true; } + return computeColumnAsymmetry(bytes); + } + + /** + * Evidence-based variant of {@link #has2ByteColumnAsymmetry} with no + * conservative short-probe default: returns {@code true} only when the + * bytes themselves demonstrate column asymmetry, regardless of probe + * length. Use this to gate <em>positive</em> UTF-16 detection (e.g. + * invoking {@code Utf16SpecialistEncodingDetector}), where absence of + * evidence must mean "not UTF-16", not "unknown". + * + * <p>Rejects probes below {@value #MIN_COLUMN_EVIDENCE_PROBE} bytes + * outright: with fewer than 8 pairs, column-distinct counts don't + * discriminate any UTF-16 variant from legacy double-byte encodings + * like GBK or Shift_JIS, which also have constrained lead-byte columns + * on short samples.</p> + */ + public static boolean has2ByteColumnAsymmetryEvidence(byte[] bytes) { + if (bytes == null || bytes.length < MIN_COLUMN_EVIDENCE_PROBE) { + return false; + } + return computeColumnAsymmetry(bytes); + } + + /** + * Minimum bytes required for {@link #has2ByteColumnAsymmetryEvidence}. + * Below this, legacy CJK double-byte encodings (GBK, Shift_JIS) can + * produce apparent column asymmetry indistinguishable from UTF-16. + */ + private static final int MIN_COLUMN_EVIDENCE_PROBE = 16; + + private static boolean computeColumnAsymmetry(byte[] bytes) { int sample = Math.min(bytes.length, 4096); boolean[] evenSeen = new boolean[256]; boolean[] oddSeen = new boolean[256]; @@ -330,7 +362,7 @@ public final class StructuralEncodingRules { } int min = Math.min(evenDistinct, oddDistinct); int max = Math.max(evenDistinct, oddDistinct); - return max >= min * 3; + return min > 0 && max >= min * 3; } public static boolean checkIbm424(byte[] bytes, int offset, int length) { diff --git a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/Utf16ColumnFeatureExtractor.java b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/Utf16ColumnFeatureExtractor.java new file mode 100644 index 0000000000..d487766534 --- /dev/null +++ b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/Utf16ColumnFeatureExtractor.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.chardetect; + +import org.apache.tika.ml.FeatureExtractor; + +/** + * Feature extractor for the UTF-16 specialist of the mixture-of-experts + * charset detector. Produces a small, dense, position-aware feature vector + * that is <strong>immune to HTML markup by construction</strong>: features + * capture the 2-byte alignment asymmetry that UTF-16 content produces and + * HTML content (which has no 2-byte alignment) cannot. + * + * <h3>Feature vector</h3> + * + * <p>12 dense integer features: byte counts across six byte-value ranges, + * split by column (even-offset vs odd-offset in the probe). Indexing:</p> + * + * <table> + * <tr><th>Index</th><th>Feature</th></tr> + * <tr><td>0</td><td>count_even(0x00)</td></tr> + * <tr><td>1</td><td>count_odd(0x00)</td></tr> + * <tr><td>2</td><td>count_even(0x01-0x1F, excluding 0x09/0x0A/0x0D)</td></tr> + * <tr><td>3</td><td>count_odd(0x01-0x1F, excluding 0x09/0x0A/0x0D)</td></tr> + * <tr><td>4</td><td>count_even(0x20-0x7E, plus 0x09, 0x0A, 0x0D)</td></tr> + * <tr><td>5</td><td>count_odd(0x20-0x7E, plus 0x09, 0x0A, 0x0D)</td></tr> + * <tr><td>6</td><td>count_even(0x7F)</td></tr> + * <tr><td>7</td><td>count_odd(0x7F)</td></tr> + * <tr><td>8</td><td>count_even(0x80-0x9F)</td></tr> + * <tr><td>9</td><td>count_odd(0x80-0x9F)</td></tr> + * <tr><td>10</td><td>count_even(0xA0-0xFF)</td></tr> + * <tr><td>11</td><td>count_odd(0xA0-0xFF)</td></tr> + * </table> + * + * <h3>Why this is HTML-immune</h3> + * + * <p>HTML has no 2-byte alignment — tags are variable-length ({@code <br>} + * is 4 bytes, {@code <div>} is 5, {@code </span>} is 7), entities and + * whitespace are arbitrary. Under random byte-offset content, any byte + * range has equal expected frequency at even vs odd positions. The + * maxent model pairing this extractor learns weights that reward column + * asymmetry: HTML produces near-zero asymmetry on every range → + * near-zero contribution to every UTF-16 class logit.</p> + * + * <p>UTF-16 has strict 2-byte alignment by definition. The "high byte" of + * every codepoint lands in one column, the "low byte" in the other. This + * alignment cannot be faked by non-UTF-16 content without deliberately + * constructing 2-byte-aligned patterns, which organic text content never + * does.</p> + * + * <h3>Why raw counts instead of asymmetry ratios</h3> + * + * <p>The maxent model learns asymmetry weights naturally from raw counts: + * a positive weight on {@code count_even(X)} paired with a negative weight + * on {@code count_odd(X)} produces a dot-product proportional to + * {@code count_even(X) - count_odd(X)}, which IS the asymmetry signal up + * to normalization. Explicit asymmetry features would add redundancy + * without adding information.</p> + * + * <h3>What it doesn't do</h3> + * + * <ul> + * <li>No UTF-32 detection. UTF-32 stays structural (4-byte alignment + * check) and doesn't need a statistical model.</li> + * <li>No discrimination between UTF-16 content languages (Japanese vs + * Chinese vs Korean). CharSoup's language scoring handles that + * after decoding. The UTF-16 specialist returns only + * {@code UTF-16-LE} or {@code UTF-16-BE}.</li> + * <li>No BOM handling — the caller is responsible for stripping BOM + * before feeding bytes to this extractor.</li> + * </ul> + * + * @see org.apache.tika.ml.LinearModel + */ +public class Utf16ColumnFeatureExtractor implements FeatureExtractor<byte[]> { + + /** Number of byte-value ranges tracked. */ + public static final int NUM_RANGES = 6; + + /** Number of columns (even-offset vs odd-offset). */ + public static final int NUM_COLUMNS = 2; + + /** Total feature-vector dimension: ranges * columns. */ + public static final int NUM_FEATURES = NUM_RANGES * NUM_COLUMNS; + + /** + * Precomputed byte-to-range-index lookup. Populated at class init. + * Ranges chosen to cover all UTF-16 high-byte distributions: + * <ul> + * <li>Range 0 — 0x00: null column (UTF-16 Latin signal)</li> + * <li>Range 1 — 0x01-0x1F excluding 0x09/0x0A/0x0D: C0 controls + * (non-Latin BMP scripts have their high byte here: Cyrillic + * 0x04, Greek 0x03, Hebrew 0x05, Arabic 0x06, Thai 0x0E)</li> + * <li>Range 2 — 0x20-0x7E + 0x09/0x0A/0x0D: printable ASCII + common + * whitespace (UTF-16 Latin text column + CJK low bytes + HTML + * content)</li> + * <li>Range 3 — 0x7F: DEL (rare)</li> + * <li>Range 4 — 0x80-0x9F: C1 controls; UTF-16 CJK high byte for + * codepoints U+8000-U+9FFF. <strong>HTML never emits these + * bytes</strong> — a crucial HTML-uncontaminable signal.</li> + * <li>Range 5 — 0xA0-0xFF: extended Latin high bytes, CJK + * codepoints U+A000+.</li> + * </ul> + */ + private static final int[] RANGE_OF_BYTE = new int[256]; + + static { + for (int b = 0; b < 256; b++) { + if (b == 0x00) { + RANGE_OF_BYTE[b] = 0; + } else if (b < 0x20 && b != 0x09 && b != 0x0A && b != 0x0D) { + RANGE_OF_BYTE[b] = 1; + } else if (b <= 0x7E) { // includes 0x09, 0x0A, 0x0D (not in range 1) and 0x20-0x7E + RANGE_OF_BYTE[b] = 2; + } else if (b == 0x7F) { + RANGE_OF_BYTE[b] = 3; + } else if (b <= 0x9F) { + RANGE_OF_BYTE[b] = 4; + } else { + RANGE_OF_BYTE[b] = 5; + } + } + } + + @Override + public int[] extract(byte[] input) { + int[] counts = new int[NUM_FEATURES]; + if (input == null || input.length == 0) { + return counts; + } + extractInto(input, 0, input.length, counts); + return counts; + } + + /** + * Extract from a sub-range of a byte array. + */ + public int[] extract(byte[] input, int offset, int length) { + int[] counts = new int[NUM_FEATURES]; + if (input == null || length == 0) { + return counts; + } + extractInto(input, offset, offset + length, counts); + return counts; + } + + /** + * Sparse extraction into caller-owned, reusable buffers. For this + * small dense vector, "sparse" just means "write non-zero feature + * indices into {@code touched}". Buckets with zero count are not + * listed. + * + * @param input raw bytes + * @param dense scratch buffer of length {@link #NUM_FEATURES}, + * all-zeros on entry; caller clears used entries afterwards + * @param touched buffer receiving indices of non-zero features + * @return number of entries written into {@code touched} + */ + public int extractSparseInto(byte[] input, int[] dense, int[] touched) { + if (input == null || input.length == 0) { + return 0; + } + extractInto(input, 0, input.length, dense); + int n = 0; + for (int i = 0; i < NUM_FEATURES; i++) { + if (dense[i] != 0) { + touched[n++] = i; + } + } + return n; + } + + private static void extractInto(byte[] b, int from, int to, int[] counts) { + for (int i = from; i < to; i++) { + int v = b[i] & 0xFF; + int range = RANGE_OF_BYTE[v]; + int column = (i - from) & 1; // 0 = even offset within probe, 1 = odd + counts[range * NUM_COLUMNS + column]++; + } + } + + @Override + public int getNumBuckets() { + return NUM_FEATURES; + } + + /** Human-readable label for feature index {@code i} (for debugging). */ + public static String featureLabel(int i) { + if (i < 0 || i >= NUM_FEATURES) { + return "(invalid: " + i + ")"; + } + int range = i / NUM_COLUMNS; + int column = i % NUM_COLUMNS; + String rangeName; + switch (range) { + case 0: + rangeName = "0x00"; + break; + case 1: + rangeName = "0x01-1F-nws"; + break; + case 2: + rangeName = "0x20-7E+tab/lf/cr"; + break; + case 3: + rangeName = "0x7F"; + break; + case 4: + rangeName = "0x80-9F"; + break; + case 5: + rangeName = "0xA0-FF"; + break; + default: + rangeName = "?"; + break; + } + String columnName = (column == 0) ? "even" : "odd"; + return "count_" + columnName + "(" + rangeName + ")"; + } + + @Override + public String toString() { + return "Utf16ColumnFeatureExtractor{features=" + NUM_FEATURES + "}"; + } +} diff --git a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/Utf16SpecialistEncodingDetector.java b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/Utf16SpecialistEncodingDetector.java new file mode 100644 index 0000000000..e72c883b7f --- /dev/null +++ b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/Utf16SpecialistEncodingDetector.java @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.chardetect; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.io.IOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.tika.config.TikaComponent; +import org.apache.tika.detect.EncodingDetector; +import org.apache.tika.detect.EncodingResult; +import org.apache.tika.io.TikaInputStream; +import org.apache.tika.metadata.Metadata; +import org.apache.tika.ml.LinearModel; +import org.apache.tika.parser.ParseContext; + +/** + * UTF-16 specialist detector of the mixture-of-experts charset detection + * architecture. Uses a tiny dense-feature maxent model paired with + * {@link Utf16ColumnFeatureExtractor} to produce a column-asymmetry-based + * judgment of UTF-16-LE vs UTF-16-BE. + * + * <h3>HTML-immune by construction</h3> + * + * <p>The feature set the model consumes (12 per-column byte-range counts) + * captures the 2-byte alignment asymmetry that UTF-16 content produces and + * HTML content cannot — HTML has no 2-byte alignment, so any byte range + * appears with equal expected frequency at even vs odd positions. No + * amount of HTML markup can fire this specialist. See + * {@link Utf16ColumnFeatureExtractor} for the detailed argument.</p> + * + * <h3>Stage 1 of the MoE migration</h3> + * + * <p>Runs alongside the existing {@code MojibusterEncodingDetector} + * rather than replacing any piece of it. Emits a single + * {@link EncodingResult.ResultType#STATISTICAL} candidate for CharSoup to + * arbitrate against the other detectors in the chain. The existing + * {@code WideUnicodeDetector}-based structural UTF-16 detection inside + * Mojibuster is not removed yet — both can operate in parallel during + * Stage 1 validation.</p> + * + * <h3>Model loading</h3> + * + * <p>The default constructor loads a trained model from the classpath at + * {@link #DEFAULT_MODEL_RESOURCE}. If the resource is absent or + * malformed, construction throws {@link IOException} — the detector + * never operates in a no-op state because silent no-ops produce wrong + * answers without any indication that something's wrong. Deploy the + * detector only when a trained model is bundled; remove it from the + * chain otherwise.</p> + * + * <h3>Probe size</h3> + * + * <p>Reads up to {@link #MAX_PROBE_BYTES} bytes. UTF-16 column-asymmetry + * signal stabilises quickly — even ~100 bytes is usually enough for a + * strong call. Default 512 is generous.</p> + */ +@TikaComponent(spi = false) +public class Utf16SpecialistEncodingDetector + implements EncodingDetector, StatisticalSpecialist { + + private static final Logger LOG = + LoggerFactory.getLogger(Utf16SpecialistEncodingDetector.class); + + /** + * Default classpath resource for the trained UTF-16 specialist model. + * Missing resource → detector is a noop (logged once at construction). + */ + public static final String DEFAULT_MODEL_RESOURCE = + "/org/apache/tika/ml/chardetect/utf16-specialist.bin"; + + /** Default number of probe bytes read. */ + public static final int MAX_PROBE_BYTES = 512; + + /** + * Minimum raw-logit margin (winner − loser) required to return a + * candidate via the standalone {@link #detect} path. + */ + private static final float MIN_LOGIT_MARGIN = 1.0f; + + /** + * Minimum probe length in bytes to attempt UTF-16 classification. + * Column-asymmetry features on 2-6 byte probes are dominated by + * noise — one stray null at even position pushes LE features hard. + * 8 bytes (4 pairs) matches the old structural {@code WideUnicodeDetector} + * threshold and is enough for the learned asymmetry boundary to separate + * real UTF-16 Latin ("a\0b\0c\0d\0") from coincidence. + */ + private static final int MIN_PROBE_BYTES = 8; + + + /** + * Maximum confidence emitted on {@code STATISTICAL} results. Kept + * below 1.0 so {@code CharSoupEncodingDetector} never mistakes a + * model output for a {@code DECLARATIVE} / {@code STRUCTURAL} + * result. + */ + private static final float MAX_STATISTICAL_CONFIDENCE = 0.99f; + + private final LinearModel model; + private final Utf16ColumnFeatureExtractor extractor; + private final int maxProbeBytes; + + /** + * Load the model from the default classpath location. + * + * @throws IOException if the model resource is missing or malformed — + * the detector does not operate in a no-op state. + */ + public Utf16SpecialistEncodingDetector() throws IOException { + this(loadModel(DEFAULT_MODEL_RESOURCE), MAX_PROBE_BYTES); + } + + /** + * {@link java.util.ServiceLoader}-compatible provider method. Wraps + * the checked {@link IOException} from the no-arg constructor in a + * {@link java.util.ServiceConfigurationError} so the arbiter can catch + * it and skip a specialist whose model is not bundled — without + * hiding the cause. + */ + public static Utf16SpecialistEncodingDetector provider() { + try { + return new Utf16SpecialistEncodingDetector(); + } catch (IOException e) { + throw new java.util.ServiceConfigurationError( + "UTF-16 specialist model not available: " + e.getMessage(), e); + } + } + + /** + * Package-visible constructor for tests. + */ + Utf16SpecialistEncodingDetector(LinearModel model, int maxProbeBytes) { + if (model == null) { + throw new IllegalArgumentException( + "UTF-16 specialist model is required; pass a valid " + + "LinearModel or use the classpath-loading constructor"); + } + validateModel(model); + this.model = model; + this.extractor = new Utf16ColumnFeatureExtractor(); + this.maxProbeBytes = maxProbeBytes; + } + + private static LinearModel loadModel(String resourcePath) throws IOException { + try (InputStream is = + Utf16SpecialistEncodingDetector.class.getResourceAsStream(resourcePath)) { + if (is == null) { + throw new IOException( + "UTF-16 specialist model resource not found at " + + resourcePath + ". The specialist must be trained " + + "and the model file bundled on the classpath before " + + "this detector can be instantiated. Either bundle " + + "the trained model or remove this detector from the " + + "encoding-detector chain."); + } + return LinearModel.load(is); + } + } + + private static void validateModel(LinearModel model) { + if (model.getNumBuckets() != Utf16ColumnFeatureExtractor.NUM_FEATURES) { + throw new IllegalArgumentException( + "UTF-16 specialist model has " + model.getNumBuckets() + + " buckets but extractor expects " + + Utf16ColumnFeatureExtractor.NUM_FEATURES); + } + if (model.getNumClasses() != 2) { + throw new IllegalArgumentException( + "UTF-16 specialist model must have exactly 2 classes " + + "(UTF-16-LE, UTF-16-BE), found " + + model.getNumClasses()); + } + } + + /** + * Specialist name used in {@link SpecialistOutput} for provenance. + */ + public static final String SPECIALIST_NAME = "utf16"; + + @Override + public String getName() { + return SPECIALIST_NAME; + } + + /** + * {@link StatisticalSpecialist} entry point: raw per-class logits, + * or {@code null} for a probe too short to evaluate (fewer than 2 + * bytes) or missing a model. Returning {@code null} declines to + * contribute; an all-low logit vector would muddy the combiner. + * + * <p>Unlike {@link #detect}, this method does not apply a margin + * threshold — downstream pooling sees raw logits for both classes.</p> + */ + @Override + public SpecialistOutput score(byte[] probe) { + // score() returns raw logits for the MoE combiner; MIN_PROBE_BYTES + // applies only to the standalone detect() path where we emit a + // charset decision. The combiner is responsible for deciding + // whether the margin is large enough to trust on short probes. + if (probe == null || probe.length < 2) { + return null; + } + int len = Math.min(probe.length, maxProbeBytes); + int[] features = extractor.extract(probe, 0, len); + float[] logits = model.predictCalibratedLogits(features); + Map<String, Float> classLogits = new LinkedHashMap<>(2); + for (int c = 0; c < logits.length; c++) { + classLogits.put(model.getLabel(c), logits[c]); + } + return new SpecialistOutput(SPECIALIST_NAME, classLogits); + } + + /** + * Convenience: mark/reset the stream, read a probe, and score it. + * Returns {@code null} if the probe is too short. + */ + public SpecialistOutput score(TikaInputStream tis) throws IOException { + byte[] probe = readProbe(tis); + return score(probe); + } + + /** + * @deprecated use {@link #score(byte[])}. Kept for existing tests. + */ + @Deprecated + public SpecialistOutput scoreBytes(byte[] probe) { + return score(probe); + } + + @Override + public List<EncodingResult> detect(TikaInputStream tis, Metadata metadata, + ParseContext parseContext) throws IOException { + return detect(readProbe(tis)); + } + + /** + * Byte-array entry point for callers that already hold a probe + * (e.g. {@link MojibusterEncodingDetector}'s pipeline). Returns an + * empty list for probes below {@link #MIN_PROBE_BYTES} or when the + * winning class has margin < {@link #MIN_LOGIT_MARGIN}. + */ + public List<EncodingResult> detect(byte[] probe) { + if (probe == null || probe.length < MIN_PROBE_BYTES) { + return Collections.emptyList(); + } + int len = Math.min(probe.length, maxProbeBytes); + int[] features = extractor.extract(probe, 0, len); + float[] logits = model.predictLogits(features); + + int winnerIdx = 0; + int loserIdx = 1; + if (logits[1] > logits[0]) { + winnerIdx = 1; + loserIdx = 0; + } + float margin = logits[winnerIdx] - logits[loserIdx]; + if (margin < MIN_LOGIT_MARGIN) { + // No confident winner — probe is either not UTF-16 or too + // ambiguous between LE and BE. + return Collections.emptyList(); + } + + String label = model.getLabel(winnerIdx); + Charset charset; + try { + charset = Charset.forName(toJavaCharsetName(label)); + } catch (Exception e) { + LOG.debug("Unknown charset from UTF-16 model label '{}'", label, e); + return Collections.emptyList(); + } + float confidence = confidenceFromMargin(margin); + return List.of(new EncodingResult(charset, confidence, label, + EncodingResult.ResultType.STATISTICAL)); + } + + private byte[] readProbe(TikaInputStream tis) throws IOException { + tis.mark(maxProbeBytes); + byte[] buf = new byte[maxProbeBytes]; + try { + int n = IOUtils.read(tis, buf); + if (n < buf.length) { + byte[] trimmed = new byte[n]; + System.arraycopy(buf, 0, trimmed, 0, n); + return trimmed; + } + return buf; + } finally { + tis.reset(); + } + } + + /** + * Map training-label charset names (e.g. {@code "UTF-16-LE"} with + * hyphens) to Java's canonical charset names ({@code "UTF-16LE"} no + * hyphen). Mirrors the mapping in {@link MojibusterEncodingDetector}. + */ + private static String toJavaCharsetName(String label) { + switch (label) { + case "UTF-16-LE": + return "UTF-16LE"; + case "UTF-16-BE": + return "UTF-16BE"; + default: + return label; + } + } + + /** + * Map a raw-logit margin to a 0..{@link #MAX_STATISTICAL_CONFIDENCE} + * confidence via a sigmoid-like squash. The specific function is a + * tunable mapping — what matters is that larger margins produce higher + * confidences and the output stays in the valid range. + */ + private static float confidenceFromMargin(float margin) { + // Sigmoid centred at 0: f(0) = 0.5, f(large) -> 1.0. + // We'll steer f so that margin=1 maps to ~0.73, margin=5 maps to ~0.99. + float s = (float) (1.0 / (1.0 + Math.exp(-margin))); + return Math.min(s, MAX_STATISTICAL_CONFIDENCE); + } + +} diff --git a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/WideUnicodeDetector.java b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/WideUnicodeDetector.java index bf2721f945..76e0102082 100644 --- a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/WideUnicodeDetector.java +++ b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/java/org/apache/tika/ml/chardetect/WideUnicodeDetector.java @@ -17,14 +17,11 @@ package org.apache.tika.ml.chardetect; import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; /** - * Structural analysis for UTF-16 LE/BE and UTF-32 LE/BE based on - * byte-position patterns. This is an internal component of - * {@link MojibusterEncodingDetector}'s pipeline — not a standalone - * {@code EncodingDetector}. It intentionally does not handle CJK UTF-16 - * (which falls through to the statistical model) and requires upstream + * Structural analysis for UTF-32 LE/BE, plus UTF-16 surrogate validity + * flags. This is an internal component of {@link MojibusterEncodingDetector}'s + * pipeline — not a standalone {@code EncodingDetector}. Requires upstream * BOM stripping. * * <h3>UTF-32</h3> @@ -34,30 +31,17 @@ import java.nio.charset.StandardCharsets; * non-UTF-32 data almost always produces out-of-range values immediately. * Inspired by ICU4J's {@code CharsetRecog_UTF_32}.</p> * - * <h3>UTF-16</h3> - * <p>Two phases, each targeting a different script family: - * <ol> - * <li><strong>Null-column</strong> — Latin/ASCII BMP content: one byte - * column (even or odd positions at stride-2) has a high null rate. - * Safe: no legacy encoding produces alternating nulls.</li> - * <li><strong>Low-block-prefix</strong> — scripts whose UTF-16 high byte - * is below {@code 0x20} (Cyrillic 0x04, Arabic 0x06, Hebrew 0x05, - * Devanagari 0x09, Bengali 0x09, Thai 0x0E, etc.): the constrained - * column has all non-null values below {@code 0x20}, the other column - * is more diverse. Safe: Big5/Shift-JIS/GBK lead bytes are always - * ≥ 0x81.</li> - * </ol> - * - * <p>CJK Unified (block prefix 0x4E–0x9F) and Hangul (0xAC–0xD7) are - * intentionally not handled — their block prefixes overlap with - * Big5/Shift-JIS/GBK lead bytes (0x81+) and with ISO-2022-JP JIS row - * bytes, making structural discrimination unsafe. Those cases fall - * through to the statistical model.</p> - * - * <p>In addition to positive detection, {@link Result} carries surrogate- - * invalidity flags for each endianness. When no positive detection fires, - * these flags allow the caller to suppress UTF-16 model predictions for - * probes that are structurally impossible as UTF-16.</p> + * <h3>UTF-16 surrogate validation</h3> + * <p>UTF-16 positive detection is handled by + * {@link Utf16SpecialistEncodingDetector}, which uses a trained maxent + * model over per-column byte-range counts and correctly distinguishes + * LE from BE for Latin, Cyrillic, Arabic, Hebrew, Indic, Thai, CJK + * Unified, and Hangul content alike. This class only performs surrogate- + * invalidity validation: {@link Result#invalidUtf16Be} and + * {@link Result#invalidUtf16Le} carry whether the probe contains + * structurally impossible UTF-16 surrogate sequences under each + * endianness, so callers can suppress UTF-16 labels from statistical + * models when the bytes cannot be valid UTF-16.</p> * * <p>All methods are stateless and safe to call from multiple threads.</p> */ @@ -190,33 +174,13 @@ final class WideUnicodeDetector { // ----------------------------------------------------------------------- /** - * Null-column threshold: the null rate in one column must exceed - * {@code 1 / NULL_DENOM} of pairs. Set to 4 (25%) to avoid false - * positives on OLE2 and bzip2 which have 12–20% null at one column. - * Real Latin UTF-16 has >90% null in the null column. - */ - private static final int NULL_DENOM = 4; - - /** - * Variety-ratio minimum: the diverse column must have at least this - * many times more distinct values than the constrained column. - */ - private static final double VARIETY_RATIO = 2.0; - - /** - * The constrained column must have fewer than this fraction of pairs - * as distinct values. Guards against uniformly random data. - */ - private static final double CONSTRAINED_MAX_RATIO = 0.40; - - /** - * Upper bound for the low-block-prefix phase. Scripts with UTF-16 high - * bytes below this value are safely distinguishable from all legacy CJK - * lead bytes (which start at 0x81). + * Surrogate-validation scan over {@code length} bytes starting at + * {@code offset}. Does not attempt UTF-16 positive detection — that is + * the job of {@link Utf16SpecialistEncodingDetector}. Returns only + * surrogate-invalidity flags under each endianness, used by + * {@link MojibusterEncodingDetector} to suppress UTF-16 labels from + * the main statistical model on probes that cannot be valid UTF-16. */ - private static final int LOW_PREFIX_MAX = 0x20; - - private static Result tryUtf16(byte[] bytes, int offset, int length) { int sampleLen = (Math.min(length, 512) / 2) * 2; if (sampleLen < 8) { @@ -224,12 +188,6 @@ final class WideUnicodeDetector { } int pairs = sampleLen / 2; - int nullsAtEven = 0; - int nullsAtOdd = 0; - int[] countsEven = new int[256]; - int[] countsOdd = new int[256]; - - // Surrogate validation boolean awaitLowBe = false, awaitLowLe = false; boolean invalidBe = false, invalidLe = false; @@ -237,12 +195,6 @@ final class WideUnicodeDetector { int even = bytes[offset + p * 2] & 0xFF; int odd = bytes[offset + p * 2 + 1] & 0xFF; - if (even == 0) nullsAtEven++; - if (odd == 0) nullsAtOdd++; - countsEven[even]++; - countsOdd[odd]++; - - // UTF-16BE surrogate validation (high byte = even) if (!invalidBe) { if (awaitLowBe) { if (even >= 0xDC && even <= 0xDF) { @@ -259,7 +211,6 @@ final class WideUnicodeDetector { } } - // UTF-16LE surrogate validation (high byte = odd) if (!invalidLe) { if (awaitLowLe) { if (odd >= 0xDC && odd <= 0xDF) { @@ -279,70 +230,7 @@ final class WideUnicodeDetector { if (awaitLowBe) invalidBe = true; if (awaitLowLe) invalidLe = true; - int uniqueEven = countUnique(countsEven); - int uniqueOdd = countUnique(countsOdd); - - // Phase 1: null-column (Latin/ASCII BMP content) - boolean highEven = nullsAtEven * NULL_DENOM > pairs; - boolean highOdd = nullsAtOdd * NULL_DENOM > pairs; - if (highOdd && !highEven && !invalidLe) { - return new Result(StandardCharsets.UTF_16LE, invalidBe, false); - } - if (highEven && !highOdd && !invalidBe) { - return new Result(StandardCharsets.UTF_16BE, false, invalidLe); - } - - // Phase 2: low-block-prefix (Cyrillic, Arabic, Hebrew, Indic, Thai, …) - // The constrained column has all non-null values < 0x20. - // Safe: no legacy CJK lead byte is below 0x81. - double constrainedMax = pairs * CONSTRAINED_MAX_RATIO; - - // Check LE: odd column is constrained (block-prefix), even is diverse - if (!invalidLe - && allNonNullBelow(countsOdd, LOW_PREFIX_MAX) - && uniqueOdd <= constrainedMax - && (double) uniqueEven / uniqueOdd >= VARIETY_RATIO - && hasNonNull(countsOdd)) { - return new Result(StandardCharsets.UTF_16LE, invalidBe, false); - } - // Check BE: even column is constrained, odd is diverse - if (!invalidBe - && allNonNullBelow(countsEven, LOW_PREFIX_MAX) - && uniqueEven <= constrainedMax - && (double) uniqueOdd / uniqueEven >= VARIETY_RATIO - && hasNonNull(countsEven)) { - return new Result(StandardCharsets.UTF_16BE, false, invalidLe); - } - return new Result(null, invalidBe, invalidLe); } - // ----------------------------------------------------------------------- - // Helpers - // ----------------------------------------------------------------------- - - private static int countUnique(int[] counts) { - int n = 0; - for (int c : counts) { - if (c > 0) n++; - } - return n; - } - - /** True if every non-null byte value in {@code counts} is < {@code max}. */ - private static boolean allNonNullBelow(int[] counts, int max) { - for (int v = max; v < counts.length; v++) { - if (counts[v] > 0) return false; - } - return true; - } - - /** True if at least one non-null byte value has a positive count. */ - private static boolean hasNonNull(int[] counts) { - for (int v = 1; v < counts.length; v++) { - if (counts[v] > 0) return true; - } - return false; - } - } diff --git a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/resources/META-INF/services/org.apache.tika.ml.chardetect.StatisticalSpecialist b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/resources/META-INF/services/org.apache.tika.ml.chardetect.StatisticalSpecialist new file mode 100644 index 0000000000..2004264e18 --- /dev/null +++ b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/resources/META-INF/services/org.apache.tika.ml.chardetect.StatisticalSpecialist @@ -0,0 +1,4 @@ +# MoE statistical specialists bundled with the core mojibuster detector. +# Additional specialists (SBCS, extended EBCDIC, IBM-DOS-OEM) register themselves +# via their own META-INF/services file when their JAR is on the classpath. +org.apache.tika.ml.chardetect.Utf16SpecialistEncodingDetector diff --git a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/resources/org/apache/tika/ml/chardetect/utf16-specialist.bin b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/resources/org/apache/tika/ml/chardetect/utf16-specialist.bin new file mode 100644 index 0000000000..be48708ae6 Binary files /dev/null and b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/main/resources/org/apache/tika/ml/chardetect/utf16-specialist.bin differ diff --git a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/test/java/org/apache/tika/ml/chardetect/ModelResourceUniquenessTest.java b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/test/java/org/apache/tika/ml/chardetect/ModelResourceUniquenessTest.java new file mode 100644 index 0000000000..1368ff8f96 --- /dev/null +++ b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/test/java/org/apache/tika/ml/chardetect/ModelResourceUniquenessTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.chardetect; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.util.Collections; +import java.util.Enumeration; +import java.util.List; + +import org.apache.commons.io.IOUtils; +import org.junit.jupiter.api.Test; + +/** + * Belt-and-suspenders check for a failure mode we've been burned by: + * a test-tree copy of a model file shadowing the production copy and + * quietly producing wrong eval numbers. These tests assert there is + * exactly one copy of each specialist's model on the classpath, so + * accidentally planting a second (test or stale) copy fails the build + * immediately instead of at eval time. + */ +public class ModelResourceUniquenessTest { + + private static final String UTF16_RESOURCE = + "org/apache/tika/ml/chardetect/utf16-specialist.bin"; + + private static List<URL> findAll(String resource) throws IOException { + Enumeration<URL> urls = + Thread.currentThread().getContextClassLoader().getResources(resource); + return Collections.list(urls); + } + + @Test + public void utf16ModelResourceIsUnique() throws IOException { + List<URL> urls = findAll(UTF16_RESOURCE); + assertEquals(1, urls.size(), + "Expected exactly one copy of " + UTF16_RESOURCE + + " on the classpath, found: " + urls); + } + + @Test + public void specialistConstructorLoadsSameBytesAsClasspathResource() + throws IOException { + // The specialist classes load via their own DEFAULT_MODEL_RESOURCE + // constants. If those constants ever drift from the production + // resource path, both the md5 match and the load would succeed but + // point at different files. Assert bytes-equal. + byte[] utf16ResourceBytes; + try (InputStream is = Thread.currentThread().getContextClassLoader() + .getResourceAsStream(UTF16_RESOURCE)) { + assertNotNull(is, "classpath missing " + UTF16_RESOURCE); + utf16ResourceBytes = IOUtils.toByteArray(is); + } + byte[] utf16ViaConstant; + try (InputStream is = Utf16SpecialistEncodingDetector.class + .getResourceAsStream( + Utf16SpecialistEncodingDetector.DEFAULT_MODEL_RESOURCE)) { + assertNotNull(is, "constant resolves to null: " + + Utf16SpecialistEncodingDetector.DEFAULT_MODEL_RESOURCE); + utf16ViaConstant = IOUtils.toByteArray(is); + } + assertArraysEqual(utf16ResourceBytes, utf16ViaConstant, + "UTF-16 model loaded via DEFAULT_MODEL_RESOURCE differs from " + + "classpath " + UTF16_RESOURCE); + } + + private static void assertArraysEqual(byte[] a, byte[] b, String message) { + if (!java.util.Arrays.equals(a, b)) { + throw new AssertionError(message + + " (len " + a.length + " vs " + b.length + ")"); + } + } +} diff --git a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/test/java/org/apache/tika/ml/chardetect/Utf16ColumnFeatureExtractorTest.java b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/test/java/org/apache/tika/ml/chardetect/Utf16ColumnFeatureExtractorTest.java new file mode 100644 index 0000000000..e2b88d12ea --- /dev/null +++ b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/test/java/org/apache/tika/ml/chardetect/Utf16ColumnFeatureExtractorTest.java @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.chardetect; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +import org.junit.jupiter.api.Test; + +/** + * Tests for {@link Utf16ColumnFeatureExtractor}. These verify that the + * raw column-count features correctly capture the alignment asymmetry + * that distinguishes UTF-16 from non-UTF-16 content — including the + * HTML-immunity property. + * + * <p>Feature indexing (must match the extractor):</p> + * <pre> + * 0 = count_even(0x00) 1 = count_odd(0x00) + * 2 = count_even(0x01-1F) 3 = count_odd(0x01-1F) (controls excl. 0x09/0x0A/0x0D) + * 4 = count_even(0x20-7E+) 5 = count_odd(0x20-7E+) (printable + tab/lf/cr) + * 6 = count_even(0x7F) 7 = count_odd(0x7F) + * 8 = count_even(0x80-9F) 9 = count_odd(0x80-9F) + * 10 = count_even(0xA0-FF) 11 = count_odd(0xA0-FF) + * </pre> + */ +public class Utf16ColumnFeatureExtractorTest { + + private static final int NUL_EVEN = 0; + private static final int NUL_ODD = 1; + private static final int CTRL_EVEN = 2; + private static final int CTRL_ODD = 3; + private static final int ASCII_EVEN = 4; + private static final int ASCII_ODD = 5; + private static final int DEL_EVEN = 6; + private static final int DEL_ODD = 7; + private static final int C1_EVEN = 8; + private static final int C1_ODD = 9; + private static final int HI_EVEN = 10; + private static final int HI_ODD = 11; + + private final Utf16ColumnFeatureExtractor extractor = new Utf16ColumnFeatureExtractor(); + + // --- basic sanity --- + + @Test + public void emptyInputReturnsAllZeros() { + int[] features = extractor.extract(new byte[0]); + assertEquals(12, features.length); + for (int i = 0; i < 12; i++) { + assertEquals(0, features[i], "feature " + i + " should be 0"); + } + } + + @Test + public void nullInputReturnsAllZeros() { + int[] features = extractor.extract(null); + assertEquals(12, features.length); + for (int i = 0; i < 12; i++) { + assertEquals(0, features[i]); + } + } + + @Test + public void numBucketsIs12() { + assertEquals(12, extractor.getNumBuckets()); + } + + @Test + public void featuresSumToProbeLength() { + byte[] probe = "some mixed content\r\n\0\0\0".getBytes(StandardCharsets.ISO_8859_1); + int[] features = extractor.extract(probe); + int sum = 0; + for (int c : features) { + sum += c; + } + assertEquals(probe.length, sum, "features must cover every byte exactly once"); + } + + // --- UTF-16 Latin cases --- + + @Test + public void utf16LeLatinPutsNullsInOddColumn() { + // "Hello World" in UTF-16LE = 48 00 65 00 6C 00 6C 00 6F 00 20 00 57 00 6F 00 72 00 6C 00 64 00 + byte[] probe = "Hello World".getBytes(Charset.forName("UTF-16LE")); + int[] f = extractor.extract(probe); + + // 11 characters, each 2 bytes: + // even positions → ASCII letters (0x20-7E range) + // odd positions → 0x00 (null range) + assertEquals(0, f[NUL_EVEN], "no nulls in even column"); + assertEquals(11, f[NUL_ODD], "every odd position is null"); + assertEquals(11, f[ASCII_EVEN], "every even position is ASCII letter/space"); + assertEquals(0, f[ASCII_ODD], "no ASCII in odd column"); + // strong asymmetry: nulls in odd, ASCII in even → UTF-16LE Latin signal + } + + @Test + public void utf16BeLatinPutsNullsInEvenColumn() { + byte[] probe = "Hello World".getBytes(Charset.forName("UTF-16BE")); + int[] f = extractor.extract(probe); + + assertEquals(11, f[NUL_EVEN], "every even position is null"); + assertEquals(0, f[NUL_ODD], "no nulls in odd column"); + assertEquals(0, f[ASCII_EVEN]); + assertEquals(11, f[ASCII_ODD]); + } + + // --- UTF-16 non-Latin BMP cases (high byte in 0x03-0x0E, the "controls" range) --- + + @Test + public void utf16LeCyrillicPutsHighByteInOddColumn() { + // Russian "Привет" in UTF-16LE. Codepoints U+041F U+0440 U+0438 U+0432 U+0435 U+0442. + // Bytes: 1F 04 40 04 38 04 32 04 35 04 42 04 + // even positions = 0x1F, 0x40, 0x38, 0x32, 0x35, 0x42 — all in 0x20-7E (except 0x1F which is control) + // odd positions = 0x04 × 6 — in the 0x01-0x1F control range + byte[] probe = "Привет".getBytes(Charset.forName("UTF-16LE")); + int[] f = extractor.extract(probe); + + // Odd column: all six 0x04 bytes → control range + assertEquals(6, f[CTRL_ODD], "every odd position is 0x04 (control range)"); + // Even column: П=0x1F (ctrl), р=0x40, и=0x38, в=0x32, е=0x35, т=0x42 → 1 ctrl + 5 printable + assertEquals(1, f[CTRL_EVEN], "0x1F from П lands in control range on even side"); + assertEquals(5, f[ASCII_EVEN], "the other 5 even bytes are in 0x20-7E range"); + assertEquals(0, f[ASCII_ODD]); + // No nulls, no high bytes + assertEquals(0, f[NUL_EVEN] + f[NUL_ODD]); + assertEquals(0, f[HI_EVEN] + f[HI_ODD]); + } + + // --- UTF-16 CJK (the hard case) --- + + @Test + public void utf16LeCjkPutsHighByteInOddColumn() { + // "精密過濾旋流器" in UTF-16LE. Codepoints in U+4E00-U+9FFF range. + // 精 U+7CBE → BE 7C + // 密 U+5BC6 → C6 5B + // 過 U+904E → 4E 90 + // 濾 U+6FFE → FE 6F + // 旋 U+65CB → CB 65 + // 流 U+6D41 → 41 6D + // 器 U+5668 → 68 56 + // Even column (low bytes of codepoints): BE, C6, 4E, FE, CB, 41, 68 + // Odd column (high bytes of codepoints): 7C, 5B, 90, 6F, 65, 6D, 56 + byte[] probe = "精密過濾旋流器".getBytes(Charset.forName("UTF-16LE")); + int[] f = extractor.extract(probe); + + // Odd column: all bytes in 0x4E-0x90 range. + // 0x7C, 0x5B, 0x6F, 0x65, 0x6D, 0x56 → range 2 (ASCII 0x20-7E) + // 0x90 → range 4 (C1 range 0x80-9F) + assertEquals(6, f[ASCII_ODD], "most odd bytes fall in ASCII-printable range for CJK low half"); + assertEquals(1, f[C1_ODD], "0x90 from 過 lands in C1 range"); + + // Even column: BE, C6, 4E, FE, CB, 41, 68 + // 0x41, 0x68, 0x4E → range 2 (ASCII 0x20-7E) + // 0xBE, 0xC6, 0xFE, 0xCB → range 5 (0xA0-FF) + assertEquals(3, f[ASCII_EVEN]); + assertEquals(4, f[HI_EVEN]); + + // No nulls anywhere for CJK + assertEquals(0, f[NUL_EVEN] + f[NUL_ODD]); + } + + @Test + public void utf16BeCjkPutsHighByteInEvenColumn() { + // Same CJK text in UTF-16BE — roles of columns swap. + byte[] probe = "精密過濾旋流器".getBytes(Charset.forName("UTF-16BE")); + int[] f = extractor.extract(probe); + + // Even column now has codepoint high bytes (7C, 5B, 90, 6F, 65, 6D, 56). + assertEquals(6, f[ASCII_EVEN], "BE even column has codepoint high bytes in ASCII range"); + assertEquals(1, f[C1_EVEN], "0x90 from 過 lands in C1 range on even side for BE"); + + // Odd column has codepoint low bytes (BE, C6, 4E, FE, CB, 41, 68). + assertEquals(3, f[ASCII_ODD]); + assertEquals(4, f[HI_ODD]); + } + + @Test + public void utf16LeUpperCjkHitsC1Range() { + // Codepoints U+8000-U+9FFF have high byte in 0x80-0x9F (the C1 range). + // Under UTF-16LE, this high byte lands in the ODD column. + // 試 U+8A66 → 66 8A (LE) + // 験 U+9A13 → 13 9A (LE) — wait, 0x13 is control + // 誠 U+8AA0 → A0 8A (LE) + byte[] probe = "試験誠".getBytes(Charset.forName("UTF-16LE")); + int[] f = extractor.extract(probe); + + // Odd column (codepoint high bytes): 8A, 9A, 8A → all in 0x80-9F (C1 range). + assertEquals(3, f[C1_ODD], "all three odd-column bytes in C1 range"); + assertEquals(0, f[C1_EVEN]); + } + + // --- HTML — must produce minimal asymmetry --- + + @Test + public void htmlProducesSymmetricColumns() { + String html = "<html><head><title>Hello</title></head>" + + "<body><p class=\"a\">Content here</p></body></html>"; + byte[] probe = html.getBytes(StandardCharsets.US_ASCII); + int[] f = extractor.extract(probe); + + // All bytes are ASCII (0x20-0x7E range). Expect rough even/odd balance. + int totalAscii = f[ASCII_EVEN] + f[ASCII_ODD]; + assertEquals(probe.length, totalAscii, "all bytes should be ASCII"); + int diff = Math.abs(f[ASCII_EVEN] - f[ASCII_ODD]); + assertTrue(diff <= 2, "HTML columns should be near-symmetric, diff=" + diff); + + // No UTF-16-signature ranges: no nulls, no C1, no high bytes. + assertEquals(0, f[NUL_EVEN] + f[NUL_ODD], "HTML has no nulls"); + assertEquals(0, f[C1_EVEN] + f[C1_ODD], "HTML never emits C1 bytes"); + assertEquals(0, f[HI_EVEN] + f[HI_ODD], "ASCII HTML has no high bytes"); + } + + @Test + public void largeHtmlStillSymmetric() { + // Simulate a larger HTML probe — symmetry should hold across columns. + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 200; i++) { + sb.append("<div class=\"node-").append(i).append("\">text ") + .append(i).append("</div>\n"); + } + byte[] probe = sb.toString().getBytes(StandardCharsets.US_ASCII); + int[] f = extractor.extract(probe); + + int asymmetry = Math.abs(f[ASCII_EVEN] - f[ASCII_ODD]); + double asymmetryRatio = (double) asymmetry / probe.length; + assertTrue(asymmetryRatio < 0.02, + "HTML column asymmetry ratio should be very small, got " + asymmetryRatio); + assertEquals(0, f[NUL_EVEN] + f[NUL_ODD]); + assertEquals(0, f[C1_EVEN] + f[C1_ODD]); + } + + // --- pure ASCII text (symmetric, like HTML) --- + + @Test + public void pureAsciiEnglishProducesSymmetricColumns() { + byte[] probe = ("The quick brown fox jumps over the lazy dog. " + + "Pack my box with five dozen liquor jugs.") + .getBytes(StandardCharsets.US_ASCII); + int[] f = extractor.extract(probe); + + int diff = Math.abs(f[ASCII_EVEN] - f[ASCII_ODD]); + assertTrue(diff <= 2, "pure ASCII should be near-symmetric, diff=" + diff); + assertEquals(0, f[NUL_EVEN] + f[NUL_ODD]); + assertEquals(0, f[C1_EVEN] + f[C1_ODD]); + } + + // --- adversarial: pure 2-byte Shift_JIS --- + + @Test + public void pure2ByteShiftJisProducesWeakerAsymmetryThanUtf16Cjk() { + // Japanese "テスト" in Shift_JIS (all 2-byte chars, no ASCII interruptions). + // テ 0x83 0x65 + // ス 0x83 0x58 + // ト 0x83 0x67 + // Even column: 0x83, 0x83, 0x83 (all in C1 range 0x80-9F) + // Odd column: 0x65, 0x58, 0x67 (all in ASCII printable range) + byte[] probe = "テスト".getBytes(Charset.forName("Shift_JIS")); + int[] f = extractor.extract(probe); + + // This looks LIKE UTF-16BE CJK (even column has high bytes, odd column has printable). + // Combiner should still pick Shift_JIS because the CJK specialist's logit is higher. + assertEquals(3, f[C1_EVEN], "Shift_JIS leads in C1 range for this probe"); + assertEquals(3, f[ASCII_ODD], "Shift_JIS trails in ASCII range"); + // We don't assert the UTF-16 logit — this is just the raw feature vector. + // The interesting question is what the trained model does with it, which is a + // training-and-evaluation concern, not a feature-extraction concern. + } + + @Test + public void mixedShiftJisWithAsciiBreaksAlignment() { + // Realistic Shift_JIS with ASCII interruptions. Alignment shifts per ASCII byte. + byte[] probe = ("test " + "テスト" + " text").getBytes(Charset.forName("Shift_JIS")); + int[] f = extractor.extract(probe); + + // Hard to predict exact counts, but asymmetry in C1 range should be much + // weaker than the pure-2-byte case because the leading "test " (5 ASCII + // chars) shifts alignment of the Japanese bytes. + int c1Asymmetry = Math.abs(f[C1_EVEN] - f[C1_ODD]); + // Some non-zero asymmetry is likely, but should be small vs pure-2-byte. + assertTrue(c1Asymmetry <= 3, "ASCII interruption should weaken column asymmetry"); + } + + // --- scattered null — the §P1 false-positive case --- + + @Test + public void scatteredNullsProduceSymmetricColumns() { + // Synthesize a probe with low-density scattered nulls: 1% null rate, + // distributed randomly across both columns. + byte[] probe = new byte[1000]; + java.util.Random rng = new java.util.Random(42); // deterministic + int nullsPlaced = 0; + for (int i = 0; i < probe.length; i++) { + if (rng.nextDouble() < 0.01) { + probe[i] = 0x00; + nullsPlaced++; + } else { + // random printable ASCII + probe[i] = (byte) (0x20 + rng.nextInt(95)); + } + } + int[] f = extractor.extract(probe); + + assertEquals(nullsPlaced, f[NUL_EVEN] + f[NUL_ODD], + "all nulls accounted for in NUL range"); + // Nulls should be roughly balanced across columns (noisy but symmetric in expectation). + int nullAsymmetry = Math.abs(f[NUL_EVEN] - f[NUL_ODD]); + assertTrue(nullAsymmetry <= nullsPlaced / 2 + 3, + "scattered nulls should be roughly balanced, asymmetry=" + nullAsymmetry); + } + + // --- controls and whitespace handling --- + + @Test + public void whitespaceCountsAsAsciiTextNotAsControls() { + // 0x09 (tab), 0x0A (LF), 0x0D (CR) should land in the ASCII range, not the control range. + byte[] probe = new byte[]{ + 0x09, 0x0A, 0x0D, ' ', 'a', // 5 bytes, all in ASCII range + 0x01, 0x02, 0x03 // 3 bytes in control range + }; + int[] f = extractor.extract(probe); + + assertEquals(5, f[ASCII_EVEN] + f[ASCII_ODD], + "tab/LF/CR plus ' ' and 'a' = 5 ASCII-range bytes"); + assertEquals(3, f[CTRL_EVEN] + f[CTRL_ODD], + "0x01/0x02/0x03 = 3 control-range bytes"); + } + + @Test + public void delByteLandsInDelRange() { + byte[] probe = new byte[]{0x7E, 0x7F, (byte) 0x80}; + int[] f = extractor.extract(probe); + assertEquals(1, f[ASCII_EVEN] + f[ASCII_ODD], "0x7E is ASCII"); + assertEquals(1, f[DEL_EVEN] + f[DEL_ODD], "0x7F is DEL"); + assertEquals(1, f[C1_EVEN] + f[C1_ODD], "0x80 is C1"); + } + + // --- sparse extraction interface --- + + @Test + public void sparseExtractionMatchesDense() { + byte[] probe = "Hello World".getBytes(Charset.forName("UTF-16LE")); + + int[] dense = extractor.extract(probe); + int[] sparseDense = new int[12]; + int[] touched = new int[12]; + int n = extractor.extractSparseInto(probe, sparseDense, touched); + + // dense[] values should match between paths + for (int i = 0; i < 12; i++) { + assertEquals(dense[i], sparseDense[i], + "feature " + i + " should match between dense and sparse"); + } + // touched[] should list exactly the non-zero indices + int nonZero = 0; + for (int i = 0; i < 12; i++) { + if (dense[i] != 0) { + nonZero++; + } + } + assertEquals(nonZero, n, "touched count should equal number of non-zero features"); + } + + @Test + public void sparseExtractionWithEmptyProbe() { + int[] dense = new int[12]; + int[] touched = new int[12]; + int n = extractor.extractSparseInto(new byte[0], dense, touched); + assertEquals(0, n); + } + + // --- range offset extraction --- + + @Test + public void subRangeExtractionIsCorrect() { + byte[] probe = "XXHelloXX".getBytes(StandardCharsets.US_ASCII); + // Extract from offset 2, length 5 ("Hello") + int[] f = extractor.extract(probe, 2, 5); + // "Hello" = 5 bytes, all ASCII. Column assignment relative to the sub-range start. + assertEquals(5, f[ASCII_EVEN] + f[ASCII_ODD]); + // 5 bytes: positions (relative) 0,1,2,3,4 → even,odd,even,odd,even → 3 even, 2 odd + assertEquals(3, f[ASCII_EVEN]); + assertEquals(2, f[ASCII_ODD]); + } + + // --- feature label sanity --- + + @Test + public void featureLabelsAreReasonable() { + assertEquals("count_even(0x00)", Utf16ColumnFeatureExtractor.featureLabel(0)); + assertEquals("count_odd(0x00)", Utf16ColumnFeatureExtractor.featureLabel(1)); + assertEquals("count_even(0x80-9F)", Utf16ColumnFeatureExtractor.featureLabel(8)); + assertEquals("count_odd(0xA0-FF)", Utf16ColumnFeatureExtractor.featureLabel(11)); + } +} diff --git a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/test/java/org/apache/tika/ml/chardetect/Utf16SpecialistEncodingDetectorTest.java b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/test/java/org/apache/tika/ml/chardetect/Utf16SpecialistEncodingDetectorTest.java new file mode 100644 index 0000000000..98917392f5 --- /dev/null +++ b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/test/java/org/apache/tika/ml/chardetect/Utf16SpecialistEncodingDetectorTest.java @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.chardetect; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.apache.tika.detect.EncodingResult; +import org.apache.tika.io.TikaInputStream; +import org.apache.tika.metadata.Metadata; +import org.apache.tika.ml.LinearModel; +import org.apache.tika.parser.ParseContext; + +/** + * Tests for {@link Utf16SpecialistEncodingDetector}. Uses a synthetic + * {@link LinearModel} with hand-picked weights to exercise the inference + * pipeline without requiring a trained model. + * + * <p>Synthetic model design:</p> + * <ul> + * <li>Class 0 = {@code UTF-16-LE}</li> + * <li>Class 1 = {@code UTF-16-BE}</li> + * <li>Weights encode asymmetry between paired features: a feature firing + * on the "LE-characteristic" column pulls class 0 up; the same feature + * firing on the "BE-characteristic" column pulls class 1 up.</li> + * <li>Specifically: for feature pairs like {@code count_even(0x00)} vs + * {@code count_odd(0x00)}, we give class 0 negative weight on even + * and positive weight on odd (so UTF-16LE Latin with nulls in odd + * column produces a positive class-0 logit), and class 1 gets the + * mirror.</li> + * </ul> + * + * <p>The synthetic model doesn't need to be accurate — it just needs to be + * well-defined so we can predict which side "should win" for each test + * probe and verify the detector behaves correspondingly.</p> + */ +public class Utf16SpecialistEncodingDetectorTest { + + // Feature indices — must match Utf16ColumnFeatureExtractor + private static final int NUL_EVEN = 0, NUL_ODD = 1; + private static final int CTRL_EVEN = 2, CTRL_ODD = 3; + private static final int ASCII_EVEN = 4, ASCII_ODD = 5; + // 6, 7 = DEL + private static final int C1_EVEN = 8, C1_ODD = 9; + private static final int HI_EVEN = 10, HI_ODD = 11; + + /** + * Build a synthetic UTF-16 specialist model with hand-picked weights. + * + * <p>Convention: class 0 = LE, class 1 = BE. Weights are assigned so + * that column asymmetry (high count in odd column for LE, high count + * in even column for BE) produces strong logits.</p> + */ + private static LinearModel syntheticModel() { + int numBuckets = Utf16ColumnFeatureExtractor.NUM_FEATURES; + int numClasses = 2; + String[] labels = {"UTF-16-LE", "UTF-16-BE"}; + + // INT8 weights: class 0 (LE) vs class 1 (BE). + // For each range, the "odd column supports LE, even column supports BE" rule. + byte[][] weights = new byte[numClasses][numBuckets]; + + // For UTF-16LE, high byte lands in ODD column, low byte in EVEN. + // Per-script "high byte" ranges: NUL (Latin), CTRL (Cyrillic/Greek), + // ASCII (CJK U+4E00-7EFF), C1 (upper CJK), HI (extreme CJK). + // + // Weights and scale chosen so that: (a) long Latin probes don't + // saturate the per-feature clip (1.5 * sqrt(nnz)) into a tie — + // requires ASCII_weight * max_count * scale < clip; (b) short CJK + // probes clear the MIN_LOGIT_MARGIN threshold — requires boosting + // the CJK-discriminating C1 weights. + weights[0][NUL_ODD] = +10; + weights[0][NUL_EVEN] = -10; + weights[0][CTRL_ODD] = +10; + weights[0][CTRL_EVEN] = -10; + weights[0][ASCII_ODD] = +3; + weights[0][ASCII_EVEN] = -3; + weights[0][C1_ODD] = +100; + weights[0][C1_EVEN] = -100; + weights[0][HI_EVEN] = +3; + weights[0][HI_ODD] = -3; + + // BE: exact mirror (high byte at EVEN) + weights[1][NUL_EVEN] = +10; + weights[1][NUL_ODD] = -10; + weights[1][CTRL_EVEN] = +10; + weights[1][CTRL_ODD] = -10; + weights[1][ASCII_EVEN] = +3; + weights[1][ASCII_ODD] = -3; + weights[1][C1_EVEN] = +100; + weights[1][C1_ODD] = -100; + weights[1][HI_ODD] = +3; + weights[1][HI_EVEN] = -3; + + float[] scales = {0.002f, 0.002f}; + float[] biases = {0.0f, 0.0f}; + + return new LinearModel(numBuckets, numClasses, labels, scales, biases, weights); + } + + private Utf16SpecialistEncodingDetector detector() { + return new Utf16SpecialistEncodingDetector(syntheticModel(), 512); + } + + private static List<EncodingResult> detect(Utf16SpecialistEncodingDetector d, + byte[] probe) throws IOException { + try (TikaInputStream tis = TikaInputStream.get(probe)) { + return d.detect(tis, new Metadata(), new ParseContext()); + } + } + + // --- model-loading semantics --- + + @Test + public void nullModelRejected() { + assertThrows(IllegalArgumentException.class, + () -> new Utf16SpecialistEncodingDetector(null, 512)); + } + + @Test + public void wrongBucketCountRejected() { + byte[][] weights = new byte[2][5]; // wrong bucket count + float[] scales = {1.0f, 1.0f}; + float[] biases = {0.0f, 0.0f}; + LinearModel bad = new LinearModel(5, 2, + new String[]{"UTF-16-LE", "UTF-16-BE"}, scales, biases, weights); + assertThrows(IllegalArgumentException.class, + () -> new Utf16SpecialistEncodingDetector(bad, 512)); + } + + @Test + public void wrongClassCountRejected() { + byte[][] weights = new byte[3][Utf16ColumnFeatureExtractor.NUM_FEATURES]; + float[] scales = {1.0f, 1.0f, 1.0f}; + float[] biases = {0.0f, 0.0f, 0.0f}; + LinearModel bad = new LinearModel(Utf16ColumnFeatureExtractor.NUM_FEATURES, 3, + new String[]{"A", "B", "C"}, scales, biases, weights); + assertThrows(IllegalArgumentException.class, + () -> new Utf16SpecialistEncodingDetector(bad, 512)); + } + + @Test + public void bundledClasspathResourceLoads() throws IOException { + // The trained model ships as a classpath resource in the mojibuster + // module. No-arg constructor must load it successfully, and the + // loaded model must have the expected shape for the UTF-16 extractor. + Utf16SpecialistEncodingDetector d = new Utf16SpecialistEncodingDetector(); + // A clean UTF-16LE probe should produce a confident LE result. + byte[] probe = "Hello World. This is a UTF-16LE sanity check." + .getBytes(Charset.forName("UTF-16LE")); + SpecialistOutput out = d.score(probe); + assertEquals(2, out.getClassLogits().size()); + assertTrue(out.getClassLogits().containsKey("UTF-16-LE")); + assertTrue(out.getClassLogits().containsKey("UTF-16-BE")); + assertTrue(out.getLogit("UTF-16-LE") > out.getLogit("UTF-16-BE"), + "bundled model should rank LE > BE on LE bytes; got " + + out.getClassLogits()); + } + + // --- detection outputs --- + + @Test + public void emptyProbeReturnsEmpty() throws IOException { + List<EncodingResult> results = detect(detector(), new byte[0]); + assertEquals(0, results.size()); + } + + @Test + public void singleByteProbeReturnsEmpty() throws IOException { + // Can't tell alignment from fewer than 2 bytes. + List<EncodingResult> results = detect(detector(), new byte[]{0x41}); + assertEquals(0, results.size()); + } + + @Test + public void utf16LeLatinDetectedAsLE() throws IOException { + byte[] probe = "Hello World. This is a UTF-16LE Latin probe." + .getBytes(Charset.forName("UTF-16LE")); + List<EncodingResult> results = detect(detector(), probe); + + assertEquals(1, results.size(), "should return exactly one candidate"); + EncodingResult r = results.get(0); + assertEquals("UTF-16-LE", r.getLabel()); + assertEquals(Charset.forName("UTF-16LE"), r.getCharset()); + assertEquals(EncodingResult.ResultType.STATISTICAL, r.getResultType()); + assertTrue(r.getConfidence() > 0.5f, + "confidence should be substantial, got " + r.getConfidence()); + } + + @Test + public void utf16BeLatinDetectedAsBE() throws IOException { + byte[] probe = "Hello World. This is a UTF-16BE Latin probe." + .getBytes(Charset.forName("UTF-16BE")); + List<EncodingResult> results = detect(detector(), probe); + + assertEquals(1, results.size()); + EncodingResult r = results.get(0); + assertEquals("UTF-16-BE", r.getLabel()); + assertEquals(Charset.forName("UTF-16BE"), r.getCharset()); + } + + @Test + public void utf16LeCjkDetectedAsLE() throws IOException { + byte[] probe = "精密過濾旋流器は日本の製品です。東京で製造されています。" + .getBytes(Charset.forName("UTF-16LE")); + List<EncodingResult> results = detect(detector(), probe); + + assertEquals(1, results.size()); + assertEquals("UTF-16-LE", results.get(0).getLabel()); + } + + @Test + public void utf16BeCjkDetectedAsBE() throws IOException { + byte[] probe = "精密過濾旋流器は日本の製品です。東京で製造されています。" + .getBytes(Charset.forName("UTF-16BE")); + List<EncodingResult> results = detect(detector(), probe); + + assertEquals(1, results.size()); + assertEquals("UTF-16-BE", results.get(0).getLabel()); + } + + @Test + public void htmlProducesNoResult() throws IOException { + // HTML: near-symmetric columns → neither LE nor BE exceeds the + // logit-margin threshold → detector returns empty. + StringBuilder html = new StringBuilder(); + for (int i = 0; i < 30; i++) { + html.append("<div class=\"item-").append(i).append("\">content ") + .append(i).append("</div>\n"); + } + byte[] probe = html.toString().getBytes(StandardCharsets.US_ASCII); + List<EncodingResult> results = detect(detector(), probe); + + assertEquals(0, results.size(), + "HTML should produce empty result (column-symmetric) — " + + "this is the HTML-immunity property"); + } + + @Test + public void pureAsciiEnglishProducesNoResult() throws IOException { + byte[] probe = ("The quick brown fox jumps over the lazy dog. " + + "Pack my box with five dozen liquor jugs.") + .getBytes(StandardCharsets.US_ASCII); + List<EncodingResult> results = detect(detector(), probe); + + assertEquals(0, results.size(), + "pure ASCII should produce empty result"); + } + + @Test + public void scatteredNullsProduceNoResult() throws IOException { + // Regression case P1: random bytes with ~1% null density that + // previously tricked the old structural UTF-16 detector. + byte[] probe = new byte[1000]; + java.util.Random rng = new java.util.Random(42); + for (int i = 0; i < probe.length; i++) { + if (rng.nextDouble() < 0.01) { + probe[i] = 0x00; + } else { + probe[i] = (byte) (0x20 + rng.nextInt(95)); + } + } + List<EncodingResult> results = detect(detector(), probe); + + assertEquals(0, results.size(), + "scattered nulls with no 2-byte alignment should not trigger"); + } + + @Test + public void probeLongerThanBudgetIsTrimmed() throws IOException { + // Build a probe much longer than the default 512-byte budget but with + // clear UTF-16LE structure. Detector should still handle it correctly + // (reading only the prefix) and produce a confident result. + String text = "This is a sufficiently long UTF-16LE Latin test probe " + + "with plenty of content to exercise the probe-size bound. "; + StringBuilder sb = new StringBuilder(); + while (sb.length() < 2000) { + sb.append(text); + } + byte[] probe = sb.toString().getBytes(Charset.forName("UTF-16LE")); + List<EncodingResult> results = detect(detector(), probe); + + assertEquals(1, results.size()); + assertEquals("UTF-16-LE", results.get(0).getLabel()); + } + + // --- logit-level (combiner) entry points --- + + @Test + public void scoreEmitsBothClassLogitsWithoutThreshold() throws IOException { + // detect() returns [] for short probes where margin < threshold. + // score() returns raw logits regardless — the combiner decides. + byte[] probe = "Hi".getBytes(Charset.forName("UTF-16LE")); + Utf16SpecialistEncodingDetector d = detector(); + try (TikaInputStream tis = TikaInputStream.get(probe)) { + SpecialistOutput out = d.score(tis); + assertEquals("utf16", out.getSpecialistName()); + assertEquals(2, out.getClassLogits().size()); + assertTrue(out.getClassLogits().containsKey("UTF-16-LE")); + assertTrue(out.getClassLogits().containsKey("UTF-16-BE")); + } + } + + @Test + public void scoreReturnsNullForTooShortProbe() throws IOException { + Utf16SpecialistEncodingDetector d = detector(); + try (TikaInputStream tis = TikaInputStream.get(new byte[]{0x41})) { + assertEquals(null, d.score(tis)); + } + } + + @Test + public void scoreBytesGivesLeHigherLogitForLePattern() { + byte[] probe = "Hello World. This is UTF-16LE." + .getBytes(Charset.forName("UTF-16LE")); + SpecialistOutput out = detector().scoreBytes(probe); + float le = out.getLogit("UTF-16-LE"); + float be = out.getLogit("UTF-16-BE"); + assertTrue(le > be, "LE should score higher than BE, got LE=" + le + " BE=" + be); + } + + @Test + public void streamPositionIsPreserved() throws IOException { + // The detector marks/resets the stream — a subsequent read should see + // the same bytes as if we hadn't called detect at all. + byte[] probe = "Hello World.".getBytes(Charset.forName("UTF-16LE")); + try (TikaInputStream tis = TikaInputStream.get(probe)) { + byte firstByte = (byte) tis.read(); + // push back... + } + // Separate test: read 2 bytes, detect, read rest, verify all bytes match. + try (TikaInputStream tis = TikaInputStream.get(probe)) { + detector().detect(tis, new Metadata(), new ParseContext()); + byte[] reRead = new byte[probe.length]; + int n = 0; + int b; + while ((b = tis.read()) != -1 && n < reRead.length) { + reRead[n++] = (byte) b; + } + assertEquals(probe.length, n); + for (int i = 0; i < probe.length; i++) { + assertEquals(probe[i], reRead[i], + "byte " + i + " should match after detect/reset cycle"); + } + } + } +} diff --git a/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/test/java/org/apache/tika/ml/chardetect/Utf16SpecialistEncodingDetectorTestFixtures.java b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/test/java/org/apache/tika/ml/chardetect/Utf16SpecialistEncodingDetectorTestFixtures.java new file mode 100644 index 0000000000..28eed6a836 --- /dev/null +++ b/tika-encoding-detectors/tika-encoding-detector-mojibuster/src/test/java/org/apache/tika/ml/chardetect/Utf16SpecialistEncodingDetectorTestFixtures.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.chardetect; + +import org.apache.tika.ml.LinearModel; + +/** + * Shared synthetic UTF-16 specialist model for tests — same weights as in + * {@link Utf16SpecialistEncodingDetectorTest}. Factored out so combiner + * integration tests can reuse it without duplicating weight tuning. + */ +final class Utf16SpecialistEncodingDetectorTestFixtures { + + private static final int NUL_EVEN = 0, NUL_ODD = 1; + private static final int CTRL_EVEN = 2, CTRL_ODD = 3; + private static final int ASCII_EVEN = 4, ASCII_ODD = 5; + private static final int C1_EVEN = 8, C1_ODD = 9; + private static final int HI_EVEN = 10, HI_ODD = 11; + + private Utf16SpecialistEncodingDetectorTestFixtures() { + } + + static LinearModel syntheticModel() { + int numBuckets = Utf16ColumnFeatureExtractor.NUM_FEATURES; + int numClasses = 2; + String[] labels = {"UTF-16-LE", "UTF-16-BE"}; + byte[][] weights = new byte[numClasses][numBuckets]; + + weights[0][NUL_ODD] = +10; + weights[0][NUL_EVEN] = -10; + weights[0][CTRL_ODD] = +10; + weights[0][CTRL_EVEN] = -10; + weights[0][ASCII_ODD] = +3; + weights[0][ASCII_EVEN] = -3; + weights[0][C1_ODD] = +100; + weights[0][C1_EVEN] = -100; + weights[0][HI_EVEN] = +3; + weights[0][HI_ODD] = -3; + + weights[1][NUL_EVEN] = +10; + weights[1][NUL_ODD] = -10; + weights[1][CTRL_EVEN] = +10; + weights[1][CTRL_ODD] = -10; + weights[1][ASCII_EVEN] = +3; + weights[1][ASCII_ODD] = -3; + weights[1][C1_EVEN] = +100; + weights[1][C1_ODD] = -100; + weights[1][HI_ODD] = +3; + weights[1][HI_EVEN] = -3; + + float[] scales = {0.002f, 0.002f}; + float[] biases = {0.0f, 0.0f}; + return new LinearModel(numBuckets, numClasses, labels, scales, biases, weights); + } +} diff --git a/tika-ml/tika-ml-core/src/main/java/org/apache/tika/ml/FeatureExtractor.java b/tika-ml/tika-ml-core/src/main/java/org/apache/tika/ml/FeatureExtractor.java index 33aff831b5..bf59b64d1c 100644 --- a/tika-ml/tika-ml-core/src/main/java/org/apache/tika/ml/FeatureExtractor.java +++ b/tika-ml/tika-ml-core/src/main/java/org/apache/tika/ml/FeatureExtractor.java @@ -37,4 +37,27 @@ public interface FeatureExtractor<T> { * @return number of hash buckets (feature-vector dimension) */ int getNumBuckets(); + + /** + * Sparse extraction into caller-owned reusable buffers: populates + * {@code dense} with feature counts, writes the indices of non-zero + * entries into {@code touched}, and returns how many indices were + * written. Callers are responsible for clearing the touched entries + * of {@code dense} before reuse. + * + * <p>Default implementation delegates to {@link #extract}. Extractors + * that can do better (avoid allocating the full dense vector, or scan + * the input only once) should override.</p> + */ + default int extractSparseInto(T input, int[] dense, int[] touched) { + int[] features = extract(input); + int n = 0; + for (int i = 0; i < features.length; i++) { + if (features[i] != 0) { + dense[i] = features[i]; + touched[n++] = i; + } + } + return n; + } } diff --git a/tika-ml/tika-ml-core/src/main/java/org/apache/tika/ml/LinearModel.java b/tika-ml/tika-ml-core/src/main/java/org/apache/tika/ml/LinearModel.java index 5fb8484c8f..1434f20b67 100644 --- a/tika-ml/tika-ml-core/src/main/java/org/apache/tika/ml/LinearModel.java +++ b/tika-ml/tika-ml-core/src/main/java/org/apache/tika/ml/LinearModel.java @@ -33,12 +33,15 @@ import java.util.zip.GZIPInputStream; * <pre> * Offset Field * 0 4B magic: 0x4C444D31 - * 4 4B version: 1 + * 4 4B version: 1 or 2 * 8 4B numBuckets (B) * 12 4B numClasses (C) * 16+ Labels: C entries of [2B length + UTF-8 bytes] * Scales: C × 4B float (per-class dequantization) * Biases: C × 4B float (per-class bias term) + * (V2 only) + * 1B hasCalibration flag + * If hasCalibration: ClassMean: C × 4B float, ClassStd: C × 4B float * Weights: B × C bytes (bucket-major, INT8 signed) * </pre> * <p> @@ -48,17 +51,36 @@ import java.util.zip.GZIPInputStream; * — each non-zero bucket reads a contiguous run of * {@code numClasses} bytes, ideal for SIMD and cache * prefetching. + * <p> + * Calibration (V2): optional per-class mean/std of training-set logits. + * When present, {@link #predictCalibratedLogits} standardizes raw logits + * so cross-specialist pooling can compare "unusually confident" signals on + * equal footing. V1 files are still readable; calibration is absent and + * {@link #predictCalibratedLogits} falls back to raw logits. */ public class LinearModel { public static final int MAGIC = 0x4C444D31; // "LDM1" - public static final int VERSION = 1; + public static final int VERSION_V1 = 1; + public static final int VERSION_V2 = 2; + /** + * Latest version we emit. + */ + public static final int VERSION = VERSION_V2; private final int numBuckets; private final int numClasses; private final String[] labels; private final float[] scales; private final float[] biases; + /** + * Optional per-class logit mean for calibration; {@code null} if absent. + */ + private final float[] classMean; + /** + * Optional per-class logit std (never zero when present). + */ + private final float[] classStd; /** * Flat INT8 weight array in bucket-major order: @@ -67,29 +89,76 @@ public class LinearModel { private final byte[] flatWeights; /** - * Construct from class-major {@code byte[][]} weights. - * Transposes to bucket-major flat layout internally. + * Construct without calibration (V1-compatible). + * Transposes class-major weights to bucket-major flat layout internally. */ public LinearModel(int numBuckets, int numClasses, String[] labels, float[] scales, float[] biases, byte[][] weights) { + this(numBuckets, numClasses, labels, scales, biases, weights, null, null); + } + + /** + * Construct with optional calibration. Pass {@code classMean} and + * {@code classStd} (each of length {@code numClasses}) to enable + * z-score calibration in {@link #predictCalibratedLogits}; pass + * {@code null} for both to skip. Any {@code classStd[c] == 0} is + * rewritten to {@code 1.0f} to avoid divide-by-zero. + */ + public LinearModel(int numBuckets, int numClasses, + String[] labels, float[] scales, + float[] biases, byte[][] weights, + float[] classMean, float[] classStd) { this.numBuckets = numBuckets; this.numClasses = numClasses; this.labels = labels; this.scales = scales; this.biases = biases; + this.classMean = classMean; + this.classStd = sanitizeStd(classStd); this.flatWeights = transposeToBucketMajor(weights, numBuckets, numClasses); + validateCalibration(); } private LinearModel(int numBuckets, int numClasses, String[] labels, float[] scales, - float[] biases, byte[] flatWeights) { + float[] biases, byte[] flatWeights, + float[] classMean, float[] classStd) { this.numBuckets = numBuckets; this.numClasses = numClasses; this.labels = labels; this.scales = scales; this.biases = biases; + this.classMean = classMean; + this.classStd = sanitizeStd(classStd); this.flatWeights = flatWeights; + validateCalibration(); + } + + private static float[] sanitizeStd(float[] std) { + if (std == null) { + return null; + } + float[] out = new float[std.length]; + for (int i = 0; i < std.length; i++) { + out[i] = std[i] > 0f ? std[i] : 1.0f; + } + return out; + } + + private void validateCalibration() { + if ((classMean == null) != (classStd == null)) { + throw new IllegalArgumentException( + "classMean and classStd must both be provided or both null"); + } + if (classMean != null && classMean.length != numClasses) { + throw new IllegalArgumentException( + "classMean length " + classMean.length + " != numClasses " + numClasses); + } + if (classStd != null && classStd.length != numClasses) { + throw new IllegalArgumentException( + "classStd length " + classStd.length + " != numClasses " + numClasses); + } } private static byte[] transposeToBucketMajor( @@ -154,7 +223,9 @@ public class LinearModel { return loadRaw(is); } - /** Read LDM1 from an already-unwrapped (non-gzip) stream. */ + /** + * Read LDM from an already-unwrapped (non-gzip) stream. + */ private static LinearModel loadRaw(InputStream is) throws IOException { DataInputStream dis = new DataInputStream(is); int magic = dis.readInt(); @@ -163,9 +234,10 @@ public class LinearModel { "Invalid magic: expected 0x%08X, got 0x%08X", MAGIC, magic)); } int version = dis.readInt(); - if (version != VERSION) { + if (version != VERSION_V1 && version != VERSION_V2) { throw new IOException( - "Unsupported version: " + version + " (expected " + VERSION + ")"); + "Unsupported version: " + version + + " (expected " + VERSION_V1 + " or " + VERSION_V2 + ")"); } int numBuckets = dis.readInt(); @@ -175,10 +247,21 @@ public class LinearModel { float[] scales = readFloats(dis, numClasses); float[] biases = readFloats(dis, numClasses); + float[] classMean = null; + float[] classStd = null; + if (version >= VERSION_V2) { + boolean hasCalibration = dis.readBoolean(); + if (hasCalibration) { + classMean = readFloats(dis, numClasses); + classStd = readFloats(dis, numClasses); + } + } + byte[] flat = new byte[numBuckets * numClasses]; dis.readFully(flat); - return new LinearModel(numBuckets, numClasses, labels, scales, biases, flat); + return new LinearModel(numBuckets, numClasses, labels, scales, biases, + flat, classMean, classStd); } // ================================================================ @@ -186,17 +269,24 @@ public class LinearModel { // ================================================================ /** - * Write the model in LDM1 binary format. + * Write the model in LDM binary format. Emits V2 (with or without + * calibration block depending on whether this model has calibration). */ public void save(OutputStream os) throws IOException { DataOutputStream dos = new DataOutputStream(os); dos.writeInt(MAGIC); - dos.writeInt(VERSION); + dos.writeInt(VERSION_V2); dos.writeInt(numBuckets); dos.writeInt(numClasses); writeLabels(dos); writeFloats(dos, scales); writeFloats(dos, biases); + boolean hasCal = hasCalibration(); + dos.writeBoolean(hasCal); + if (hasCal) { + writeFloats(dos, classMean); + writeFloats(dos, classStd); + } dos.write(flatWeights); dos.flush(); } @@ -254,6 +344,40 @@ public class LinearModel { return softmax(predictLogits(features)); } + /** + * Compute calibrated logits: {@code (raw - classMean[c]) / classStd[c]} + * for each class, if the model carries calibration statistics, else raw + * logits (no-op). Calibrated logits are comparable across specialists + * with different natural logit scales — they express "how many standard + * deviations above this class's training-set mean" rather than raw weight + * arithmetic. + */ + public float[] predictCalibratedLogits(int[] features) { + float[] raw = predictLogits(features); + if (classMean == null || classStd == null) { + return raw; + } + for (int c = 0; c < numClasses; c++) { + raw[c] = (raw[c] - classMean[c]) / classStd[c]; + } + return raw; + } + + /** + * {@code true} if this model carries per-class calibration statistics. + */ + public boolean hasCalibration() { + return classMean != null && classStd != null; + } + + public float[] getClassMean() { + return classMean; + } + + public float[] getClassStd() { + return classStd; + } + /** * In-place softmax with numerical stability. */ diff --git a/tika-ml/tika-ml-core/src/test/java/org/apache/tika/ml/LinearModelCalibrationTest.java b/tika-ml/tika-ml-core/src/test/java/org/apache/tika/ml/LinearModelCalibrationTest.java new file mode 100644 index 0000000000..2401dce2c5 --- /dev/null +++ b/tika-ml/tika-ml-core/src/test/java/org/apache/tika/ml/LinearModelCalibrationTest.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +import org.junit.jupiter.api.Test; + +public class LinearModelCalibrationTest { + + private static LinearModel modelWithCalibration(float[] mean, float[] std) { + byte[][] weights = new byte[2][4]; + weights[0][0] = 10; + weights[1][1] = 10; + return new LinearModel(4, 2, + new String[]{"A", "B"}, + new float[]{1.0f, 1.0f}, new float[]{0.0f, 0.0f}, + weights, mean, std); + } + + @Test + public void hasCalibrationReflectsConstructor() { + LinearModel cal = modelWithCalibration( + new float[]{0.5f, -0.5f}, new float[]{1.0f, 1.0f}); + assertTrue(cal.hasCalibration()); + + LinearModel raw = new LinearModel(4, 2, + new String[]{"A", "B"}, + new float[]{1.0f, 1.0f}, new float[]{0.0f, 0.0f}, + new byte[2][4]); + assertFalse(raw.hasCalibration()); + } + + @Test + public void predictCalibratedLogitsFallsBackToRawWithoutCalibration() { + LinearModel raw = new LinearModel(4, 2, + new String[]{"A", "B"}, + new float[]{1.0f, 1.0f}, new float[]{0.0f, 0.0f}, + new byte[2][4]); + int[] features = {1, 0, 0, 0}; + float[] rawLogits = raw.predictLogits(features); + float[] calibrated = raw.predictCalibratedLogits(features); + assertArrayEquals(rawLogits, calibrated, 1e-6f); + } + + @Test + public void predictCalibratedLogitsStandardizes() { + // mean=2, std=0.5 for class A → calibrated = (raw - 2) / 0.5 + LinearModel cal = modelWithCalibration( + new float[]{2.0f, 0.0f}, new float[]{0.5f, 2.0f}); + int[] features = {5, 0, 0, 0}; // class 0 weight=10, scale=1 → logit=10*5/... clipped + float[] raw = cal.predictLogits(features); + float[] calibrated = cal.predictCalibratedLogits(features); + assertEquals((raw[0] - 2.0f) / 0.5f, calibrated[0], 1e-5f); + assertEquals((raw[1] - 0.0f) / 2.0f, calibrated[1], 1e-5f); + } + + @Test + public void zeroStdIsSanitizedToOne() { + // std=0 would divide-by-zero; constructor must rewrite to 1.0. + LinearModel cal = modelWithCalibration( + new float[]{1.0f, 1.0f}, new float[]{0.0f, 0.0f}); + assertEquals(1.0f, cal.getClassStd()[0], 0.0f); + assertEquals(1.0f, cal.getClassStd()[1], 0.0f); + } + + @Test + public void saveLoadRoundTripPreservesCalibration() throws IOException { + LinearModel src = modelWithCalibration( + new float[]{1.5f, -0.25f}, new float[]{0.7f, 2.3f}); + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + src.save(bos); + LinearModel loaded = LinearModel.load(new ByteArrayInputStream(bos.toByteArray())); + + assertTrue(loaded.hasCalibration()); + assertArrayEquals(src.getClassMean(), loaded.getClassMean(), 1e-6f); + assertArrayEquals(src.getClassStd(), loaded.getClassStd(), 1e-6f); + } + + @Test + public void saveLoadRoundTripWithoutCalibration() throws IOException { + LinearModel src = new LinearModel(4, 2, + new String[]{"A", "B"}, + new float[]{1.0f, 1.0f}, new float[]{0.0f, 0.0f}, + new byte[2][4]); + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + src.save(bos); + LinearModel loaded = LinearModel.load(new ByteArrayInputStream(bos.toByteArray())); + + assertFalse(loaded.hasCalibration()); + } + + @Test + public void v1FormatStillLoadable() throws IOException { + // Hand-build a V1 file (no calibration bytes) and verify it loads. + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + java.io.DataOutputStream dos = new java.io.DataOutputStream(bos); + dos.writeInt(LinearModel.MAGIC); + dos.writeInt(LinearModel.VERSION_V1); // version 1, no calibration + dos.writeInt(4); // numBuckets + dos.writeInt(2); // numClasses + for (String lbl : new String[]{"A", "B"}) { + byte[] utf8 = lbl.getBytes(java.nio.charset.StandardCharsets.UTF_8); + dos.writeShort(utf8.length); + dos.write(utf8); + } + for (int c = 0; c < 2; c++) { + dos.writeFloat(1.0f); // scales + } + for (int c = 0; c < 2; c++) { + dos.writeFloat(0.0f); // biases + } + // No hasCalibration byte in V1. Weights follow directly. + for (int b = 0; b < 4 * 2; b++) { + dos.write(0); + } + dos.flush(); + + LinearModel loaded = LinearModel.load(new ByteArrayInputStream(bos.toByteArray())); + assertFalse(loaded.hasCalibration()); + assertEquals(4, loaded.getNumBuckets()); + assertEquals(2, loaded.getNumClasses()); + } +}
