This is an automated email from the ASF dual-hosted git repository. tballison pushed a commit to branch universal-junk-detector in repository https://gitbox.apache.org/repos/asf/tika.git
commit 7e65b5e8a7d16d8ed51e9095cad40736f7fb5013 Author: tballison <[email protected]> AuthorDate: Thu Apr 23 16:15:27 2026 -0400 model v4 --- .../apache/tika/ml/junkdetect/JunkDetector.java | 108 +++- .../tika/ml/junkdetect/tools/EvalJunkDetector.java | 355 ++++++++--- .../tika/ml/junkdetect/tools/TrainJunkModel.java | 648 +++++++++++++++++---- .../org/apache/tika/ml/junkdetect/junkdetect.bin | Bin 543946 -> 468582 bytes 4 files changed, 899 insertions(+), 212 deletions(-) diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java index d954df3207..eec94f9c13 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java @@ -105,6 +105,14 @@ public final class JunkDetector implements TextQualityDetector { // float[numFeatures+1] = {w1, ..., wN, bias}; positive logit = clean private final Map<String, float[]> classifierWeights; + // Feature 4: global script-transition (version 4+); null for v1/v2/v3 models + // One global table: float[numScriptBuckets * numScriptBuckets] log P(script_b | script_a) + // Uses raw UnicodeScript names (not SCRIPT_MODEL_FALLBACK) to distinguish HIRAGANA/KATAKANA/HAN. + private final float[] scriptTransitionTable; + private final float[] scriptTransitionCalibration; // float[2] = {mu, sigma} + private final Map<String, Integer> scriptBucketIndex; // raw UnicodeScript name → bucket ID + private final int numScriptBuckets; // 0 for v1/v2/v3 + // Shared block index for v2+ models: UnicodeBlock → index [0, blockN-1) // Index blockN-1 is the "unassigned" bucket (null UnicodeBlock). private final Map<Character.UnicodeBlock, Integer> blockIndex; @@ -117,7 +125,11 @@ public final class JunkDetector implements TextQualityDetector { int blockN, Map<String, float[]> controlCalibrations, Map<String, float[]> classifierWeights, - Map<Character.UnicodeBlock, Integer> blockIndex) { + Map<Character.UnicodeBlock, Integer> blockIndex, + float[] scriptTransitionTable, + float[] scriptTransitionCalibration, + Map<String, Integer> scriptBucketIndex, + int numScriptBuckets) { this.modelVersion = modelVersion; this.tables = Collections.unmodifiableMap(tables); this.calibrations = Collections.unmodifiableMap(calibrations); @@ -131,6 +143,11 @@ public final class JunkDetector implements TextQualityDetector { this.classifierWeights = classifierWeights != null ? Collections.unmodifiableMap(classifierWeights) : null; this.blockIndex = blockIndex; + this.scriptTransitionTable = scriptTransitionTable; + this.scriptTransitionCalibration = scriptTransitionCalibration; + this.scriptBucketIndex = scriptBucketIndex != null + ? Collections.unmodifiableMap(scriptBucketIndex) : null; + this.numScriptBuckets = numScriptBuckets; } // ----------------------------------------------------------------------- @@ -184,7 +201,7 @@ public final class JunkDetector implements TextQualityDetector { throw new IOException("Not a JunkDetector model file (bad magic)"); } int version = dis.readUnsignedByte(); - if (version < 1 || version > 3) { + if (version < 1 || version > 4) { throw new IOException("Unsupported model version: " + version); } @@ -212,6 +229,26 @@ public final class JunkDetector implements TextQualityDetector { Map<String, float[]> controlCalibrations = version >= 2 ? new HashMap<>(numScripts * 2) : null; Map<String, float[]> classifierWeights = version >= 3 ? new HashMap<>(numScripts * 2) : null; + // Version 4+: global script-transition section + float[] scriptTransitionTable = null; + float[] scriptTransitionCalibration = null; + Map<String, Integer> scriptBucketIndex = null; + int numScriptBuckets = 0; + + if (version >= 4) { + numScriptBuckets = dis.readUnsignedByte(); + scriptBucketIndex = new LinkedHashMap<>(numScriptBuckets * 2); + for (int i = 0; i < numScriptBuckets; i++) { + int nameLen = dis.readUnsignedShort(); + String bucketName = new String(dis.readNBytes(nameLen), StandardCharsets.UTF_8); + scriptBucketIndex.put(bucketName, i); + } + scriptTransitionTable = readFloatTable(dis, numScriptBuckets * numScriptBuckets); + float mu4 = dis.readFloat(); + float sigma4 = dis.readFloat(); + scriptTransitionCalibration = new float[]{mu4, sigma4}; + } + for (int s = 0; s < numScripts; s++) { int nameLen = dis.readUnsignedShort(); String script = new String(dis.readNBytes(nameLen), StandardCharsets.UTF_8); @@ -248,7 +285,9 @@ public final class JunkDetector implements TextQualityDetector { return new JunkDetector(version, tables, calibrations, blockTables, blockCalibrations, blockN, - controlCalibrations, classifierWeights, blockIndex); + controlCalibrations, classifierWeights, blockIndex, + scriptTransitionTable, scriptTransitionCalibration, + scriptBucketIndex, numScriptBuckets); } } @@ -341,6 +380,10 @@ public final class JunkDetector implements TextQualityDetector { private TextQualityScore scoreText(String text) { List<ScriptRun> runs = buildScriptRuns(text); + // Global z4: script-transition feature over the whole input string. + // Computed before chunking because it captures document-level script mixing. + float z4 = computeScriptTransitionZ(text); + // Score each run against its own model; aggregate weighted by byte count. float totalBytes = 0; float weightedLogit = 0; @@ -357,7 +400,7 @@ public final class JunkDetector implements TextQualityDetector { if (runUtf8.length < 2) { continue; // too short to score } - float logit = scoreChunk(runUtf8, run.text, run.script); + float logit = scoreChunk(runUtf8, run.text, run.script, z4); int n = runUtf8.length; weightedLogit += logit * n; totalBytes += n; @@ -370,14 +413,12 @@ public final class JunkDetector implements TextQualityDetector { } if (totalBytes == 0 || dominantScript == null) { - // No scoreable runs; return UNKNOWN keyed on the first run's script (for debug) String label = runs.isEmpty() ? "LATIN" : runs.get(0).script; return unknownScore(label); } float zScore = weightedLogit / totalBytes; - // CI: standard error of the weighted mean, approximated via dominant script's sigma float uncertainty = (dominantCal1 != null && totalBigramCount > 0) ? (float) (1.96 * dominantCal1[1] / Math.sqrt(totalBigramCount)) : 0f; float ciLow = zScore - uncertainty; @@ -392,7 +433,7 @@ public final class JunkDetector implements TextQualityDetector { * Positive = clean, negative = junk. Returns 0 (neutral) if the chunk * has no model or is too short. */ - private float scoreChunk(byte[] utf8, String text, String script) { + private float scoreChunk(byte[] utf8, String text, String script, float z4) { float[] bigramTable = tables.get(script); if (bigramTable == null || utf8.length < 2) { return 0f; @@ -448,10 +489,16 @@ public final class JunkDetector implements TextQualityDetector { if (modelVersion >= 3 && classifierWeights != null) { float[] cw = classifierWeights.get(script); - if (cw != null && cw.length >= 4) { - return cw[0] * z1 + cw[1] * z2 + cw[2] * z3 + cw[cw.length - 1]; + if (cw != null) { + int nFeat = cw.length - 1; // bias is last + float logit = cw[nFeat]; // bias + if (nFeat >= 1) logit += cw[0] * z1; + if (nFeat >= 2) logit += cw[1] * z2; + if (nFeat >= 3) logit += cw[2] * z3; + if (nFeat >= 4) logit += cw[3] * z4; + return logit; } - return (z1 + z2 + z3) / 3.0f; + return (z1 + z2 + z3) / 4.0f; // fallback: equal weight including z4 } else if (modelVersion >= 2 && blockTables != null) { return (z1 + z2 + z3) / 3.0f; } else { @@ -459,6 +506,47 @@ public final class JunkDetector implements TextQualityDetector { } } + /** + * Computes the global script-transition z-score for the whole input string. + * Uses raw {@link Character.UnicodeScript} values — NOT {@link #SCRIPT_MODEL_FALLBACK} — + * so that HIRAGANA, KATAKANA, and HAN remain distinct, preserving the + * characteristic script-mixing pattern of Japanese text. + * + * <p>Returns 0 if no v4 model is loaded or the string has fewer than two + * non-neutral codepoints. + */ + private float computeScriptTransitionZ(String text) { + if (scriptTransitionTable == null || scriptBucketIndex == null + || scriptTransitionCalibration == null || numScriptBuckets == 0) { + return 0f; + } + int otherBucket = numScriptBuckets - 1; + int prev = -1; + double sum = 0; + int count = 0; + for (int i = 0; i < text.length(); ) { + int cp = text.codePointAt(i); + i += Character.charCount(cp); + Character.UnicodeScript s = Character.UnicodeScript.of(cp); + if (s == Character.UnicodeScript.COMMON + || s == Character.UnicodeScript.INHERITED + || s == Character.UnicodeScript.UNKNOWN) { + continue; + } + int bucket = scriptBucketIndex.getOrDefault(s.name(), otherBucket); + if (prev >= 0) { + sum += scriptTransitionTable[prev * numScriptBuckets + bucket]; + count++; + } + prev = bucket; + } + if (count == 0) { + return 0f; + } + float mean = (float) (sum / count); + return (mean - scriptTransitionCalibration[0]) / scriptTransitionCalibration[1]; + } + /** * Splits text into maximal runs of the same Unicode script. * COMMON, INHERITED, and UNKNOWN codepoints (spaces, punctuation, digits) diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/EvalJunkDetector.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/EvalJunkDetector.java index 0538f29537..6b6057fc34 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/EvalJunkDetector.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/EvalJunkDetector.java @@ -20,63 +20,76 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.PrintWriter; +import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; +import java.nio.charset.UnsupportedCharsetException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Random; import java.util.stream.Collectors; import java.util.zip.GZIPInputStream; import org.apache.tika.ml.junkdetect.JunkDetector; +import org.apache.tika.quality.TextQualityComparison; import org.apache.tika.quality.TextQualityScore; /** * Ablation evaluation for the junk detector. * - * <p>For each script's dev set, scores clean sentences alongside three corruption - * modes — random-byte injection, codepoint-reversal, and byte-shuffle — at several - * injection rates and string lengths. Computes per-cell Cohen's d (discrimination - * power) and TPR/FPR at a fixed z-score threshold. + * <p>For each script's dev set, scores clean sentences alongside several corruption + * modes at various injection rates and string lengths. Computes per-cell Cohen's d + * (discrimination power) and TPR/FPR at a fixed z-score threshold. * - * <p>Output: two TSV files. + * <p>Output files in {@code --output-dir}: * <ul> * <li><b>detail.tsv</b> — one row per (script, distortion, rate, length): * {@code script, distortion, param, length, n_clean, n_corrupt, * mean_clean_z, mean_corrupt_z, cohens_d, fpr, tpr} * <li><b>summary.tsv</b> — macro-averaged Cohen's d and FPR/TPR per * (distortion, rate, length) across all scripts. + * <li><b>compare.tsv</b> — pairwise codec-comparison accuracy using the + * {@link JunkDetector#compare} API, stratified by string length. + * This is the primary metric for the charset-arbitration use case; + * larger mean delta = better discrimination at that length. * </ul> * + * <p><b>Why char-remap is not in summary.tsv:</b> The character-level wrong-codec + * substitution (e.g. CP1252→CP1255, replacing umlauts with Hebrew letters) is added + * to training at a 5% rate. At that rate it is too subtle to detect via the absolute + * {@link JunkDetector#score} API — z-score distributions barely separate (Cohen's d ≈ 0). + * The distortion trains the LR to distinguish subtly-wrong from correct decodings, which + * only manifests as larger pairwise deltas in {@link JunkDetector#compare}. Measuring it + * via summary.tsv would produce misleading d≈0 "failure" rows; see compare.tsv instead. + * * <p>Cohen's d = (mean_clean_z − mean_corrupt_z) / pooled_std. * Higher d = better discrimination. FPR = fraction of clean text falsely flagged; * TPR = fraction of corrupted text correctly flagged. Both use threshold = −2.0. * * <p>To compare two model versions: run eval before and after, then diff the - * summary TSVs. The "macro_d" column in summary.tsv is the single headline metric. + * summary and compare TSVs. The "macro_d" column in summary.tsv and the + * "mean_delta" columns in compare.tsv are the headline metrics. * * <p>Usage: * <pre> * java EvalJunkDetector \ - * --model /path/to/junkdetect.bin (default: classpath) - * --data-dir ~/datasets/madlad/junkdetect - * --output-dir /path/to/results (default: data-dir/eval) - * --split dev|test (default: dev — use test only for final reporting) - * --samples 200 - * --seed 42 - * --lengths 15,30,50,100,200 - * --rates 0.01,0.05,0.10,0.25,0.50,0.90 - * --threshold -2.0 + * --model /path/to/junkdetect.bin (default: classpath) + * --data-dir ~/datasets/madlad/junkdetect + * --output-dir /path/to/results (default: data-dir/eval) + * --split dev|test (default: dev) + * --samples 200 + * --compare-n 200 (qualifying pairs per codec pair per length) + * --seed 42 + * --lengths 5,9,15,30,50,100,200 + * --compare-lengths 5,9,15,30,50 + * --rates 0.01,0.05,0.10,0.25,0.50,0.90 + * --threshold -2.0 * </pre> - * - * <p><b>Which split to use:</b> Use {@code --split dev} during iterative development - * (dev data is seen by the calibration step, so numbers are slightly optimistic for - * calibration quality, but still valid for relative comparisons between model versions). - * Use {@code --split test} only when reporting final numbers — the test split is - * completely held out and was never used to make any model or threshold decision. */ public class EvalJunkDetector { @@ -86,10 +99,12 @@ public class EvalJunkDetector { Path dataDir = Paths.get(System.getProperty("user.home"), "datasets", "madlad", "junkdetect"); Path outputDir = null; - String split = "dev"; // dev during development; test for final reporting + String split = "dev"; int samplesPerCell = 200; + int compareN = 200; long seed = 42L; - int[] lengths = {15, 30, 50, 100, 200}; + int[] lengths = {5, 9, 15, 30, 50, 100, 200}; + int[] compareLengths = {5, 9, 15, 30, 50}; double[] rates = {0.01, 0.05, 0.10, 0.25, 0.50, 0.90}; float threshold = -2.0f; @@ -114,6 +129,9 @@ public class EvalJunkDetector { case "--samples": samplesPerCell = Integer.parseInt(args[++i]); break; + case "--compare-n": + compareN = Integer.parseInt(args[++i]); + break; case "--seed": seed = Long.parseLong(args[++i]); break; @@ -121,6 +139,10 @@ public class EvalJunkDetector { lengths = Arrays.stream(args[++i].split(",")) .mapToInt(Integer::parseInt).toArray(); break; + case "--compare-lengths": + compareLengths = Arrays.stream(args[++i].split(",")) + .mapToInt(Integer::parseInt).toArray(); + break; case "--rates": rates = Arrays.stream(args[++i].split(",")) .mapToDouble(Double::parseDouble).toArray(); @@ -144,12 +166,20 @@ public class EvalJunkDetector { : JunkDetector.loadFromClasspath(); System.err.println("=== EvalJunkDetector ==="); - System.err.println(" data-dir: " + dataDir); - System.err.println(" output-dir: " + outputDir); - System.err.println(" split: " + split + System.err.println(" data-dir: " + dataDir); + System.err.println(" output-dir: " + outputDir); + System.err.println(" split: " + split + (split.equals("test") ? " [FINAL REPORTING MODE]" : "")); System.err.println(" scripts in model: " + detector.knownScripts().size()); - System.err.println(" threshold: " + threshold); + System.err.println(" threshold: " + threshold); + + // Build wrong-codec remap tables for char-remap distortion + List<Map<Character, Character>> remapTables = new ArrayList<>(); + for (String[] pair : TrainJunkModel.WRONG_CODEC_PAIRS) { + Map<Character, Character> table = TrainJunkModel.buildRemapTable(pair[0], pair[1]); + if (!table.isEmpty()) remapTables.add(table); + } + System.err.println(" remap tables: " + remapTables.size()); String suffix = "." + split + ".gz"; List<Path> devFiles; @@ -167,8 +197,8 @@ public class EvalJunkDetector { Path detailPath = outputDir.resolve("detail.tsv"); Path summaryPath = outputDir.resolve("summary.tsv"); + Path comparePath = outputDir.resolve("compare.tsv"); - // Accumulate all rows for summary aggregation List<Row> allRows = new ArrayList<>(); try (PrintWriter detail = new PrintWriter( @@ -193,10 +223,6 @@ public class EvalJunkDetector { continue; } - Random rng = new Random(seed); - - // Score clean baseline once per (script, length) - // Reuse the same clean scores for all distortion comparisons at that length for (int len : lengths) { List<Float> cleanZ = scoreClean(detector, sentences, len, samplesPerCell, new Random(seed)); @@ -242,7 +268,7 @@ public class EvalJunkDetector { detail.println(row.toTsv()); } - // --- byte-swap: swap each adjacent pair of bytes (endianness flip) --- + // --- byte-swap --- { List<Float> corruptZ = scoreByteSwapped(detector, sentences, len, samplesPerCell, new Random(seed + 5)); @@ -252,16 +278,20 @@ public class EvalJunkDetector { detail.println(row.toTsv()); } + // char-remap distortion is evaluated only via compare.tsv (pairwise delta), + // not via absolute score() — see class Javadoc for rationale. + detail.flush(); - rng = new Random(seed); // reset between lengths for reproducibility } } } writeSummary(summaryPath, allRows, lengths, rates, threshold); + writeCompareEval(detector, dataDir, suffix, comparePath, compareN, compareLengths, seed); System.err.println("\nWrote " + detailPath); System.err.println("Wrote " + summaryPath); + System.err.println("Wrote " + comparePath); System.err.println("Done."); } @@ -278,8 +308,6 @@ public class EvalJunkDetector { out.println("distortion\tparam\tlength\tn_scripts" + "\tmacro_cohens_d\tmacro_fpr\tmacro_tpr"); - // For each unique (distortion, param, length), average across scripts - // Build groups: inject@rate, char-reverse, byte-shuffle List<String[]> conditions = new ArrayList<>(); for (double rate : rates) { conditions.add(new String[]{"inject", String.format("%.2f", rate)}); @@ -288,6 +316,11 @@ public class EvalJunkDetector { conditions.add(new String[]{"byte-shuffle", "-"}); conditions.add(new String[]{"wrong-codec", "latin1-as-utf8"}); conditions.add(new String[]{"byte-swap", "-"}); + // char-remap is intentionally excluded from summary: at 5% rate the character-level + // wrong-codec substitution is too subtle to detect via absolute score() — both the + // clean and corrupted strings score similarly. The right metric for char-remap is + // compare.tsv (pairwise delta), where it shows up strongly. Including it here would + // make d≈0 rows that look like failures but are actually expected. for (String[] cond : conditions) { String distortion = cond[0]; @@ -298,9 +331,8 @@ public class EvalJunkDetector { && r.param.equals(param) && r.length == len) .collect(Collectors.toList()); - if (matching.isEmpty()) { - continue; - } + if (matching.isEmpty()) continue; + double macroCohensD = matching.stream() .filter(r -> !Double.isNaN(r.cohensD)) .mapToDouble(r -> r.cohensD) @@ -318,7 +350,6 @@ public class EvalJunkDetector { } } - // Overall headline: macro-average Cohen's d across everything double overallD = rows.stream() .filter(r -> !Double.isNaN(r.cohensD)) .mapToDouble(r -> r.cohensD) @@ -338,6 +369,155 @@ public class EvalJunkDetector { } } + // ----------------------------------------------------------------------- + // Compare eval — pairwise codec arbitration, stratified by string length + // ----------------------------------------------------------------------- + + /** + * For each entry in {@link TrainJunkModel#WRONG_CODEC_PAIRS}, encodes sentences + * from the appropriate script's dev file as the source charset, then calls + * {@link JunkDetector#compare} with the correct decoding (A) vs the wrong + * decoding (B). Reports accuracy (how often A wins) and mean/median delta + * at each requested string length. + * + * <p>Mean delta is the headline metric: larger delta means the model more + * confidently picks the correct decoding. At short lengths (5–9 bytes) + * delta is expected to be small; at 50 bytes it should be decisive. + */ + private static void writeCompareEval(JunkDetector detector, + Path dataDir, String suffix, + Path comparePath, + int nPerCell, int[] lengths, + long seed) throws IOException { + try (PrintWriter out = new PrintWriter( + Files.newBufferedWriter(comparePath, StandardCharsets.UTF_8))) { + + out.println("source_codec\twrong_codec\tlength" + + "\tn_tested\taccuracy\tmean_delta\tmedian_delta\tn_no_diff"); + + System.err.printf("%n--- compare() eval ---%n"); + + for (String[] pair : TrainJunkModel.WRONG_CODEC_PAIRS) { + String sourceCodec = pair[0]; + String wrongCodec = pair[1]; + + Charset srcCharset, wrongCharset; + try { + srcCharset = Charset.forName(sourceCodec); + wrongCharset = Charset.forName(wrongCodec); + } catch (UnsupportedCharsetException e) { + System.err.printf(" [%s→%s] charset unavailable, skipping%n", + sourceCodec, wrongCodec); + continue; + } + + String script = codecToScript(sourceCodec); + Path devFile = dataDir.resolve(script.toLowerCase() + suffix); + if (!Files.exists(devFile)) { + System.err.printf(" [%s→%s] no dev file for %s, skipping%n", + sourceCodec, wrongCodec, script); + continue; + } + + // Load a large pool; we'll filter down per-length + List<String> allSentences = loadSentences(devFile, nPerCell * 50); + + // Pre-filter: keep only sentences that roundtrip through sourceCodec + // and produce at least one differing character vs wrongCodec. + List<String[]> candidates = new ArrayList<>(); // {asSource, asWrong} + for (String sentence : allSentences) { + byte[] bytes = sentence.getBytes(srcCharset); + String asSource = new String(bytes, srcCharset); + if (!asSource.equals(sentence)) continue; // encoding lost data + String asWrong = new String(bytes, wrongCharset); + if (asSource.equals(asWrong)) continue; // no differentiating bytes + candidates.add(new String[]{asSource, asWrong}); + } + + if (candidates.isEmpty()) { + System.err.printf(" [%s→%s] no qualifying sentences%n", + sourceCodec, wrongCodec); + continue; + } + + System.err.printf(" [%s→%s] %d candidates from %s%n", + sourceCodec, wrongCodec, candidates.size(), script); + + for (int targetLen : lengths) { + Random rng = new Random(seed); + // Shuffle candidates for this length independently + List<String[]> shuffled = new ArrayList<>(candidates); + Collections.shuffle(shuffled, rng); + + List<Float> deltas = new ArrayList<>(); + int nCorrect = 0; + int nNoDiff = 0; + + for (String[] cand : shuffled) { + if (deltas.size() + nNoDiff >= nPerCell * 3 && deltas.size() >= nPerCell) { + break; + } + String asSource = trimToLength(cand[0], targetLen); + String asWrong = trimToLength(cand[1], targetLen); + + if (asSource.equals(asWrong)) { + nNoDiff++; + continue; + } + if (asSource.isEmpty() || asWrong.isEmpty()) continue; + + TextQualityComparison result = detector.compare( + sourceCodec, asSource, wrongCodec, asWrong); + + deltas.add(result.delta()); + if ("A".equals(result.winner())) nCorrect++; + } + + if (deltas.isEmpty()) continue; + + double accuracy = (double) nCorrect / deltas.size(); + double meanDelta = deltas.stream().mapToDouble(Float::floatValue).average().orElse(0); + List<Float> sorted = new ArrayList<>(deltas); + Collections.sort(sorted); + float medianDelta = sorted.get(sorted.size() / 2); + + System.err.printf(" len=%3d n=%3d acc=%.3f mean_delta=%.3f median_delta=%.3f%n", + targetLen, deltas.size(), accuracy, meanDelta, medianDelta); + + out.printf("%s\t%s\t%d\t%d\t%.3f\t%.3f\t%.3f\t%d%n", + sourceCodec, wrongCodec, targetLen, + deltas.size(), accuracy, meanDelta, medianDelta, nNoDiff); + } + out.flush(); + } + } + } + + /** + * Returns which dev-file script to use for a given source codec. + * CP1251 → CYRILLIC, CP1253 → GREEK, CP1255 → HEBREW, everything else → LATIN. + */ + private static String codecToScript(String codec) { + switch (codec.toLowerCase()) { + case "windows-1251": return "CYRILLIC"; + case "windows-1253": return "GREEK"; + case "windows-1255": return "HEBREW"; + default: return "LATIN"; + } + } + + /** + * Trims a string to approximately {@code targetLen} UTF-8 bytes, aligned to + * a codepoint boundary. Used to produce short-string variants for compare() testing. + */ + private static String trimToLength(String s, int targetLen) { + byte[] bytes = s.getBytes(StandardCharsets.UTF_8); + if (bytes.length <= targetLen) return s; + int end = targetLen; + while (end < bytes.length && (bytes[end] & 0xC0) == 0x80) end++; + return new String(bytes, 0, end, StandardCharsets.UTF_8); + } + // ----------------------------------------------------------------------- // Row (one evaluation cell) // ----------------------------------------------------------------------- @@ -383,23 +563,14 @@ public class EvalJunkDetector { // Statistics // ----------------------------------------------------------------------- - /** - * Cohen's d = (mean_clean − mean_corrupt) / pooled_std. - * Positive = clean scores higher than corrupt (desirable). - * Higher absolute value = better discrimination. - */ private static double computeCohensD(List<Float> clean, List<Float> corrupt) { - if (clean.isEmpty() || corrupt.isEmpty()) { - return Double.NaN; - } + if (clean.isEmpty() || corrupt.isEmpty()) return Double.NaN; double mc = mean(clean); double mj = mean(corrupt); double vc = variance(clean, mc); double vj = variance(corrupt, mj); double pooledStd = Math.sqrt((vc + vj) / 2.0); - if (pooledStd < 1e-9) { - return Double.NaN; - } + if (pooledStd < 1e-9) return Double.NaN; return (mc - mj) / pooledStd; } @@ -412,9 +583,7 @@ public class EvalJunkDetector { } private static double fractionBelow(List<Float> zs, float threshold) { - if (zs.isEmpty()) { - return Double.NaN; - } + if (zs.isEmpty()) return Double.NaN; long count = zs.stream().filter(z -> z < threshold).count(); return (double) count / zs.size(); } @@ -429,9 +598,7 @@ public class EvalJunkDetector { for (int i = 0; i < n; i++) { String s = pickSubstring(sentences, targetLen, rng); TextQualityScore score = detector.score(s); - if (!score.isUnknown()) { - results.add(score.getZScore()); - } + if (!score.isUnknown()) results.add(score.getZScore()); } return results; } @@ -445,9 +612,7 @@ public class EvalJunkDetector { byte[] bytes = s.getBytes(StandardCharsets.UTF_8); injectRandomBytes(bytes, rate, rng); TextQualityScore score = detector.score(new String(bytes, StandardCharsets.ISO_8859_1)); - if (!score.isUnknown()) { - results.add(score.getZScore()); - } + if (!score.isUnknown()) results.add(score.getZScore()); } return results; } @@ -458,9 +623,7 @@ public class EvalJunkDetector { for (int i = 0; i < n; i++) { String s = reverseCodepoints(pickSubstring(sentences, targetLen, rng)); TextQualityScore score = detector.score(s); - if (!score.isUnknown()) { - results.add(score.getZScore()); - } + if (!score.isUnknown()) results.add(score.getZScore()); } return results; } @@ -473,9 +636,7 @@ public class EvalJunkDetector { byte[] bytes = s.getBytes(StandardCharsets.UTF_8); shuffleBytes(bytes, rng); TextQualityScore score = detector.score(new String(bytes, StandardCharsets.ISO_8859_1)); - if (!score.isUnknown()) { - results.add(score.getZScore()); - } + if (!score.isUnknown()) results.add(score.getZScore()); } return results; } @@ -487,9 +648,7 @@ public class EvalJunkDetector { String s = pickSubstring(sentences, targetLen, rng); byte[] garbled = wrongCodecBytes(s.getBytes(StandardCharsets.UTF_8)); TextQualityScore score = detector.score(new String(garbled, StandardCharsets.UTF_8)); - if (!score.isUnknown()) { - results.add(score.getZScore()); - } + if (!score.isUnknown()) results.add(score.getZScore()); } return results; } @@ -501,9 +660,27 @@ public class EvalJunkDetector { String s = pickSubstring(sentences, targetLen, rng); byte[] swapped = swapByteOrder(s.getBytes(StandardCharsets.UTF_8)); TextQualityScore score = detector.score(new String(swapped, StandardCharsets.ISO_8859_1)); - if (!score.isUnknown()) { - results.add(score.getZScore()); - } + if (!score.isUnknown()) results.add(score.getZScore()); + } + return results; + } + + /** + * Applies a randomly chosen wrong-codec character remap at {@code rate} to each + * sample. Simulates real-world charset misdetection at the character level + * (e.g. CP1252-encoded text decoded as CP1255, replacing umlauts with Hebrew letters). + */ + private static List<Float> scoreWithRemap(JunkDetector detector, List<String> sentences, + int targetLen, + List<Map<Character, Character>> remapTables, + double rate, int n, Random rng) { + List<Float> results = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + String s = pickSubstring(sentences, targetLen, rng); + Map<Character, Character> table = remapTables.get(rng.nextInt(remapTables.size())); + String corrupted = TrainJunkModel.wrongCodecRemap(s, table, rate, rng); + TextQualityScore score = detector.score(corrupted); + if (!score.isUnknown()) results.add(score.getZScore()); } return results; } @@ -513,26 +690,19 @@ public class EvalJunkDetector { // ----------------------------------------------------------------------- /** - * Injects control characters (0x01–0x09, i.e. below newline/0x0A, excluding null) - * at the given rate. These bytes never appear in clean natural-language UTF-8 text - * and simulate binary data leaking into a text stream. 0x00 is excluded because - * null bytes cause problems in many text-processing pipelines. + * Injects control characters (0x01–0x09) at the given rate. */ static void injectRandomBytes(byte[] bytes, double rate, Random rng) { for (int i = 0; i < bytes.length; i++) { if (rng.nextDouble() < rate) { - // 0x01..0x09 inclusive (9 values): SOH STX ETX EOT ENQ ACK BEL BS HT bytes[i] = (byte) (0x01 + rng.nextInt(9)); } } } /** - * Wrong-codec distortion: the UTF-8 bytes of the sentence are re-interpreted - * as ISO-8859-1 (Latin-1) and then re-encoded as UTF-8. This is the classic - * "saved as UTF-8, displayed as Latin-1" mojibake: every byte in 0x80–0xFF - * becomes a two-byte UTF-8 sequence, doubling the byte length of non-ASCII runs - * and producing bogus accented-Latin bigrams. + * Wrong-codec distortion: UTF-8 bytes re-interpreted as ISO-8859-1, then + * re-encoded as UTF-8. Produces bogus two-byte sequences for any non-ASCII byte. */ static byte[] wrongCodecBytes(byte[] utf8) { String misread = new String(utf8, StandardCharsets.ISO_8859_1); @@ -540,10 +710,8 @@ public class EvalJunkDetector { } /** - * Byte-swap distortion: swaps each adjacent pair of bytes — (0,1), (2,3), etc. - * If the array has an odd length the last byte is left unchanged. - * Simulates reading a 2-byte encoding (UTF-16, UCS-2, CP932 two-byte sequences) - * with the wrong byte order. + * Swaps each adjacent pair of bytes — simulates reading a 2-byte encoding + * (UTF-16, CP932 two-byte sequences) with wrong byte order. */ static byte[] swapByteOrder(byte[] bytes) { byte[] out = bytes.clone(); @@ -581,18 +749,11 @@ public class EvalJunkDetector { private static String pickSubstring(List<String> sentences, int targetLen, Random rng) { String s = sentences.get(rng.nextInt(sentences.size())); byte[] bytes = s.getBytes(StandardCharsets.UTF_8); - if (bytes.length <= targetLen) { - return s; - } - // Pick a random window of targetLen bytes, aligned to a codepoint boundary + if (bytes.length <= targetLen) return s; int start = rng.nextInt(bytes.length - targetLen); - while (start > 0 && (bytes[start] & 0xC0) == 0x80) { - start--; - } + while (start > 0 && (bytes[start] & 0xC0) == 0x80) start--; int end = Math.min(start + targetLen, bytes.length); - while (end < bytes.length && (bytes[end] & 0xC0) == 0x80) { - end++; - } + while (end < bytes.length && (bytes[end] & 0xC0) == 0x80) end++; return new String(bytes, start, end - start, StandardCharsets.UTF_8); } @@ -606,7 +767,7 @@ public class EvalJunkDetector { while ((line = r.readLine()) != null && result.size() < maxSentences) { String trimmed = line.strip(); if (!trimmed.isEmpty() - && trimmed.getBytes(StandardCharsets.UTF_8).length >= 15) { + && trimmed.getBytes(StandardCharsets.UTF_8).length >= 5) { result.add(trimmed); } } diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java index 2ba083c139..d5e7ce2430 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java @@ -22,16 +22,21 @@ import java.io.IOException; import java.io.InputStreamReader; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; +import java.nio.charset.UnsupportedCharsetException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.TreeMap; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; @@ -41,7 +46,7 @@ import java.util.zip.GZIPOutputStream; * {@link BuildJunkTrainingData}. * * <p>For each script group (identified by a {@code {script}.train.gz} file), - * three features are trained and then combined by a per-script logistic + * four features are trained and then combined by a per-script logistic * regression classifier: * <ol> * <li><b>Byte-bigram log-probability</b>: 256×256 table of log P(b|a) over @@ -54,30 +59,42 @@ import java.util.zip.GZIPOutputStream; * ranges ([0x01–0x08, 0x0B, 0x0C, 0x0E–0x1F, 0x7F]). Stored as * {@code −fraction} so the z-score convention matches the other features * (higher = cleaner).</li> + * <li><b>Script-transition log-probability</b>: global table of log P(script_b | script_a) + * over raw {@link Character.UnicodeScript} values (excluding COMMON, INHERITED, UNKNOWN), + * pooled across all training scripts (z4).</li> * </ol> * - * <p>All three features are calibrated (mu/sigma) on the dev split so their + * <p>All four features are calibrated (mu/sigma) on the dev split so their * z-scores are on a common scale. A per-script binary logistic regression - * classifier is then fit on (z1, z2, z3) using clean dev windows and corrupted + * classifier is then fit on (z1, z2, z3, z4) using clean dev windows and corrupted * versions (inject@5%, char-shuffle) as training examples. The learned weights * replace the fixed equal-weight average, allowing the model to automatically * downweight noisy features (e.g. high-variance block transitions for MYANMAR) * and upweight informative ones (e.g. control-byte fraction for [email protected]). * * <p>At inference, the final score is the linear combination - * {@code w1*z1 + w2*z2 + w3*z3 + bias}; positive values indicate clean text. + * {@code w1*z1 + w2*z2 + w3*z3 + w4*z4 + bias}; positive values indicate clean text. * The natural threshold is 0 (probability 0.5); use a negative threshold for * more conservative junk detection. * - * <p>Output format: {@code JUNKDET1} gzipped binary, <b>version 3</b>. - * Version 1 (bigrams only) and version 2 (equal-weight average) files can + * <p>Output format: {@code JUNKDET1} gzipped binary, <b>version 4</b>. + * Version 1 (bigrams only), version 2 (equal-weight average), and version 3 files can * still be loaded by {@code JunkDetector}. * * <pre> * [8 bytes] magic "JUNKDET1" (ASCII) - * [1 byte] version = 3 + * [1 byte] version = 4 * [4 bytes] num_scripts (big-endian int) * [2 bytes] block_N — number of distinct named Unicode blocks + 1 (unassigned) + * // Global script-transition section (version 4+) + * [1 byte] num_script_buckets + * for each bucket: + * [2 bytes] name length (big-endian ushort) + * [name bytes] bucket name (UnicodeScript.name() or "OTHER") + * [num_script_buckets² × 4 bytes] script-transition log-prob table + * [4 bytes] mu4 (float32 big-endian) + * [4 bytes] sigma4 (float32 big-endian) + * // Per-script data (same as v3 but num_features = 4) * for each script (sorted by name): * [2 bytes] name length (big-endian ushort) * [name bytes] script name (UTF-8) @@ -93,17 +110,18 @@ import java.util.zip.GZIPOutputStream; * [4 bytes] mu3 (float32 big-endian) * [4 bytes] sigma3 (float32 big-endian) * // Linear classifier weights - * [1 byte] num_features (= 3) + * [1 byte] num_features (= 4 for v4) * [4 bytes] w1 (float32 big-endian) * [4 bytes] w2 (float32 big-endian) * [4 bytes] w3 (float32 big-endian) + * [4 bytes] w4 (float32 big-endian) * [4 bytes] bias (float32 big-endian) * </pre> */ public class TrainJunkModel { static final String MAGIC = "JUNKDET1"; - static final byte VERSION = 3; + static final byte VERSION = 4; /** Number of clean (and corrupted) windows used to train the per-script classifier. */ static final int NUM_CLASSIFIER_SAMPLES = 500; @@ -121,6 +139,28 @@ public class TrainJunkModel { */ static final float CONTROL_BYTE_MIN_SIGMA = 0.005f; + /** + * Codec pairs used to build wrong-codec remap tables for training. + * Each entry is {sourceCodec, wrongCodec}: text encoded in sourceCodec but + * decoded as wrongCodec. Pairs within the same script family (e.g. CP1250↔CP1252) + * produce wrong-accent distortions that shift characters between Unicode blocks + * while staying in LATIN. Cross-script pairs (CP1252↔CP1255) additionally change + * the Unicode script, which z4 also detects. + */ + static final String[][] WRONG_CODEC_PAIRS = { + {"windows-1252", "windows-1250"}, // Western ↔ Central European (wrong accents) + {"windows-1250", "windows-1252"}, // reverse + {"windows-1252", "windows-1257"}, // Western ↔ Baltic (wrong accents) + {"windows-1257", "windows-1252"}, // reverse + {"windows-1252", "windows-1254"}, // Western ↔ Turkish (wrong accents) + {"windows-1251", "windows-1252"}, // Cyrillic → Latin (cross-script) + {"windows-1252", "windows-1251"}, // Latin → Cyrillic (cross-script) + {"windows-1253", "windows-1252"}, // Greek → Latin (cross-script) + {"windows-1252", "windows-1253"}, // Latin → Greek (cross-script) + {"windows-1255", "windows-1252"}, // Hebrew → Latin (cross-script) + {"windows-1252", "windows-1255"}, // Latin → Hebrew (the German vcard case) + }; + /** * Target byte-lengths used for calibration sampling, matching the evaluator defaults. */ @@ -151,7 +191,7 @@ public class TrainJunkModel { } } - System.out.println("=== TrainJunkModel ==="); + System.out.println("=== TrainJunkModel (v4) ==="); System.out.println(" data-dir: " + dataDir); System.out.println(" output: " + output); @@ -163,7 +203,7 @@ public class TrainJunkModel { System.out.print("Building Unicode named-block index... "); long t0 = System.currentTimeMillis(); Map<Character.UnicodeBlock, Integer> blockIndex = buildBlockIndex(); - int blockN = blockIndex.size() + 1; // +1 for unassigned bucket + int blockN = blockIndex.size() + 1; System.out.printf("%d named blocks → table size %d×%d (%dms)%n", blockIndex.size(), blockN, blockN, System.currentTimeMillis() - t0); @@ -173,87 +213,153 @@ public class TrainJunkModel { TreeMap<String, float[]> blockCalibrations = new TreeMap<>(); TreeMap<String, float[]> controlCalibrations = new TreeMap<>(); TreeMap<String, float[]> classifierWeights = new TreeMap<>(); + TreeMap<String, Path> devFilePaths = new TreeMap<>(); + List<Path> allTrainFiles = new ArrayList<>(); + List<Path> allDevFiles = new ArrayList<>(); + List<Path> trainFiles; try (var stream = Files.list(dataDir)) { - List<Path> trainFiles = stream + trainFiles = stream .filter(p -> p.getFileName().toString().endsWith(".train.gz")) .sorted() .toList(); + } - if (trainFiles.isEmpty()) { - System.err.println("ERROR: no *.train.gz files found in " + dataDir); - System.exit(1); - } - - for (Path trainFile : trainFiles) { - String filename = trainFile.getFileName().toString(); - String script = filename.substring(0, filename.length() - ".train.gz".length()) - .toUpperCase(); - Path devFile = trainFile.getParent().resolve( - filename.replace(".train.gz", ".dev.gz")); + if (trainFiles.isEmpty()) { + System.err.println("ERROR: no *.train.gz files found in " + dataDir); + System.exit(1); + } - System.out.printf("%n--- %s ---%n", script); + // ----------------------------------------------------------------------- + // Phase 1 — per-script bigram tables, block tables, calibrations + // ----------------------------------------------------------------------- + System.out.println("\n--- Phase 1: per-script tables and calibrations ---"); + for (Path trainFile : trainFiles) { + String filename = trainFile.getFileName().toString(); + String script = filename.substring(0, filename.length() - ".train.gz".length()) + .toUpperCase(); + Path devFile = trainFile.getParent().resolve( + filename.replace(".train.gz", ".dev.gz")); + + System.out.printf("%n [%s]%n", script); + allTrainFiles.add(trainFile); + + t0 = System.currentTimeMillis(); + System.out.print(" Training byte-bigram table... "); + float[] bigramTable = trainBigramTable(trainFile); + System.out.printf("done (%dms)%n", System.currentTimeMillis() - t0); + + t0 = System.currentTimeMillis(); + System.out.print(" Training named-block table... "); + float[] blockTable = trainBlockTable(trainFile, blockIndex, blockN); + System.out.printf("done (%dms)%n", System.currentTimeMillis() - t0); + + float[] bigramCal = new float[]{0f, 1f}; + float[] blockCal = new float[]{0f, 1f}; + float[] controlCal = new float[]{0f, 1f}; + + if (Files.exists(devFile)) { + t0 = System.currentTimeMillis(); + System.out.print(" Calibrating byte bigrams on dev... "); + bigramCal = computeBigramCalibration(devFile, bigramTable); + System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", + bigramCal[0], bigramCal[1], System.currentTimeMillis() - t0); t0 = System.currentTimeMillis(); - System.out.print(" Training byte-bigram table... "); - float[] bigramTable = trainBigramTable(trainFile); - System.out.printf("done (%dms)%n", System.currentTimeMillis() - t0); + System.out.print(" Calibrating named blocks on dev... "); + blockCal = computeBlockCalibration(devFile, blockTable, blockIndex, blockN); + System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", + blockCal[0], blockCal[1], System.currentTimeMillis() - t0); t0 = System.currentTimeMillis(); - System.out.print(" Training named-block table... "); - float[] blockTable = trainBlockTable(trainFile, blockIndex, blockN); - System.out.printf("done (%dms)%n", System.currentTimeMillis() - t0); - - float[] bigramCal = new float[]{0f, 1f}; - float[] blockCal = new float[]{0f, 1f}; - float[] controlCal = new float[]{0f, 1f}; - // Default: equal-weight average (w=[1/3,1/3,1/3], bias=0) - float[] weights = new float[]{1f / 3, 1f / 3, 1f / 3, 0f}; - - if (Files.exists(devFile)) { - t0 = System.currentTimeMillis(); - System.out.print(" Calibrating byte bigrams on dev... "); - bigramCal = computeBigramCalibration(devFile, bigramTable); - System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", - bigramCal[0], bigramCal[1], System.currentTimeMillis() - t0); - - t0 = System.currentTimeMillis(); - System.out.print(" Calibrating named blocks on dev... "); - blockCal = computeBlockCalibration(devFile, blockTable, blockIndex, blockN); - System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", - blockCal[0], blockCal[1], System.currentTimeMillis() - t0); - - t0 = System.currentTimeMillis(); - System.out.print(" Calibrating control bytes on dev... "); - controlCal = computeControlByteCalibration(devFile); - System.out.printf("done — mu=%.6f sigma=%.6f (%dms)%n", - controlCal[0], controlCal[1], System.currentTimeMillis() - t0); - - t0 = System.currentTimeMillis(); - System.out.print(" Training linear classifier... "); - weights = trainClassifier(devFile, bigramTable, bigramCal, - blockTable, blockCal, controlCal, blockIndex, blockN); - System.out.printf("done — w=[%.3f,%.3f,%.3f] bias=%.3f (%dms)%n", - weights[0], weights[1], weights[2], weights[3], - System.currentTimeMillis() - t0); - } else { - System.out.println(" WARNING: no dev file found, using uncalibrated defaults"); - } + System.out.print(" Calibrating control bytes on dev..."); + controlCal = computeControlByteCalibration(devFile); + System.out.printf("done — mu=%.6f sigma=%.6f (%dms)%n", + controlCal[0], controlCal[1], System.currentTimeMillis() - t0); - bigramTables.put(script, bigramTable); - bigramCalibrations.put(script, bigramCal); - blockTables.put(script, blockTable); - blockCalibrations.put(script, blockCal); - controlCalibrations.put(script, controlCal); - classifierWeights.put(script, weights); + devFilePaths.put(script, devFile); + allDevFiles.add(devFile); + } else { + System.out.println(" WARNING: no dev file found, using uncalibrated defaults"); } + + bigramTables.put(script, bigramTable); + bigramCalibrations.put(script, bigramCal); + blockTables.put(script, blockTable); + blockCalibrations.put(script, blockCal); + controlCalibrations.put(script, controlCal); + // Placeholder — set in phase 3 + classifierWeights.put(script, new float[]{1f / 4, 1f / 4, 1f / 4, 1f / 4, 0f}); } - System.out.printf("%nWriting model (%d scripts, blockN=%d) → %s%n", - bigramTables.size(), blockN, output); + // ----------------------------------------------------------------------- + // Phase 2 — global script-transition table + // ----------------------------------------------------------------------- + System.out.println("\n--- Phase 2: global script-transition table ---"); + List<String> scriptBuckets = buildScriptBuckets(); + int numScriptBuckets = scriptBuckets.size(); + Map<String, Integer> scriptBucketMap = new LinkedHashMap<>(); + for (int i = 0; i < numScriptBuckets; i++) { + scriptBucketMap.put(scriptBuckets.get(i), i); + } + System.out.printf(" %d script buckets (including OTHER)%n", numScriptBuckets); + + t0 = System.currentTimeMillis(); + System.out.print(" Training script-transition table... "); + float[] scriptTransTable = trainScriptTransitionTable(allTrainFiles, scriptBucketMap, numScriptBuckets); + System.out.printf("done (%dms)%n", System.currentTimeMillis() - t0); + + t0 = System.currentTimeMillis(); + System.out.print(" Calibrating script transitions... "); + float[] scriptTransCal = calibrateScriptTransitions(allDevFiles, scriptTransTable, + scriptBucketMap, numScriptBuckets); + System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", + scriptTransCal[0], scriptTransCal[1], System.currentTimeMillis() - t0); + + t0 = System.currentTimeMillis(); + System.out.print(" Collecting per-script codepoint pools... "); + Map<String, List<Integer>> scriptCodepoints = collectScriptCodepoints(allTrainFiles, 200); + System.out.printf("done — %d scripts (%dms)%n", + scriptCodepoints.size(), System.currentTimeMillis() - t0); + + System.out.print(" Building wrong-codec remap tables... "); + List<Map<Character, Character>> remapTables = new ArrayList<>(); + for (String[] pair : WRONG_CODEC_PAIRS) { + Map<Character, Character> table = buildRemapTable(pair[0], pair[1]); + if (!table.isEmpty()) remapTables.add(table); + } + System.out.printf("%d tables built%n", remapTables.size()); + + // ----------------------------------------------------------------------- + // Phase 3 — per-script linear classifiers (now with z4) + // ----------------------------------------------------------------------- + System.out.println("\n--- Phase 3: per-script linear classifiers (z1,z2,z3,z4) ---"); + for (String script : bigramTables.keySet()) { + Path devFile = devFilePaths.get(script); + if (devFile == null) { + System.out.printf(" [%s] WARNING: no dev file, keeping equal-weight defaults%n", script); + continue; + } + t0 = System.currentTimeMillis(); + System.out.printf(" [%s] training classifier... ", script); + float[] weights = trainClassifier(devFile, + bigramTables.get(script), bigramCalibrations.get(script), + blockTables.get(script), blockCalibrations.get(script), + controlCalibrations.get(script), blockIndex, blockN, + scriptTransTable, scriptTransCal, scriptBucketMap, numScriptBuckets, + scriptCodepoints, remapTables); + classifierWeights.put(script, weights); + System.out.printf("done — w=[%.3f,%.3f,%.3f,%.3f] bias=%.3f (%dms)%n", + weights[0], weights[1], weights[2], weights[3], weights[4], + System.currentTimeMillis() - t0); + } + + System.out.printf("%nWriting model (%d scripts, blockN=%d, scriptBuckets=%d) → %s%n", + bigramTables.size(), blockN, numScriptBuckets, output); saveModel(bigramTables, bigramCalibrations, blockTables, blockCalibrations, - controlCalibrations, classifierWeights, blockN, output); + controlCalibrations, classifierWeights, + blockN, scriptBuckets, scriptTransTable, scriptTransCal, output); System.out.printf("Model size: %,d bytes (%.1f MB)%n", Files.size(output), Files.size(output) / 1_000_000.0); System.out.println("Done."); @@ -496,64 +602,98 @@ public class TrainJunkModel { // ----------------------------------------------------------------------- /** - * Trains a per-script binary logistic regression classifier on (z1, z2, z3). + * Trains a per-script binary logistic regression classifier on (z1, z2, z3, z4). * * <p>Clean examples: {@link #NUM_CLASSIFIER_SAMPLES} random dev windows (seed 100). - * Corrupted examples: same count, alternating inject@5% (seed 102, even indices) - * and char-shuffle (odd indices) applied to windows sampled with seed 101. + * Corrupted examples: same count, cycling through four distortions (seed 102): + * <ol> + * <li>inject@5% control chars</li> + * <li>char-shuffle</li> + * <li>cross-script substitution — replaces ~5% of characters with codepoints from + * foreign scripts, simulating charset encoding errors such as German umlauts + * becoming Hebrew letters when CP1252 text is decoded as CP1255</li> + * <li>wrong-codec remap — replaces ~5% of characters using a random pre-computed + * charset remap table (e.g. CP1252→CP1250 for wrong accents, CP1252→CP1255 + * for script crossings), simulating real-world charset misdetection</li> + * </ol> * - * @return float[4] = {w1, w2, w3, bias} — classifier weights; positive logit = clean + * @param remapTables list of pre-built wrong-codec remap tables from {@link #buildRemapTable} + * @return float[5] = {w1, w2, w3, w4, bias} — classifier weights; positive logit = clean */ static float[] trainClassifier(Path devGz, float[] bigramTable, float[] bigramCal, float[] blockTable, float[] blockCal, float[] controlCal, Map<Character.UnicodeBlock, Integer> blockIndex, - int blockN) throws IOException { + int blockN, + float[] scriptTransTable, float[] scriptTransCal, + Map<String, Integer> scriptBucketMap, int numScriptBuckets, + Map<String, List<Integer>> scriptCodepoints, + List<Map<Character, Character>> remapTables) + throws IOException { int nEach = NUM_CLASSIFIER_SAMPLES; // Clean windows List<String> cleanWindows = sampleSubstrings(devGz, nEach, CALIB_LENGTHS, 100); // Corrupted windows: sample base windows (seed 101), then distort + // Four-way rotation: inject / shuffle / cross-script / wrong-codec remap List<String> baseWindows = sampleSubstrings(devGz, nEach, CALIB_LENGTHS, 101); Random rng = new Random(102); List<String> corruptedWindows = new ArrayList<>(nEach); for (int i = 0; i < baseWindows.size(); i++) { String w = baseWindows.get(i); - if (i % 2 == 0) { - corruptedWindows.add(injectControlChars(w, CLASSIFIER_INJECT_RATE, rng)); - } else { - corruptedWindows.add(shuffleChars(w, rng)); + switch (i % 4) { + case 0: + corruptedWindows.add(injectControlChars(w, CLASSIFIER_INJECT_RATE, rng)); + break; + case 1: + corruptedWindows.add(shuffleChars(w, rng)); + break; + case 2: + corruptedWindows.add(injectCrossScriptChars(w, CLASSIFIER_INJECT_RATE, rng, + scriptCodepoints)); + break; + default: + if (!remapTables.isEmpty()) { + Map<Character, Character> table = + remapTables.get(rng.nextInt(remapTables.size())); + corruptedWindows.add(wrongCodecRemap(w, table, CLASSIFIER_INJECT_RATE, rng)); + } else { + corruptedWindows.add(injectControlChars(w, CLASSIFIER_INJECT_RATE, rng)); + } + break; } } - // Build (z1, z2, z3) feature matrix + // Build (z1, z2, z3, z4) feature matrix List<float[]> features = new ArrayList<>(cleanWindows.size() + corruptedWindows.size()); List<Integer> labels = new ArrayList<>(cleanWindows.size() + corruptedWindows.size()); for (String w : cleanWindows) { features.add(extractFeatures(w, bigramTable, bigramCal, - blockTable, blockCal, blockN, controlCal, blockIndex)); + blockTable, blockCal, blockN, controlCal, blockIndex, + scriptTransTable, scriptTransCal, scriptBucketMap, numScriptBuckets)); labels.add(1); // clean } for (String w : corruptedWindows) { features.add(extractFeatures(w, bigramTable, bigramCal, - blockTable, blockCal, blockN, controlCal, blockIndex)); + blockTable, blockCal, blockN, controlCal, blockIndex, + scriptTransTable, scriptTransCal, scriptBucketMap, numScriptBuckets)); labels.add(0); // corrupted } - float[] weights = fitLogisticRegression(features, labels, 3); + float[] weights = fitLogisticRegression(features, labels, 4); // Calibrate bias using only short (len=15) windows so that FPR ≤ 2.5% - // even at the worst-case (shortest) window length. Longer windows have - // lower logit variance and will score well above this threshold naturally. + // even at the worst-case (shortest) window length. List<String> shortWindows = sampleSubstrings(devGz, nEach, new int[]{15}, 200); List<Float> shortLogits = new ArrayList<>(shortWindows.size()); int nFeat = weights.length - 1; for (String w : shortWindows) { float[] x = extractFeatures(w, bigramTable, bigramCal, - blockTable, blockCal, blockN, controlCal, blockIndex); + blockTable, blockCal, blockN, controlCal, blockIndex, + scriptTransTable, scriptTransCal, scriptBucketMap, numScriptBuckets); float logit = weights[nFeat]; for (int j = 0; j < nFeat; j++) logit += weights[j] * x[j]; shortLogits.add(logit); @@ -562,22 +702,24 @@ public class TrainJunkModel { Collections.sort(shortLogits); int pIdx = (int) (0.025 * shortLogits.size()); float p025 = shortLogits.get(Math.max(0, pIdx)); - weights[nFeat] -= p025; // shift bias so p2.5 of len=15 logits = 0 + weights[nFeat] -= p025; } return weights; } /** - * Extracts calibrated z-scores (z1, z2, z3) for a single text window. + * Extracts calibrated z-scores (z1, z2, z3, z4) for a single text window. * - * @return float[3] = {z1_bigram, z2_block, z3_control} + * @return float[4] = {z1_bigram, z2_block, z3_control, z4_scriptTrans} */ static float[] extractFeatures(String window, float[] bigramTable, float[] bigramCal, float[] blockTable, float[] blockCal, int blockN, float[] controlCal, - Map<Character.UnicodeBlock, Integer> blockIndex) { + Map<Character.UnicodeBlock, Integer> blockIndex, + float[] scriptTransTable, float[] scriptTransCal, + Map<String, Integer> scriptBucketMap, int numScriptBuckets) { byte[] utf8 = window.getBytes(StandardCharsets.UTF_8); // z1: byte-bigram mean log-prob @@ -626,7 +768,17 @@ public class TrainJunkModel { z3 = (score - controlCal[0]) / controlCal[1]; } - return new float[]{z1, z2, z3}; + // z4: script-transition mean log-prob (raw UnicodeScript, no model fallback) + float z4 = 0f; + if (scriptTransTable != null && scriptTransCal != null) { + double raw = rawScriptTransitionLogProb(window, scriptTransTable, + scriptBucketMap, numScriptBuckets, numScriptBuckets - 1); + if (!Double.isNaN(raw)) { + z4 = ((float) raw - scriptTransCal[0]) / scriptTransCal[1]; + } + } + + return new float[]{z1, z2, z3, z4}; } /** @@ -738,13 +890,16 @@ public class TrainJunkModel { // ----------------------------------------------------------------------- /** - * Writes the trained model (version 3) to a gzipped binary file. + * Writes the trained model (version 4) to a gzipped binary file. * * <p>Format documented in the class Javadoc. All multi-byte integers are * big-endian; floats are IEEE 754 big-endian. * - * @param classifierWeights per-script float[4] = {w1, w2, w3, bias} + * @param classifierWeights per-script float[5] = {w1, w2, w3, w4, bias} * @param blockN the block table dimension (blockIndex.size() + 1) + * @param scriptBuckets ordered list of script bucket names (last = "OTHER") + * @param scriptTransTable global script-transition log-prob table + * @param scriptTransCal float[2] = {mu, sigma} for script-transition feature */ static void saveModel(TreeMap<String, float[]> bigramTables, TreeMap<String, float[]> bigramCalibrations, @@ -753,6 +908,9 @@ public class TrainJunkModel { TreeMap<String, float[]> controlCalibrations, TreeMap<String, float[]> classifierWeights, int blockN, + List<String> scriptBuckets, + float[] scriptTransTable, + float[] scriptTransCal, Path output) throws IOException { try (DataOutputStream dos = new DataOutputStream( new GZIPOutputStream(Files.newOutputStream(output)))) { @@ -760,7 +918,19 @@ public class TrainJunkModel { dos.write(MAGIC.getBytes(StandardCharsets.UTF_8)); dos.writeByte(VERSION); dos.writeInt(bigramTables.size()); - dos.writeShort(blockN); // global: block table dimension + dos.writeShort(blockN); + + // Global script-transition section (v4+) + int numBuckets = scriptBuckets.size(); + dos.writeByte(numBuckets); + for (String bucketName : scriptBuckets) { + byte[] nameBytes = bucketName.getBytes(StandardCharsets.UTF_8); + dos.writeShort(nameBytes.length); + dos.write(nameBytes); + } + dos.write(toBytes(scriptTransTable)); + dos.writeFloat(scriptTransCal[0]); // mu + dos.writeFloat(scriptTransCal[1]); // sigma for (var entry : bigramTables.entrySet()) { String script = entry.getKey(); @@ -770,28 +940,24 @@ public class TrainJunkModel { float[] blockCal = blockCalibrations.getOrDefault(script, new float[]{0f, 1f}); float[] controlCal = controlCalibrations.getOrDefault(script, new float[]{0f, 1f}); float[] weights = classifierWeights.getOrDefault(script, - new float[]{1f / 3, 1f / 3, 1f / 3, 0f}); + new float[]{1f / 4, 1f / 4, 1f / 4, 1f / 4, 0f}); byte[] nameBytes = script.getBytes(StandardCharsets.UTF_8); dos.writeShort(nameBytes.length); dos.write(nameBytes); - // Feature 1: byte bigrams dos.writeFloat(bigramCal[0]); dos.writeFloat(bigramCal[1]); dos.write(toBytes(bigramTable)); - // Feature 2: named-block transitions dos.writeFloat(blockCal[0]); dos.writeFloat(blockCal[1]); dos.write(toBytes(blockTable)); - // Feature 3: control-byte fraction dos.writeFloat(controlCal[0]); dos.writeFloat(controlCal[1]); - // Classifier weights: num_features (1 byte) + weights + bias - int numFeatures = weights.length - 1; // last element is bias + int numFeatures = weights.length - 1; dos.writeByte(numFeatures); for (float v : weights) dos.writeFloat(v); } @@ -840,6 +1006,278 @@ public class TrainJunkModel { StandardCharsets.UTF_8)); } + /** + * Returns an ordered list of all recognized {@link Character.UnicodeScript} names + * (excluding COMMON, INHERITED, UNKNOWN pseudo-scripts), sorted alphabetically, + * with "OTHER" appended as the final fallback bucket. + * + * <p>Using raw UnicodeScript names (not the SCRIPT_MODEL_FALLBACK-mapped names) + * preserves discrimination power: clean Japanese text has characteristic + * KANJI→HIRAGANA→KATAKANA transitions that char-shuffle disrupts, which would + * be lost if all three were merged into "HAN". + */ + static List<String> buildScriptBuckets() { + List<String> buckets = new ArrayList<>(); + for (Character.UnicodeScript s : Character.UnicodeScript.values()) { + if (s != Character.UnicodeScript.COMMON + && s != Character.UnicodeScript.INHERITED + && s != Character.UnicodeScript.UNKNOWN) { + buckets.add(s.name()); + } + } + Collections.sort(buckets); + buckets.add("OTHER"); + return buckets; + } + + /** + * Trains a global {@code numBuckets×numBuckets} script-transition log-probability + * table by pooling all training files. Uses raw {@link Character.UnicodeScript} + * values (not the SCRIPT_MODEL_FALLBACK mapping) so that HIRAGANA, KATAKANA, and + * HAN remain distinct buckets. + * + * @return float[numBuckets * numBuckets] where index {@code a*numBuckets+b} = log P(script_b | script_a) + */ + static float[] trainScriptTransitionTable(List<Path> trainFiles, + Map<String, Integer> scriptBucketMap, + int numBuckets) throws IOException { + long[] counts = new long[numBuckets * numBuckets]; + int otherBucket = numBuckets - 1; + long totalTransitions = 0; + + for (Path trainFile : trainFiles) { + try (BufferedReader r = openGzipped(trainFile)) { + String line; + while ((line = r.readLine()) != null) { + int prev = -1; + for (int i = 0; i < line.length(); ) { + int cp = line.codePointAt(i); + i += Character.charCount(cp); + Character.UnicodeScript s = Character.UnicodeScript.of(cp); + if (s == Character.UnicodeScript.COMMON + || s == Character.UnicodeScript.INHERITED + || s == Character.UnicodeScript.UNKNOWN) { + continue; + } + int bucket = scriptBucketMap.getOrDefault(s.name(), otherBucket); + if (prev >= 0) { + counts[prev * numBuckets + bucket]++; + totalTransitions++; + } + prev = bucket; + } + } + } + } + System.out.printf("%,d script transitions across %d files%n", totalTransitions, trainFiles.size()); + return laplaceSmoothLogProb(counts, numBuckets); + } + + /** + * Calibrates the script-transition feature by computing mu and sigma over pooled + * dev windows from all scripts. + * + * @return float[2] = {mu, sigma} + */ + static float[] calibrateScriptTransitions(List<Path> devFiles, + float[] scriptTransTable, + Map<String, Integer> scriptBucketMap, + int numBuckets) throws IOException { + List<Double> scores = new ArrayList<>(); + int otherBucket = numBuckets - 1; + for (Path devFile : devFiles) { + List<String> windows = sampleSubstrings(devFile, 200, CALIB_LENGTHS, 45); + for (String window : windows) { + double raw = rawScriptTransitionLogProb(window, scriptTransTable, + scriptBucketMap, numBuckets, otherBucket); + if (!Double.isNaN(raw)) { + scores.add(raw); + } + } + } + System.out.printf("%,d dev windows pooled%n", scores.size()); + return muSigma(scores); + } + + /** + * Returns the mean script-transition log-probability for a string, or + * {@link Double#NaN} if there are fewer than two non-neutral codepoints. + * Uses raw {@link Character.UnicodeScript} values (no SCRIPT_MODEL_FALLBACK). + */ + private static double rawScriptTransitionLogProb(String text, float[] table, + Map<String, Integer> bucketMap, + int numBuckets, int otherBucket) { + int prev = -1; + double sum = 0; + int count = 0; + for (int i = 0; i < text.length(); ) { + int cp = text.codePointAt(i); + i += Character.charCount(cp); + Character.UnicodeScript s = Character.UnicodeScript.of(cp); + if (s == Character.UnicodeScript.COMMON + || s == Character.UnicodeScript.INHERITED + || s == Character.UnicodeScript.UNKNOWN) { + continue; + } + int bucket = bucketMap.getOrDefault(s.name(), otherBucket); + if (prev >= 0) { + sum += table[prev * numBuckets + bucket]; + count++; + } + prev = bucket; + } + return count > 0 ? sum / count : Double.NaN; + } + + /** + * Builds a character→character remap table for a (sourceCodec, wrongCodec) pair. + * For every byte 0x80–0xFF, if the two codecs decode it to different characters + * (and neither produces the replacement character U+FFFD), the source character + * maps to the wrong-codec character. + * + * <p>Returns an empty map if either codec is unavailable on this JVM. + */ + static Map<Character, Character> buildRemapTable(String sourceCodec, String wrongCodec) { + Charset src, wrong; + try { + src = Charset.forName(sourceCodec); + wrong = Charset.forName(wrongCodec); + } catch (UnsupportedCharsetException e) { + return Collections.emptyMap(); + } + Map<Character, Character> table = new HashMap<>(); + byte[] singleByte = new byte[1]; + for (int b = 0x80; b <= 0xFF; b++) { + singleByte[0] = (byte) b; + String fromSrc = new String(singleByte, src); + String fromWrong = new String(singleByte, wrong); + if (fromSrc.length() == 1 && fromWrong.length() == 1 + && fromSrc.charAt(0) != '\uFFFD' && fromWrong.charAt(0) != '\uFFFD' + && fromSrc.charAt(0) != fromWrong.charAt(0)) { + table.put(fromSrc.charAt(0), fromWrong.charAt(0)); + } + } + return table; + } + + /** + * Replaces characters using a pre-computed wrong-codec remap table, simulating + * the effect of encoding text in one charset and decoding it in another. + * Only characters present in the remap table are candidates for replacement. + * + * <p>This produces realistic mojibake: German umlauts becoming Hebrew letters, + * Polish characters becoming Western accents, Cyrillic becoming Latin symbols, etc. + * + * @param remapTable source-char → wrong-char substitution table (from {@link #buildRemapTable}) + * @param rate fraction of remappable characters to replace [0, 1] + */ + static String wrongCodecRemap(String text, Map<Character, Character> remapTable, + double rate, Random rng) { + if (text.isEmpty() || remapTable.isEmpty()) { + return text; + } + int[] codepoints = text.codePoints().toArray(); + for (int i = 0; i < codepoints.length; i++) { + if (codepoints[i] < 0x10000 && rng.nextDouble() < rate) { + Character replacement = remapTable.get((char) codepoints[i]); + if (replacement != null) { + codepoints[i] = replacement; + } + } + } + return new String(codepoints, 0, codepoints.length); + } + + /** + * Collects a sample of codepoints from each raw {@link Character.UnicodeScript} + * found across all training files. Used to build the foreign-script codepoint + * pools for the cross-script substitution distortion. + * + * @param maxPerScript maximum distinct codepoints to collect per script + * @return map from raw UnicodeScript name → list of sampled codepoints + */ + static Map<String, List<Integer>> collectScriptCodepoints(List<Path> trainFiles, + int maxPerScript) + throws IOException { + Map<String, Set<Integer>> collected = new HashMap<>(); + for (Path trainFile : trainFiles) { + try (BufferedReader r = openGzipped(trainFile)) { + String line; + while ((line = r.readLine()) != null) { + for (int i = 0; i < line.length(); ) { + int cp = line.codePointAt(i); + i += Character.charCount(cp); + Character.UnicodeScript s = Character.UnicodeScript.of(cp); + if (s == Character.UnicodeScript.COMMON + || s == Character.UnicodeScript.INHERITED + || s == Character.UnicodeScript.UNKNOWN) { + continue; + } + Set<Integer> pool = collected.computeIfAbsent( + s.name(), k -> new HashSet<>()); + if (pool.size() < maxPerScript) { + pool.add(cp); + } + } + } + } + } + Map<String, List<Integer>> result = new HashMap<>(collected.size() * 2); + for (Map.Entry<String, Set<Integer>> e : collected.entrySet()) { + result.put(e.getKey(), new ArrayList<>(e.getValue())); + } + return result; + } + + /** + * Replaces a random fraction of characters with codepoints drawn from scripts + * that do NOT appear in the source text. Simulates real-world charset encoding + * errors where accented characters in one script are misread as characters from + * a completely different script — e.g., German umlauts (ä, ö, ü) becoming + * Hebrew letters when CP1252-encoded text is decoded as CP1255. + * + * @param rate fraction of characters to replace [0, 1] + * @param scriptCodepoints map from raw UnicodeScript name → pool of codepoints + */ + static String injectCrossScriptChars(String text, double rate, Random rng, + Map<String, List<Integer>> scriptCodepoints) { + if (text.isEmpty() || scriptCodepoints.isEmpty()) { + return text; + } + + // Identify which scripts appear in the source text + Set<String> sourceScripts = new HashSet<>(); + for (int i = 0; i < text.length(); ) { + int cp = text.codePointAt(i); + i += Character.charCount(cp); + Character.UnicodeScript s = Character.UnicodeScript.of(cp); + if (s != Character.UnicodeScript.COMMON + && s != Character.UnicodeScript.INHERITED + && s != Character.UnicodeScript.UNKNOWN) { + sourceScripts.add(s.name()); + } + } + + // Build pool of codepoints from all other scripts + List<Integer> foreignPool = new ArrayList<>(); + for (Map.Entry<String, List<Integer>> e : scriptCodepoints.entrySet()) { + if (!sourceScripts.contains(e.getKey())) { + foreignPool.addAll(e.getValue()); + } + } + if (foreignPool.isEmpty()) { + return text; + } + + int[] codepoints = text.codePoints().toArray(); + for (int i = 0; i < codepoints.length; i++) { + if (rng.nextDouble() < rate) { + codepoints[i] = foreignPool.get(rng.nextInt(foreignPool.size())); + } + } + return new String(codepoints, 0, codepoints.length); + } + private static void printUsage() { System.err.println("Usage: TrainJunkModel [options]"); System.err.println(" --data-dir <path> Directory with {script}.train.gz / .dev.gz files"); diff --git a/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin b/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin index bebc3293fa..050a5c4054 100644 Binary files a/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin and b/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin differ
