This is an automated email from the ASF dual-hosted git repository. tballison pushed a commit to branch universal-junk-detector in repository https://gitbox.apache.org/repos/asf/tika.git
commit 136e19d2dd16f54a8ae3b275b344b47aa9355354 Author: tballison <[email protected]> AuthorDate: Thu Apr 23 14:53:50 2026 -0400 v3 --- .../ROOT/pages/advanced/junk-detection-build.adoc | 17 +- .../ROOT/pages/advanced/junk-detection.adoc | 48 +- .../apache/tika/quality/TextQualityComparison.java | 94 +++ .../apache/tika/quality/TextQualityDetector.java | 70 ++ .../org/apache/tika/quality/TextQualityScore.java | 45 +- tika-ml/tika-ml-junkdetect/pom.xml | 5 + .../apache/tika/ml/junkdetect/JunkDetector.java | 448 ++++++++----- .../ml/junkdetect/tools/BuildJunkTrainingData.java | 107 ++- .../tika/ml/junkdetect/tools/EvalJunkDetector.java | 97 ++- .../tika/ml/junkdetect/tools/TrainJunkModel.java | 740 ++++++++++++++++++--- .../org.apache.tika.quality.TextQualityDetector | 1 + .../org/apache/tika/ml/junkdetect/junkdetect.bin | Bin 414029 -> 543946 bytes .../tika/ml/junkdetect/JunkDetectorSmokeTest.java | 119 ++-- 13 files changed, 1429 insertions(+), 362 deletions(-) diff --git a/docs/modules/ROOT/pages/advanced/junk-detection-build.adoc b/docs/modules/ROOT/pages/advanced/junk-detection-build.adoc index 27e6e8754e..046099899f 100644 --- a/docs/modules/ROOT/pages/advanced/junk-detection-build.adoc +++ b/docs/modules/ROOT/pages/advanced/junk-detection-build.adoc @@ -392,28 +392,33 @@ The default 50 MB byte budget is a proof-of-concept setting. For production: == Smoke tests -Five smoke tests in `JunkDetectorSmokeTest` verify the bundled model: +Five smoke tests in `JunkDetectorSmokeTest` verify the bundled model. +All tests use the `TextQualityDetector` interface and return `TextQualityScore` +or `TextQualityComparison` from `tika-core`. [cols="1,3"] |=== | Test | What it checks | `cleanVsGarbage` -| Clean English z-score > random high-byte garbage z-score. +| Clean English `TextQualityScore` z-score > random high-byte garbage z-score. + Garbage is decoded from ISO-8859-1 to produce a scoreable string. | `forwardVsReversedArabic` | Forward Arabic z-score > codepoint-reversed Arabic z-score. + Reversal is done at codepoint (not byte) granularity, preserving valid Unicode. | `cp1252VsCp1257OnBalticText` -| `compare()` picks cp1257 as the correct encoding for Lithuanian text. +| `compare()` returns `TextQualityComparison` picking cp1257 for Lithuanian text. Delta > 0.1 (weak; Baltic limitation documented above). | `cp1252VsCp1251OnRussianText` -| `compare()` picks cp1251 as the correct encoding for Russian text. - Delta > 1.0 (strong; Cyrillic bigrams are highly distinctive). +| `compare()` picks cp1251 for Russian text. Delta > 1.0 (strong; Cyrillic + bigrams are highly distinctive). | `cleanVsShuffledCjk` -| Clean Japanese UTF-8 z-score > byte-shuffled Japanese z-score. +| Clean Japanese z-score > byte-shuffled Japanese z-score. + Shuffled bytes are decoded as ISO-8859-1 to produce a scoreable string. |=== NOTE: Codepoint reversal of LTR scripts (Russian, Latin) is **not** a useful diff --git a/docs/modules/ROOT/pages/advanced/junk-detection.adoc b/docs/modules/ROOT/pages/advanced/junk-detection.adoc index 6c6f06037f..5425b7a764 100644 --- a/docs/modules/ROOT/pages/advanced/junk-detection.adoc +++ b/docs/modules/ROOT/pages/advanced/junk-detection.adoc @@ -56,6 +56,10 @@ than clean"; a score of −10 means "almost certainly garbled." == Using the API +The public interface is `TextQualityDetector` in `tika-core`. +The implementation lives in `tika-ml-junkdetect`, which registers itself via +the Java `ServiceLoader` mechanism. + Add the dependency to your project: [source,xml] @@ -67,25 +71,31 @@ Add the dependency to your project: </dependency> ---- -=== Scoring a string or byte array +=== Loading the detector [source,java] ---- -JunkDetector detector = JunkDetector.loadFromClasspath(); - -// Score a string directly -JunkScore score = detector.score("The quick brown fox jumps over the lazy dog."); -System.out.println(score.getZScore()); // e.g. -0.74 — within normal range -System.out.println(score.getPClean()); // e.g. 0.32 — P(clean) via sigmoid +// Via ServiceLoader — picks up any registered TextQualityDetector implementation +TextQualityDetector detector = ServiceLoader.load(TextQualityDetector.class) + .findFirst() + .orElseThrow(() -> new IllegalStateException("No TextQualityDetector on classpath")); -// Score raw UTF-8 bytes (same result; use when you already have bytes) -byte[] utf8 = text.getBytes(StandardCharsets.UTF_8); -JunkScore score2 = detector.score(utf8); +// Or directly, when you know you want JunkDetector specifically +JunkDetector detector = JunkDetector.loadFromClasspath(); ---- `JunkDetector` is **immutable and thread-safe** after construction. Load it once at application startup. +=== Scoring a string + +[source,java] +---- +TextQualityScore score = detector.score("The quick brown fox jumps over the lazy dog."); +System.out.println(score.getZScore()); // e.g. -0.74 — within normal range +System.out.println(score.getPClean()); // e.g. 0.32 — P(clean) via sigmoid +---- + === Interpreting the score [cols="1,3"] @@ -111,7 +121,7 @@ at application startup. or heavy corruption. |=== -The `JunkScore` also carries: +The `TextQualityScore` also carries: * `getPClean()` — `sigmoid(z)`, a rough probability estimate in [0, 1] that the text is clean. Useful for ranking candidates; the absolute value is not @@ -123,18 +133,24 @@ The `JunkScore` also carries: `"CYRILLIC"`, `"ARABIC"`, `"HAN"`). If `isUnknown()` is true, the dominant script had no model and scoring was not possible. -=== Comparing two charset interpretations +=== Comparing two candidates The `compare()` method is the primary use case for charset detection: given the same raw bytes decoded two different ways, which decoding looks more like natural language? +The caller is responsible for decoding the raw bytes; the detector just compares +the resulting strings. Each candidate is given a human-readable label (typically +the charset name) that is echoed back in the result. + [source,java] ---- byte[] rawBytes = ...; // bytes from an unknown-encoding file -JunkDetector.CompareResult result = - detector.compare(rawBytes, "cp1252", "cp1251"); +String ascp1252 = new String(rawBytes, Charset.forName("cp1252")); +String ascp1251 = new String(rawBytes, Charset.forName("cp1251")); + +TextQualityComparison result = detector.compare("cp1252", ascp1252, "cp1251", ascp1251); System.out.println(result.winner()); // "A" or "B" System.out.println(result.delta()); // z-score separation between the two @@ -144,7 +160,7 @@ if (result.winner().equals("B") && result.delta() > 1.0) { } ---- -The `delta()` is the absolute difference in z-scores between the two decodings. +The `delta()` is the absolute difference in z-scores between the two candidates. As a rough guide: [cols="1,3"] @@ -177,7 +193,7 @@ detector.knownScripts(); // returns Set<String> ---- If the dominant script of an input is not in this set, `score()` returns a -`JunkScore` where `isUnknown()` is true and no z-score is available. +`TextQualityScore` where `isUnknown()` is true and no z-score is available. == Thresholds and operating points diff --git a/tika-core/src/main/java/org/apache/tika/quality/TextQualityComparison.java b/tika-core/src/main/java/org/apache/tika/quality/TextQualityComparison.java new file mode 100644 index 0000000000..ce838cd309 --- /dev/null +++ b/tika-core/src/main/java/org/apache/tika/quality/TextQualityComparison.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.quality; + +/** + * Result of comparing two candidate strings for text quality via + * {@link TextQualityDetector#compare}. + * + * <p>A typical use is charset-decoding arbitration: given raw bytes decoded + * two different ways (e.g. cp1251 vs cp1252), pass each decoded string with a + * label and let the detector pick the cleaner one. + * + * <p>The {@code delta} field is the absolute difference between the two z-scores. + * A delta near zero means the model is uncertain; larger values indicate + * confident discrimination. As a rough guide: delta > 1.0 is useful signal, + * delta > 3.0 is confident. + */ +public final class TextQualityComparison { + + private final String winner; + private final float delta; + private final TextQualityScore scoreA; + private final TextQualityScore scoreB; + private final String labelA; + private final String labelB; + + public TextQualityComparison(String winner, float delta, + TextQualityScore scoreA, TextQualityScore scoreB, + String labelA, String labelB) { + this.winner = winner; + this.delta = delta; + this.scoreA = scoreA; + this.scoreB = scoreB; + this.labelA = labelA; + this.labelB = labelB; + } + + /** + * Returns {@code "A"} if candidate A is cleaner, {@code "B"} otherwise. + * Check {@link #delta()} to gauge confidence. + */ + public String winner() { + return winner; + } + + /** + * Absolute difference in z-scores between the two candidates. + * Small delta = uncertain; large delta = confident. + */ + public float delta() { + return delta; + } + + /** Quality score for candidate A. */ + public TextQualityScore scoreA() { + return scoreA; + } + + /** Quality score for candidate B. */ + public TextQualityScore scoreB() { + return scoreB; + } + + /** Label supplied for candidate A (e.g. a charset name or encoding description). */ + public String labelA() { + return labelA; + } + + /** Label supplied for candidate B. */ + public String labelB() { + return labelB; + } + + @Override + public String toString() { + return String.format("TextQualityComparison[winner=%s(%s) delta=%.3f A=%s B=%s]", + winner, winner.equals("A") ? labelA : labelB, + delta, scoreA, scoreB); + } +} diff --git a/tika-core/src/main/java/org/apache/tika/quality/TextQualityDetector.java b/tika-core/src/main/java/org/apache/tika/quality/TextQualityDetector.java new file mode 100644 index 0000000000..d832b5a169 --- /dev/null +++ b/tika-core/src/main/java/org/apache/tika/quality/TextQualityDetector.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.quality; + +/** + * Scores a string for text quality and arbitrates between two candidate strings. + * + * <p>Implementations are expected to be immutable and thread-safe after construction. + * + * <p>Implementations are registered via the standard Java {@link java.util.ServiceLoader} + * mechanism: place the fully-qualified class name in + * {@code META-INF/services/org.apache.tika.quality.TextQualityDetector}. + * + * <p>Typical usage: + * <pre>{@code + * TextQualityDetector detector = ServiceLoader.load(TextQualityDetector.class) + * .findFirst().orElseThrow(); + * + * // Score a string + * TextQualityScore score = detector.score(text); + * if (score.getZScore() < -2.0) { ... flag or re-process ... } + * + * // Arbitrate between two charset decodings + * TextQualityComparison cmp = detector.compare("cp1252", decodedAsCp1252, + * "cp1251", decodedAsCp1251); + * String winner = cmp.winner(); // "A" or "B" + * }</pre> + */ +public interface TextQualityDetector { + + /** + * Scores the given string for text quality. + * + * @param text the string to score; must not be null + * @return a {@link TextQualityScore}; check {@link TextQualityScore#isUnknown()} + * if the input is empty or the script is not covered by the model + */ + TextQualityScore score(String text); + + /** + * Compares two candidate strings and returns which is higher-quality (cleaner text). + * + * <p>A common use case is charset-decoding arbitration: given raw bytes decoded + * via two different charsets, pass each decoded string here with a human-readable + * label (e.g. the charset name) and the detector will pick the one that looks + * more like natural language. + * + * @param labelA human-readable label for candidate A (e.g. {@code "cp1252"}) + * @param candidateA first candidate string + * @param labelB human-readable label for candidate B (e.g. {@code "cp1251"}) + * @param candidateB second candidate string + * @return a {@link TextQualityComparison} with the winning label and confidence delta + */ + TextQualityComparison compare(String labelA, String candidateA, + String labelB, String candidateB); +} diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkScore.java b/tika-core/src/main/java/org/apache/tika/quality/TextQualityScore.java similarity index 60% rename from tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkScore.java rename to tika-core/src/main/java/org/apache/tika/quality/TextQualityScore.java index 393976c127..a388689a58 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkScore.java +++ b/tika-core/src/main/java/org/apache/tika/quality/TextQualityScore.java @@ -14,27 +14,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.tika.ml.junkdetect; +package org.apache.tika.quality; /** - * Result of scoring a UTF-8 string for text quality. - * <p> - * {@code zScore} is the primary output: how many standard deviations below - * typical clean text this string scores on its dominant script's bigram model. + * Result of scoring a string for text quality via a {@link TextQualityDetector}. + * + * <p>{@code zScore} is the primary output: how many standard deviations below + * typical clean text this string scores on its dominant script's model. * Negative means worse than average clean text; more negative means worse. - * <p> - * {@code pClean} is the logistic-regression probability of being clean text, - * combining bigram log-prob, block-transition, and scalar features. - * <p> - * {@code ciLow} / {@code ciHigh} are the 95% confidence interval bounds on - * {@code zScore}, derived from a length-dependent variance model. For short - * strings these bounds are wide; for long strings they narrow. Use - * {@code ciLow < threshold} rather than {@code zScore < threshold} when - * triggering actions to avoid false positives on short strings. + * + * <p>{@code pClean} is a probability estimate in [0,1] that this is clean text. + * + * <p>{@code ciLow} / {@code ciHigh} are the 95% confidence interval bounds on + * {@code zScore}. For short strings these bounds are wide; for long strings + * they narrow. Prefer {@code ciLow < threshold} over {@code zScore < threshold} + * when triggering actions, to reduce false positives on short strings. */ -public final class JunkScore { +public final class TextQualityScore { - /** Sentinel z-score returned when detection could not be run (e.g. null input, ASCII-only). */ + /** Sentinel z-score returned when scoring could not be run (e.g. null or empty input). */ public static final float UNKNOWN = Float.NaN; private final float zScore; @@ -43,7 +41,9 @@ public final class JunkScore { private final float ciHigh; private final String dominantScript; - public JunkScore(float zScore, float pClean, float ciLow, float ciHigh, String dominantScript) { + public TextQualityScore(float zScore, float pClean, + float ciLow, float ciHigh, + String dominantScript) { this.zScore = zScore; this.pClean = pClean; this.ciLow = ciLow; @@ -56,17 +56,17 @@ public final class JunkScore { return zScore; } - /** Probability in [0,1] that this string is clean text (logistic regression output). */ + /** Probability in [0,1] that this string is clean text. */ public float getPClean() { return pClean; } - /** Lower bound of 95% confidence interval on zScore. */ + /** Lower bound of the 95% confidence interval on zScore. */ public float getCiLow() { return ciLow; } - /** Upper bound of 95% confidence interval on zScore. */ + /** Upper bound of the 95% confidence interval on zScore. */ public float getCiHigh() { return ciHigh; } @@ -76,6 +76,7 @@ public final class JunkScore { return dominantScript; } + /** True if scoring could not be performed (e.g. empty or unsupported-script input). */ public boolean isUnknown() { return Float.isNaN(zScore); } @@ -83,9 +84,9 @@ public final class JunkScore { @Override public String toString() { if (isUnknown()) { - return "JunkScore[UNKNOWN script=" + dominantScript + "]"; + return "TextQualityScore[UNKNOWN script=" + dominantScript + "]"; } - return String.format("JunkScore[z=%.3f p=%.3f ci=(%.3f,%.3f) script=%s]", + return String.format("TextQualityScore[z=%.3f p=%.3f ci=(%.3f,%.3f) script=%s]", zScore, pClean, ciLow, ciHigh, dominantScript); } } diff --git a/tika-ml/tika-ml-junkdetect/pom.xml b/tika-ml/tika-ml-junkdetect/pom.xml index 15abcdffee..672e49195a 100644 --- a/tika-ml/tika-ml-junkdetect/pom.xml +++ b/tika-ml/tika-ml-junkdetect/pom.xml @@ -38,6 +38,11 @@ </description> <dependencies> + <dependency> + <groupId>org.apache.tika</groupId> + <artifactId>tika-core</artifactId> + <version>${revision}</version> + </dependency> <dependency> <groupId>org.apache.tika</groupId> <artifactId>tika-ml-core</artifactId> diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java index 091ad1d04b..be3dc97e62 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java @@ -21,40 +21,63 @@ import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.Map; +import java.util.Set; import java.util.zip.GZIPInputStream; +import org.apache.tika.quality.TextQualityComparison; +import org.apache.tika.quality.TextQualityDetector; +import org.apache.tika.quality.TextQualityScore; + /** * Language-agnostic text quality scorer. Discriminates clean UTF-8 text from * mojibake, reversed text, wrong-codec decodings, and other corruption forms. * - * <p>Scoring is based on a per-script byte-bigram log-probability model: a 256×256 - * table of {@code log P(b|a)} values trained on clean Wikipedia and MADLAD-400 text. - * The per-sentence mean bigram log-prob is z-scored against the calibration statistics - * (mean and stddev measured on held-out clean text) to produce a dimensionless quality - * score. Negative z-score = worse than average clean text for that script; - * more negative = worse. + * <p>Scoring combines up to three features, depending on the model version: + * <ol> + * <li><b>Byte-bigram log-probability</b> — 256×256 table of log P(b|a) over + * consecutive byte pairs in the UTF-8 encoding.</li> + * <li><b>Unicode named-block transition log-probability</b> (version 2+) — + * N×N table of log P(block_b | block_a) where block IDs are the named + * {@link Character.UnicodeBlock} values (BASIC_LATIN, ARABIC, + * CJK_UNIFIED_IDEOGRAPHS, etc.).</li> + * <li><b>Control-byte fraction</b> (version 2+) — fraction of bytes in control + * ranges [0x01–0x08, 0x0B, 0x0C, 0x0E–0x1F, 0x7F].</li> + * </ol> + * + * <p>All features are calibrated (mu/sigma) on held-out dev text so their z-scores + * are on a common scale. + * + * <ul> + * <li><b>Version 1</b>: bigrams only; z-score = z1.</li> + * <li><b>Version 2</b>: equal-weight average: {@code (z1 + z2 + z3) / 3}.</li> + * <li><b>Version 3</b>: per-script learned linear combination: + * {@code w1*z1 + w2*z2 + w3*z3 + bias}, where weights are fit by logistic + * regression on clean vs. corrupted dev windows. The natural junk threshold + * is 0 (positive logit = clean); use a negative threshold for conservative + * detection (e.g., {@code score < -1}).</li> + * </ul> * * <p>Instances are immutable and thread-safe after construction. * * <p>Typical usage: * <pre>{@code * JunkDetector detector = JunkDetector.loadFromClasspath(); - * JunkScore score = detector.score("some text"); - * if (score.getZScore() < -2.0) { ... re-OCR or flag ... } + * TextQualityScore score = detector.score("some text"); + * if (score.getZScore() < 0) { ... flag as junk ... } * - * // Compare two charset interpretations of the same bytes - * JunkDetector.CompareResult result = detector.compare(rawBytes, "cp1252", "cp1257"); + * // Arbitrate between two charset decodings + * TextQualityComparison result = detector.compare("cp1252", ascp1252, "cp1251", ascp1251); * String winner = result.winner(); // "A" or "B" * }</pre> */ -public final class JunkDetector { +public final class JunkDetector implements TextQualityDetector { /** Classpath resource path for the bundled production model. */ public static final String DEFAULT_MODEL_RESOURCE = @@ -62,13 +85,50 @@ public final class JunkDetector { static final String MAGIC = "JUNKDET1"; - // Per-script model data - private final Map<String, float[]> tables; // script → float[65536] log-prob table + private final int modelVersion; + + // Feature 1: byte bigrams (all versions) + private final Map<String, float[]> tables; // script → float[65536] log-prob private final Map<String, float[]> calibrations; // script → float[2] {mu, sigma} - private JunkDetector(Map<String, float[]> tables, Map<String, float[]> calibrations) { + // Feature 2: named-block transitions (version 2+); null for v1 models + private final Map<String, float[]> blockTables; // script → float[blockN*blockN] + private final Map<String, float[]> blockCalibrations; // script → float[2] {mu, sigma} + private final int blockN; // block table dimension (0 for v1) + + // Feature 3: control-byte fraction (version 2+); null for v1 models + private final Map<String, float[]> controlCalibrations; // script → float[2] {mu, sigma} + + // Feature combination: per-script linear classifier (version 3+); null for v1/v2 models + // float[numFeatures+1] = {w1, ..., wN, bias}; positive logit = clean + private final Map<String, float[]> classifierWeights; + + // Shared block index for v2+ models: UnicodeBlock → index [0, blockN-1) + // Index blockN-1 is the "unassigned" bucket (null UnicodeBlock). + private final Map<Character.UnicodeBlock, Integer> blockIndex; + + private JunkDetector(int modelVersion, + Map<String, float[]> tables, + Map<String, float[]> calibrations, + Map<String, float[]> blockTables, + Map<String, float[]> blockCalibrations, + int blockN, + Map<String, float[]> controlCalibrations, + Map<String, float[]> classifierWeights, + Map<Character.UnicodeBlock, Integer> blockIndex) { + this.modelVersion = modelVersion; this.tables = Collections.unmodifiableMap(tables); this.calibrations = Collections.unmodifiableMap(calibrations); + this.blockTables = blockTables != null + ? Collections.unmodifiableMap(blockTables) : null; + this.blockCalibrations = blockCalibrations != null + ? Collections.unmodifiableMap(blockCalibrations) : null; + this.blockN = blockN; + this.controlCalibrations = controlCalibrations != null + ? Collections.unmodifiableMap(controlCalibrations) : null; + this.classifierWeights = classifierWeights != null + ? Collections.unmodifiableMap(classifierWeights) : null; + this.blockIndex = blockIndex; } // ----------------------------------------------------------------------- @@ -103,15 +163,13 @@ public final class JunkDetector { /** * Loads a model from an {@link InputStream}. Gzip-detection is automatic. + * Supports model versions 1, 2, and 3. */ public static JunkDetector load(InputStream rawIs) throws IOException { - // Peek to detect gzip magic - InputStream is = rawIs.markSupported() ? rawIs : rawIs; // already have stream - // Wrap in buffered so we can read the first bytes; rely on GZIPInputStream magic - InputStream in; byte[] peek = rawIs.readNBytes(2); InputStream rest = new java.io.SequenceInputStream( new java.io.ByteArrayInputStream(peek), rawIs); + InputStream in; if (peek.length >= 2 && (peek[0] & 0xFF) == 0x1f && (peek[1] & 0xFF) == 0x8b) { in = new GZIPInputStream(rest); } else { @@ -119,160 +177,288 @@ public final class JunkDetector { } try (DataInputStream dis = new DataInputStream(in)) { - // Verify magic byte[] magic = dis.readNBytes(8); if (!new String(magic, StandardCharsets.UTF_8).equals(MAGIC)) { throw new IOException("Not a JunkDetector model file (bad magic)"); } int version = dis.readUnsignedByte(); - if (version != 1) { + if (version < 1 || version > 3) { throw new IOException("Unsupported model version: " + version); } int numScripts = dis.readInt(); - Map<String, float[]> tables = new HashMap<>(numScripts * 2); + + // Version 2+: read global block table dimension + int blockN = 0; + Map<Character.UnicodeBlock, Integer> blockIndex = null; + if (version >= 2) { + blockN = dis.readUnsignedShort(); + blockIndex = buildBlockIndex(); + int expectedN = blockIndex.size() + 1; + if (blockN != expectedN) { + throw new IOException(String.format( + "Block table dimension mismatch: model has %d but JVM gives %d. " + + "Model was trained with a different Java version.", blockN, expectedN)); + } + } + + Map<String, float[]> tables = new HashMap<>(numScripts * 2); Map<String, float[]> calibrations = new HashMap<>(numScripts * 2); + Map<String, float[]> blockTables = version >= 2 ? new HashMap<>(numScripts * 2) : null; + Map<String, float[]> blockCalibrations = version >= 2 ? new HashMap<>(numScripts * 2) : null; + Map<String, float[]> controlCalibrations = version >= 2 ? new HashMap<>(numScripts * 2) : null; + Map<String, float[]> classifierWeights = version >= 3 ? new HashMap<>(numScripts * 2) : null; + for (int s = 0; s < numScripts; s++) { int nameLen = dis.readUnsignedShort(); String script = new String(dis.readNBytes(nameLen), StandardCharsets.UTF_8); - float mu = dis.readFloat(); - float sigma = dis.readFloat(); - calibrations.put(script, new float[]{mu, sigma}); - - byte[] tableBytes = dis.readNBytes(65536 * 4); - float[] table = new float[65536]; - ByteBuffer buf = ByteBuffer.wrap(tableBytes).order(ByteOrder.BIG_ENDIAN); - buf.asFloatBuffer().get(table); - tables.put(script, table); + // Feature 1: byte bigrams + float mu1 = dis.readFloat(); + float sigma1 = dis.readFloat(); + calibrations.put(script, new float[]{mu1, sigma1}); + tables.put(script, readFloatTable(dis, 65536)); + + if (version >= 2) { + // Feature 2: named-block transitions + float mu2 = dis.readFloat(); + float sigma2 = dis.readFloat(); + blockCalibrations.put(script, new float[]{mu2, sigma2}); + blockTables.put(script, readFloatTable(dis, blockN * blockN)); + + // Feature 3: control-byte fraction + float mu3 = dis.readFloat(); + float sigma3 = dis.readFloat(); + controlCalibrations.put(script, new float[]{mu3, sigma3}); + + if (version >= 3) { + // Classifier weights: num_features (1 byte) + num_features floats + 1 bias + int numFeatures = dis.readUnsignedByte(); + float[] weights = new float[numFeatures + 1]; // last = bias + for (int j = 0; j <= numFeatures; j++) { + weights[j] = dis.readFloat(); + } + classifierWeights.put(script, weights); + } + } } - return new JunkDetector(tables, calibrations); + return new JunkDetector(version, tables, calibrations, + blockTables, blockCalibrations, blockN, + controlCalibrations, classifierWeights, blockIndex); } } - // ----------------------------------------------------------------------- - // Scoring API - // ----------------------------------------------------------------------- + private static float[] readFloatTable(DataInputStream dis, int size) throws IOException { + byte[] tableBytes = dis.readNBytes(size * 4); + float[] table = new float[size]; + ByteBuffer buf = ByteBuffer.wrap(tableBytes).order(ByteOrder.BIG_ENDIAN); + buf.asFloatBuffer().get(table); + return table; + } /** - * Scores a UTF-8 string for text quality. - * - * @param text the string to score (will be encoded to UTF-8 internally) - * @return a {@link JunkScore}; use {@link JunkScore#isUnknown()} to check - * whether scoring was possible + * Builds the stable ordered mapping from {@link Character.UnicodeBlock} to index. + * This must produce the same ordering as {@link TrainJunkModel#buildBlockIndex()}. */ - public JunkScore score(String text) { - if (text == null || text.isEmpty()) { - return unknownScore("UNKNOWN"); + static Map<Character.UnicodeBlock, Integer> buildBlockIndex() { + LinkedHashMap<Character.UnicodeBlock, Integer> index = new LinkedHashMap<>(); + for (int cp = 0; cp <= 0x10FFFF; cp++) { + Character.UnicodeBlock b = Character.UnicodeBlock.of(cp); + if (b != null) index.putIfAbsent(b, index.size()); } - return scoreBytes(text.getBytes(StandardCharsets.UTF_8), text); + return Collections.unmodifiableMap(index); } + // ----------------------------------------------------------------------- + // TextQualityDetector implementation + // ----------------------------------------------------------------------- + /** - * Scores a byte array assumed to be UTF-8 text. + * {@inheritDoc} * - * @param utf8 raw UTF-8 bytes - * @return a {@link JunkScore} + * <p>The string is encoded to UTF-8 internally for bigram and control-byte scoring. + * Codepoints are used directly for block-transition scoring. */ - public JunkScore score(byte[] utf8) { - if (utf8 == null || utf8.length == 0) { + @Override + public TextQualityScore score(String text) { + if (text == null || text.isEmpty()) { return unknownScore("UNKNOWN"); } - String text = new String(utf8, StandardCharsets.UTF_8); - return scoreBytes(utf8, text); + return scoreText(text.getBytes(StandardCharsets.UTF_8), text); } /** - * Compares two charset interpretations of the same raw bytes and returns - * which decoding scores higher (is more likely to be clean natural language). + * {@inheritDoc} + * + * <p>Each candidate is scored independently via {@link #score(String)}. + * The candidate with the higher score wins. * - * @param rawBytes the raw bytes to decode - * @param charsetA first charset name (e.g. {@code "cp1252"}) - * @param charsetB second charset name (e.g. {@code "cp1257"}) - * @return a {@link CompareResult} indicating the winner and confidence + * <p>An UNKNOWN score (script not in model) is treated as neutral (0) rather + * than {@code -∞}. This prevents a garbled-but-recognisable decoding from + * beating a correct decoding whose script happens to be unknown to the model — + * for example, a pure-katakana zip entry name decoded as Shift-JIS (UNKNOWN) + * vs. the same bytes decoded as UTF-8 (garbled LATIN, negative z-score). */ - public CompareResult compare(byte[] rawBytes, String charsetA, String charsetB) { - JunkScore scoreA = decodeAndScore(rawBytes, charsetA); - JunkScore scoreB = decodeAndScore(rawBytes, charsetB); + @Override + public TextQualityComparison compare(String labelA, String candidateA, + String labelB, String candidateB) { + TextQualityScore scoreA = score(candidateA); + TextQualityScore scoreB = score(candidateB); - float zA = scoreA.isUnknown() ? Float.NEGATIVE_INFINITY : scoreA.getZScore(); - float zB = scoreB.isUnknown() ? Float.NEGATIVE_INFINITY : scoreB.getZScore(); + // UNKNOWN = "no evidence" = 0, not -∞. A text whose script is not in the + // model is assumed to be neutral, not junk. + float zA = scoreA.isUnknown() ? 0f : scoreA.getZScore(); + float zB = scoreB.isUnknown() ? 0f : scoreB.getZScore(); String winner = zA >= zB ? "A" : "B"; float delta = Math.abs(zA - zB); - return new CompareResult(winner, delta, scoreA, scoreB, charsetA, charsetB); + return new TextQualityComparison(winner, delta, scoreA, scoreB, labelA, labelB); } /** Returns the set of script names this model knows about. */ - public java.util.Set<String> knownScripts() { + public Set<String> knownScripts() { return tables.keySet(); } + /** Returns the version of the loaded model (1, 2, or 3). */ + public int getModelVersion() { + return modelVersion; + } + // ----------------------------------------------------------------------- // Internal scoring // ----------------------------------------------------------------------- - private JunkScore scoreBytes(byte[] utf8, String text) { + private TextQualityScore scoreText(byte[] utf8, String text) { String script = detectDominantScript(text); - float[] table = tables.get(script); - if (table == null) { - // Script not in model — return unknown with script name for diagnostics + float[] bigramTable = tables.get(script); + if (bigramTable == null) { return unknownScore(script); } - if (utf8.length < 2) { return unknownScore(script); } - // Mean byte-bigram log-prob - double sum = 0; - int count = 0; + // Feature 1: byte-bigram mean log-prob + double bigramSum = 0; + int bigramCount = 0; for (int i = 0; i + 1 < utf8.length; i++) { - sum += table[((utf8[i] & 0xFF) << 8) | (utf8[i + 1] & 0xFF)]; - count++; + bigramSum += bigramTable[((utf8[i] & 0xFF) << 8) | (utf8[i + 1] & 0xFF)]; + bigramCount++; } - float meanLogProb = (float) (sum / count); + float meanBigramLogProb = (float) (bigramSum / bigramCount); + float[] cal1 = calibrations.get(script); + float z1 = (meanBigramLogProb - cal1[0]) / cal1[1]; + + // Features 2 & 3 (version 2+) + float z2 = 0f, z3 = 0f; + if (modelVersion >= 2 && blockTables != null) { + // Feature 2: named-block transition mean log-prob + float[] blockTable = blockTables.get(script); + if (blockTable != null) { + int nullId = blockN - 1; + int prev = -1; + double blockSum = 0; + int blockCount = 0; + for (int i = 0; i < text.length(); ) { + int cp = text.codePointAt(i); + Character.UnicodeBlock b = Character.UnicodeBlock.of(cp); + int blockId = b != null ? blockIndex.getOrDefault(b, nullId) : nullId; + if (prev >= 0) { + blockSum += blockTable[prev * blockN + blockId]; + blockCount++; + } + prev = blockId; + i += Character.charCount(cp); + } + if (blockCount > 0) { + float meanBlockLogProb = (float) (blockSum / blockCount); + float[] cal2 = blockCalibrations.get(script); + z2 = cal2 != null ? (meanBlockLogProb - cal2[0]) / cal2[1] : 0f; + } + } - // Z-score against calibration - float[] cal = calibrations.get(script); - float mu = cal[0]; - float sigma = cal[1]; - float zScore = (meanLogProb - mu) / sigma; + // Feature 3: control-byte fraction (stored as −fraction, so higher = cleaner) + long controlCount = 0; + for (byte b : utf8) { + if (isControlByte(b & 0xFF)) controlCount++; + } + float controlScore = -(float) controlCount / utf8.length; + float[] cal3 = controlCalibrations.get(script); + z3 = cal3 != null ? (controlScore - cal3[0]) / cal3[1] : 0f; + } + + // Combine features + float zScore; + if (modelVersion >= 3 && classifierWeights != null) { + // Version 3: per-script linear combination (logistic regression weights) + float[] cw = classifierWeights.get(script); + if (cw != null && cw.length >= 4) { + // cw = {w1, w2, w3, bias}; positive logit = clean + zScore = cw[0] * z1 + cw[1] * z2 + cw[2] * z3 + cw[cw.length - 1]; + } else { + zScore = (z1 + z2 + z3) / 3.0f; // fallback if weights missing + } + } else if (modelVersion >= 2 && blockTables != null) { + // Version 2: equal-weight average + zScore = (z1 + z2 + z3) / 3.0f; + } else { + // Version 1: bigrams only + zScore = z1; + } - // Confidence interval: uncertainty ~ 1.96 * sigma / sqrt(count) - float uncertainty = (float) (1.96 * sigma / Math.sqrt(count)); + // CI is approximated from the bigram count and bigram sigma + float uncertainty = (float) (1.96 * cal1[1] / Math.sqrt(bigramCount)); float ciLow = zScore - uncertainty; float ciHigh = zScore + uncertainty; - // P(clean): sigmoid of z-score (simple calibration-free estimate) float pClean = (float) (1.0 / (1.0 + Math.exp(-zScore))); - return new JunkScore(zScore, pClean, ciLow, ciHigh, script); + return new TextQualityScore(zScore, pClean, ciLow, ciHigh, script); } - private JunkScore decodeAndScore(byte[] raw, String charsetName) { - try { - Charset cs = Charset.forName(charsetName); - byte[] utf8 = new String(raw, cs).getBytes(StandardCharsets.UTF_8); - return score(utf8); - } catch (Exception e) { - return unknownScore(charsetName); - } + /** + * Returns true if the byte value is a control character that should not appear + * in natural-language UTF-8 text: {@code [0x01–0x08, 0x0B, 0x0C, 0x0E–0x1F, 0x7F]}. + * + * <p>Excluded: 0x00 (null), 0x09 (tab), 0x0A (newline), 0x0D (carriage return) + * — all appear legitimately in text. + */ + static boolean isControlByte(int b) { + return (b >= 0x01 && b <= 0x08) + || b == 0x0B || b == 0x0C + || (b >= 0x0E && b <= 0x1F) + || b == 0x7F; } - private static JunkScore unknownScore(String script) { - return new JunkScore(JunkScore.UNKNOWN, Float.NaN, Float.NaN, Float.NaN, script); + private static TextQualityScore unknownScore(String script) { + return new TextQualityScore(TextQualityScore.UNKNOWN, Float.NaN, + Float.NaN, Float.NaN, script); } + /** + * Maps Unicode scripts that share a trained model with a related script. + * Japanese kana (HIRAGANA, KATAKANA) map to HAN because the HAN model is + * trained on mixed Japanese text containing all three writing systems, so + * its byte-bigram and block-transition tables cover kana sequences. + */ + private static final Map<String, String> SCRIPT_MODEL_FALLBACK = Map.of( + "HIRAGANA", "HAN", + "KATAKANA", "HAN" + ); + /** * Detects the dominant Unicode script of the given text by histogramming * {@link Character.UnicodeScript} over all codepoints, excluding COMMON, - * INHERITED, and UNKNOWN pseudo-scripts. Returns "LATIN" for ASCII-only - * text (no non-ASCII codepoints). + * INHERITED, and UNKNOWN pseudo-scripts. Returns "LATIN" for ASCII-only text. + * + * <p>Script names are mapped through {@link #SCRIPT_MODEL_FALLBACK} so that + * scripts without dedicated models fall back to a related trained model + * (e.g. KATAKANA and HIRAGANA both use the HAN model). */ static String detectDominantScript(String text) { Map<Character.UnicodeScript, Integer> counts = new HashMap<>(); @@ -287,74 +473,12 @@ public final class JunkDetector { i += Character.charCount(cp); } if (counts.isEmpty()) { - return "LATIN"; // ASCII-only → use Latin model + return "LATIN"; } - return counts.entrySet().stream() - .max(java.util.Map.Entry.comparingByValue()) + String name = counts.entrySet().stream() + .max(Map.Entry.comparingByValue()) .map(e -> e.getKey().name()) .orElse("LATIN"); - } - - // ----------------------------------------------------------------------- - // Result type for compare() - // ----------------------------------------------------------------------- - - /** - * Result of comparing two charset decodings of the same raw bytes. - */ - public static final class CompareResult { - private final String winner; - private final float delta; - private final JunkScore scoreA; - private final JunkScore scoreB; - private final String charsetA; - private final String charsetB; - - CompareResult(String winner, float delta, - JunkScore scoreA, JunkScore scoreB, - String charsetA, String charsetB) { - this.winner = winner; - this.delta = delta; - this.scoreA = scoreA; - this.scoreB = scoreB; - this.charsetA = charsetA; - this.charsetB = charsetB; - } - - /** "A" if charsetA decodes to cleaner text, "B" otherwise. */ - public String winner() { - return winner; - } - - /** - * Absolute difference in z-scores. Small delta = uncertain; large delta = confident. - * As a rough guide: delta > 1.0 is useful signal, delta > 3.0 is confident. - */ - public float delta() { - return delta; - } - - public JunkScore scoreA() { - return scoreA; - } - - public JunkScore scoreB() { - return scoreB; - } - - public String charsetA() { - return charsetA; - } - - public String charsetB() { - return charsetB; - } - - @Override - public String toString() { - return String.format("CompareResult[winner=%s(%s) delta=%.3f A=%s B=%s]", - winner, winner.equals("A") ? charsetA : charsetB, - delta, scoreA, scoreB); - } + return SCRIPT_MODEL_FALLBACK.getOrDefault(name, name); } } diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/BuildJunkTrainingData.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/BuildJunkTrainingData.java index 77d9283f9a..27a5436d5e 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/BuildJunkTrainingData.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/BuildJunkTrainingData.java @@ -113,6 +113,15 @@ public class BuildJunkTrainingData { private static final double DEV_FRAC = 0.10; // remaining (1 - TRAIN_FRAC - DEV_FRAC) goes to the test split + /** + * Minimum number of sentences that must land in the dev split for a script to be + * included in the model. Scripts below this floor have too few samples to reliably + * estimate calibration statistics (mu/sigma), which produces noisy z-scores and + * inflated false positive rates. With DEV_FRAC=0.10 the effective minimum total + * sentence count is minDevSentences / DEV_FRAC (default: 5,000 total sentences). + */ + private static final int DEFAULT_MIN_DEV_SENTENCES = 500; + // ----------------------------------------------------------------------- // Entry point // ----------------------------------------------------------------------- @@ -126,6 +135,7 @@ public class BuildJunkTrainingData { double maxPuncFrac = DEFAULT_MAX_PUNC_FRAC; int seed = 42; boolean dryRun = false; + int minDevSentences = DEFAULT_MIN_DEV_SENTENCES; for (int i = 0; i < args.length; i++) { switch (args[i]) { @@ -150,6 +160,9 @@ public class BuildJunkTrainingData { case "--seed": seed = Integer.parseInt(args[++i]); break; + case "--min-dev-sentences": + minDevSentences = Integer.parseInt(args[++i]); + break; case "--dry-run": dryRun = true; break; @@ -167,6 +180,8 @@ public class BuildJunkTrainingData { totalBudgetBytes, totalBudgetBytes / 1_000_000.0); System.out.printf( " min-bytes: %d%n", minBytes); System.out.printf( " max-punc-frac: %.2f%n", maxPuncFrac); + System.out.printf( " min-dev-sentences: %d (min total ≈ %d)%n", + minDevSentences, (int)(minDevSentences / DEV_FRAC)); System.out.println(" dry-run: " + dryRun); if (!Files.isDirectory(dataDir)) { @@ -246,12 +261,13 @@ public class BuildJunkTrainingData { // ----------------------------------------------------------------------- Files.createDirectories(outputDir); - System.out.println("\n--- Phase 4: Collecting and writing per-script files ---"); + System.out.println("\n--- Phase 4a: Round 1 — collecting with initial budgets ---"); Random rng = new Random(seed); - // manifest columns: script, entropy, budget_bytes, written_bytes, sentences, train_bytes, languages - Map<String, long[]> manifestStats = new TreeMap<>(); + // Collect sentences and actual byte counts for every script + Map<String, List<String>> allSentences = new LinkedHashMap<>(); + Map<String, Long> actualBytes = new LinkedHashMap<>(); for (Map.Entry<String, Long> budgetEntry : scriptBudget.entrySet()) { String script = budgetEntry.getKey(); @@ -259,15 +275,12 @@ public class BuildJunkTrainingData { List<Path> langDirs = scriptGroups.get(script); long perLangBytes = Math.max(budget / langDirs.size(), 1L); - List<String> sentences = new ArrayList<>(); long totalBytesLoaded = 0; for (Path langDir : langDirs) { long remaining = budget - totalBytesLoaded; - if (remaining <= 0) { - break; - } + if (remaining <= 0) break; long langBytes = loadSentences(langDir, Math.min(perLangBytes, remaining), minBytes, maxPuncFrac, sentences); @@ -277,10 +290,80 @@ public class BuildJunkTrainingData { script, langDir.getFileName(), langBytes); } } + allSentences.put(script, sentences); + actualBytes.put(script, totalBytesLoaded); + } + + // Compute surplus bytes from data-starved scripts (< 90% of budget used) + long surplus = 0; + for (Map.Entry<String, Long> e : scriptBudget.entrySet()) { + long budget = e.getValue(); + long actual = actualBytes.getOrDefault(e.getKey(), 0L); + if (actual < budget * 0.9) { + surplus += (budget - actual); + } + } + + // Round 2: redistribute surplus to saturated scripts proportional to entropy + if (surplus > 0) { + System.out.printf( + "\n--- Phase 4b: Redistributing %,d surplus bytes (%.1f MB) ---\n", + surplus, surplus / 1_000_000.0); + + double saturatedEntropy = scriptBudget.entrySet().stream() + .filter(e -> actualBytes.getOrDefault(e.getKey(), 0L) >= e.getValue() * 0.9) + .mapToDouble(e -> scriptEntropy.getOrDefault(e.getKey(), 0.0)) + .sum(); + + for (Map.Entry<String, Long> budgetEntry : scriptBudget.entrySet()) { + String script = budgetEntry.getKey(); + long budget = budgetEntry.getValue(); + long actual = actualBytes.getOrDefault(script, 0L); + if (actual < budget * 0.9) continue; // data-starved — skip + + long extra = (long) (surplus + * scriptEntropy.getOrDefault(script, 0.0) / saturatedEntropy); + if (extra <= 0) continue; + + long newBudget = budget + extra; + List<Path> langDirs = scriptGroups.get(script); + long perLangBytes = Math.max(newBudget / langDirs.size(), 1L); + + List<String> sentences = new ArrayList<>(); + long totalBytesLoaded = 0; + for (Path langDir : langDirs) { + long remaining = newBudget - totalBytesLoaded; + if (remaining <= 0) break; + long langBytes = loadSentences(langDir, + Math.min(perLangBytes, remaining), + minBytes, maxPuncFrac, sentences); + totalBytesLoaded += langBytes; + } + if (!sentences.isEmpty()) { + allSentences.put(script, sentences); + actualBytes.put(script, totalBytesLoaded); + System.out.printf(" %-20s +%,d extra → %,d total bytes, %,d sentences%n", + script, extra, totalBytesLoaded, sentences.size()); + } + } + } + + // Write split files + System.out.println("\n--- Phase 4c: Writing train/dev/test splits ---"); + + // manifest columns: script, entropy, budget_bytes, written_bytes, sentences, train_bytes, languages + Map<String, long[]> manifestStats = new TreeMap<>(); + + for (Map.Entry<String, List<String>> e : allSentences.entrySet()) { + String script = e.getKey(); + List<String> sentences = e.getValue(); - if (sentences.isEmpty()) { - System.out.printf(" SKIP %-12s — no sentences collected%n", script); - manifestStats.put(script, new long[]{0, 0, 0}); + int expectedDevSize = (int) (sentences.size() * DEV_FRAC); + if (sentences.isEmpty() || expectedDevSize < minDevSentences) { + System.out.printf( + " SKIP %-20s — %,d sentences → dev=%d < min-dev-sentences=%d%n", + script, sentences.size(), expectedDevSize, minDevSentences); + manifestStats.put(script, new long[]{0, 0, 0, 0, 0}); continue; } @@ -297,6 +380,7 @@ public class BuildJunkTrainingData { writeGzipped(outputDir.resolve(baseName + ".dev.gz"), dev); writeGzipped(outputDir.resolve(baseName + ".test.gz"), test); + long totalBytesLoaded = actualBytes.getOrDefault(script, 0L); manifestStats.put(script, new long[]{totalBytesLoaded, sentences.size(), nTrain, nDev, test.size()}); System.out.printf( @@ -552,6 +636,9 @@ public class BuildJunkTrainingData { + " (default: 50)"); System.err.println(" --max-punc-frac F Max ASCII punct fraction" + " (default: 0.30)"); + System.err.println(" --min-dev-sentences N Min sentences in dev split for a" + + " script to be included (default: 500). Scripts below this floor" + + " have unreliable calibration and inflated FPR."); System.err.println(" --seed N Random seed (default: 42)"); System.err.println(" --dry-run Detect scripts + show budget," + " skip file writing"); diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/EvalJunkDetector.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/EvalJunkDetector.java index de2494816a..0538f29537 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/EvalJunkDetector.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/EvalJunkDetector.java @@ -32,7 +32,7 @@ import java.util.stream.Collectors; import java.util.zip.GZIPInputStream; import org.apache.tika.ml.junkdetect.JunkDetector; -import org.apache.tika.ml.junkdetect.JunkScore; +import org.apache.tika.quality.TextQualityScore; /** * Ablation evaluation for the junk detector. @@ -232,6 +232,26 @@ public class EvalJunkDetector { detail.println(row.toTsv()); } + // --- wrong-codec: re-read UTF-8 bytes as ISO-8859-1 then re-encode --- + { + List<Float> corruptZ = scoreWrongCodec(detector, sentences, len, + samplesPerCell, new Random(seed + 4)); + Row row = new Row(script, "wrong-codec", "latin1-as-utf8", len, + cleanZ, corruptZ, threshold); + allRows.add(row); + detail.println(row.toTsv()); + } + + // --- byte-swap: swap each adjacent pair of bytes (endianness flip) --- + { + List<Float> corruptZ = scoreByteSwapped(detector, sentences, len, + samplesPerCell, new Random(seed + 5)); + Row row = new Row(script, "byte-swap", "-", len, + cleanZ, corruptZ, threshold); + allRows.add(row); + detail.println(row.toTsv()); + } + detail.flush(); rng = new Random(seed); // reset between lengths for reproducibility } @@ -266,6 +286,8 @@ public class EvalJunkDetector { } conditions.add(new String[]{"char-reverse", "-"}); conditions.add(new String[]{"byte-shuffle", "-"}); + conditions.add(new String[]{"wrong-codec", "latin1-as-utf8"}); + conditions.add(new String[]{"byte-swap", "-"}); for (String[] cond : conditions) { String distortion = cond[0]; @@ -406,7 +428,7 @@ public class EvalJunkDetector { List<Float> results = new ArrayList<>(n); for (int i = 0; i < n; i++) { String s = pickSubstring(sentences, targetLen, rng); - JunkScore score = detector.score(s); + TextQualityScore score = detector.score(s); if (!score.isUnknown()) { results.add(score.getZScore()); } @@ -422,7 +444,7 @@ public class EvalJunkDetector { String s = pickSubstring(sentences, targetLen, rng); byte[] bytes = s.getBytes(StandardCharsets.UTF_8); injectRandomBytes(bytes, rate, rng); - JunkScore score = detector.score(bytes); + TextQualityScore score = detector.score(new String(bytes, StandardCharsets.ISO_8859_1)); if (!score.isUnknown()) { results.add(score.getZScore()); } @@ -435,7 +457,7 @@ public class EvalJunkDetector { List<Float> results = new ArrayList<>(n); for (int i = 0; i < n; i++) { String s = reverseCodepoints(pickSubstring(sentences, targetLen, rng)); - JunkScore score = detector.score(s); + TextQualityScore score = detector.score(s); if (!score.isUnknown()) { results.add(score.getZScore()); } @@ -450,7 +472,35 @@ public class EvalJunkDetector { String s = pickSubstring(sentences, targetLen, rng); byte[] bytes = s.getBytes(StandardCharsets.UTF_8); shuffleBytes(bytes, rng); - JunkScore score = detector.score(bytes); + TextQualityScore score = detector.score(new String(bytes, StandardCharsets.ISO_8859_1)); + if (!score.isUnknown()) { + results.add(score.getZScore()); + } + } + return results; + } + + private static List<Float> scoreWrongCodec(JunkDetector detector, List<String> sentences, + int targetLen, int n, Random rng) { + List<Float> results = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + String s = pickSubstring(sentences, targetLen, rng); + byte[] garbled = wrongCodecBytes(s.getBytes(StandardCharsets.UTF_8)); + TextQualityScore score = detector.score(new String(garbled, StandardCharsets.UTF_8)); + if (!score.isUnknown()) { + results.add(score.getZScore()); + } + } + return results; + } + + private static List<Float> scoreByteSwapped(JunkDetector detector, List<String> sentences, + int targetLen, int n, Random rng) { + List<Float> results = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + String s = pickSubstring(sentences, targetLen, rng); + byte[] swapped = swapByteOrder(s.getBytes(StandardCharsets.UTF_8)); + TextQualityScore score = detector.score(new String(swapped, StandardCharsets.ISO_8859_1)); if (!score.isUnknown()) { results.add(score.getZScore()); } @@ -462,14 +512,49 @@ public class EvalJunkDetector { // Distortion primitives // ----------------------------------------------------------------------- + /** + * Injects control characters (0x01–0x09, i.e. below newline/0x0A, excluding null) + * at the given rate. These bytes never appear in clean natural-language UTF-8 text + * and simulate binary data leaking into a text stream. 0x00 is excluded because + * null bytes cause problems in many text-processing pipelines. + */ static void injectRandomBytes(byte[] bytes, double rate, Random rng) { for (int i = 0; i < bytes.length; i++) { if (rng.nextDouble() < rate) { - bytes[i] = (byte) (0x80 | rng.nextInt(128)); + // 0x01..0x09 inclusive (9 values): SOH STX ETX EOT ENQ ACK BEL BS HT + bytes[i] = (byte) (0x01 + rng.nextInt(9)); } } } + /** + * Wrong-codec distortion: the UTF-8 bytes of the sentence are re-interpreted + * as ISO-8859-1 (Latin-1) and then re-encoded as UTF-8. This is the classic + * "saved as UTF-8, displayed as Latin-1" mojibake: every byte in 0x80–0xFF + * becomes a two-byte UTF-8 sequence, doubling the byte length of non-ASCII runs + * and producing bogus accented-Latin bigrams. + */ + static byte[] wrongCodecBytes(byte[] utf8) { + String misread = new String(utf8, StandardCharsets.ISO_8859_1); + return misread.getBytes(StandardCharsets.UTF_8); + } + + /** + * Byte-swap distortion: swaps each adjacent pair of bytes — (0,1), (2,3), etc. + * If the array has an odd length the last byte is left unchanged. + * Simulates reading a 2-byte encoding (UTF-16, UCS-2, CP932 two-byte sequences) + * with the wrong byte order. + */ + static byte[] swapByteOrder(byte[] bytes) { + byte[] out = bytes.clone(); + for (int i = 0; i + 1 < out.length; i += 2) { + byte tmp = out[i]; + out[i] = out[i + 1]; + out[i + 1] = tmp; + } + return out; + } + static void shuffleBytes(byte[] bytes, Random rng) { for (int i = bytes.length - 1; i > 0; i--) { int j = rng.nextInt(i + 1); diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java index 34ecffb533..2ba083c139 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java @@ -27,7 +27,11 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import java.util.Random; import java.util.TreeMap; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; @@ -36,40 +40,96 @@ import java.util.zip.GZIPOutputStream; * Trains the junk detector model from per-script corpus files produced by * {@link BuildJunkTrainingData}. * - * <p>For each script group (identified by a {@code {script}.train.gz} file): + * <p>For each script group (identified by a {@code {script}.train.gz} file), + * three features are trained and then combined by a per-script logistic + * regression classifier: * <ol> - * <li>Accumulates byte-bigram counts from the training sentences.</li> - * <li>Applies add-1 (Laplace) smoothing per row, converts to natural - * log-probabilities.</li> - * <li>Computes calibration statistics (mean and stddev of per-sentence mean - * bigram log-prob) from the dev split ({@code {script}.dev.gz}).</li> + * <li><b>Byte-bigram log-probability</b>: 256×256 table of log P(b|a) over + * consecutive byte pairs in the UTF-8 encoding.</li> + * <li><b>Unicode named-block transition log-probability</b>: N×N table of + * log P(block_b | block_a), where block ID is determined by + * {@link Character.UnicodeBlock#of(int)} — one of the ~327 named Unicode + * blocks plus one extra bucket for unassigned codepoints.</li> + * <li><b>Control-byte fraction</b>: fraction of bytes in control-character + * ranges ([0x01–0x08, 0x0B, 0x0C, 0x0E–0x1F, 0x7F]). Stored as + * {@code −fraction} so the z-score convention matches the other features + * (higher = cleaner).</li> * </ol> * - * <p>Output: a single gzipped binary model file ({@code junkdetect.bin}) in the - * following format: + * <p>All three features are calibrated (mu/sigma) on the dev split so their + * z-scores are on a common scale. A per-script binary logistic regression + * classifier is then fit on (z1, z2, z3) using clean dev windows and corrupted + * versions (inject@5%, char-shuffle) as training examples. The learned weights + * replace the fixed equal-weight average, allowing the model to automatically + * downweight noisy features (e.g. high-variance block transitions for MYANMAR) + * and upweight informative ones (e.g. control-byte fraction for [email protected]). + * + * <p>At inference, the final score is the linear combination + * {@code w1*z1 + w2*z2 + w3*z3 + bias}; positive values indicate clean text. + * The natural threshold is 0 (probability 0.5); use a negative threshold for + * more conservative junk detection. + * + * <p>Output format: {@code JUNKDET1} gzipped binary, <b>version 3</b>. + * Version 1 (bigrams only) and version 2 (equal-weight average) files can + * still be loaded by {@code JunkDetector}. + * * <pre> - * [8 bytes] magic "JUNKDET1" - * [1 byte] version = 1 + * [8 bytes] magic "JUNKDET1" (ASCII) + * [1 byte] version = 3 * [4 bytes] num_scripts (big-endian int) + * [2 bytes] block_N — number of distinct named Unicode blocks + 1 (unassigned) * for each script (sorted by name): - * [2 bytes] name length (big-endian ushort) - * [N bytes] script name (UTF-8) - * [4 bytes] mu — mean of mean_bigram_logprob over dev sentences (float) - * [4 bytes] sigma — stddev (float) - * [65536 * 4 bytes] float32 log-prob table, row a*256+b = log P(b|a) - * </pre> - * - * <p>Usage: - * <pre> - * java TrainJunkModel \ - * --data-dir ~/datasets/madlad/junkdetect \ - * --output ~/datasets/madlad/junkdetect/junkdetect.bin + * [2 bytes] name length (big-endian ushort) + * [name bytes] script name (UTF-8) + * // Feature 1 — byte bigrams + * [4 bytes] mu1 (float32 big-endian) + * [4 bytes] sigma1 (float32 big-endian) + * [65536×4 bytes] byte-bigram log-prob table (256×256) + * // Feature 2 — block transitions + * [4 bytes] mu2 (float32 big-endian) + * [4 bytes] sigma2 (float32 big-endian) + * [block_N²×4 bytes] block-transition log-prob table + * // Feature 3 — control-byte fraction + * [4 bytes] mu3 (float32 big-endian) + * [4 bytes] sigma3 (float32 big-endian) + * // Linear classifier weights + * [1 byte] num_features (= 3) + * [4 bytes] w1 (float32 big-endian) + * [4 bytes] w2 (float32 big-endian) + * [4 bytes] w3 (float32 big-endian) + * [4 bytes] bias (float32 big-endian) * </pre> */ public class TrainJunkModel { static final String MAGIC = "JUNKDET1"; - static final byte VERSION = 1; + static final byte VERSION = 3; + + /** Number of clean (and corrupted) windows used to train the per-script classifier. */ + static final int NUM_CLASSIFIER_SAMPLES = 500; + + /** Fraction of characters replaced with control characters for inject distortion. */ + static final double CLASSIFIER_INJECT_RATE = 0.05; + + /** + * Minimum sigma for the control-byte feature. Because clean dev text + * typically has zero control bytes in every sentence, the sample standard + * deviation collapses to 0 and would be clamped to 1.0 by the generic + * {@link #muSigma} helper — making the feature useless. This floor + * ensures a 1% control-byte injection ({@code [email protected]}) produces + * approximately z = −2, providing meaningful signal. + */ + static final float CONTROL_BYTE_MIN_SIGMA = 0.005f; + + /** + * Target byte-lengths used for calibration sampling, matching the evaluator defaults. + */ + static final int[] CALIB_LENGTHS = {15, 30, 50, 100, 200}; + + /** + * Number of random byte-window samples drawn from the dev set for calibration. + */ + static final int CALIB_SAMPLES = 5000; public static void main(String[] args) throws IOException { Path dataDir = Paths.get(System.getProperty("user.home"), @@ -100,9 +160,19 @@ public class TrainJunkModel { System.exit(1); } - // Collect all script names by finding *.train.gz files - TreeMap<String, float[]> tables = new TreeMap<>(); - TreeMap<String, float[]> calibrations = new TreeMap<>(); + System.out.print("Building Unicode named-block index... "); + long t0 = System.currentTimeMillis(); + Map<Character.UnicodeBlock, Integer> blockIndex = buildBlockIndex(); + int blockN = blockIndex.size() + 1; // +1 for unassigned bucket + System.out.printf("%d named blocks → table size %d×%d (%dms)%n", + blockIndex.size(), blockN, blockN, System.currentTimeMillis() - t0); + + TreeMap<String, float[]> bigramTables = new TreeMap<>(); + TreeMap<String, float[]> bigramCalibrations = new TreeMap<>(); + TreeMap<String, float[]> blockTables = new TreeMap<>(); + TreeMap<String, float[]> blockCalibrations = new TreeMap<>(); + TreeMap<String, float[]> controlCalibrations = new TreeMap<>(); + TreeMap<String, float[]> classifierWeights = new TreeMap<>(); try (var stream = Files.list(dataDir)) { List<Path> trainFiles = stream @@ -119,53 +189,108 @@ public class TrainJunkModel { String filename = trainFile.getFileName().toString(); String script = filename.substring(0, filename.length() - ".train.gz".length()) .toUpperCase(); - Path devFile = trainFile.getParent().resolve( filename.replace(".train.gz", ".dev.gz")); System.out.printf("%n--- %s ---%n", script); - System.out.print(" Training bigram table... "); - long t0 = System.currentTimeMillis(); - float[] table = trainBigramTable(trainFile); + t0 = System.currentTimeMillis(); + System.out.print(" Training byte-bigram table... "); + float[] bigramTable = trainBigramTable(trainFile); System.out.printf("done (%dms)%n", System.currentTimeMillis() - t0); - float[] cal = new float[]{0f, 1f}; + t0 = System.currentTimeMillis(); + System.out.print(" Training named-block table... "); + float[] blockTable = trainBlockTable(trainFile, blockIndex, blockN); + System.out.printf("done (%dms)%n", System.currentTimeMillis() - t0); + + float[] bigramCal = new float[]{0f, 1f}; + float[] blockCal = new float[]{0f, 1f}; + float[] controlCal = new float[]{0f, 1f}; + // Default: equal-weight average (w=[1/3,1/3,1/3], bias=0) + float[] weights = new float[]{1f / 3, 1f / 3, 1f / 3, 0f}; + if (Files.exists(devFile)) { - System.out.print(" Calibrating on dev set... "); t0 = System.currentTimeMillis(); - cal = computeCalibration(devFile, table); + System.out.print(" Calibrating byte bigrams on dev... "); + bigramCal = computeBigramCalibration(devFile, bigramTable); System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", - cal[0], cal[1], System.currentTimeMillis() - t0); + bigramCal[0], bigramCal[1], System.currentTimeMillis() - t0); + + t0 = System.currentTimeMillis(); + System.out.print(" Calibrating named blocks on dev... "); + blockCal = computeBlockCalibration(devFile, blockTable, blockIndex, blockN); + System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", + blockCal[0], blockCal[1], System.currentTimeMillis() - t0); + + t0 = System.currentTimeMillis(); + System.out.print(" Calibrating control bytes on dev... "); + controlCal = computeControlByteCalibration(devFile); + System.out.printf("done — mu=%.6f sigma=%.6f (%dms)%n", + controlCal[0], controlCal[1], System.currentTimeMillis() - t0); + + t0 = System.currentTimeMillis(); + System.out.print(" Training linear classifier... "); + weights = trainClassifier(devFile, bigramTable, bigramCal, + blockTable, blockCal, controlCal, blockIndex, blockN); + System.out.printf("done — w=[%.3f,%.3f,%.3f] bias=%.3f (%dms)%n", + weights[0], weights[1], weights[2], weights[3], + System.currentTimeMillis() - t0); } else { System.out.println(" WARNING: no dev file found, using uncalibrated defaults"); } - tables.put(script, table); - calibrations.put(script, cal); + bigramTables.put(script, bigramTable); + bigramCalibrations.put(script, bigramCal); + blockTables.put(script, blockTable); + blockCalibrations.put(script, blockCal); + controlCalibrations.put(script, controlCal); + classifierWeights.put(script, weights); } } - System.out.printf("%nWriting model (%d scripts) → %s%n", tables.size(), output); - saveModel(tables, calibrations, output); + System.out.printf("%nWriting model (%d scripts, blockN=%d) → %s%n", + bigramTables.size(), blockN, output); + saveModel(bigramTables, bigramCalibrations, + blockTables, blockCalibrations, + controlCalibrations, classifierWeights, blockN, output); System.out.printf("Model size: %,d bytes (%.1f MB)%n", Files.size(output), Files.size(output) / 1_000_000.0); System.out.println("Done."); } // ----------------------------------------------------------------------- - // Training + // Block index // ----------------------------------------------------------------------- /** - * Trains a 256×256 byte-bigram log-probability table from a gzipped - * sentence file (one UTF-8 sentence per line). + * Builds a stable ordered mapping from {@link Character.UnicodeBlock} to integer index + * by scanning all valid Unicode codepoints in order (U+0000 to U+10FFFF) and + * recording each block's first occurrence. + * + * <p>The resulting map has {@code size()} entries (one per named block). + * Callers should reserve index {@code size()} as the "unassigned" bucket + * (for codepoints where {@code UnicodeBlock.of(cp)} returns null). * - * <p>All 256×256 consecutive byte-pair counts are accumulated, then - * add-1 (Laplace) smoothing is applied per row before converting to - * natural log-probabilities: {@code log P(b|a) = log((C[a][b]+1) / sum_b(C[a][b]+1))}. + * @return immutable ordered map: UnicodeBlock → integer index [0, size) + */ + static Map<Character.UnicodeBlock, Integer> buildBlockIndex() { + LinkedHashMap<Character.UnicodeBlock, Integer> index = new LinkedHashMap<>(); + for (int cp = 0; cp <= 0x10FFFF; cp++) { + Character.UnicodeBlock b = Character.UnicodeBlock.of(cp); + if (b != null) index.putIfAbsent(b, index.size()); + } + return Collections.unmodifiableMap(index); + } + + // ----------------------------------------------------------------------- + // Training + // ----------------------------------------------------------------------- + + /** + * Trains a 256×256 byte-bigram log-probability table from a gzipped sentence file. * - * @return float[65536] table where index {@code a*256+b} = log P(b|a) + * @return float[65536] where index {@code a*256+b} = log P(b|a) */ static float[] trainBigramTable(Path trainGz) throws IOException { long[] counts = new long[65536]; @@ -184,116 +309,530 @@ public class TrainJunkModel { } } - System.out.printf(" %,d sentences, %,d bigrams%n", sentences, totalBigrams); + System.out.printf(" %,d sentences, %,d byte bigrams%n", sentences, totalBigrams); + return laplaceSmoothLogProb(counts, 256); + } - // Add-1 smoothing per row, then log-prob - float[] table = new float[65536]; - for (int a = 0; a < 256; a++) { - long rowTotal = 256; // add 1 for each of the 256 possible next bytes - for (int b = 0; b < 256; b++) { - rowTotal += counts[a * 256 + b]; + /** + * Trains a {@code blockN×blockN} named-Unicode-block transition log-probability table. + * + * @param blockIndex ordered mapping from UnicodeBlock to index [0, blockIndex.size()) + * @param blockN blockIndex.size() + 1 (includes the null bucket) + * @return float[blockN*blockN] where index {@code a*blockN+b} = log P(block_b | block_a) + */ + static float[] trainBlockTable(Path trainGz, + Map<Character.UnicodeBlock, Integer> blockIndex, + int blockN) throws IOException { + long[] counts = new long[blockN * blockN]; + int nullId = blockN - 1; + long totalBigrams = 0; + long sentences = 0; + + try (BufferedReader r = openGzipped(trainGz)) { + String line; + while ((line = r.readLine()) != null) { + int prev = -1; + for (int i = 0; i < line.length(); ) { + int cp = line.codePointAt(i); + Character.UnicodeBlock b = Character.UnicodeBlock.of(cp); + int blockId = b != null ? blockIndex.getOrDefault(b, nullId) : nullId; + if (prev >= 0) { + counts[prev * blockN + blockId]++; + totalBigrams++; + } + prev = blockId; + i += Character.charCount(cp); + } + sentences++; + } + } + + System.out.printf(" %,d sentences, %,d block bigrams%n", sentences, totalBigrams); + return laplaceSmoothLogProb(counts, blockN); + } + + /** + * Applies Laplace (add-1) smoothing per row and converts to log-probabilities. + * + * @param counts raw bigram counts, length = size*size + * @param size number of distinct symbols (256 for byte table, blockN for block table) + * @return float[size*size] log-prob table + */ + private static float[] laplaceSmoothLogProb(long[] counts, int size) { + float[] table = new float[size * size]; + for (int a = 0; a < size; a++) { + long rowTotal = size; // add-1 pseudocount for each possible next symbol + for (int b = 0; b < size; b++) { + rowTotal += counts[a * size + b]; } - for (int b = 0; b < 256; b++) { - table[a * 256 + b] = (float) Math.log((counts[a * 256 + b] + 1.0) / rowTotal); + for (int b = 0; b < size; b++) { + table[a * size + b] = + (float) Math.log((counts[a * size + b] + 1.0) / rowTotal); } } return table; } + // ----------------------------------------------------------------------- + // Calibration + // ----------------------------------------------------------------------- + /** - * Computes calibration statistics for a script by scoring each sentence - * in the dev set with the given bigram table. + * Loads all sentences from a gzipped file and draws {@code nSamples} random + * byte-window substrings of target lengths cycling through {@code lengths}. * - * <p>For each sentence, the per-sentence score is the mean log-probability - * of its byte bigrams. The mean (mu) and stddev (sigma) of those scores - * across all dev sentences are returned. At inference, z-score = - * (score - mu) / sigma. + * <p>This mirrors the evaluator's {@code pickSubstring}: takes a random + * UTF-8-aligned window of {@code targetLen} bytes from a randomly chosen + * sentence, or the whole sentence if it is shorter. * - * @return float[2] = {mu, sigma} + * @param nSamples number of windows to sample + * @param lengths target byte-lengths to cycle through (round-robin) + * @param seed RNG seed for reproducibility */ - static float[] computeCalibration(Path devGz, float[] table) throws IOException { - List<Double> scores = new ArrayList<>(); - + static List<String> sampleSubstrings(Path devGz, int nSamples, + int[] lengths, long seed) throws IOException { + List<byte[]> sentenceBytes = new ArrayList<>(); try (BufferedReader r = openGzipped(devGz)) { String line; while ((line = r.readLine()) != null) { - byte[] bytes = line.getBytes(StandardCharsets.UTF_8); - if (bytes.length < 2) { - continue; - } - double sum = 0; - for (int i = 0; i + 1 < bytes.length; i++) { - sum += table[((bytes[i] & 0xFF) << 8) | (bytes[i + 1] & 0xFF)]; - } - scores.add(sum / (bytes.length - 1)); + byte[] b = line.getBytes(StandardCharsets.UTF_8); + if (b.length >= 2) sentenceBytes.add(b); } } + if (sentenceBytes.isEmpty()) return Collections.emptyList(); - System.out.printf(" %,d dev sentences%n", scores.size()); + Random rng = new Random(seed); + List<String> result = new ArrayList<>(nSamples); + for (int i = 0; i < nSamples; i++) { + byte[] bytes = sentenceBytes.get(rng.nextInt(sentenceBytes.size())); + int targetLen = lengths[i % lengths.length]; - if (scores.isEmpty()) { - return new float[]{0f, 1f}; + if (bytes.length <= targetLen) { + result.add(new String(bytes, StandardCharsets.UTF_8)); + continue; + } + int start = rng.nextInt(bytes.length - targetLen); + while (start > 0 && (bytes[start] & 0xC0) == 0x80) { + start--; + } + int end = Math.min(start + targetLen, bytes.length); + while (end < bytes.length && (bytes[end] & 0xC0) == 0x80) { + end++; + } + result.add(new String(bytes, start, end - start, StandardCharsets.UTF_8)); } + return result; + } + /** @return float[2] = {mu, sigma} of byte-bigram mean log-prob on dev windows */ + static float[] computeBigramCalibration(Path devGz, float[] bigramTable) throws IOException { + List<String> windows = sampleSubstrings(devGz, CALIB_SAMPLES, CALIB_LENGTHS, 42); + List<Double> scores = new ArrayList<>(windows.size()); + for (String window : windows) { + byte[] bytes = window.getBytes(StandardCharsets.UTF_8); + if (bytes.length < 2) continue; + double sum = 0; + for (int i = 0; i + 1 < bytes.length; i++) { + sum += bigramTable[((bytes[i] & 0xFF) << 8) | (bytes[i + 1] & 0xFF)]; + } + scores.add(sum / (bytes.length - 1)); + } + System.out.printf(" %,d dev windows%n", scores.size()); + return muSigma(scores); + } + + /** @return float[2] = {mu, sigma} of block-transition mean log-prob on dev windows */ + static float[] computeBlockCalibration(Path devGz, float[] blockTable, + Map<Character.UnicodeBlock, Integer> blockIndex, + int blockN) throws IOException { + List<String> windows = sampleSubstrings(devGz, CALIB_SAMPLES, CALIB_LENGTHS, 43); + List<Double> scores = new ArrayList<>(windows.size()); + int nullId = blockN - 1; + for (String window : windows) { + int[] ids = new int[window.length()]; + int len = 0; + for (int i = 0; i < window.length(); ) { + int cp = window.codePointAt(i); + Character.UnicodeBlock b = Character.UnicodeBlock.of(cp); + ids[len++] = b != null ? blockIndex.getOrDefault(b, nullId) : nullId; + i += Character.charCount(cp); + } + if (len < 2) continue; + double sum = 0; + for (int i = 0; i + 1 < len; i++) { + sum += blockTable[ids[i] * blockN + ids[i + 1]]; + } + scores.add(sum / (len - 1)); + } + System.out.printf(" %,d dev windows%n", scores.size()); + return muSigma(scores); + } + + /** @return float[2] = {mu, sigma} of control-byte fraction on dev windows */ + static float[] computeControlByteCalibration(Path devGz) throws IOException { + List<String> windows = sampleSubstrings(devGz, CALIB_SAMPLES, CALIB_LENGTHS, 44); + List<Double> scores = new ArrayList<>(windows.size()); + for (String window : windows) { + byte[] bytes = window.getBytes(StandardCharsets.UTF_8); + if (bytes.length == 0) continue; + long controlCount = 0; + for (byte b : bytes) { + if (isControlByte(b & 0xFF)) controlCount++; + } + scores.add(-(double) controlCount / bytes.length); + } + System.out.printf(" %,d dev windows%n", scores.size()); + if (scores.isEmpty()) return new float[]{0f, CONTROL_BYTE_MIN_SIGMA}; double mu = scores.stream().mapToDouble(Double::doubleValue).average().orElse(0); double variance = scores.stream() .mapToDouble(s -> (s - mu) * (s - mu)) - .average().orElse(1.0); - double sigma = Math.sqrt(variance); - if (sigma < 1e-9) { - sigma = 1.0; - } + .average().orElse(0); + double sigma = Math.max(Math.sqrt(variance), CONTROL_BYTE_MIN_SIGMA); return new float[]{(float) mu, (float) sigma}; } + // ----------------------------------------------------------------------- + // Linear classifier training + // ----------------------------------------------------------------------- + + /** + * Trains a per-script binary logistic regression classifier on (z1, z2, z3). + * + * <p>Clean examples: {@link #NUM_CLASSIFIER_SAMPLES} random dev windows (seed 100). + * Corrupted examples: same count, alternating inject@5% (seed 102, even indices) + * and char-shuffle (odd indices) applied to windows sampled with seed 101. + * + * @return float[4] = {w1, w2, w3, bias} — classifier weights; positive logit = clean + */ + static float[] trainClassifier(Path devGz, + float[] bigramTable, float[] bigramCal, + float[] blockTable, float[] blockCal, + float[] controlCal, + Map<Character.UnicodeBlock, Integer> blockIndex, + int blockN) throws IOException { + int nEach = NUM_CLASSIFIER_SAMPLES; + + // Clean windows + List<String> cleanWindows = sampleSubstrings(devGz, nEach, CALIB_LENGTHS, 100); + + // Corrupted windows: sample base windows (seed 101), then distort + List<String> baseWindows = sampleSubstrings(devGz, nEach, CALIB_LENGTHS, 101); + Random rng = new Random(102); + List<String> corruptedWindows = new ArrayList<>(nEach); + for (int i = 0; i < baseWindows.size(); i++) { + String w = baseWindows.get(i); + if (i % 2 == 0) { + corruptedWindows.add(injectControlChars(w, CLASSIFIER_INJECT_RATE, rng)); + } else { + corruptedWindows.add(shuffleChars(w, rng)); + } + } + + // Build (z1, z2, z3) feature matrix + List<float[]> features = new ArrayList<>(cleanWindows.size() + corruptedWindows.size()); + List<Integer> labels = new ArrayList<>(cleanWindows.size() + corruptedWindows.size()); + + for (String w : cleanWindows) { + features.add(extractFeatures(w, bigramTable, bigramCal, + blockTable, blockCal, blockN, controlCal, blockIndex)); + labels.add(1); // clean + } + for (String w : corruptedWindows) { + features.add(extractFeatures(w, bigramTable, bigramCal, + blockTable, blockCal, blockN, controlCal, blockIndex)); + labels.add(0); // corrupted + } + + float[] weights = fitLogisticRegression(features, labels, 3); + + // Calibrate bias using only short (len=15) windows so that FPR ≤ 2.5% + // even at the worst-case (shortest) window length. Longer windows have + // lower logit variance and will score well above this threshold naturally. + List<String> shortWindows = sampleSubstrings(devGz, nEach, new int[]{15}, 200); + List<Float> shortLogits = new ArrayList<>(shortWindows.size()); + int nFeat = weights.length - 1; + for (String w : shortWindows) { + float[] x = extractFeatures(w, bigramTable, bigramCal, + blockTable, blockCal, blockN, controlCal, blockIndex); + float logit = weights[nFeat]; + for (int j = 0; j < nFeat; j++) logit += weights[j] * x[j]; + shortLogits.add(logit); + } + if (!shortLogits.isEmpty()) { + Collections.sort(shortLogits); + int pIdx = (int) (0.025 * shortLogits.size()); + float p025 = shortLogits.get(Math.max(0, pIdx)); + weights[nFeat] -= p025; // shift bias so p2.5 of len=15 logits = 0 + } + + return weights; + } + + /** + * Extracts calibrated z-scores (z1, z2, z3) for a single text window. + * + * @return float[3] = {z1_bigram, z2_block, z3_control} + */ + static float[] extractFeatures(String window, + float[] bigramTable, float[] bigramCal, + float[] blockTable, float[] blockCal, + int blockN, float[] controlCal, + Map<Character.UnicodeBlock, Integer> blockIndex) { + byte[] utf8 = window.getBytes(StandardCharsets.UTF_8); + + // z1: byte-bigram mean log-prob + float z1 = 0f; + if (utf8.length >= 2) { + double sum = 0; + int count = 0; + for (int i = 0; i + 1 < utf8.length; i++) { + sum += bigramTable[((utf8[i] & 0xFF) << 8) | (utf8[i + 1] & 0xFF)]; + count++; + } + z1 = ((float) (sum / count) - bigramCal[0]) / bigramCal[1]; + } + + // z2: block-transition mean log-prob + float z2 = 0f; + if (blockTable != null && window.length() >= 2) { + int nullId = blockN - 1; + int prev = -1; + double sum = 0; + int count = 0; + for (int i = 0; i < window.length(); ) { + int cp = window.codePointAt(i); + Character.UnicodeBlock b = Character.UnicodeBlock.of(cp); + int blockId = b != null ? blockIndex.getOrDefault(b, nullId) : nullId; + if (prev >= 0) { + sum += blockTable[prev * blockN + blockId]; + count++; + } + prev = blockId; + i += Character.charCount(cp); + } + if (count > 0) { + z2 = ((float) (sum / count) - blockCal[0]) / blockCal[1]; + } + } + + // z3: control-byte fraction (stored as −fraction, so higher = cleaner) + float z3 = 0f; + if (utf8.length > 0 && controlCal != null) { + long controlCount = 0; + for (byte b : utf8) { + if (isControlByte(b & 0xFF)) controlCount++; + } + float score = -(float) controlCount / utf8.length; + z3 = (score - controlCal[0]) / controlCal[1]; + } + + return new float[]{z1, z2, z3}; + } + + /** + * Replaces a random fraction of characters with Unicode control characters. + * Operates at the codepoint level to produce well-formed strings with actual + * control bytes in the UTF-8 encoding. + * + * @param rate fraction of characters to replace [0, 1] + */ + static String injectControlChars(String text, double rate, Random rng) { + if (text.isEmpty()) return text; + int[] codepoints = text.codePoints().toArray(); + int[] controlChars = {0x01, 0x02, 0x03, 0x04, 0x07, 0x0B, 0x0C, 0x0E, 0x0F, 0x1A, 0x1B, 0x7F}; + for (int i = 0; i < codepoints.length; i++) { + if (rng.nextDouble() < rate) { + codepoints[i] = controlChars[rng.nextInt(controlChars.length)]; + } + } + return new String(codepoints, 0, codepoints.length); + } + + /** + * Randomly permutes all characters in the text (Fisher-Yates shuffle). + * Destroys both bigram and block-transition structure while preserving script + * distribution, making it a good test of transition-based features. + */ + static String shuffleChars(String text, Random rng) { + if (text.isEmpty()) return text; + int[] codepoints = text.codePoints().toArray(); + for (int i = codepoints.length - 1; i > 0; i--) { + int j = rng.nextInt(i + 1); + int tmp = codepoints[i]; + codepoints[i] = codepoints[j]; + codepoints[j] = tmp; + } + return new String(codepoints, 0, codepoints.length); + } + + /** + * Fits a binary logistic regression classifier on the given feature matrix. + * + * <p>Label convention: 1 = clean, 0 = corrupted. At inference, positive + * logit → clean text; negative logit → corrupted text. + * + * <p>Uses full-batch gradient descent with L2 regularization. Converges + * reliably for {@code numFeatures} ≤ 10 with the default hyperparameters. + * + * @param features list of feature vectors, each of length {@code numFeatures} + * @param labels parallel list of labels (0 or 1) + * @param numFeatures number of features + * @return float[numFeatures + 1] = {w[0], ..., w[numFeatures-1], bias} + */ + static float[] fitLogisticRegression(List<float[]> features, List<Integer> labels, + int numFeatures) { + int n = features.size(); + float[] w = new float[numFeatures]; // zero-initialized + float bias = 0f; + + if (n == 0) { + float[] result = new float[numFeatures + 1]; + for (int i = 0; i < numFeatures; i++) result[i] = 1f / numFeatures; + return result; + } + + float lr = 0.05f; + float lambda = 0.01f; // L2 regularization + int epochs = 500; + + for (int epoch = 0; epoch < epochs; epoch++) { + double[] gradW = new double[numFeatures]; + double gradB = 0; + + for (int i = 0; i < n; i++) { + float[] x = features.get(i); + int y = labels.get(i); + + double logit = bias; + for (int j = 0; j < numFeatures; j++) logit += w[j] * x[j]; + + // Numerically stable sigmoid + double p; + if (logit >= 0) { + double e = Math.exp(-logit); + p = 1.0 / (1.0 + e); + } else { + double e = Math.exp(logit); + p = e / (1.0 + e); + } + + double err = p - y; + for (int j = 0; j < numFeatures; j++) gradW[j] += err * x[j]; + gradB += err; + } + + for (int j = 0; j < numFeatures; j++) { + w[j] -= lr * (float) (gradW[j] / n + lambda * w[j]); + } + bias -= lr * (float) (gradB / n); + } + + float[] result = new float[numFeatures + 1]; + for (int j = 0; j < numFeatures; j++) result[j] = w[j]; + result[numFeatures] = bias; + return result; + } + // ----------------------------------------------------------------------- // Model serialisation // ----------------------------------------------------------------------- /** - * Writes the trained model to a gzipped binary file. + * Writes the trained model (version 3) to a gzipped binary file. + * + * <p>Format documented in the class Javadoc. All multi-byte integers are + * big-endian; floats are IEEE 754 big-endian. * - * <p>Format: {@code [magic:8][version:1][num_scripts:4] - * ([name_len:2][name:N][mu:4][sigma:4][table:65536*4])*} - * All multi-byte integers are big-endian. Floats are IEEE 754 big-endian. + * @param classifierWeights per-script float[4] = {w1, w2, w3, bias} + * @param blockN the block table dimension (blockIndex.size() + 1) */ - static void saveModel(TreeMap<String, float[]> tables, - TreeMap<String, float[]> calibrations, + static void saveModel(TreeMap<String, float[]> bigramTables, + TreeMap<String, float[]> bigramCalibrations, + TreeMap<String, float[]> blockTables, + TreeMap<String, float[]> blockCalibrations, + TreeMap<String, float[]> controlCalibrations, + TreeMap<String, float[]> classifierWeights, + int blockN, Path output) throws IOException { try (DataOutputStream dos = new DataOutputStream( new GZIPOutputStream(Files.newOutputStream(output)))) { - // Magic + version + count dos.write(MAGIC.getBytes(StandardCharsets.UTF_8)); dos.writeByte(VERSION); - dos.writeInt(tables.size()); + dos.writeInt(bigramTables.size()); + dos.writeShort(blockN); // global: block table dimension - for (var entry : tables.entrySet()) { + for (var entry : bigramTables.entrySet()) { String script = entry.getKey(); - float[] table = entry.getValue(); - float[] cal = calibrations.getOrDefault(script, new float[]{0f, 1f}); + float[] bigramTable = entry.getValue(); + float[] bigramCal = bigramCalibrations.getOrDefault(script, new float[]{0f, 1f}); + float[] blockTable = blockTables.getOrDefault(script, new float[blockN * blockN]); + float[] blockCal = blockCalibrations.getOrDefault(script, new float[]{0f, 1f}); + float[] controlCal = controlCalibrations.getOrDefault(script, new float[]{0f, 1f}); + float[] weights = classifierWeights.getOrDefault(script, + new float[]{1f / 3, 1f / 3, 1f / 3, 0f}); byte[] nameBytes = script.getBytes(StandardCharsets.UTF_8); dos.writeShort(nameBytes.length); dos.write(nameBytes); - dos.writeFloat(cal[0]); // mu - dos.writeFloat(cal[1]); // sigma + // Feature 1: byte bigrams + dos.writeFloat(bigramCal[0]); + dos.writeFloat(bigramCal[1]); + dos.write(toBytes(bigramTable)); - // Write 65536 float32 values in big-endian - ByteBuffer buf = ByteBuffer.allocate(65536 * 4).order(ByteOrder.BIG_ENDIAN); - for (float v : table) { - buf.putFloat(v); - } - dos.write(buf.array()); + // Feature 2: named-block transitions + dos.writeFloat(blockCal[0]); + dos.writeFloat(blockCal[1]); + dos.write(toBytes(blockTable)); + + // Feature 3: control-byte fraction + dos.writeFloat(controlCal[0]); + dos.writeFloat(controlCal[1]); + + // Classifier weights: num_features (1 byte) + weights + bias + int numFeatures = weights.length - 1; // last element is bias + dos.writeByte(numFeatures); + for (float v : weights) dos.writeFloat(v); } } } + private static byte[] toBytes(float[] table) { + ByteBuffer buf = ByteBuffer.allocate(table.length * 4).order(ByteOrder.BIG_ENDIAN); + for (float v : table) buf.putFloat(v); + return buf.array(); + } + // ----------------------------------------------------------------------- // Helpers // ----------------------------------------------------------------------- + /** + * Returns true if the byte value is a control character that should not appear + * in natural-language UTF-8 text: {@code [0x01–0x08, 0x0B, 0x0C, 0x0E–0x1F, 0x7F]}. + * + * <p>Excluded: 0x00 (null), 0x09 (tab), 0x0A (newline), 0x0D (carriage return) + * — all appear legitimately in text. + */ + static boolean isControlByte(int b) { + return (b >= 0x01 && b <= 0x08) + || b == 0x0B || b == 0x0C + || (b >= 0x0E && b <= 0x1F) + || b == 0x7F; + } + + private static float[] muSigma(List<Double> scores) { + if (scores.isEmpty()) return new float[]{0f, 1f}; + double mu = scores.stream().mapToDouble(Double::doubleValue).average().orElse(0); + double variance = scores.stream() + .mapToDouble(s -> (s - mu) * (s - mu)) + .average().orElse(1.0); + double sigma = Math.sqrt(variance); + if (sigma < 1e-9) sigma = 1.0; + return new float[]{(float) mu, (float) sigma}; + } + static BufferedReader openGzipped(Path path) throws IOException { return new BufferedReader( new InputStreamReader( @@ -308,5 +847,4 @@ public class TrainJunkModel { System.err.println(" --output <path> Output model file"); System.err.println(" (default: {data-dir}/junkdetect.bin)"); } - } diff --git a/tika-ml/tika-ml-junkdetect/src/main/resources/META-INF/services/org.apache.tika.quality.TextQualityDetector b/tika-ml/tika-ml-junkdetect/src/main/resources/META-INF/services/org.apache.tika.quality.TextQualityDetector new file mode 100644 index 0000000000..b9abdc1347 --- /dev/null +++ b/tika-ml/tika-ml-junkdetect/src/main/resources/META-INF/services/org.apache.tika.quality.TextQualityDetector @@ -0,0 +1 @@ +org.apache.tika.ml.junkdetect.JunkDetector diff --git a/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin b/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin index 623b60df16..bebc3293fa 100644 Binary files a/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin and b/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin differ diff --git a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorSmokeTest.java b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorSmokeTest.java index 822eae19cf..88a5a8c16f 100644 --- a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorSmokeTest.java +++ b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorSmokeTest.java @@ -25,10 +25,12 @@ import java.util.Random; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import org.apache.tika.quality.TextQualityComparison; +import org.apache.tika.quality.TextQualityScore; + /** - * Smoke tests corresponding to Phase 5.1 target cases in the design doc. - * These should pass once the model is trained; failures indicate the model - * needs more data or the feature extraction is wrong. + * Smoke tests verifying the bundled model meets minimum quality thresholds. + * Failures indicate the model needs more data or feature extraction is wrong. */ public class JunkDetectorSmokeTest { @@ -40,43 +42,44 @@ public class JunkDetectorSmokeTest { } /** - * Clean English should score higher than random high-byte garbage. - * Also serves as the byte-reversal baseline: garbage bytes ~ byte-reversed text. + * Clean English should score higher than random high-byte garbage interpreted + * as ISO-8859-1. Simulates binary data mixed into a text extraction. */ @Test void cleanVsGarbage() { - JunkScore clean = detector.score("The quick brown fox jumps over the lazy dog. " - + "Pack my box with five dozen liquor jugs."); + String clean = "The quick brown fox jumps over the lazy dog. " + + "Pack my box with five dozen liquor jugs."; + byte[] garbageBytes = new byte[80]; new Random(42).nextBytes(garbageBytes); - // Force all bytes >= 0x80 so it's clearly invalid UTF-8-looking garbage for (int i = 0; i < garbageBytes.length; i++) { garbageBytes[i] = (byte) (0x80 | (garbageBytes[i] & 0x7F)); } - JunkScore garbage = detector.score(new String(garbageBytes, StandardCharsets.ISO_8859_1) - .getBytes(StandardCharsets.UTF_8)); + // Decode as ISO-8859-1 so the string contains high-codepoint characters + String garbage = new String(garbageBytes, StandardCharsets.ISO_8859_1); + + TextQualityScore cleanScore = detector.score(clean); + TextQualityScore garbageScore = detector.score(garbage); - System.out.println("clean: " + clean); - System.out.println("garbage: " + garbage); + System.out.println("clean: " + cleanScore); + System.out.println("garbage: " + garbageScore); - assertTrue(clean.getZScore() > garbage.getZScore(), + assertTrue(cleanScore.getZScore() > garbageScore.getZScore(), "Clean text should score higher than garbage"); } /** * Forward Arabic should score higher than character-reversed Arabic. - * Character (codepoint) reversal is a realistic distortion: it produces - * valid UTF-8 but wrong reading order — analogous to bidirectional rendering - * failures or incorrectly stored RTL text. + * Character (codepoint) reversal produces valid UTF-8 but wrong reading order — + * analogous to bidirectional rendering failures or incorrectly stored RTL text. */ @Test void forwardVsReversedArabic() { String arabic = "اللغة العربية جميلة وغنية بالمفردات والتعبيرات"; - byte[] forward = arabic.getBytes(StandardCharsets.UTF_8); - byte[] reversed = reverseString(arabic).getBytes(StandardCharsets.UTF_8); + String reversed = reverseString(arabic); - JunkScore fwd = detector.score(forward); - JunkScore rev = detector.score(reversed); + TextQualityScore fwd = detector.score(arabic); + TextQualityScore rev = detector.score(reversed); System.out.println("arabic forward: " + fwd); System.out.println("arabic reversed: " + rev); @@ -88,19 +91,22 @@ public class JunkDetectorSmokeTest { /** * cp1257 (Baltic) decoding of Lithuanian text should win over cp1252. * - * <p>This tests the {@link JunkDetector#compare} API: given raw bytes that were - * encoded as cp1257, scoring both decodings should prefer the correct one. + * <p>Tests the {@link JunkDetector#compare} API: given raw bytes that were + * encoded as cp1257, comparing both decodings should prefer the correct one. * A low delta is expected because the LATIN model is trained across ~322 languages * and Baltic-specific bigrams are diluted. * - * <p>TODO: improve separation by adding a Baltic sub-model or Baltic-weighted retraining. + * <p>TODO: improve separation with a Baltic sub-model or Baltic-weighted retraining. */ @Test void cp1252VsCp1257OnBalticText() throws Exception { String lithuanian = "Lietuvių kalba yra labai graži ir turtinga"; byte[] cp1257bytes = lithuanian.getBytes("cp1257"); - JunkDetector.CompareResult result = detector.compare(cp1257bytes, "cp1252", "cp1257"); + String ascp1252 = new String(cp1257bytes, "cp1252"); + String ascp1257 = new String(cp1257bytes, "cp1257"); + + TextQualityComparison result = detector.compare("cp1252", ascp1252, "cp1257", ascp1257); System.out.println("Baltic comparison: " + result); @@ -115,21 +121,24 @@ public class JunkDetectorSmokeTest { /** * cp1251 decoding of Russian text should win over cp1252. * - * <p>This is the canonical Cyrillic mojibake scenario: Windows-1251-encoded - * Russian text misinterpreted as Windows-1252 (Western European). The cp1252 - * decoding produces Latin symbols interspersed with control characters, while - * cp1251 produces proper Cyrillic. The model should strongly prefer cp1251. + * <p>This is the canonical Cyrillic mojibake scenario: Windows-1251-encoded Russian + * text misinterpreted as Windows-1252 (Western European). The cp1252 decoding + * produces Latin symbols interspersed with control characters, while cp1251 produces + * proper Cyrillic. The model should strongly prefer cp1251. * - * <p>Note: character-reversal of LTR Cyrillic is NOT a useful test here — - * byte-bigram statistics are nearly identical forward and backward for LTR scripts, - * so the model cannot distinguish them. Use codec-confusion tests for LTR scripts. + * <p>Note: character-reversal of LTR Cyrillic is NOT a useful test — byte-bigram + * statistics are nearly identical forward and backward for LTR scripts. Codec + * comparison is the correct test for LTR scripts. */ @Test void cp1252VsCp1251OnRussianText() throws Exception { String russian = "Русский язык является одним из восточнославянских языков"; byte[] cp1251bytes = russian.getBytes("cp1251"); - JunkDetector.CompareResult result = detector.compare(cp1251bytes, "cp1252", "cp1251"); + String ascp1252 = new String(cp1251bytes, "cp1252"); + String ascp1251 = new String(cp1251bytes, "cp1251"); + + TextQualityComparison result = detector.compare("cp1252", ascp1252, "cp1251", ascp1251); System.out.println("Russian Cyrillic comparison: " + result); @@ -140,16 +149,19 @@ public class JunkDetectorSmokeTest { } /** - * Clean Japanese (CJK) should score higher than shuffled bytes. + * Clean Japanese (CJK) should score higher than byte-shuffled Japanese. */ @Test void cleanVsShuffledCjk() { String japanese = "日本語は美しい言語であり、世界中で約1億3千万人が話している。"; - byte[] clean = japanese.getBytes(StandardCharsets.UTF_8); - byte[] shuffled = shuffled(clean, 42); + byte[] cleanBytes = japanese.getBytes(StandardCharsets.UTF_8); + byte[] shuffledBytes = shuffled(cleanBytes, 42); - JunkScore cleanScore = detector.score(clean); - JunkScore shuffledScore = detector.score(shuffled); + // Shuffled bytes are not valid UTF-8; decode as ISO-8859-1 to get a scoreable string + String shuffledText = new String(shuffledBytes, StandardCharsets.ISO_8859_1); + + TextQualityScore cleanScore = detector.score(japanese); + TextQualityScore shuffledScore = detector.score(shuffledText); System.out.println("Japanese clean: " + cleanScore); System.out.println("Japanese shuffled: " + shuffledScore); @@ -158,12 +170,41 @@ public class JunkDetectorSmokeTest { "Clean Japanese should score higher than shuffled bytes"); } + /** + * Shift-JIS zip entry name (9 bytes) decoded as Shift-JIS should beat the same + * bytes decoded as UTF-8 (which produces mojibake with FFFD replacement chars). + * + * <p>This is the canonical short-text use case: zip parsers encounter raw filename + * bytes with no BOM or language tag. At 9 bytes the z-score signal is weak, but + * the corrupted UTF-8 decode contains FFFD sequences (0xEF 0xBF 0xBD) which are + * very unlikely in LATIN text, yielding a clearly negative bigram z-score. + * + * <p>"テスト.tx" is pure katakana — KATAKANA script maps to the HAN model via + * {@link JunkDetector#SCRIPT_MODEL_FALLBACK}. + */ + @Test + void shiftJisZipEntryNameVsUtf8() throws Exception { + // 9 Shift-JIS bytes: テスト.tx + byte[] sjisBytes = "テスト.tx".getBytes("Shift_JIS"); + assertEquals(9, sjisBytes.length, "fixture sanity: expect exactly 9 Shift-JIS bytes"); + + String asShiftJis = new String(sjisBytes, "Shift_JIS"); // "テスト.tx" + String asUtf8 = new String(sjisBytes, StandardCharsets.UTF_8); // "?e?X?g.tx" (mojibake) + + TextQualityComparison result = detector.compare("Shift-JIS", asShiftJis, "UTF-8", asUtf8); + + System.out.println("Shift-JIS zip entry: " + result); + + assertEquals("A", result.winner(), + "Shift-JIS decode should beat garbled UTF-8 for short Japanese filename"); + } + // ----------------------------------------------------------------------- /** * Reverses the string at codepoint granularity (not char granularity), so - * surrogate pairs are kept intact. This produces valid Unicode text in - * reverse reading order — a realistic distortion for RTL-language tests. + * surrogate pairs are kept intact. Produces valid Unicode in reverse reading + * order — a realistic distortion for RTL-language tests. */ static String reverseString(String s) { int[] codepoints = s.codePoints().toArray();
