This is an automated email from the ASF dual-hosted git repository.

tballison pushed a commit to branch junk-detector-v6
in repository https://gitbox.apache.org/repos/asf/tika.git

commit c9bc39e6414046679de8d9e7da54878f23aa26d1
Author: tballison <[email protected]>
AuthorDate: Thu May 14 12:02:51 2026 -0400

    junk-detector: add --min-bigram-count to TrainJunkModel
    
    New optional flag prunes F1 bigrams whose global per-pair count is
    below the threshold from the codepoint-bigram hash table and Bloom
    filter.  Unigram counts and backoff are unaffected.
    
    When the flag is omitted (or set to 1), behavior is byte-identical to
    the previous code path; the existing 2-arg overload of
    trainCodepointHashTables is preserved as a thin wrapper.
    
    When >= 2, the trainer makes a pre-pass over all *.train.gz files to
    tally per-pair occurrence counts in a HashMap<Long,long[]>, then in
    the main pass only emits bigrams whose tally meets the cutoff.  Pre-
    pass memory is bounded by the distinct-pair count (~450K pairs on the
    current 34-script madlad corpus, ~50 MB heap).
    
    Rationale: ablation on the dev split (held-out from training) shows
    that min_bigram_count=3 cuts the v6 model from 1456 KB -> 889 KB
    (-39%) and macro FPR from 0.018 -> 0.007 (-61%) with macro TPR only
    moving 0.890 -> 0.883.  Per-distortion Cohen's d goes up on the
    realistic junk modes (byte-shuffle, byte-swap, wrong-codec) and only
    down on the synthetic inject distortion, where baseline d ~ 11.86
    saturates well past any operating threshold anyway.  See discussion
    in 20260514-junk-retrain-v6.md.
    
    The singletons dropped are mostly OCR artifacts, proper nouns, and
    typos that inflate the clean-side distribution tail without
    contributing real distributional information.
    
    All 24 existing tests pass with the change.
    
    Co-authored-by: Cursor <[email protected]>
---
 .../tika/ml/junkdetect/tools/TrainJunkModel.java   | 96 +++++++++++++++++++---
 1 file changed, 85 insertions(+), 11 deletions(-)

diff --git 
a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java
 
b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java
index 8855ac9338..2d95d7db5e 100644
--- 
a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java
+++ 
b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java
@@ -202,6 +202,7 @@ public class TrainJunkModel {
                 "datasets", "madlad", "junkdetect");
         Path output = dataDir.resolve("junkdetect.bin");
         int bloomBits = V6_BLOOM_BITS_DEFAULT;
