This is an automated email from the ASF dual-hosted git repository. tballison pushed a commit to branch junk-detector-v6 in repository https://gitbox.apache.org/repos/asf/tika.git
commit c9bc39e6414046679de8d9e7da54878f23aa26d1 Author: tballison <[email protected]> AuthorDate: Thu May 14 12:02:51 2026 -0400 junk-detector: add --min-bigram-count to TrainJunkModel New optional flag prunes F1 bigrams whose global per-pair count is below the threshold from the codepoint-bigram hash table and Bloom filter. Unigram counts and backoff are unaffected. When the flag is omitted (or set to 1), behavior is byte-identical to the previous code path; the existing 2-arg overload of trainCodepointHashTables is preserved as a thin wrapper. When >= 2, the trainer makes a pre-pass over all *.train.gz files to tally per-pair occurrence counts in a HashMap<Long,long[]>, then in the main pass only emits bigrams whose tally meets the cutoff. Pre- pass memory is bounded by the distinct-pair count (~450K pairs on the current 34-script madlad corpus, ~50 MB heap). Rationale: ablation on the dev split (held-out from training) shows that min_bigram_count=3 cuts the v6 model from 1456 KB -> 889 KB (-39%) and macro FPR from 0.018 -> 0.007 (-61%) with macro TPR only moving 0.890 -> 0.883. Per-distortion Cohen's d goes up on the realistic junk modes (byte-shuffle, byte-swap, wrong-codec) and only down on the synthetic inject distortion, where baseline d ~ 11.86 saturates well past any operating threshold anyway. See discussion in 20260514-junk-retrain-v6.md. The singletons dropped are mostly OCR artifacts, proper nouns, and typos that inflate the clean-side distribution tail without contributing real distributional information. All 24 existing tests pass with the change. Co-authored-by: Cursor <[email protected]> --- .../tika/ml/junkdetect/tools/TrainJunkModel.java | 96 +++++++++++++++++++--- 1 file changed, 85 insertions(+), 11 deletions(-) diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java index 8855ac9338..2d95d7db5e 100644 --- a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java @@ -202,6 +202,7 @@ public class TrainJunkModel { "datasets", "madlad", "junkdetect"); Path output = dataDir.resolve("junkdetect.bin"); int bloomBits = V6_BLOOM_BITS_DEFAULT; + int minBigramCount = 1; for (int i = 0; i < args.length; i++) { switch (args[i]) { @@ -218,6 +219,13 @@ public class TrainJunkModel { System.exit(1); } break; + case "--min-bigram-count": + minBigramCount = Integer.parseInt(args[++i]); + if (minBigramCount < 1) { + System.err.println("ERROR: --min-bigram-count must be >= 1"); + System.exit(1); + } + break; default: System.err.println("Unknown argument: " + args[i]); printUsage(); @@ -234,6 +242,7 @@ public class TrainJunkModel { bloomBits, bloomBits / 8 / 1024, V6_BLOOM_K); System.out.printf( " fnv_seed: 0x%08X%n", V6_FNV_SEED); System.out.printf( " backoff_alpha: %.2f%n", V6_BACKOFF_ALPHA); + System.out.printf( " min_bigram_count: %d%n", minBigramCount); if (!Files.isDirectory(dataDir)) { System.err.println("ERROR: data-dir not found: " + dataDir); @@ -273,7 +282,7 @@ public class TrainJunkModel { System.out.println("\n--- Phase 1: global codepoint-hash tables + Bloom ---"); t0 = System.currentTimeMillis(); System.out.print(" Training global codepoint-bigram + unigram + Bloom... "); - F1Tables f1Tables = trainCodepointHashTables(trainFiles, bloomBits); + F1Tables f1Tables = trainCodepointHashTables(trainFiles, bloomBits, minBigramCount); System.out.printf("done (%dms)%n", System.currentTimeMillis() - t0); System.out.print(f1Tables.statsString()); @@ -735,12 +744,62 @@ public class TrainJunkModel { */ public static F1Tables trainCodepointHashTables(List<Path> trainFiles, int bloomBits) throws IOException { + return trainCodepointHashTables(trainFiles, bloomBits, 1); + } + + /** + * Same as the 2-arg overload, but only bigrams with global per-pair count + * >= {@code minBigramCount} contribute to the bigram hash table and the + * Bloom filter. Unigrams are always counted (used for backoff). When + * {@code minBigramCount == 1} this is a no-op and the single-pass code + * path runs. Otherwise a first pass tallies per-pair counts into an + * in-memory map and a second pass emits only frequent pairs. + */ + public static F1Tables trainCodepointHashTables(List<Path> trainFiles, int bloomBits, + int minBigramCount) throws IOException { long[] bigramCounts = new long[V6_BIGRAM_BUCKETS]; long[] unigramCounts = new long[V6_UNIGRAM_BUCKETS]; long bigramTotal = 0; long unigramTotal = 0; long[] bloomBitArr = new long[(bloomBits + 63) >> 6]; + HashMap<Long, long[]> pairTallies = null; + if (minBigramCount > 1) { + System.out.printf(" pre-pass: tallying per-pair counts " + + "(min_bigram_count=%d)%n", minBigramCount); + pairTallies = new HashMap<>(1 << 18); + for (Path trainFile : trainFiles) { + try (BufferedReader r = openGzipped(trainFile)) { + String line; + while ((line = r.readLine()) != null) { + int prevCp = -1; + for (int i = 0; i < line.length(); ) { + int cp = line.codePointAt(i); + i += Character.charCount(cp); + if (prevCp >= 0) { + long packed = ((long) prevCp << 24) | (cp & 0xFFFFFFL); + long[] c = pairTallies.get(packed); + if (c == null) { + pairTallies.put(packed, new long[]{1L}); + } else { + c[0]++; + } + } + prevCp = cp; + } + } + } + } + int kept = 0; + int dropped = 0; + for (long[] c : pairTallies.values()) { + if (c[0] >= minBigramCount) kept++; + else dropped++; + } + System.out.printf(" pre-pass: distinct pairs=%,d kept=%,d dropped=%,d%n", + pairTallies.size(), kept, dropped); + } + for (Path trainFile : trainFiles) { try (BufferedReader r = openGzipped(trainFile)) { String line; @@ -754,12 +813,20 @@ public class TrainJunkModel { unigramCounts[uBucket]++; unigramTotal++; if (prevCp >= 0) { - int bBucket = (int) (fnv1aBigramV6(prevCp, cp, V6_FNV_SEED) - & (V6_BIGRAM_BUCKETS - 1)); - bigramCounts[bBucket]++; - bigramTotal++; - bloomAddV6(bloomBitArr, bloomBits, V6_BLOOM_K, - prevCp, cp, V6_FNV_SEED); + boolean accept = true; + if (pairTallies != null) { + long packed = ((long) prevCp << 24) | (cp & 0xFFFFFFL); + long[] c = pairTallies.get(packed); + accept = c != null && c[0] >= minBigramCount; + } + if (accept) { + int bBucket = (int) (fnv1aBigramV6(prevCp, cp, V6_FNV_SEED) + & (V6_BIGRAM_BUCKETS - 1)); + bigramCounts[bBucket]++; + bigramTotal++; + bloomAddV6(bloomBitArr, bloomBits, V6_BLOOM_K, + prevCp, cp, V6_FNV_SEED); + } } prevCp = cp; } @@ -1458,9 +1525,16 @@ public class TrainJunkModel { private static void printUsage() { System.err.println("Usage: TrainJunkModel [options]"); - System.err.println(" --data-dir <path> Directory with {script}.train.gz / .dev.gz files"); - System.err.println(" (default: ~/datasets/madlad/junkdetect)"); - System.err.println(" --output <path> Output model file"); - System.err.println(" (default: {data-dir}/junkdetect.bin)"); + System.err.println(" --data-dir <path> Directory with {script}.train.gz / .dev.gz files"); + System.err.println(" (default: ~/datasets/madlad/junkdetect)"); + System.err.println(" --output <path> Output model file"); + System.err.println(" (default: {data-dir}/junkdetect.bin)"); + System.err.println(" --bloom-bits <n> F1 Bloom filter size in bits (multiple of 64)"); + System.err.println(" --min-bigram-count <n> Drop F1 bigrams with global per-pair count < n."); + System.err.println(" n>=2 enables a pre-pass that tallies per-pair"); + System.err.println(" counts; rare bigrams (typically OCR/proper-noun"); + System.err.println(" noise) are excluded from the hash table and"); + System.err.println(" Bloom filter, cutting model size and FPR with"); + System.err.println(" negligible TPR impact. Default: 1 (no pruning)."); } }
