This is an automated email from the ASF dual-hosted git repository. tballison pushed a commit to branch TIKA-4731-common-script in repository https://gitbox.apache.org/repos/asf/tika.git
commit 628b83f0f4a2f7e5cddc35330409235492282be7 Author: tallison <[email protected]> AuthorDate: Wed May 20 21:17:33 2026 -0400 TIKA-4731 - checkpoint --- .../apache/tika/ml/junkdetect/JunkDetector.java | 305 ++++++++++----- .../tika/ml/junkdetect/tools/TrainJunkModel.java | 431 ++++++++++++++++++--- .../org/apache/tika/ml/junkdetect/junkdetect.bin | Bin 2898974 -> 2466921 bytes .../tika/ml/junkdetect/JunkDetectorV7Test.java | 43 +- 4 files changed, 602 insertions(+), 177 deletions(-) diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java index aa14812cc2..b0541f0e1c 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java @@ -95,7 +95,10 @@ public final class JunkDetector implements TextQualityDetector { * prior versions live in git history and are not loadable by this * build. We deliberately don't keep dual-version paths so it's * impossible to confuse model versions. */ - static final int VERSION = 13; + // v14: COMMON-as-a-script — COMMON runs route to a pooled COMMON table, + // all runs scored with kind-typed boundary sentinels (int[] codepoints). + // Pre-v14 models score incompatibly and are rejected at load. + public static final int VERSION = 14; // Feature 1 — per-script open-addressed codepoint-bigram tables. // No global Bloom: empty-slot is the membership oracle. @@ -460,16 +463,11 @@ public final class JunkDetector implements TextQualityDetector { // ----------------------------------------------------------------------- private TextQualityScore scoreText(String text) { - // NFD-normalize before scoring so we match the training pipeline. - // NFD decomposes precomposed accented letters into base + combining - // marks (e.g. `ề` → `e` + U+0302 + U+0300); the trainer's - // extractFeaturesV7 + sampleSubstrings apply the same NFD so the - // per-script bigram tables index the decomposed form. NFD chosen - // over NFC so combining-mark scripts (Vietnamese precomposed, - // Indic, Thai) all surface their marks as separate codepoints, - // letting z5 (letter-adjacent-to-mark) discriminate uniformly. + // NFC-normalize before scoring so we match the training pipeline: + // the trainer's extractFeaturesV7 applies the same NFC, so the + // per-script bigram tables are indexed on the composed form. text = java.text.Normalizer.normalize(text, java.text.Normalizer.Form.NFC); - List<ScriptRun> runs = buildScriptRuns(text); + List<Run> runs = segmentRuns(text); // Document-level features computed once per scoring call. // z4 = script-transition log-prob (cross-script mixing). @@ -494,25 +492,25 @@ public final class JunkDetector implements TextQualityDetector { int totalBigramCount = 0; float[] dominantCal1 = null; - for (ScriptRun run : runs) { + for (Run run : runs) { if (!calibrations.containsKey(run.script)) { continue; // skip scripts not in model; handled by no-script fallback below } - byte[] runUtf8 = run.text.getBytes(StandardCharsets.UTF_8); - // Skip if too short to form a bigram by either metric. A single - // CJK char is 3 UTF-8 bytes (passes the byte filter) but 1 UTF-16 - // unit, and computeF1MeanLogP filters by text.length() < 2 and - // returns NaN — which would poison the weighted sum here. - if (runUtf8.length < 2 || run.text.length() < 2) { + byte[] runUtf8 = run.text().getBytes(StandardCharsets.UTF_8); + if (runUtf8.length == 0) { continue; } - float logit = scoreChunk(runUtf8, run.text, run.script, - z4, z5, z6, z7, z8, z9); + // Sentinels make even single-codepoint runs scorable (^x$), so no + // length<2 skip is needed; computeF1MeanLogP(withSentinels) is never + // NaN for a non-empty run. + float logit = scoreChunk(run, runUtf8, z4, z5, z6, z7, z8, z9); int n = runUtf8.length; weightedLogit += logit * n; totalBytes += n; - totalBigramCount += n - 1; - if (n > maxBytes) { + totalBigramCount += run.cps.length + 1; // sentinel-bounded bigrams + // COMMON is not a real script — never report it as dominant, but it + // still contributes its (cancelling / junk-flagging) logit above. + if (!run.isCommon() && n > maxBytes) { maxBytes = n; dominantScript = run.script; dominantCal1 = calibrations.get(run.script); @@ -562,7 +560,7 @@ public final class JunkDetector implements TextQualityDetector { } // Same NFC normalization as scoreText — keep train/infer aligned. text = java.text.Normalizer.normalize(text, java.text.Normalizer.Form.NFC); - List<ScriptRun> runs = buildScriptRuns(text); + List<Run> runs = segmentRuns(text); float z4 = computeScriptTransitionZ(text); float z5 = computeZ5LetterAdjacentToMarkRatio(text); float z6 = computeZ6ReplacementRatio(text); @@ -580,15 +578,15 @@ public final class JunkDetector implements TextQualityDetector { String dominantScript = null; int maxBytes = 0; - for (ScriptRun run : runs) { + for (Run run : runs) { if (!calibrations.containsKey(run.script)) { continue; } - byte[] runUtf8 = run.text.getBytes(StandardCharsets.UTF_8); - if (runUtf8.length < 2 || run.text.length() < 2) { - continue; // see scoreText: paired filter avoids NaN poisoning + byte[] runUtf8 = run.text().getBytes(StandardCharsets.UTF_8); + if (runUtf8.length == 0) { + continue; } - float[] zs = computeChunkZs(runUtf8, run.text, run.script); + float[] zs = computeChunkZs(run, runUtf8); float chunkLogit = combineLogit(zs[0], zs[1], zs[2], z4, z5, z6, z7, z8, z9, run.script); int n = runUtf8.length; @@ -597,7 +595,7 @@ public final class JunkDetector implements TextQualityDetector { weightedZ3 += zs[2] * n; weightedLogit += chunkLogit * n; totalBytes += n; - if (n > maxBytes) { + if (!run.isCommon() && n > maxBytes) { maxBytes = n; dominantScript = run.script; } @@ -681,14 +679,14 @@ public final class JunkDetector implements TextQualityDetector { * <p>z4/z5/z6 are document-level features passed in by the caller — * the chunk reuses the same document-wide values. */ - private float scoreChunk(byte[] utf8, String text, String script, + private float scoreChunk(Run run, byte[] utf8, float z4, float z5, float z6, float z7, float z8, float z9) { - if (utf8.length < 2 || !calibrations.containsKey(script)) { + if (!calibrations.containsKey(run.script)) { return 0f; } - float[] zs = computeChunkZs(utf8, text, script); - return combineLogit(zs[0], zs[1], zs[2], z4, z5, z6, z7, z8, z9, script); + float[] zs = computeChunkZs(run, utf8); + return combineLogit(zs[0], zs[1], zs[2], z4, z5, z6, z7, z8, z9, run.script); } // ----------------------------------------------------------------------- @@ -757,13 +755,17 @@ public final class JunkDetector implements TextQualityDetector { * via the public {@code computeZ2/3/4...} static helpers so * training and inference share the same math. */ - private float[] computeChunkZs(byte[] utf8, String text, String script) { - // Feature 1: per-script codepoint-bigram, calibrated per-script + private float[] computeChunkZs(Run run, byte[] utf8) { + String script = run.script; + // Feature 1: per-script codepoint-bigram over the SENTINEL-BOUNDED run, + // calibrated per-script. COMMON runs route to the COMMON table. V7Tables tables = f1TablesByScript.get(script); - float meanF1LogProb = computeCodepointF1MeanLogP(text, tables); + float meanF1LogProb = computeCodepointF1MeanLogP(run.withSentinels(), tables); float[] cal1 = calibrations.get(script); float z1 = (meanF1LogProb - cal1[0]) / cal1[1]; + // z2/z3 score the literal run text (sentinels have no bytes/blocks). + String text = run.text(); float z2 = computeZ2BlockTransitionQuantized(text, blockTables.get(script), blockTableQuant.get(script), blockCalibrations.get(script)); @@ -811,9 +813,9 @@ public final class JunkDetector implements TextQualityDetector { return ((float) (sum / count) - blockCal[0]) / blockCal[1]; } - private static float computeCodepointF1MeanLogP(String text, V7Tables tables) { + private static float computeCodepointF1MeanLogP(int[] cps, V7Tables tables) { if (tables == null) return Float.NaN; - double v = computeF1MeanLogP(text, tables); + double v = computeF1MeanLogP(cps, tables); return Double.isNaN(v) ? Float.NaN : (float) v; } @@ -989,22 +991,35 @@ public final class JunkDetector implements TextQualityDetector { if (text == null || text.length() < 2 || tables == null) { return Double.NaN; } + return computeF1MeanLogP(text.codePoints().toArray(), tables); + } + + /** + * Codepoint-array form of {@link #computeF1MeanLogP(String, V7Tables)}. + * + * <p>Operates on a pre-decoded codepoint sequence rather than a + * {@code String} so callers can splice run-boundary sentinel + * pseudo-codepoints (> {@code 0x10FFFF}, which a UTF-16 {@code String} + * cannot represent) into the sequence before scoring. Same F1 math — + * this is the single authoritative implementation; the {@code String} + * overload just decodes and delegates. + */ + public static double computeF1MeanLogP(int[] cps, V7Tables tables) { + if (cps == null || cps.length < 2 || tables == null) { + return Double.NaN; + } + // Every adjacent pair is scored. The old whitespace+whitespace skip + // (HTML-indentation guard) is gone: whitespace is now COMMON and lives + // in its own COMMON run/table, so it no longer pollutes a per-script + // mean. Dropping the skip also makes the training tally trivially + // match this scorer — both just count every adjacent pair. double sum = 0; int n = 0; int prevCp = -1; int prevIdx = -1; - for (int i = 0; i < text.length(); ) { - int cp = text.codePointAt(i); - i += Character.charCount(cp); + for (int cp : cps) { int curIdx = codepointToIndex(tables, cp); - if (prevCp >= 0 - && !(isAsciiWhitespace(prevCp) && isAsciiWhitespace(cp))) { - // γ-analog of NaiveBayesBigramEncodingDetector's - // whitespace-bigram skip: only the whitespace+whitespace - // case is dropped. (letter, space) and (space, letter) - // still score so that real inter-word context is kept, - // but (space, space) runs from HTML indentation don't - // dominate the mean with unigram-fallback penalties. + if (prevCp >= 0) { sum += scorePairF1V7(prevCp, prevIdx, cp, curIdx, tables); n++; } @@ -1014,19 +1029,6 @@ public final class JunkDetector implements TextQualityDetector { return n == 0 ? Double.NaN : sum / n; } - /** - * ASCII whitespace per the γ filter in - * {@code NaiveBayesBigramEncodingDetector}: tab, LF, VT, FF, CR, space. - * Deliberately ASCII-only (not {@link Character#isWhitespace(int)}) - * to match the encoding-detector's filter exactly and to leave the - * Unicode whitespace separators (no-break space, ideographic space, - * etc.) inside the bigram model. - */ - private static boolean isAsciiWhitespace(int cp) { - return cp == ' ' || cp == '\t' || cp == '\n' || cp == '\r' - || cp == 0x0B /* VT */ || cp == 0x0C /* FF */; - } - /** * Binary-search a codepoint in the script's index. * @@ -1164,64 +1166,155 @@ public final class JunkDetector implements TextQualityDetector { / scriptTransitionCalibration[1]; } + + // ----------------------------------------------------------------------- + // Shared run segmentation + boundary sentinels (COMMON-as-a-script). + // Used by BOTH inference (scoreText) and training table-building so the + // two tally / score exactly the same sentinel-bounded sequences. + // ----------------------------------------------------------------------- + + /** Model script key for the pooled COMMON (digits/punctuation/symbols) table. */ + public static final String COMMON_SCRIPT = "COMMON"; + + // Run-boundary sentinel pseudo-codepoints, typed by neighbor KIND + // (doc edge / COMMON run / script run). Chosen above U+10FFFF so they can + // never collide with a real codepoint; a UTF-16 String cannot hold them, + // which is why z1 scoring runs on int[] (see computeF1MeanLogP(int[],...)). + // They are appended to each table's codepointIndex at training time so + // binarySearch resolves them. Typing is by neighbor KIND, never identity + // (which script): identity typing would make the same charset-invariant + // COMMON run score differently per candidate decode — the pollution bug. + public static final int SENT_START_DOC = 0x110000; + public static final int SENT_START_COMMON = 0x110001; + public static final int SENT_START_SCRIPT = 0x110002; + public static final int SENT_END_DOC = 0x110003; + public static final int SENT_END_COMMON = 0x110004; + public static final int SENT_END_SCRIPT = 0x110005; + + /** All sentinels, ascending — appended to every per-script codepointIndex. */ + public static final int[] SENTINEL_CODEPOINTS = { + SENT_START_DOC, SENT_START_COMMON, SENT_START_SCRIPT, + SENT_END_DOC, SENT_END_COMMON, SENT_END_SCRIPT + }; + + /** Kind of an adjacent run, for sentinel typing. */ + public enum NeighborKind { DOC, COMMON, SCRIPT } + + private static int startSentinel(NeighborKind k) { + switch (k) { + case COMMON: return SENT_START_COMMON; + case SCRIPT: return SENT_START_SCRIPT; + default: return SENT_START_DOC; + } + } + + private static int endSentinel(NeighborKind k) { + switch (k) { + case COMMON: return SENT_END_COMMON; + case SCRIPT: return SENT_END_SCRIPT; + default: return SENT_END_DOC; + } + } + /** - * Splits text into maximal runs of the same Unicode script. - * COMMON, INHERITED, and UNKNOWN codepoints (spaces, punctuation, digits) - * are attached to the preceding script run so that inter-word bigrams are - * preserved within each run. Any leading COMMON characters are prepended - * to the first non-COMMON run. + * A maximal same-class run: either one model script or a COMMON run + * (digits/punctuation/symbols/space — {@code script == }{@link #COMMON_SCRIPT}). + * {@code left}/{@code right} record the KIND of the neighboring run so the + * boundary sentinels can be typed. */ - private List<ScriptRun> buildScriptRuns(String text) { - List<ScriptRun> runs = new ArrayList<>(); - String currentScript = null; - StringBuilder currentText = new StringBuilder(); - StringBuilder leadingCommon = new StringBuilder(); + public static final class Run { + public final String script; + public final int[] cps; // literal codepoints, no sentinels + public NeighborKind left = NeighborKind.DOC; + public NeighborKind right = NeighborKind.DOC; - for (int i = 0; i < text.length(); ) { - int cp = text.codePointAt(i); - i += Character.charCount(cp); + Run(String script, int[] cps) { + this.script = script; + this.cps = cps; + } - Character.UnicodeScript s = Character.UnicodeScript.of(cp); - if (s == Character.UnicodeScript.COMMON - || s == Character.UnicodeScript.INHERITED - || s == Character.UnicodeScript.UNKNOWN) { - if (currentScript != null) { - currentText.appendCodePoint(cp); - } else { - leadingCommon.appendCodePoint(cp); - } - continue; - } + public boolean isCommon() { + return COMMON_SCRIPT.equals(script); + } - String scriptName = SCRIPT_MODEL_FALLBACK.getOrDefault(s.name(), s.name()); + NeighborKind kind() { + return isCommon() ? NeighborKind.COMMON : NeighborKind.SCRIPT; + } - if (!scriptName.equals(currentScript)) { - if (currentScript != null && currentText.length() > 0) { - runs.add(new ScriptRun(currentScript, currentText.toString())); - } - currentScript = scriptName; - currentText = new StringBuilder(); - if (leadingCommon.length() > 0) { - currentText.append(leadingCommon); - leadingCommon.setLength(0); - } + /** Literal run text (sentinels excluded; every cp is ≤ U+10FFFF). */ + public String text() { + return new String(cps, 0, cps.length); + } + + /** + * Sentinel-bounded codepoint sequence for z1 scoring: + * {@code [startSentinel(left), cps..., endSentinel(right)]}. Only z1 + * (the bigram LM) sees sentinels; z2/z3 score the literal {@link #text()}. + */ + public int[] withSentinels() { + int[] out = new int[cps.length + 2]; + out[0] = startSentinel(left); + System.arraycopy(cps, 0, out, 1, cps.length); + out[out.length - 1] = endSentinel(right); + return out; + } + } + + /** COMMON-class predicate: COMMON, INHERITED, UNKNOWN all pool into COMMON. */ + private static String classKey(int cp) { + Character.UnicodeScript s = Character.UnicodeScript.of(cp); + if (s == Character.UnicodeScript.COMMON + || s == Character.UnicodeScript.INHERITED + || s == Character.UnicodeScript.UNKNOWN) { + return COMMON_SCRIPT; + } + return SCRIPT_MODEL_FALLBACK.getOrDefault(s.name(), s.name()); + } + + /** + * Segments a codepoint sequence into maximal same-class runs. COMMON / + * INHERITED / UNKNOWN codepoints form their own COMMON runs rather than + * attaching to a neighboring script — this is what stops charset-invariant + * content (digits, punctuation) from being charged to a per-script bigram + * table. Real scripts map through {@link #SCRIPT_MODEL_FALLBACK}; + * consecutive same-key codepoints form one run. Each run records its + * left/right neighbor KIND for boundary-sentinel typing. + */ + public static List<Run> segmentRuns(int[] cps) { + List<Run> runs = new ArrayList<>(); + if (cps == null || cps.length == 0) { + return runs; + } + int[] scratch = new int[cps.length]; + int len = 0; + String curKey = null; + for (int cp : cps) { + String key = classKey(cp); + if (curKey == null) { + curKey = key; + } else if (!key.equals(curKey)) { + runs.add(new Run(curKey, java.util.Arrays.copyOf(scratch, len))); + curKey = key; + len = 0; } - currentText.appendCodePoint(cp); + scratch[len++] = cp; } + runs.add(new Run(curKey, java.util.Arrays.copyOf(scratch, len))); - if (currentScript != null && currentText.length() > 0) { - runs.add(new ScriptRun(currentScript, currentText.toString())); + for (int i = 0; i < runs.size(); i++) { + if (i > 0) { + runs.get(i).left = runs.get(i - 1).kind(); + } + if (i < runs.size() - 1) { + runs.get(i).right = runs.get(i + 1).kind(); + } } return runs; } - private static final class ScriptRun { - final String script; - final String text; - ScriptRun(String script, String text) { - this.script = script; - this.text = text; - } + /** {@code String} convenience for {@link #segmentRuns(int[])}. */ + public static List<Run> segmentRuns(String text) { + return text == null ? new ArrayList<>() : segmentRuns(text.codePoints().toArray()); } /** diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java index 77517a4e3b..f087e20e3b 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java @@ -125,11 +125,6 @@ import org.apache.tika.ml.junkdetect.V7Tables; public class TrainJunkModel { static final String MAGIC = "JUNKDET1"; - /** Current file-format version produced by this trainer. v8 adds two - * global calibrations (z5 letter-adjacent-to-mark, z6 replacement-char) - * after the script-transition calibration and writes 6-feature LR - * weights per script. Matches {@link JunkDetector#VERSION}. */ - static final byte VERSION = 13; // ----------------------------------------------------------------------- // v7 model constants (per-script open-addressing codepoint-bigram tables) @@ -312,8 +307,13 @@ public class TrainJunkModel { public static void main(String[] args) throws IOException { Path dataDir = Paths.get(System.getProperty("user.home"), - "datasets", "madlad", "junkdetect"); - Path output = dataDir.resolve("junkdetect.bin"); + "data", "junk-augmented-symbolboost"); + // Output is written straight to the bundled resource (run from the repo + // root). Prior model versions live in git history; the model ships with + // the code and is intentionally NOT backwards compatible. + Path output = Paths.get( + "tika-ml/tika-ml-junkdetect/src/main/resources", + "org/apache/tika/ml/junkdetect/junkdetect.bin"); // Durable training parameters live in JunkDetectorTrainingConfig; this // tool deliberately refuses CLI overrides so a built model file's @@ -402,6 +402,12 @@ public class TrainJunkModel { // F3 control-byte calibration // ----------------------------------------------------------------------- TreeMap<String, V7Tables> f1TablesByScript = new TreeMap<>(); + // Shared COMMON pool: digits/punctuation/symbols/space runs from EVERY + // corpus go into one COMMON table so charset-invariant content cancels + // across candidate decodes instead of polluting a per-script table. + HashMap<Long, long[]> commonPairs = new HashMap<>(1 << 16); + HashMap<Integer, long[]> commonUnigrams = new HashMap<>(1 << 12); + long[] commonTotals = new long[2]; System.out.println("\n--- Phase 1: per-script F1 tables + calibrations ---"); for (Path trainFile : trainFiles) { String filename = trainFile.getFileName().toString(); @@ -413,8 +419,15 @@ public class TrainJunkModel { t0 = System.currentTimeMillis(); System.out.print(" Training V7 F1 tables (cp index + OA).."); - V7Tables v7 = trainV7TablesForScript(trainFile, minBigramCount, - loadFactor, keyIndexBits); + // Segment + route: this file's non-COMMON runs build its script + // table; its COMMON runs accumulate into the shared COMMON pool. + HashMap<Long, long[]> scriptPairs = new HashMap<>(1 << 14); + HashMap<Integer, long[]> scriptUnigrams = new HashMap<>(1 << 12); + long[] scriptTotals = new long[2]; + tallyFileRuns(trainFile, scriptPairs, scriptUnigrams, scriptTotals, + commonPairs, commonUnigrams, commonTotals); + V7Tables v7 = buildV7TablesFromCounts(scriptPairs, scriptUnigrams, + scriptTotals[0], minBigramCount, loadFactor, keyIndexBits); System.out.printf(" done (%dms)%n", System.currentTimeMillis() - t0); System.out.println(v7.statsString()); f1TablesByScript.put(script, v7); @@ -456,6 +469,33 @@ public class TrainJunkModel { classifierWeights.put(script, new float[]{1f / 4, 1f / 4, 1f / 4, 1f / 4, 0f}); } + // ----------------------------------------------------------------------- + // Phase 1b — pooled COMMON table (digits/punctuation/symbols/space). + // Built from COMMON runs accumulated across every corpus. Registered + // into f1Calibrations/classifierWeights only AFTER Phase 3 so no COMMON + // classifier is trained — COMMON uses z1-passthrough weights (approach + // (i)): its z1 cancels across candidate decodes and flags symbol salad. + // ----------------------------------------------------------------------- + System.out.println("\n--- Phase 1b: pooled COMMON table ---"); + t0 = System.currentTimeMillis(); + V7Tables commonV7 = buildV7TablesFromCounts(commonPairs, commonUnigrams, + commonTotals[0], minBigramCount, loadFactor, keyIndexBits); + f1TablesByScript.put(JunkDetector.COMMON_SCRIPT, commonV7); + System.out.printf(" %s done (%dms)%n", commonV7.statsString(), + System.currentTimeMillis() - t0); + List<Double> commonScores = new ArrayList<>(); + for (Path f : trainFiles) { + for (String window : sampleSubstrings(f, CALIB_SAMPLES, CALIB_LENGTHS, 77)) { + double s = windowMeanRunF1(window, commonV7, true); + if (!Double.isNaN(s)) { + commonScores.add(s); + } + } + } + float[] commonCal = muSigma(commonScores); + System.out.printf(" COMMON F1 cal: mu=%.4f sigma=%.4f (%,d windows)%n", + commonCal[0], commonCal[1], commonScores.size()); + // ----------------------------------------------------------------------- // Phase 2 — global script-transition table + supporting pools // ----------------------------------------------------------------------- @@ -484,29 +524,38 @@ public class TrainJunkModel { scriptTransCal[0], scriptTransCal[1], System.currentTimeMillis() - t0); // ----------------------------------------------------------------------- - // Phase 3 — per-script linear classifiers (9 features: z1-z9) + // Phase 3 — ONE global combiner, trained pointwise (clean>garbage, the + // absolute junkness scale) + contrastive (correct-decode>wrong, incl. RTL + // logical>reversed, the ranking task). Replaces the per-script LRs and + // their corruption recipes. Features are extracted through a temp model + // so they are exactly the inference features (no train/infer drift). // ----------------------------------------------------------------------- - System.out.println("\n--- Phase 3: per-script linear classifiers (z1..z9) ---"); - for (String script : f1Calibrations.keySet()) { - Path trainFile = trainFilePaths.get(script); - if (trainFile == null) { - System.out.printf(" [%s] WARNING: no train file, keeping equal-weight defaults%n", script); - continue; + System.out.println("\n--- Phase 3: global contrastive combiner ---"); + // COMMON (z1-passthrough) is registered before the temp save so feature + // extraction routes COMMON runs to the COMMON table. + f1Calibrations.put(JunkDetector.COMMON_SCRIPT, commonCal); + classifierWeights.put(JunkDetector.COMMON_SCRIPT, + new float[]{1f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f}); + + Path tmpModel = Files.createTempFile("junkdetect-feat", ".bin"); + saveModel(f1TablesByScript, f1Calibrations, blockTables, blockCalibrations, + controlCalibrations, classifierWeights, scriptBuckets, + scriptTransTable, scriptTransCal, tmpModel); + JunkDetector featExtractor = JunkDetector.loadFromPath(tmpModel); + Files.deleteIfExists(tmpModel); + + t0 = System.currentTimeMillis(); + float[] global = trainGlobalCombiner(featExtractor, trainFilePaths); + System.out.printf( + " global w=[%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f] bias=%.3f (%dms)%n", + global[0], global[1], global[2], global[3], global[4], global[5], + global[6], global[7], global[8], global[9], + System.currentTimeMillis() - t0); + // Apply the global combiner to every real script; COMMON keeps passthrough. + for (String s : f1TablesByScript.keySet()) { + if (!JunkDetector.COMMON_SCRIPT.equals(s)) { + classifierWeights.put(s, global); } - t0 = System.currentTimeMillis(); - System.out.printf(" [%s] training classifier... ", script); - float[] weights = trainClassifierV7(script, trainFile, - f1TablesByScript.get(script), f1Calibrations.get(script), - blockTables.get(script), blockCalibrations.get(script), - controlCalibrations.get(script), - scriptTransTable, scriptTransCal, scriptBucketMap, numScriptBuckets); - classifierWeights.put(script, weights); - System.out.printf( - "done — w=[%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f] bias=%.3f (%dms)%n", - weights[0], weights[1], weights[2], weights[3], - weights[4], weights[5], weights[6], weights[7], weights[8], - weights[9], - System.currentTimeMillis() - t0); } System.out.printf("%nWriting model (%d scripts, blockN=%d, scriptBuckets=%d) → %s%n", @@ -1152,40 +1201,100 @@ public class TrainJunkModel { int minBigramCount, double loadFactor, int keyIndexBits) throws IOException { - // --- Pass 1: tally pair and unigram counts. --- + // Single-file convenience: pools only this file's own (non-COMMON) + // script runs. The full pipeline routes COMMON across all corpora via + // tallyFileRuns into a shared COMMON table (the throwaway maps here + // discard this file's COMMON runs). HashMap<Long, long[]> pairCounts = new HashMap<>(1 << 14); HashMap<Integer, long[]> unigramCounts = new HashMap<>(1 << 12); - long bigramTotal = 0; - long unigramTotal = 0; + long[] scriptTotals = new long[2]; + tallyFileRuns(trainFile, pairCounts, unigramCounts, scriptTotals, + new HashMap<>(), new HashMap<>(), new long[2]); + return buildV7TablesFromCounts(pairCounts, unigramCounts, scriptTotals[0], + minBigramCount, loadFactor, keyIndexBits); + } + /** + * Reads a {@code *.train.gz}, NFC-normalizes each line, segments it via + * {@link JunkDetector#segmentRuns} and tallies sentinel-bounded bigrams: + * non-COMMON runs into the per-script maps, COMMON runs into the shared + * COMMON maps. NFC + sentinels match inference exactly (no train/infer + * drift). {@code *Totals[0]}=unigram count, {@code *Totals[1]}=bigrams. + */ + static void tallyFileRuns(Path trainFile, + HashMap<Long, long[]> scriptPairs, + HashMap<Integer, long[]> scriptUnigrams, + long[] scriptTotals, + HashMap<Long, long[]> commonPairs, + HashMap<Integer, long[]> commonUnigrams, + long[] commonTotals) throws IOException { try (BufferedReader r = openGzipped(trainFile)) { String line; while ((line = r.readLine()) != null) { - int prevCp = -1; - for (int i = 0; i < line.length(); ) { - int cp = line.codePointAt(i); - i += Character.charCount(cp); - long[] uc = unigramCounts.get(cp); - if (uc == null) { - unigramCounts.put(cp, new long[]{1L}); + if (line.isEmpty()) { + continue; + } + String norm = java.text.Normalizer.normalize( + line, java.text.Normalizer.Form.NFC); + for (JunkDetector.Run run : JunkDetector.segmentRuns(norm)) { + int[] seq = run.withSentinels(); + if (run.isCommon()) { + tallySeq(seq, commonPairs, commonUnigrams, commonTotals); } else { - uc[0]++; - } - unigramTotal++; - if (prevCp >= 0) { - long packed = ((long) prevCp << 32) | (cp & 0xFFFFFFFFL); - long[] bc = pairCounts.get(packed); - if (bc == null) { - pairCounts.put(packed, new long[]{1L}); - } else { - bc[0]++; - } - bigramTotal++; + tallySeq(seq, scriptPairs, scriptUnigrams, scriptTotals); } - prevCp = cp; } } } + } + + /** + * Tallies unigram + adjacent-pair counts of a codepoint sequence into the + * given maps ({@code totals[0]}+=unigrams, {@code totals[1]}+=bigrams). + * Matches {@link JunkDetector#computeF1MeanLogP(int[], V7Tables)} exactly + * (every adjacent pair, no skips) so trained tables fit scored sequences. + */ + private static void tallySeq(int[] seq, + HashMap<Long, long[]> pairs, + HashMap<Integer, long[]> unigrams, + long[] totals) { + int prevCp = -1; + for (int cp : seq) { + long[] uc = unigrams.get(cp); + if (uc == null) { + unigrams.put(cp, new long[]{1L}); + } else { + uc[0]++; + } + totals[0]++; + if (prevCp >= 0) { + long packed = ((long) prevCp << 32) | (cp & 0xFFFFFFFFL); + long[] bc = pairs.get(packed); + if (bc == null) { + pairs.put(packed, new long[]{1L}); + } else { + bc[0]++; + } + totals[1]++; + } + prevCp = cp; + } + } + + /** + * Builds the {@link V7Tables} F1 carrier from pre-tallied pair/unigram + * counts (see {@link #tallyFileRuns}). Drops pairs below + * {@code minBigramCount}, assigns dense codepoint indices, and packs an + * open-addressing bigram table; unigram log-probs use {@code unigramTotal} + * as the denominator. + */ + public static V7Tables buildV7TablesFromCounts( + HashMap<Long, long[]> pairCounts, + HashMap<Integer, long[]> unigramCounts, + long unigramTotal, + int minBigramCount, + double loadFactor, + int keyIndexBits) { // --- Filter pairs by count, collect kept-codepoint set. --- int totalDistinct = pairCounts.size(); @@ -1216,8 +1325,7 @@ public class TrainJunkModel { throw new IllegalStateException("Per-script codepoint count " + cpIndex.length + " exceeds 2^KEY_INDEX_BITS (= " + (maxIndex + 1) + "). Increase KEY_INDEX_BITS or apply" - + " a tighter pair-count filter for " - + trainFile.getFileName()); + + " a tighter pair-count filter."); } // --- Compute per-pair log-prob (add-α smoothed over kept pairs). --- @@ -1329,7 +1437,7 @@ public class TrainJunkModel { List<String> windows = sampleSubstrings(devGz, CALIB_SAMPLES, CALIB_LENGTHS, 42); List<Double> scores = new ArrayList<>(windows.size()); for (String window : windows) { - double score = JunkDetector.computeF1MeanLogP(window, tables); + double score = windowMeanRunF1(window, tables, false); if (!Double.isNaN(score)) { scores.add(score); } @@ -1338,6 +1446,33 @@ public class TrainJunkModel { return muSigma(scores); } + /** + * Byte-weighted mean of sentinel-bounded per-run F1 over the runs of + * {@code window} selected by {@code wantCommon} (COMMON runs if true, else + * non-COMMON script runs), scored against {@code tables}. This is the + * windowed analog of inference's per-run z1 input (pre-calibration): it + * segments and sentinel-bounds exactly as {@link JunkDetector#scoreText} + * does, so calibration and classifier z1 cannot drift from inference. + */ + static double windowMeanRunF1(String window, V7Tables tables, boolean wantCommon) { + String norm = java.text.Normalizer.normalize(window, java.text.Normalizer.Form.NFC); + double weighted = 0; + long bytes = 0; + for (JunkDetector.Run run : JunkDetector.segmentRuns(norm)) { + if (run.isCommon() != wantCommon) { + continue; + } + double f1 = JunkDetector.computeF1MeanLogP(run.withSentinels(), tables); + if (Double.isNaN(f1)) { + continue; + } + int n = run.text().getBytes(StandardCharsets.UTF_8).length; + weighted += f1 * n; + bytes += n; + } + return bytes == 0 ? Double.NaN : weighted / bytes; + } + // ----------------------------------------------------------------------- // v7 Phase 3: classifier feature extractor + orchestrator // ----------------------------------------------------------------------- @@ -1357,15 +1492,16 @@ public class TrainJunkModel { float[] scriptTransTable, float[] scriptTransCal, Map<String, Integer> scriptBucketMap, int numScriptBuckets) { - // NFD-normalize defensively — corruption modes (utf8AsWindows1252- + // NFC-normalize defensively — corruption modes (utf8AsWindows1252- // Mojibake, etc.) produce text in whatever form the encoder yields. // Matches JunkDetector.scoreText / scoreWithFeatureComponents. window = java.text.Normalizer.normalize(window, java.text.Normalizer.Form.NFC); byte[] utf8 = window.getBytes(StandardCharsets.UTF_8); - // z1: per-script codepoint-bigram mean log-prob + // z1: per-script codepoint-bigram mean log-prob over the window's + // non-COMMON runs, sentinel-bounded — mirrors inference's per-run z1. float z1 = 0f; - double rawF1 = JunkDetector.computeF1MeanLogP(window, tables); + double rawF1 = windowMeanRunF1(window, tables, false); if (!Double.isNaN(rawF1) && f1Cal != null && f1Cal[1] > 0) { z1 = ((float) rawF1 - f1Cal[0]) / f1Cal[1]; } @@ -1426,6 +1562,179 @@ public class TrainJunkModel { * (sample windows, corrupt half, fit LR, bias-calibrate on short * windows) but uses v7 per-script F1 tables. */ + private static final java.util.Set<String> RTL_SCRIPTS = java.util.Set.of( + "ARABIC", "HEBREW", "SYRIAC", "NKO", "THAANA"); + + private static boolean isRtl(String script) { + return RTL_SCRIPTS.contains(script); + } + + /** + * Trains ONE global combiner over z1..z9 + bias. The pointwise term anchors + * the absolute junkness scale (clean windows positive, generic garbage + * negative); the contrastive term ranks decodes (a clean window must outscore + * every wrong decode of its own bytes, including RTL logical vs reversed). + * Features come from {@code fx} (a temp model) so they match inference + * exactly. Reads only {@code *.train.gz}. + */ + static float[] trainGlobalCombiner(JunkDetector fx, + TreeMap<String, Path> trainFiles) + throws IOException { + List<float[]> good = new ArrayList<>(); + List<float[]> bad = new ArrayList<>(); + List<float[]> pairCorrect = new ArrayList<>(); + List<float[]> pairWrong = new ArrayList<>(); + Random rng = new Random(303); + + for (Map.Entry<String, Path> e : trainFiles.entrySet()) { + String script = e.getKey(); + boolean rtl = isRtl(script); + List<String> windows = sampleSubstrings(e.getValue(), + NUM_CLASSIFIER_SAMPLES, CALIB_LENGTHS, 300); + for (String w : windows) { + float[] fc = featureVector(fx, w); + if (fc == null) { + continue; + } + good.add(fc); + + // Contrastive: the clean decode must beat wrong decodes of its bytes. + for (int k = 0; k < 2; k++) { + String[] pr = BYTE_LEVEL_MOJIBAKE_PAIRS[ + rng.nextInt(BYTE_LEVEL_MOJIBAKE_PAIRS.length)]; + addContrastivePair(fx, w, byteLevelMojibake(w, pr[0], pr[1]), + fc, pairCorrect, pairWrong); + } + if ("LATIN".equals(script)) { + String[] pr = LATIN_TO_CJK_PAIRS[ + rng.nextInt(LATIN_TO_CJK_PAIRS.length)]; + addContrastivePair(fx, w, byteLevelMojibake(w, pr[0], pr[1]), + fc, pairCorrect, pairWrong); + } + if (rtl) { + addContrastivePair(fx, w, reverseRtlText(w), fc, + pairCorrect, pairWrong); + } + + // Pointwise garbage anchor (generic junk, no correct counterpart). + String junk; + int mode = rng.nextInt(3); + if (mode == 0) { + junk = injectControlChars(w, 0.15, rng); + } else if (mode == 1) { + junk = shuffleChars(w, rng); + } else { + junk = injectPrivateUseAreaChars(w, 0.12, rng); + } + float[] fb = featureVector(fx, junk); + if (fb != null) { + bad.add(fb); + } + } + } + System.out.printf(" examples: good=%,d bad=%,d pairs=%,d%n", + good.size(), bad.size(), pairCorrect.size()); + return fitContrastiveCombiner(good, bad, pairCorrect, pairWrong); + } + + private static void addContrastivePair(JunkDetector fx, String correct, + String wrong, float[] correctFeat, + List<float[]> pc, List<float[]> pw) { + if (wrong.equals(correct)) { + return; + } + float[] fw = featureVector(fx, wrong); + if (fw == null) { + return; + } + pc.add(correctFeat); + pw.add(fw); + } + + /** z1..z9 for {@code text} via the inference feature path; null if any NaN. */ + private static float[] featureVector(JunkDetector fx, String text) { + JunkDetector.FeatureComponents f = fx.scoreWithFeatureComponents(text); + float[] z = {f.z1, f.z2, f.z3, f.z4, f.z5, f.z6, f.z7, f.z8, f.z9}; + for (float v : z) { + if (Float.isNaN(v)) { + return null; + } + } + return z; + } + + /** + * Fits {@code w[9]+bias} by full-batch gradient descent: pointwise + * cross-entropy (good=1, bad=0) + pairwise logistic (correct must outscore + * wrong) + L2. Feature weights are projected non-negative (same orientation + * convention as {@link #fitLogisticRegression}); the bias is unconstrained. + */ + static float[] fitContrastiveCombiner(List<float[]> good, List<float[]> bad, + List<float[]> pairCorrect, + List<float[]> pairWrong) { + int f = 9; + double[] w = new double[f]; + double bias = 0; + double lr = 0.05; + double lambda = 0.01; + int epochs = 3000; + int ng = Math.max(1, good.size()); + int nb = Math.max(1, bad.size()); + int np = Math.max(1, pairCorrect.size()); + + for (int epoch = 0; epoch < epochs; epoch++) { + double[] grad = new double[f]; + double gradB = 0; + for (float[] x : good) { + double p = sigmoid(dotF(w, x) + bias); + for (int j = 0; j < f; j++) { + grad[j] += (p - 1) * x[j] / ng; + } + gradB += (p - 1) / ng; + } + for (float[] x : bad) { + double p = sigmoid(dotF(w, x) + bias); + for (int j = 0; j < f; j++) { + grad[j] += p * x[j] / nb; + } + gradB += p / nb; + } + for (int i = 0; i < pairCorrect.size(); i++) { + float[] c = pairCorrect.get(i); + float[] wr = pairWrong.get(i); + double s = sigmoid(-(dotF(w, c) - dotF(w, wr))); + for (int j = 0; j < f; j++) { + grad[j] += -(c[j] - wr[j]) * s / np; + } + } + for (int j = 0; j < f; j++) { + w[j] -= lr * (grad[j] + lambda * w[j]); + if (w[j] < 0) { + w[j] = 0; + } + } + bias -= lr * gradB; + } + float[] out = new float[f + 1]; + for (int j = 0; j < f; j++) { + out[j] = (float) w[j]; + } + out[f] = (float) bias; + return out; + } + + private static double dotF(double[] w, float[] x) { + double s = 0; + for (int j = 0; j < w.length; j++) { + s += w[j] * x[j]; + } + return s; + } + + private static double sigmoid(double z) { + return 1.0 / (1.0 + Math.exp(-z)); + } + static float[] trainClassifierV7(String script, Path devGz, V7Tables tables, float[] f1Cal, @@ -1576,7 +1885,7 @@ public class TrainJunkModel { /** * Writes a model file in the current binary format. Layout: gzip - * envelope around {@code JUNKDET1} magic + {@link #VERSION} byte + + * envelope around {@code JUNKDET1} magic + {@link JunkDetector#VERSION} byte + * global script-transition section + z5/z6 calibrations + per-script * sections (F1 tables, F2 block transitions, F3 control calibration, * 7-element LR weight vector = 6 weights + bias). See @@ -1596,7 +1905,7 @@ public class TrainJunkModel { new GZIPOutputStream(Files.newOutputStream(output)))) { dos.write(MAGIC.getBytes(StandardCharsets.UTF_8)); - dos.writeByte(VERSION); + dos.writeByte(JunkDetector.VERSION); // single source of truth dos.writeInt(f1Calibrations.size()); // Block-scheme version byte — bound to the JVM-independent diff --git a/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin b/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin index c09c38cdb4..17b96ea848 100644 Binary files a/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin and b/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin differ diff --git a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV7Test.java b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV7Test.java index 20219e854a..1763e245c9 100644 --- a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV7Test.java +++ b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorV7Test.java @@ -83,8 +83,9 @@ public class JunkDetectorV7Test { assertEquals("LATIN", score.getDominantScript(), "Dominant script should be LATIN"); // Quantization of [-4, -1] to 8 bits introduces ~0.012 nat / level. // Net z-error over 3 pairs bounded ~0.05; allow 0.3 to be safe. - assertEquals(3.0f, score.getZScore(), 0.3f, - "Expected z ≈ +3.0 for 'ABAB' (seen-pair + backoff mix)"); + // Sentinel-bounded scoring: expected = calibrated F1 over ^_doc A B A B $_doc. + assertEquals(expectedRunZ(tables, "ABAB", -5.0f, 1.0f), score.getZScore(), 0.05f, + "Inference z must match the authoritative F1 on the sentinel-bounded run"); } @Test @@ -117,9 +118,8 @@ public class JunkDetectorV7Test { JunkDetector detector = JunkDetector.loadFromPath(modelFile); TextQualityScore score = detector.score("ABAB"); - // mean = -1.0, z1 = (-1 - -5) / 1 = +4.0 - assertEquals(4.0f, score.getZScore(), 0.3f, - "All-seen 'ABAB' should score z ≈ +4"); + assertEquals(expectedRunZ(tables, "ABAB", -5.0f, 1.0f), score.getZScore(), 0.05f, + "All-seen 'ABAB': inference z must match authoritative sentinel-bounded F1"); } /** @@ -225,12 +225,12 @@ public class JunkDetectorV7Test { // JunkDetector.computeF1MeanLogP on the same text — if these // two ever disagree, the model's calibration is silently wrong. String probe = "pack my box with five dozen liquor jugs"; - double trainerRawMean = JunkDetector.computeF1MeanLogP(probe, tables); - float expectedZ1 = (float) ((trainerRawMean - f1CalLatin[0]) / f1CalLatin[1]); + float expectedZ1 = expectedRunZ(tables, probe, f1CalLatin[0], f1CalLatin[1]); TextQualityScore probeScore = detector.score(probe); - // logit = w1 * z1 + 0 + 0 + 0 + 0 = z1 in this test configuration. - assertEquals(expectedZ1, probeScore.getZScore(), 0.001f, - "Inference z1 must match trainer-computed z1 " + // logit = w1*z1 (rest 0); inference aggregates the same per-run + // sentinel-bounded F1 the helper computes — must agree (no drift). + assertEquals(expectedZ1, probeScore.getZScore(), 0.02f, + "Inference z1 must match the per-run sentinel-bounded F1 " + "(train/infer F1 math drift)"); } @@ -277,6 +277,29 @@ public class JunkDetectorV7Test { -10.0f, 1.0f); } + /** + * Mirrors inference's z1: byte-weighted mean over non-COMMON runs of the + * sentinel-bounded F1, then calibrated. COMMON runs are skipped (these + * minimal models have no COMMON table, so inference skips them too). + */ + private static float expectedRunZ(V7Tables tables, String text, float mu, float sigma) { + double weighted = 0; + long bytes = 0; + for (JunkDetector.Run r : JunkDetector.segmentRuns(text)) { + if (r.isCommon()) { + continue; + } + double f1 = JunkDetector.computeF1MeanLogP(r.withSentinels(), tables); + if (Double.isNaN(f1)) { + continue; + } + int n = r.text().getBytes(java.nio.charset.StandardCharsets.UTF_8).length; + weighted += f1 * n; + bytes += n; + } + return (float) ((weighted / bytes - mu) / sigma); + } + /** Quantize a single float to 8-bit unsigned using the explicit range. */ private static byte quantizeOne(float v, float min, float max) { float range = max - min;