+        int minBigramCount = 1;
 
         for (int i = 0; i < args.length; i++) {
             switch (args[i]) {
@@ -218,6 +219,13 @@ public class TrainJunkModel {
                         System.exit(1);
                     }
                     break;
+                case "--min-bigram-count":
+                    minBigramCount = Integer.parseInt(args[++i]);
+                    if (minBigramCount < 1) {
+                        System.err.println("ERROR: --min-bigram-count must be 
>= 1");
+                        System.exit(1);
+                    }
+                    break;
                 default:
                     System.err.println("Unknown argument: " + args[i]);
                     printUsage();
@@ -234,6 +242,7 @@ public class TrainJunkModel {
                 bloomBits, bloomBits / 8 / 1024, V6_BLOOM_K);
         System.out.printf( "  fnv_seed:           0x%08X%n", V6_FNV_SEED);
         System.out.printf( "  backoff_alpha:      %.2f%n", V6_BACKOFF_ALPHA);
+        System.out.printf( "  min_bigram_count:   %d%n", minBigramCount);
 
         if (!Files.isDirectory(dataDir)) {
             System.err.println("ERROR: data-dir not found: " + dataDir);
@@ -273,7 +282,7 @@ public class TrainJunkModel {
         System.out.println("\n--- Phase 1: global codepoint-hash tables + 
Bloom ---");
         t0 = System.currentTimeMillis();
         System.out.print("  Training global codepoint-bigram + unigram + 
Bloom... ");
-        F1Tables f1Tables = trainCodepointHashTables(trainFiles, bloomBits);
+        F1Tables f1Tables = trainCodepointHashTables(trainFiles, bloomBits, 
minBigramCount);
         System.out.printf("done (%dms)%n", System.currentTimeMillis() - t0);
         System.out.print(f1Tables.statsString());
 
@@ -735,12 +744,62 @@ public class TrainJunkModel {
      */
     public static F1Tables trainCodepointHashTables(List<Path> trainFiles, int 
bloomBits)
             throws IOException {
+        return trainCodepointHashTables(trainFiles, bloomBits, 1);
+    }
+
+    /**
+     * Same as the 2-arg overload, but only bigrams with global per-pair count
+     * &gt;= {@code minBigramCount} contribute to the bigram hash table and the
+     * Bloom filter.  Unigrams are always counted (used for backoff).  When
+     * {@code minBigramCount == 1} this is a no-op and the single-pass code
+     * path runs.  Otherwise a first pass tallies per-pair counts into an
+     * in-memory map and a second pass emits only frequent pairs.
+     */
+    public static F1Tables trainCodepointHashTables(List<Path> trainFiles, int 
bloomBits,
+            int minBigramCount) throws IOException {
         long[] bigramCounts = new long[V6_BIGRAM_BUCKETS];
         long[] unigramCounts = new long[V6_UNIGRAM_BUCKETS];
         long bigramTotal = 0;
         long unigramTotal = 0;
         long[] bloomBitArr = new long[(bloomBits + 63) >> 6];
 
+        HashMap<Long, long[]> pairTallies = null;
+        if (minBigramCount > 1) {
+            System.out.printf("  pre-pass: tallying per-pair counts "
+                    + "(min_bigram_count=%d)%n", minBigramCount);
+            pairTallies = new HashMap<>(1 << 18);
+            for (Path trainFile : trainFiles) {
+                try (BufferedReader r = openGzipped(trainFile)) {
+                    String line;
+                    while ((line = r.readLine()) != null) {
+                        int prevCp = -1;
+                        for (int i = 0; i < line.length(); ) {
+                            int cp = line.codePointAt(i);
+                            i += Character.charCount(cp);
+                            if (prevCp >= 0) {
+                                long packed = ((long) prevCp << 24) | (cp & 
0xFFFFFFL);
+                                long[] c = pairTallies.get(packed);
+                                if (c == null) {
+                                    pairTallies.put(packed, new long[]{1L});
+                                } else {
+                                    c[0]++;
+                                }
+                            }
+                            prevCp = cp;
+                        }
+                    }
+                }
+            }
+            int kept = 0;
+            int dropped = 0;
+            for (long[] c : pairTallies.values()) {
+                if (c[0] >= minBigramCount) kept++;
+                else dropped++;
+            }
+            System.out.printf("  pre-pass: distinct pairs=%,d  kept=%,d  
dropped=%,d%n",
+                    pairTallies.size(), kept, dropped);
+        }
+
         for (Path trainFile : trainFiles) {
             try (BufferedReader r = openGzipped(trainFile)) {
                 String line;
@@ -754,12 +813,20 @@ public class TrainJunkModel {
                         unigramCounts[uBucket]++;
                         unigramTotal++;
                         if (prevCp >= 0) {
-                            int bBucket = (int) (fnv1aBigramV6(prevCp, cp, 
V6_FNV_SEED)
-                                    & (V6_BIGRAM_BUCKETS - 1));
-                            bigramCounts[bBucket]++;
-                            bigramTotal++;
-                            bloomAddV6(bloomBitArr, bloomBits, V6_BLOOM_K,
-                                    prevCp, cp, V6_FNV_SEED);
+                            boolean accept = true;
+                            if (pairTallies != null) {
+                                long packed = ((long) prevCp << 24) | (cp & 
0xFFFFFFL);
+                                long[] c = pairTallies.get(packed);
+                                accept = c != null && c[0] >= minBigramCount;
+                            }
+                            if (accept) {
+                                int bBucket = (int) (fnv1aBigramV6(prevCp, cp, 
V6_FNV_SEED)
+                                        & (V6_BIGRAM_BUCKETS - 1));
+                                bigramCounts[bBucket]++;
+                                bigramTotal++;
+                                bloomAddV6(bloomBitArr, bloomBits, V6_BLOOM_K,
+                                        prevCp, cp, V6_FNV_SEED);
+                            }
                         }
                         prevCp = cp;
                     }
@@ -1458,9 +1525,16 @@ public class TrainJunkModel {
 
     private static void printUsage() {
         System.err.println("Usage: TrainJunkModel [options]");
-        System.err.println("  --data-dir <path>  Directory with 
{script}.train.gz / .dev.gz files");
-        System.err.println("                     (default: 
~/datasets/madlad/junkdetect)");
-        System.err.println("  --output   <path>  Output model file");
-        System.err.println("                     (default: 
{data-dir}/junkdetect.bin)");
+        System.err.println("  --data-dir <path>         Directory with 
{script}.train.gz / .dev.gz files");
+        System.err.println("                            (default: 
~/datasets/madlad/junkdetect)");
+        System.err.println("  --output   <path>         Output model file");
+        System.err.println("                            (default: 
{data-dir}/junkdetect.bin)");
+        System.err.println("  --bloom-bits <n>          F1 Bloom filter size 
in bits (multiple of 64)");
+        System.err.println("  --min-bigram-count <n>    Drop F1 bigrams with 
global per-pair count < n.");
+        System.err.println("                            n>=2 enables a 
pre-pass that tallies per-pair");
+        System.err.println("                            counts; rare bigrams 
(typically OCR/proper-noun");
+        System.err.println("                            noise) are excluded 
from the hash table and");
+        System.err.println("                            Bloom filter, cutting 
model size and FPR with");
+        System.err.println("                            negligible TPR impact. 
 Default: 1 (no pruning).");
     }
 }

Reply via email to