This is an automated email from the ASF dual-hosted git repository. tballison pushed a commit to branch universal-junk-detector in repository https://gitbox.apache.org/repos/asf/tika.git
commit 517d771aeacdee9eb635bd7aa3e95d1fbd11244f Author: tballison <[email protected]> AuthorDate: Thu Apr 23 11:20:16 2026 -0400 universal junk detector, take 1 --- docs/modules/ROOT/nav.adoc | 2 + .../ROOT/pages/advanced/junk-detection-build.adoc | 423 ++++++++++++++++ .../ROOT/pages/advanced/junk-detection.adoc | 216 ++++++++ tika-ml/pom.xml | 1 + tika-ml/tika-ml-junkdetect/pom.xml | 148 ++++++ .../apache/tika/ml/junkdetect/JunkDetector.java | 360 +++++++++++++ .../org/apache/tika/ml/junkdetect/JunkScore.java | 91 ++++ .../ml/junkdetect/tools/BuildJunkTrainingData.java | 559 +++++++++++++++++++++ .../tika/ml/junkdetect/tools/EvalJunkDetector.java | 531 +++++++++++++++++++ .../tika/ml/junkdetect/tools/TrainJunkModel.java | 312 ++++++++++++ .../org/apache/tika/ml/junkdetect/junkdetect.bin | Bin 0 -> 414029 bytes .../tika/ml/junkdetect/JunkDetectorSmokeTest.java | 189 +++++++ 12 files changed, 2832 insertions(+) diff --git a/docs/modules/ROOT/nav.adoc b/docs/modules/ROOT/nav.adoc index 819f9e2098..c533786bcd 100644 --- a/docs/modules/ROOT/nav.adoc +++ b/docs/modules/ROOT/nav.adoc @@ -49,6 +49,8 @@ ** xref:advanced/language-detection.adoc[Language Detection] ** xref:advanced/generative-language-model.adoc[Generative Language Model] ** xref:advanced/language-detection-build.adoc[Building the Language Detector] +** xref:advanced/junk-detection.adoc[Text Quality Scoring (Junk Detection)] +** xref:advanced/junk-detection-build.adoc[Building the Junk Detector] ** xref:advanced/robustness.adoc[Robustness] ** xref:advanced/setting-limits.adoc[Setting Limits] ** xref:advanced/spooling.adoc[Spooling] diff --git a/docs/modules/ROOT/pages/advanced/junk-detection-build.adoc b/docs/modules/ROOT/pages/advanced/junk-detection-build.adoc new file mode 100644 index 0000000000..27e6e8754e --- /dev/null +++ b/docs/modules/ROOT/pages/advanced/junk-detection-build.adoc @@ -0,0 +1,423 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + += Building the Junk Detector + +This page documents the training pipeline, model format, evaluation methodology, +and guidance for improving the junk detector model. For usage, see +xref:advanced/junk-detection.adoc[Text Quality Scoring (Junk Detection)]. + +== Overview + +The junk detector is a per-script byte-bigram language model. For each +Unicode script (Latin, Cyrillic, Arabic, Han, etc.) it maintains a 256×256 +table of `log P(byte_b | byte_a)` values — the probability of seeing byte `b` +immediately after byte `a` in clean UTF-8 text of that script. + +The pipeline has three stages: + +[source] +---- +1. BuildJunkTrainingData — collect and split corpus per script group +2. TrainJunkModel — train bigram tables and calibrate z-scores +3. EvalJunkDetector — measure discrimination quality +---- + +All three tools are packaged as a fat JAR via the `train` Maven profile: + +[source,bash] +---- +mvn -pl tika-ml/tika-ml-junkdetect package -Ptrain -DskipTests +---- + +The resulting JAR is `tika-ml-junkdetect-*-train.jar`. + +== Stage 1: Corpus collection (`BuildJunkTrainingData`) + +This tool collects clean UTF-8 sentences from language-specific source files, +groups them by Unicode script, allocates a byte budget proportional to +per-script bigram entropy, and writes 80/10/10 train/dev/test splits. + +=== Data format + +Source data lives in one directory per language (ISO 639 code), each containing +up to two files: + +`sentences_wikipedia.txt`:: + Line-numbered Wikipedia sentences: `{lineNum}{TAB}{text}`. + One sentence per line. + +`sentences_madlad.txt`:: + Line-numbered MADLAD-400 documents: `{lineNum}{TAB}{text}`. + Documents contain literal two-character `\n` escape sequences as + sub-sentence separators. The tool splits on these before processing. + +=== Script group detection + +For each language directory the dominant Unicode script is detected by +sampling up to 2,000 lines and histogramming `Character.UnicodeScript` over +all codepoints. The `COMMON`, `INHERITED`, and `UNKNOWN` pseudo-scripts are +excluded. The plurality script (with a 1% minimum floor to suppress spurious +wins on mixed-script text) determines which group that language belongs to. + +Languages that share the same dominant script are pooled together into one +training group. No script groups are hardcoded — the set of groups is derived +entirely from the data. + +=== Entropy-proportional byte budget + +All scripts are not equal: CJK text has thousands of distinct 3-byte UTF-8 +codepoints producing high byte-bigram entropy (~10.4 bits), while Arabic text +clusters in a narrow 0xD8–0xDB high-byte range (~7.2 bits). A naïve +sentence-count budget would badly over-represent low-entropy scripts. + +Instead the tool allocates a **total byte budget** (default 50 MB) across +script groups in proportion to their empirical byte-bigram Shannon entropy, +estimated from a 200 KB sample per group: + +[source] +---- +H(script) = -Σ p(a,b) · log₂ p(a,b) over all observed bigrams (a,b) + +budget(script) = totalBudget × H(script) / Σ H(all scripts) +---- + +Within each script group the budget is distributed evenly across its member +languages, ensuring no single language dominates the training data. + +=== Train/dev/test split + +After collecting and shuffling sentences, the tool writes three gzipped files +per script: + +[cols="1,1,3"] +|=== +| File | Split | Purpose + +| `{script}.train.gz` +| 80% +| Bigram count accumulation in `TrainJunkModel`. + +| `{script}.dev.gz` +| 10% +| Calibration (mu/sigma estimation) in `TrainJunkModel`. + Also used for iterative evaluation during development. + +| `{script}.test.gz` +| 10% +| **Held out completely.** Use only for final reported evaluation numbers. + Never use to make model or threshold decisions. +|=== + +=== Running corpus collection + +[source,bash] +---- +java -cp tika-ml-junkdetect-*-train.jar \ + org.apache.tika.ml.junkdetect.tools.BuildJunkTrainingData \ + --data-dir ~/datasets/madlad/data \ + --output-dir ~/datasets/madlad/junkdetect \ + --total-budget-bytes 50000000 +---- + +Key options: + +[cols="2,1,3"] +|=== +| Option | Default | Description + +| `--data-dir` +| `~/datasets/madlad/data` +| Root directory containing per-language subdirectories. + +| `--output-dir` +| `~/datasets/madlad/junkdetect` +| Where to write `{script}.train.gz`, `.dev.gz`, `.test.gz`, and `manifest.tsv`. + +| `--total-budget-bytes` +| `50000000` +| Total UTF-8 byte budget across all scripts. Increase for production runs. + +| `--min-bytes` +| `50` +| Minimum UTF-8 byte length for a sentence to be accepted. + +| `--max-punc-frac` +| `0.30` +| Maximum fraction of codepoints that may be ASCII punctuation or digits. + Filters out bullet lists, code snippets, and other non-prose content. + +| `--seed` +| `42` +| Random seed for reproducible shuffles. + +| `--dry-run` +| `false` +| Print script detection and entropy results without writing files. +|=== + +== Stage 2: Training (`TrainJunkModel`) + +For each script, this tool reads the `.train.gz` file, accumulates +byte-bigram counts, applies Laplace smoothing, computes log-probabilities, +then calibrates z-score statistics from the `.dev.gz` file. + +=== Bigram table training + +[source] +---- +for each sentence in {script}.train.gz: + utf8 = sentence.getBytes(UTF-8) + for each consecutive pair (a, b) in utf8: + counts[a * 256 + b]++ + +for each row a in 0..255: + rowTotal = Σ (counts[a * 256 + b] + 1) for b in 0..255 // Laplace add-1 + for each b in 0..255: + table[a * 256 + b] = log((counts[a * 256 + b] + 1) / rowTotal) +---- + +Laplace (add-1) smoothing is applied per row: every possible next byte is +given a pseudocount of 1, preventing log(0) for unseen bigrams and providing +a small but nonzero probability for novel byte sequences. + +=== Calibration + +For each sentence in `{script}.dev.gz`: + +[source] +---- +meanLogProb = Σ table[bigram] / (bytes - 1) +---- + +The calibration statistics are the mean (μ) and standard deviation (σ) of +`meanLogProb` across all dev sentences. At inference: + +[source] +---- +zScore = (meanLogProb - μ) / σ +---- + +A z-score of 0 means "exactly as likely as average clean text for this script." +Negative scores indicate text that is less likely than clean — i.e., garbled. + +=== Running training + +[source,bash] +---- +java -cp tika-ml-junkdetect-*-train.jar \ + org.apache.tika.ml.junkdetect.tools.TrainJunkModel \ + --data-dir ~/datasets/madlad/junkdetect \ + --output ~/datasets/madlad/junkdetect/junkdetect.bin +---- + +After training, copy the model to the classpath resource location: + +[source,bash] +---- +cp ~/datasets/madlad/junkdetect/junkdetect.bin \ + tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin +---- + +== Stage 3: Evaluation (`EvalJunkDetector`) + +The evaluator measures how well the model separates clean text from +corrupted text across scripts, distortion types, and string lengths. + +=== Distortion modes + +[cols="1,3"] +|=== +| Mode | Description + +| `inject` +| Random bytes (0x80–0xFF) are substituted at rate `r` of positions. + Tests from 1% injection (subtle corruption) to 90% (nearly all garbage). + +| `char-reverse` +| Codepoints are reversed (Unicode-aware, preserving surrogate pairs). + Produces valid UTF-8 but in nonsensical reading order. + Most meaningful for RTL scripts (Arabic, Hebrew) where reversed text + is a realistic failure mode; LTR script bigrams are nearly symmetric, + so detection is harder. + +| `byte-shuffle` +| All bytes are randomly shuffled (Fisher-Yates). + The most extreme corruption — destroys all sequential structure. +|=== + +=== Output files + +`detail.tsv`:: + One row per `(script, distortion, param, length)` cell, with columns: + `script`, `distortion`, `param`, `length`, `n_clean`, `n_corrupt`, + `mean_clean_z`, `mean_corrupt_z`, `cohens_d`, `fpr`, `tpr`. + +`summary.tsv`:: + Macro-averaged across scripts per `(distortion, param, length)`. + The `macro_cohens_d` column is the headline comparison metric. + +=== Key metrics + +**Cohen's d** (primary metric):: + Effect size separating clean from corrupted z-scores: ++ +[source] +---- +d = (mean_clean_z - mean_corrupt_z) / pooled_std +---- ++ +Higher is better. A value of 1.0 means the distributions are separated by +one pooled standard deviation. Values above 2.0 indicate strong, reliable +discrimination. + +**True positive rate (TPR)**:: + Fraction of corrupted samples with z < threshold (−2.0 by default). + Higher is better. + +**False positive rate (FPR)**:: + Fraction of clean samples with z < threshold. Should stay near 2–5%. + A well-calibrated model will have FPR ≈ 2.5% (since z < −2.0 corresponds + to the left tail of the standard normal for clean text). + +=== Running evaluation + +[source,bash] +---- +# During development: use the dev split +java -cp tika-ml-junkdetect-*-train.jar \ + org.apache.tika.ml.junkdetect.tools.EvalJunkDetector \ + --data-dir ~/datasets/madlad/junkdetect \ + --split dev \ + --output-dir ~/datasets/madlad/junkdetect/eval + +# Final reporting only: use the held-out test split +java -cp tika-ml-junkdetect-*-train.jar \ + org.apache.tika.ml.junkdetect.tools.EvalJunkDetector \ + --data-dir ~/datasets/madlad/junkdetect \ + --split test \ + --output-dir ~/datasets/madlad/junkdetect/eval-final +---- + +IMPORTANT: Use `--split test` only once, for final reporting. The test split +is completely held out and should never inform model or threshold decisions. + +=== Tracking improvement + +To compare two model versions: + +1. Train model A, run `EvalJunkDetector --split dev`, save `summary.tsv` as + `summary-A.tsv`. +2. Retrain as model B, run eval again, save as `summary-B.tsv`. +3. Diff the `macro_cohens_d` column. Positive change = improvement. + +The `# OVERALL` line at the bottom of `summary.tsv` gives a single-number +summary of model quality. + +== Model binary format (JUNKDET1) + +The model is stored as a gzipped binary file. Auto-detection of the gzip +wrapper is done by inspecting the first two bytes (magic `0x1f 0x8b`). + +[source] +---- +[8 bytes] magic "JUNKDET1" (ASCII) +[1 byte] version = 1 +[4 bytes] num_scripts (int32 big-endian) + +For each script (sorted by name): + [2 bytes] name length (uint16 big-endian) + [N bytes] script name (UTF-8) + [4 bytes] μ — mean of dev-set mean_bigram_logprob (float32 big-endian) + [4 bytes] σ — std deviation (float32 big-endian) + [65536×4 bytes] log-prob table (float32 big-endian, index = a*256+b) +---- + +The default classpath resource is +`org/apache/tika/ml/junkdetect/junkdetect.bin`. + +== Known limitations and improvement paths + +=== Baltic and closely related Latin scripts + +The LATIN script pools ~322 languages from Latin, Basic Latin, and extended +Latin alphabets. Baltic languages (Lithuanian, Latvian) use distinctive +diacritics encoded differently in cp1257 vs. cp1252, but these bigrams are +diluted by the large shared Latin vocabulary. The model correctly identifies +the winner but with low delta (< 0.5), below the production confidence +threshold of 1.0. + +**Possible improvements:** + +* Retrain with Baltic languages weighted more heavily within the LATIN group. +* Split LATIN into LATIN-WEST and LATIN-EAST sub-models, where LATIN-EAST + receives its own dedicated bigram table trained primarily on Baltic, Slavic + Latin (Polish, Czech, Slovak), and Romanian. + +=== RTL script reversal + +For Arabic and Hebrew, codepoint-reversal is a realistic failure mode (text +stored in the wrong visual order). The model detects this with moderate +Cohen's d at lengths ≥ 50 characters. Shorter strings (15–30 characters) +show weaker separation because there are too few bigrams to be statistically +reliable. + +**Possible improvement:** train a secondary short-text specialist model for +RTL scripts using finer-grained features (trigrams or unigram frequency +distributions). + +=== Scaling up + +The default 50 MB byte budget is a proof-of-concept setting. For production: + +* Increase `--total-budget-bytes` to 500 MB or more. +* Larger budgets improve calibration quality (tighter σ, more accurate μ) + and reduce variance on infrequent bigrams. +* The model binary grows only slightly (the 256×256 table is the same size + regardless of training set size) — only calibration quality improves. + +== Smoke tests + +Five smoke tests in `JunkDetectorSmokeTest` verify the bundled model: + +[cols="1,3"] +|=== +| Test | What it checks + +| `cleanVsGarbage` +| Clean English z-score > random high-byte garbage z-score. + +| `forwardVsReversedArabic` +| Forward Arabic z-score > codepoint-reversed Arabic z-score. + +| `cp1252VsCp1257OnBalticText` +| `compare()` picks cp1257 as the correct encoding for Lithuanian text. + Delta > 0.1 (weak; Baltic limitation documented above). + +| `cp1252VsCp1251OnRussianText` +| `compare()` picks cp1251 as the correct encoding for Russian text. + Delta > 1.0 (strong; Cyrillic bigrams are highly distinctive). + +| `cleanVsShuffledCjk` +| Clean Japanese UTF-8 z-score > byte-shuffled Japanese z-score. +|=== + +NOTE: Codepoint reversal of LTR scripts (Russian, Latin) is **not** a useful +smoke test — LTR byte-bigram distributions are nearly symmetric, so the model +cannot reliably distinguish forward from reversed text. The Russian test uses +codec comparison (cp1251 vs. cp1252) instead, which is the actual real-world +failure mode for Cyrillic text. diff --git a/docs/modules/ROOT/pages/advanced/junk-detection.adoc b/docs/modules/ROOT/pages/advanced/junk-detection.adoc new file mode 100644 index 0000000000..6c6f06037f --- /dev/null +++ b/docs/modules/ROOT/pages/advanced/junk-detection.adoc @@ -0,0 +1,216 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + += Text Quality Scoring (Junk Detection) + +The `tika-ml-junkdetect` module provides a language-agnostic scorer that +distinguishes clean natural-language text from garbled, corrupted, or +mis-decoded content — without needing to know the language in advance. + +== What it detects + +* **Mojibake** — text decoded with the wrong character set (e.g., a Windows-1251 + Russian document decoded as Windows-1252, producing Latin lookalike garbage) +* **Byte-level corruption** — random or partially-overwritten byte sequences that + produce structurally invalid UTF-8 +* **Reversed or shuffled text** — text that contains valid characters but in + nonsensical order, as can occur in bidirectional rendering failures or corrupted + OCR streams +* **OCR garbage** — low-confidence OCR output full of symbol noise + +It does _not_ detect incorrect language (e.g., an English document mistakenly +labeled as French) — use xref:advanced/language-detection.adoc[Language Detection] +for that. + +== How it works + +The scorer uses a per-script byte-bigram log-probability model trained on clean +Wikipedia and MADLAD-400 text. For each input it: + +1. **Identifies the dominant Unicode script** (Latin, Cyrillic, Arabic, Han, etc.) + by histogramming `Character.UnicodeScript` over all codepoints. +2. **Looks up the script's bigram table** — a 256×256 matrix of + `log P(byte_b | byte_a)` values trained on clean text for that script. +3. **Computes a mean log-probability** across all consecutive byte pairs in the + UTF-8 encoding of the input. +4. **Z-scores the result** against calibration statistics (mean and standard + deviation measured on a held-out set of clean text for the same script). + +The z-score is the primary output: a score of 0 means "exactly as expected for +clean text of this script"; a score of −3 means "three standard deviations worse +than clean"; a score of −10 means "almost certainly garbled." + +== Using the API + +Add the dependency to your project: + +[source,xml] +---- +<dependency> + <groupId>org.apache.tika</groupId> + <artifactId>tika-ml-junkdetect</artifactId> + <version>${tika.version}</version> +</dependency> +---- + +=== Scoring a string or byte array + +[source,java] +---- +JunkDetector detector = JunkDetector.loadFromClasspath(); + +// Score a string directly +JunkScore score = detector.score("The quick brown fox jumps over the lazy dog."); +System.out.println(score.getZScore()); // e.g. -0.74 — within normal range +System.out.println(score.getPClean()); // e.g. 0.32 — P(clean) via sigmoid + +// Score raw UTF-8 bytes (same result; use when you already have bytes) +byte[] utf8 = text.getBytes(StandardCharsets.UTF_8); +JunkScore score2 = detector.score(utf8); +---- + +`JunkDetector` is **immutable and thread-safe** after construction. Load it once +at application startup. + +=== Interpreting the score + +[cols="1,3"] +|=== +| Z-score range | Interpretation + +| > 0 +| Better than average clean text — high-quality, well-formed natural language. + +| −1 to 0 +| Within normal range for clean text. Most real documents fall here. + +| −1 to −2 +| Mildly degraded. May indicate noisy OCR, code-heavy text, or unusual domain + language. Not necessarily junk. + +| < −2 +| Two or more standard deviations below clean. Worth investigating. + A reasonable threshold for triggering re-OCR or re-encoding. + +| < −5 +| Almost certainly garbled. Wrong charset decoding, byte-reversed content, + or heavy corruption. +|=== + +The `JunkScore` also carries: + +* `getPClean()` — `sigmoid(z)`, a rough probability estimate in [0, 1] that the + text is clean. Useful for ranking candidates; the absolute value is not + calibrated as a true probability. +* `getCiLow()` / `getCiHigh()` — 95% confidence interval on the z-score. Narrow + on long texts, wide on short ones. Use these when making threshold decisions on + short strings. +* `getDominantScript()` — the Unicode script name used for scoring (e.g. `"LATIN"`, + `"CYRILLIC"`, `"ARABIC"`, `"HAN"`). If `isUnknown()` is true, the dominant + script had no model and scoring was not possible. + +=== Comparing two charset interpretations + +The `compare()` method is the primary use case for charset detection: +given the same raw bytes decoded two different ways, which decoding +looks more like natural language? + +[source,java] +---- +byte[] rawBytes = ...; // bytes from an unknown-encoding file + +JunkDetector.CompareResult result = + detector.compare(rawBytes, "cp1252", "cp1251"); + +System.out.println(result.winner()); // "A" or "B" +System.out.println(result.delta()); // z-score separation between the two + +if (result.winner().equals("B") && result.delta() > 1.0) { + // cp1251 is confidently the better decoding +} +---- + +The `delta()` is the absolute difference in z-scores between the two decodings. +As a rough guide: + +[cols="1,3"] +|=== +| Delta | Confidence + +| < 0.5 +| Very uncertain — both decodings look similar to the model. Fall back to + other heuristics. + +| 0.5 – 1.0 +| Weak signal — winner is likely correct but not assured. + +| 1.0 – 3.0 +| Useful signal. Trust the winner for most production purposes. + +| > 3.0 +| High confidence. One decoding is clearly more language-like. +|=== + +=== Listing known scripts + +[source,java] +---- +detector.knownScripts(); // returns Set<String> +// e.g. [ARABIC, ARMENIAN, BENGALI, CYRILLIC, DEVANAGARI, GEORGIAN, +// GREEK, GUJARATI, GURMUKHI, HAN, HANGUL, HEBREW, HIRAGANA, +// KANNADA, KHMER, LAO, LATIN, MALAYALAM, MYANMAR, ORIYA, +// SINHALA, TAMIL, TELUGU, THAANA, THAI, TIBETAN, ...] +---- + +If the dominant script of an input is not in this set, `score()` returns a +`JunkScore` where `isUnknown()` is true and no z-score is available. + +== Thresholds and operating points + +There is no universally correct threshold. The right cutoff depends on your +content and tolerance for false positives (flagging good text as junk). + +**Starting points:** + +* **Trigger re-OCR**: z < −2.0 (catches ~95% of severe corruption while flagging + ~2–5% of legitimate text on average, more for short strings). +* **Charset tiebreaking**: prefer the candidate with the higher z-score when + `delta() > 1.0`; abstain if `delta() < 0.5`. +* **Training data filtering**: z < −1.5 to remove mojibake and bot-generated + noise from NLP corpora. + +For short text (under ~50 UTF-8 bytes), use `getCiLow()` rather than `getZScore()` +for threshold decisions, since the confidence interval widens substantially. + +== Limitations + +* **Script coverage**: only scripts with a trained model can be scored. Unknown + scripts return `isUnknown() = true`. +* **Short text**: scoring is unreliable below ~15 UTF-8 bytes. The model needs + at least a few bigrams to produce a stable estimate. +* **Closely related charsets in the same script pool**: the LATIN model is trained + across hundreds of languages, which dilutes the signal for closely related + Western European and Baltic encodings (e.g., cp1252 vs. cp1257 on Lithuanian + text). The winner is usually correct, but delta may be small (< 0.5). +* **Deliberately obfuscated text**: content designed to look like natural language + (e.g. by adversarial padding) is not detected. + +== Further reading + +For training methodology, model format, evaluation harness, and guidance on +improving the model, see +xref:advanced/junk-detection-build.adoc[Building the Junk Detector]. diff --git a/tika-ml/pom.xml b/tika-ml/pom.xml index 3af2057aa8..5c9cf03af1 100644 --- a/tika-ml/pom.xml +++ b/tika-ml/pom.xml @@ -34,6 +34,7 @@ <modules> <module>tika-ml-core</module> <module>tika-ml-chardetect</module> + <module>tika-ml-junkdetect</module> </modules> <build> diff --git a/tika-ml/tika-ml-junkdetect/pom.xml b/tika-ml/tika-ml-junkdetect/pom.xml new file mode 100644 index 0000000000..15abcdffee --- /dev/null +++ b/tika-ml/tika-ml-junkdetect/pom.xml @@ -0,0 +1,148 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"> + <parent> + <artifactId>tika-ml</artifactId> + <groupId>org.apache.tika</groupId> + <version>${revision}</version> + </parent> + <modelVersion>4.0.0</modelVersion> + + <artifactId>tika-ml-junkdetect</artifactId> + <name>Apache Tika ML junk detector — runtime and training tools</name> + <description> + Language-agnostic text quality scorer that discriminates between clean UTF-8 text and + mojibake, reversed text, wrong-codec decodings, and other corruption forms. + Provides a standalone "languageyness" score suitable for re-OCR triggering and + charset-decoding arbitration. + + Runtime classes (JunkDetector, ScriptDetector, feature extractors) and bundled model + resources live here. Training and evaluation CLI tools live in the tools subpackage. + </description> + + <dependencies> + <dependency> + <groupId>org.apache.tika</groupId> + <artifactId>tika-ml-core</artifactId> + <version>${revision}</version> + </dependency> + + <!-- Test dependencies --> + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter-api</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter-engine</artifactId> + <scope>test</scope> + </dependency> + </dependencies> + + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <configuration> + <archive> + <manifestEntries> + <Automatic-Module-Name>org.apache.tika.ml.junkdetect</Automatic-Module-Name> + </manifestEntries> + </archive> + </configuration> + </plugin> + <plugin> + <groupId>org.apache.rat</groupId> + <artifactId>apache-rat-plugin</artifactId> + <configuration> + <inputExcludes> + <inputExclude>**/*.bin</inputExclude> + <inputExclude>**/*.txt</inputExclude> + </inputExcludes> + </configuration> + </plugin> + <!-- Tools package uses System.out/printf freely --> + <plugin> + <groupId>de.thetaphi</groupId> + <artifactId>forbiddenapis</artifactId> + <configuration> + <skip>true</skip> + </configuration> + </plugin> + </plugins> + </build> + + <profiles> + <profile> + <!-- + Build a self-contained fat JAR for model training and evaluation. + Usage: + ./mvnw package -pl tika-ml/tika-ml-junkdetect -am -Ptrain -DskipTests \ + -Dmaven.repo.local=.local_m2_repo + java -jar tika-ml/tika-ml-junkdetect/target/tika-ml-junkdetect-*-tools.jar \ + [BuildJunkTrainingData|TrainJunkModel|EvalJunkDetector] \ + [args...] + --> + <id>train</id> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-shade-plugin</artifactId> + <executions> + <execution> + <phase>package</phase> + <goals><goal>shade</goal></goals> + <configuration> + <shadedArtifactAttached>true</shadedArtifactAttached> + <shadedClassifierName>tools</shadedClassifierName> + <transformers> + <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> + <mainClass>org.apache.tika.ml.junkdetect.tools.TrainJunkModel</mainClass> + </transformer> + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/> + <transformer implementation="org.apache.maven.plugins.shade.resource.ApacheLicenseResourceTransformer"/> + <transformer implementation="org.apache.maven.plugins.shade.resource.ApacheNoticeResourceTransformer"/> + </transformers> + <filters> + <filter> + <artifact>*:*</artifact> + <excludes> + <exclude>META-INF/*.SF</exclude> + <exclude>META-INF/*.DSA</exclude> + <exclude>META-INF/*.RSA</exclude> + </excludes> + </filter> + </filters> + </configuration> + </execution> + </executions> + </plugin> + </plugins> + </build> + </profile> + </profiles> + + <scm> + <tag>3.0.0-rc1</tag> + </scm> +</project> diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java new file mode 100644 index 0000000000..091ad1d04b --- /dev/null +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkDetector.java @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.junkdetect; + +import java.io.DataInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.zip.GZIPInputStream; + +/** + * Language-agnostic text quality scorer. Discriminates clean UTF-8 text from + * mojibake, reversed text, wrong-codec decodings, and other corruption forms. + * + * <p>Scoring is based on a per-script byte-bigram log-probability model: a 256×256 + * table of {@code log P(b|a)} values trained on clean Wikipedia and MADLAD-400 text. + * The per-sentence mean bigram log-prob is z-scored against the calibration statistics + * (mean and stddev measured on held-out clean text) to produce a dimensionless quality + * score. Negative z-score = worse than average clean text for that script; + * more negative = worse. + * + * <p>Instances are immutable and thread-safe after construction. + * + * <p>Typical usage: + * <pre>{@code + * JunkDetector detector = JunkDetector.loadFromClasspath(); + * JunkScore score = detector.score("some text"); + * if (score.getZScore() < -2.0) { ... re-OCR or flag ... } + * + * // Compare two charset interpretations of the same bytes + * JunkDetector.CompareResult result = detector.compare(rawBytes, "cp1252", "cp1257"); + * String winner = result.winner(); // "A" or "B" + * }</pre> + */ +public final class JunkDetector { + + /** Classpath resource path for the bundled production model. */ + public static final String DEFAULT_MODEL_RESOURCE = + "org/apache/tika/ml/junkdetect/junkdetect.bin"; + + static final String MAGIC = "JUNKDET1"; + + // Per-script model data + private final Map<String, float[]> tables; // script → float[65536] log-prob table + private final Map<String, float[]> calibrations; // script → float[2] {mu, sigma} + + private JunkDetector(Map<String, float[]> tables, Map<String, float[]> calibrations) { + this.tables = Collections.unmodifiableMap(tables); + this.calibrations = Collections.unmodifiableMap(calibrations); + } + + // ----------------------------------------------------------------------- + // Factory methods + // ----------------------------------------------------------------------- + + /** + * Loads the bundled model from the classpath. + * + * @throws IOException if the model resource is missing or malformed + */ + public static JunkDetector loadFromClasspath() throws IOException { + InputStream is = JunkDetector.class.getClassLoader() + .getResourceAsStream(DEFAULT_MODEL_RESOURCE); + if (is == null) { + throw new IOException("Model resource not found on classpath: " + + DEFAULT_MODEL_RESOURCE); + } + try (InputStream wrapped = is) { + return load(wrapped); + } + } + + /** + * Loads a model from the given file path. The file may be gzipped or raw. + */ + public static JunkDetector loadFromPath(Path path) throws IOException { + try (InputStream is = Files.newInputStream(path)) { + return load(is); + } + } + + /** + * Loads a model from an {@link InputStream}. Gzip-detection is automatic. + */ + public static JunkDetector load(InputStream rawIs) throws IOException { + // Peek to detect gzip magic + InputStream is = rawIs.markSupported() ? rawIs : rawIs; // already have stream + // Wrap in buffered so we can read the first bytes; rely on GZIPInputStream magic + InputStream in; + byte[] peek = rawIs.readNBytes(2); + InputStream rest = new java.io.SequenceInputStream( + new java.io.ByteArrayInputStream(peek), rawIs); + if (peek.length >= 2 && (peek[0] & 0xFF) == 0x1f && (peek[1] & 0xFF) == 0x8b) { + in = new GZIPInputStream(rest); + } else { + in = rest; + } + + try (DataInputStream dis = new DataInputStream(in)) { + // Verify magic + byte[] magic = dis.readNBytes(8); + if (!new String(magic, StandardCharsets.UTF_8).equals(MAGIC)) { + throw new IOException("Not a JunkDetector model file (bad magic)"); + } + int version = dis.readUnsignedByte(); + if (version != 1) { + throw new IOException("Unsupported model version: " + version); + } + + int numScripts = dis.readInt(); + Map<String, float[]> tables = new HashMap<>(numScripts * 2); + Map<String, float[]> calibrations = new HashMap<>(numScripts * 2); + + for (int s = 0; s < numScripts; s++) { + int nameLen = dis.readUnsignedShort(); + String script = new String(dis.readNBytes(nameLen), StandardCharsets.UTF_8); + + float mu = dis.readFloat(); + float sigma = dis.readFloat(); + calibrations.put(script, new float[]{mu, sigma}); + + byte[] tableBytes = dis.readNBytes(65536 * 4); + float[] table = new float[65536]; + ByteBuffer buf = ByteBuffer.wrap(tableBytes).order(ByteOrder.BIG_ENDIAN); + buf.asFloatBuffer().get(table); + tables.put(script, table); + } + + return new JunkDetector(tables, calibrations); + } + } + + // ----------------------------------------------------------------------- + // Scoring API + // ----------------------------------------------------------------------- + + /** + * Scores a UTF-8 string for text quality. + * + * @param text the string to score (will be encoded to UTF-8 internally) + * @return a {@link JunkScore}; use {@link JunkScore#isUnknown()} to check + * whether scoring was possible + */ + public JunkScore score(String text) { + if (text == null || text.isEmpty()) { + return unknownScore("UNKNOWN"); + } + return scoreBytes(text.getBytes(StandardCharsets.UTF_8), text); + } + + /** + * Scores a byte array assumed to be UTF-8 text. + * + * @param utf8 raw UTF-8 bytes + * @return a {@link JunkScore} + */ + public JunkScore score(byte[] utf8) { + if (utf8 == null || utf8.length == 0) { + return unknownScore("UNKNOWN"); + } + String text = new String(utf8, StandardCharsets.UTF_8); + return scoreBytes(utf8, text); + } + + /** + * Compares two charset interpretations of the same raw bytes and returns + * which decoding scores higher (is more likely to be clean natural language). + * + * @param rawBytes the raw bytes to decode + * @param charsetA first charset name (e.g. {@code "cp1252"}) + * @param charsetB second charset name (e.g. {@code "cp1257"}) + * @return a {@link CompareResult} indicating the winner and confidence + */ + public CompareResult compare(byte[] rawBytes, String charsetA, String charsetB) { + JunkScore scoreA = decodeAndScore(rawBytes, charsetA); + JunkScore scoreB = decodeAndScore(rawBytes, charsetB); + + float zA = scoreA.isUnknown() ? Float.NEGATIVE_INFINITY : scoreA.getZScore(); + float zB = scoreB.isUnknown() ? Float.NEGATIVE_INFINITY : scoreB.getZScore(); + + String winner = zA >= zB ? "A" : "B"; + float delta = Math.abs(zA - zB); + + return new CompareResult(winner, delta, scoreA, scoreB, charsetA, charsetB); + } + + /** Returns the set of script names this model knows about. */ + public java.util.Set<String> knownScripts() { + return tables.keySet(); + } + + // ----------------------------------------------------------------------- + // Internal scoring + // ----------------------------------------------------------------------- + + private JunkScore scoreBytes(byte[] utf8, String text) { + String script = detectDominantScript(text); + + float[] table = tables.get(script); + if (table == null) { + // Script not in model — return unknown with script name for diagnostics + return unknownScore(script); + } + + if (utf8.length < 2) { + return unknownScore(script); + } + + // Mean byte-bigram log-prob + double sum = 0; + int count = 0; + for (int i = 0; i + 1 < utf8.length; i++) { + sum += table[((utf8[i] & 0xFF) << 8) | (utf8[i + 1] & 0xFF)]; + count++; + } + float meanLogProb = (float) (sum / count); + + // Z-score against calibration + float[] cal = calibrations.get(script); + float mu = cal[0]; + float sigma = cal[1]; + float zScore = (meanLogProb - mu) / sigma; + + // Confidence interval: uncertainty ~ 1.96 * sigma / sqrt(count) + float uncertainty = (float) (1.96 * sigma / Math.sqrt(count)); + float ciLow = zScore - uncertainty; + float ciHigh = zScore + uncertainty; + + // P(clean): sigmoid of z-score (simple calibration-free estimate) + float pClean = (float) (1.0 / (1.0 + Math.exp(-zScore))); + + return new JunkScore(zScore, pClean, ciLow, ciHigh, script); + } + + private JunkScore decodeAndScore(byte[] raw, String charsetName) { + try { + Charset cs = Charset.forName(charsetName); + byte[] utf8 = new String(raw, cs).getBytes(StandardCharsets.UTF_8); + return score(utf8); + } catch (Exception e) { + return unknownScore(charsetName); + } + } + + private static JunkScore unknownScore(String script) { + return new JunkScore(JunkScore.UNKNOWN, Float.NaN, Float.NaN, Float.NaN, script); + } + + /** + * Detects the dominant Unicode script of the given text by histogramming + * {@link Character.UnicodeScript} over all codepoints, excluding COMMON, + * INHERITED, and UNKNOWN pseudo-scripts. Returns "LATIN" for ASCII-only + * text (no non-ASCII codepoints). + */ + static String detectDominantScript(String text) { + Map<Character.UnicodeScript, Integer> counts = new HashMap<>(); + for (int i = 0; i < text.length(); ) { + int cp = text.codePointAt(i); + Character.UnicodeScript s = Character.UnicodeScript.of(cp); + if (s != Character.UnicodeScript.COMMON + && s != Character.UnicodeScript.INHERITED + && s != Character.UnicodeScript.UNKNOWN) { + counts.merge(s, 1, Integer::sum); + } + i += Character.charCount(cp); + } + if (counts.isEmpty()) { + return "LATIN"; // ASCII-only → use Latin model + } + return counts.entrySet().stream() + .max(java.util.Map.Entry.comparingByValue()) + .map(e -> e.getKey().name()) + .orElse("LATIN"); + } + + // ----------------------------------------------------------------------- + // Result type for compare() + // ----------------------------------------------------------------------- + + /** + * Result of comparing two charset decodings of the same raw bytes. + */ + public static final class CompareResult { + private final String winner; + private final float delta; + private final JunkScore scoreA; + private final JunkScore scoreB; + private final String charsetA; + private final String charsetB; + + CompareResult(String winner, float delta, + JunkScore scoreA, JunkScore scoreB, + String charsetA, String charsetB) { + this.winner = winner; + this.delta = delta; + this.scoreA = scoreA; + this.scoreB = scoreB; + this.charsetA = charsetA; + this.charsetB = charsetB; + } + + /** "A" if charsetA decodes to cleaner text, "B" otherwise. */ + public String winner() { + return winner; + } + + /** + * Absolute difference in z-scores. Small delta = uncertain; large delta = confident. + * As a rough guide: delta > 1.0 is useful signal, delta > 3.0 is confident. + */ + public float delta() { + return delta; + } + + public JunkScore scoreA() { + return scoreA; + } + + public JunkScore scoreB() { + return scoreB; + } + + public String charsetA() { + return charsetA; + } + + public String charsetB() { + return charsetB; + } + + @Override + public String toString() { + return String.format("CompareResult[winner=%s(%s) delta=%.3f A=%s B=%s]", + winner, winner.equals("A") ? charsetA : charsetB, + delta, scoreA, scoreB); + } + } +} diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkScore.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkScore.java new file mode 100644 index 0000000000..393976c127 --- /dev/null +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/JunkScore.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.junkdetect; + +/** + * Result of scoring a UTF-8 string for text quality. + * <p> + * {@code zScore} is the primary output: how many standard deviations below + * typical clean text this string scores on its dominant script's bigram model. + * Negative means worse than average clean text; more negative means worse. + * <p> + * {@code pClean} is the logistic-regression probability of being clean text, + * combining bigram log-prob, block-transition, and scalar features. + * <p> + * {@code ciLow} / {@code ciHigh} are the 95% confidence interval bounds on + * {@code zScore}, derived from a length-dependent variance model. For short + * strings these bounds are wide; for long strings they narrow. Use + * {@code ciLow < threshold} rather than {@code zScore < threshold} when + * triggering actions to avoid false positives on short strings. + */ +public final class JunkScore { + + /** Sentinel z-score returned when detection could not be run (e.g. null input, ASCII-only). */ + public static final float UNKNOWN = Float.NaN; + + private final float zScore; + private final float pClean; + private final float ciLow; + private final float ciHigh; + private final String dominantScript; + + public JunkScore(float zScore, float pClean, float ciLow, float ciHigh, String dominantScript) { + this.zScore = zScore; + this.pClean = pClean; + this.ciLow = ciLow; + this.ciHigh = ciHigh; + this.dominantScript = dominantScript; + } + + /** Z-score relative to clean text for the detected script. 0 = average clean; negative = worse. */ + public float getZScore() { + return zScore; + } + + /** Probability in [0,1] that this string is clean text (logistic regression output). */ + public float getPClean() { + return pClean; + } + + /** Lower bound of 95% confidence interval on zScore. */ + public float getCiLow() { + return ciLow; + } + + /** Upper bound of 95% confidence interval on zScore. */ + public float getCiHigh() { + return ciHigh; + } + + /** Name of the dominant Unicode script detected, e.g. "LATIN", "CYRILLIC", "ARABIC". */ + public String getDominantScript() { + return dominantScript; + } + + public boolean isUnknown() { + return Float.isNaN(zScore); + } + + @Override + public String toString() { + if (isUnknown()) { + return "JunkScore[UNKNOWN script=" + dominantScript + "]"; + } + return String.format("JunkScore[z=%.3f p=%.3f ci=(%.3f,%.3f) script=%s]", + zScore, pClean, ciLow, ciHigh, dominantScript); + } +} diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/BuildJunkTrainingData.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/BuildJunkTrainingData.java new file mode 100644 index 0000000000..77d9283f9a --- /dev/null +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/BuildJunkTrainingData.java @@ -0,0 +1,559 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.junkdetect.tools; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.text.Normalizer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.TreeMap; +import java.util.zip.GZIPOutputStream; + +/** + * Builds per-script positive training data for the junk detector from MADLAD-400 + * and Wikipedia sentence files. + * + * <p>Script groups are derived entirely from the data: for each language directory + * the dominant Unicode script is detected by histogramming {@link Character.UnicodeScript} + * over a sample of sentences (COMMON, INHERITED, and UNKNOWN pseudo-scripts excluded). + * Languages that share the same dominant script are pooled. No script groups are + * hardcoded. + * + * <p>The total byte budget is distributed across script groups proportionally to + * each group's empirical byte-bigram entropy, measured from a small sample. + * Scripts with high entropy (e.g. CJK, which has thousands of distinct 3-byte + * codepoints) receive a proportionally larger allocation than low-entropy scripts + * (e.g. Arabic, whose UTF-8 high bytes cluster in a narrow 0xD8-0xDB range). + * This ensures every script's bigram table is estimated with comparable statistical + * quality regardless of character-set size. + * + * <p>Within each script group the byte budget is distributed evenly across its + * member languages, ensuring diversity (no single language dominates). + * + * <p>Input format ({@code sentences_madlad.txt} and {@code sentences_wikipedia.txt}): + * {@code lineNum TAB text}, UTF-8. MADLAD records contain literal {@code \n} escape + * sequences as sub-sentence separators (full scraped documents); Wikipedia records + * are individual sentences. Both are split/cleaned to sentence-level strings. + * + * <p>Output: + * <pre> + * output-dir/ + * {script}.train.gz — 80% split, one NFC-normalised sentence per line + * {script}.dev.gz — 10% split, used for calibration (mu/sigma) + * {script}.test.gz — 10% split, held out for final evaluation only + * manifest.tsv — per-script stats: entropy, budget, bytes written, languages + * </pre> + * + * <p>Usage: + * <pre> + * java BuildJunkTrainingData \ + * --data-dir ~/datasets/madlad/data \ + * --output-dir ~/datasets/madlad/junkdetect \ + * [--total-budget-bytes 50000000] + * </pre> + */ +public class BuildJunkTrainingData { + + // ----------------------------------------------------------------------- + // Defaults + // ----------------------------------------------------------------------- + + /** Lines read per language to determine dominant script. */ + private static final int DEFAULT_SCRIPT_SAMPLE_LINES = 2_000; + + /** + * UTF-8 bytes loaded per script group for entropy estimation. + * Budget is spread evenly across languages in the group. + * 200KB is enough to observe the bigram distribution reliably. + */ + private static final long ENTROPY_SAMPLE_BYTES = 200_000L; + + /** + * Total UTF-8 byte budget across all script groups. Divided proportionally + * by bigram entropy after the sampling phase. 50MB gives ~1–3MB per script + * on average across 34 groups; scale up for production runs. + */ + private static final long DEFAULT_TOTAL_BUDGET_BYTES = 50_000_000L; + + /** Minimum UTF-8 byte length for a sentence to pass the quality filter. */ + private static final int DEFAULT_MIN_BYTES = 50; + + /** Maximum fraction of codepoints that may be ASCII punctuation/digits. */ + private static final double DEFAULT_MAX_PUNC_FRAC = 0.30; + + /** Fraction of sentences written to each split (train / dev / test = 80/10/10). */ + private static final double TRAIN_FRAC = 0.80; + private static final double DEV_FRAC = 0.10; + // remaining (1 - TRAIN_FRAC - DEV_FRAC) goes to the test split + + // ----------------------------------------------------------------------- + // Entry point + // ----------------------------------------------------------------------- + + public static void main(String[] args) throws IOException { + Path dataDir = Paths.get(System.getProperty("user.home"), "datasets", "madlad", "data"); + Path outputDir = Paths.get(System.getProperty("user.home"), "datasets", "madlad", "junkdetect"); + int scriptSampleLines = DEFAULT_SCRIPT_SAMPLE_LINES; + long totalBudgetBytes = DEFAULT_TOTAL_BUDGET_BYTES; + int minBytes = DEFAULT_MIN_BYTES; + double maxPuncFrac = DEFAULT_MAX_PUNC_FRAC; + int seed = 42; + boolean dryRun = false; + + for (int i = 0; i < args.length; i++) { + switch (args[i]) { + case "--data-dir": + dataDir = Paths.get(args[++i]); + break; + case "--output-dir": + outputDir = Paths.get(args[++i]); + break; + case "--script-sample-lines": + scriptSampleLines = Integer.parseInt(args[++i]); + break; + case "--total-budget-bytes": + totalBudgetBytes = Long.parseLong(args[++i]); + break; + case "--min-bytes": + minBytes = Integer.parseInt(args[++i]); + break; + case "--max-punc-frac": + maxPuncFrac = Double.parseDouble(args[++i]); + break; + case "--seed": + seed = Integer.parseInt(args[++i]); + break; + case "--dry-run": + dryRun = true; + break; + default: + System.err.println("Unknown argument: " + args[i]); + printUsage(); + System.exit(1); + } + } + + System.out.println("=== BuildJunkTrainingData ==="); + System.out.println(" data-dir: " + dataDir); + System.out.println(" output-dir: " + outputDir); + System.out.printf( " total-budget-bytes: %,d (%.1f MB)%n", + totalBudgetBytes, totalBudgetBytes / 1_000_000.0); + System.out.printf( " min-bytes: %d%n", minBytes); + System.out.printf( " max-punc-frac: %.2f%n", maxPuncFrac); + System.out.println(" dry-run: " + dryRun); + + if (!Files.isDirectory(dataDir)) { + System.err.println("ERROR: data-dir not found: " + dataDir); + System.exit(1); + } + + // ----------------------------------------------------------------------- + // Phase 1: Detect dominant script per language, group languages + // ----------------------------------------------------------------------- + + System.out.println("\n--- Phase 1: Detecting dominant script per language ---"); + + Map<String, List<Path>> scriptGroups = new TreeMap<>(); + Map<String, String> langToScript = new LinkedHashMap<>(); + + try (var dirStream = Files.list(dataDir)) { + List<Path> langDirs = dirStream.filter(Files::isDirectory).sorted().toList(); + for (Path langDir : langDirs) { + String lang = langDir.getFileName().toString(); + String script = detectDominantScript(langDir, scriptSampleLines); + langToScript.put(lang, script); + scriptGroups.computeIfAbsent(script, k -> new ArrayList<>()).add(langDir); + System.out.printf(" %-12s → %s%n", lang, script); + } + } + System.out.printf("%n → %d languages, %d script groups%n", + langToScript.size(), scriptGroups.size()); + + // ----------------------------------------------------------------------- + // Phase 2: Load small sample per script, compute byte-bigram entropy + // ----------------------------------------------------------------------- + + System.out.println("\n--- Phase 2: Estimating byte-bigram entropy per script ---"); + + Map<String, Double> scriptEntropy = new TreeMap<>(); + for (Map.Entry<String, List<Path>> entry : scriptGroups.entrySet()) { + String script = entry.getKey(); + List<Path> langDirs = entry.getValue(); + + long perLangSampleBytes = Math.max(ENTROPY_SAMPLE_BYTES / langDirs.size(), 2_000L); + List<String> sample = new ArrayList<>(); + for (Path langDir : langDirs) { + loadSentences(langDir, perLangSampleBytes, minBytes, maxPuncFrac, sample); + } + + double entropy = computeBigramEntropy(sample); + scriptEntropy.put(script, entropy); + System.out.printf(" %-20s H=%.3f bits (%d sentences)%n", + script, entropy, sample.size()); + } + + // ----------------------------------------------------------------------- + // Phase 3: Allocate byte budget proportional to entropy + // ----------------------------------------------------------------------- + + System.out.println("\n--- Phase 3: Allocating byte budget ---"); + + double totalEntropy = scriptEntropy.values().stream() + .mapToDouble(Double::doubleValue).sum(); + + Map<String, Long> scriptBudget = new TreeMap<>(); + for (Map.Entry<String, Double> e : scriptEntropy.entrySet()) { + long budget = (long) (totalBudgetBytes * e.getValue() / totalEntropy); + scriptBudget.put(e.getKey(), budget); + System.out.printf(" %-20s H=%.3f → %,d bytes (%.1f MB)%n", + e.getKey(), e.getValue(), budget, budget / 1_000_000.0); + } + + if (dryRun) { + System.out.println("\nDry-run: stopping before writing files."); + return; + } + + // ----------------------------------------------------------------------- + // Phase 4: Collect data, write train/dev splits + // ----------------------------------------------------------------------- + + Files.createDirectories(outputDir); + System.out.println("\n--- Phase 4: Collecting and writing per-script files ---"); + + Random rng = new Random(seed); + + // manifest columns: script, entropy, budget_bytes, written_bytes, sentences, train_bytes, languages + Map<String, long[]> manifestStats = new TreeMap<>(); + + for (Map.Entry<String, Long> budgetEntry : scriptBudget.entrySet()) { + String script = budgetEntry.getKey(); + long budget = budgetEntry.getValue(); + List<Path> langDirs = scriptGroups.get(script); + + long perLangBytes = Math.max(budget / langDirs.size(), 1L); + + List<String> sentences = new ArrayList<>(); + long totalBytesLoaded = 0; + + for (Path langDir : langDirs) { + long remaining = budget - totalBytesLoaded; + if (remaining <= 0) { + break; + } + long langBytes = loadSentences(langDir, + Math.min(perLangBytes, remaining), + minBytes, maxPuncFrac, sentences); + totalBytesLoaded += langBytes; + if (langBytes > 0) { + System.out.printf(" %-12s %-20s +%,d bytes%n", + script, langDir.getFileName(), langBytes); + } + } + + if (sentences.isEmpty()) { + System.out.printf(" SKIP %-12s — no sentences collected%n", script); + manifestStats.put(script, new long[]{0, 0, 0}); + continue; + } + + Collections.shuffle(sentences, rng); + + int nTrain = (int) (sentences.size() * TRAIN_FRAC); + int nDev = (int) (sentences.size() * DEV_FRAC); + List<String> train = sentences.subList(0, nTrain); + List<String> dev = sentences.subList(nTrain, nTrain + nDev); + List<String> test = sentences.subList(nTrain + nDev, sentences.size()); + + String baseName = script.toLowerCase(); + writeGzipped(outputDir.resolve(baseName + ".train.gz"), train); + writeGzipped(outputDir.resolve(baseName + ".dev.gz"), dev); + writeGzipped(outputDir.resolve(baseName + ".test.gz"), test); + + manifestStats.put(script, + new long[]{totalBytesLoaded, sentences.size(), nTrain, nDev, test.size()}); + System.out.printf( + " WROTE %-12s — %,d bytes, %,d sentences (train=%,d dev=%,d test=%,d)%n", + script, totalBytesLoaded, sentences.size(), + nTrain, nDev, test.size()); + } + + // ----------------------------------------------------------------------- + // Phase 5: Write manifest + // ----------------------------------------------------------------------- + + Path manifest = outputDir.resolve("manifest.tsv"); + try (BufferedWriter w = Files.newBufferedWriter(manifest, StandardCharsets.UTF_8)) { + w.write("script\tentropy_bits\tbudget_bytes\twritten_bytes\tsentences" + + "\ttrain_sentences\tdev_sentences\ttest_sentences\tlanguages\n"); + for (Map.Entry<String, long[]> e : manifestStats.entrySet()) { + String script = e.getKey(); + long[] stats = e.getValue(); + double entropy = scriptEntropy.getOrDefault(script, 0.0); + long budget = scriptBudget.getOrDefault(script, 0L); + String langs = scriptGroups.get(script).stream() + .map(p -> p.getFileName().toString()) + .reduce((a, b) -> a + "," + b).orElse(""); + w.write(String.format("%s\t%.3f\t%d\t%d\t%d\t%d\t%d\t%d\t%s%n", + script, entropy, budget, + stats[0], stats[1], stats[2], stats[3], stats[4], langs)); + } + } + + System.out.println("\nWrote manifest: " + manifest); + System.out.println("Done."); + } + + // ----------------------------------------------------------------------- + // Script detection + // ----------------------------------------------------------------------- + + /** + * Detects the dominant Unicode script for a language by histogramming + * {@link Character.UnicodeScript} over a sample of its sentences. + * COMMON, INHERITED, and UNKNOWN pseudo-scripts are excluded from voting. + * Returns "COMMON" if no script reaches at least 1% of codepoints. + */ + static String detectDominantScript(Path langDir, int sampleLines) { + Map<Character.UnicodeScript, Long> counts = new HashMap<>(); + long total = 0; + + outer: + for (String filename : new String[]{"sentences_wikipedia.txt", "sentences_madlad.txt"}) { + Path file = langDir.resolve(filename); + if (!Files.exists(file)) { + continue; + } + try (BufferedReader r = new BufferedReader( + new InputStreamReader(Files.newInputStream(file), StandardCharsets.UTF_8))) { + String line; + int linesRead = 0; + while ((line = r.readLine()) != null && linesRead < sampleLines) { + String text = extractText(line); + for (int i = 0; i < text.length(); ) { + int cp = text.codePointAt(i); + Character.UnicodeScript s = Character.UnicodeScript.of(cp); + if (s != Character.UnicodeScript.COMMON + && s != Character.UnicodeScript.INHERITED + && s != Character.UnicodeScript.UNKNOWN) { + counts.merge(s, 1L, Long::sum); + total++; + } + i += Character.charCount(cp); + } + linesRead++; + } + } catch (IOException e) { + // Skip unreadable file; report COMMON if nothing else succeeds + } + if (total >= sampleLines * 10L) { + break outer; // sufficient signal + } + } + + if (total == 0) { + return "COMMON"; + } + + // Plurality with a 1% floor to suppress spurious Latin wins on mixed text + Character.UnicodeScript best = Character.UnicodeScript.COMMON; + long bestCount = total / 100; + for (Map.Entry<Character.UnicodeScript, Long> e : counts.entrySet()) { + if (e.getValue() > bestCount) { + bestCount = e.getValue(); + best = e.getKey(); + } + } + return best.name(); + } + + // ----------------------------------------------------------------------- + // Entropy estimation + // ----------------------------------------------------------------------- + + /** + * Computes the empirical byte-bigram Shannon entropy (bits) of a list of + * UTF-8 sentences. + * + * <p>All 256×256 = 65,536 consecutive byte pairs are counted; entropy is + * {@code -sum p(a,b) * log2(p(a,b))} over pairs with non-zero count. + * Maximum theoretical value is 16 bits (all pairs equally likely). + * Typical ranges: Latin ~8–11 bits, Arabic ~9–12, CJK ~13–15. + */ + static double computeBigramEntropy(List<String> sentences) { + long[] counts = new long[65536]; + long total = 0; + for (String s : sentences) { + byte[] bytes = s.getBytes(StandardCharsets.UTF_8); + for (int i = 0; i + 1 < bytes.length; i++) { + counts[((bytes[i] & 0xFF) << 8) | (bytes[i + 1] & 0xFF)]++; + total++; + } + } + if (total == 0) { + return 0.0; + } + double entropy = 0.0; + for (long c : counts) { + if (c > 0) { + double p = (double) c / total; + entropy -= p * (Math.log(p) / Math.log(2.0)); + } + } + return entropy; + } + + // ----------------------------------------------------------------------- + // Sentence loading and filtering + // ----------------------------------------------------------------------- + + /** + * Loads filtered, NFC-normalised sentences from {@code langDir} until + * {@code maxBytes} UTF-8 bytes have been accumulated, and appends them + * to {@code result}. + * + * <p>Reads {@code sentences_wikipedia.txt} before {@code sentences_madlad.txt}. + * MADLAD records contain literal {@code \n} escape sequences as sub-sentence + * separators (full scraped documents) and are split accordingly. + * + * @return total UTF-8 bytes of accepted sentences appended + */ + static long loadSentences(Path langDir, long maxBytes, int minBytes, + double maxPuncFrac, List<String> result) { + long bytesLoaded = 0; + for (String filename : new String[]{"sentences_wikipedia.txt", "sentences_madlad.txt"}) { + if (bytesLoaded >= maxBytes) { + break; + } + Path file = langDir.resolve(filename); + if (!Files.exists(file)) { + continue; + } + try (BufferedReader r = new BufferedReader( + new InputStreamReader(Files.newInputStream(file), StandardCharsets.UTF_8))) { + String line; + while ((line = r.readLine()) != null && bytesLoaded < maxBytes) { + String raw = extractText(line); + for (String part : raw.split("\\\\n")) { + String text = part.replace("\\r", "") + .replace("\\t", " ") + .strip() + .replaceAll("\\s+", " "); + if (text.isEmpty()) { + continue; + } + String filtered = filterSentence(text, minBytes, maxPuncFrac); + if (filtered != null) { + int sentBytes = filtered.getBytes(StandardCharsets.UTF_8).length; + result.add(filtered); + bytesLoaded += sentBytes; + if (bytesLoaded >= maxBytes) { + break; + } + } + } + } + } catch (IOException e) { + System.err.println("WARNING: could not read " + file + ": " + e.getMessage()); + } + } + return bytesLoaded; + } + + /** + * Applies quality filters to a single sentence and NFC-normalises it. + * + * @return the normalised sentence, or {@code null} if it should be discarded + */ + static String filterSentence(String text, int minBytes, double maxPuncFrac) { + if (text.indexOf('\uFFFD') >= 0) { + return null; + } + text = Normalizer.normalize(text, Normalizer.Form.NFC); + if (text.getBytes(StandardCharsets.UTF_8).length < minBytes) { + return null; + } + int cpCount = 0; + int puncCount = 0; + for (int i = 0; i < text.length(); ) { + int cp = text.codePointAt(i); + cpCount++; + if (cp >= 0x21 && cp <= 0x7E && !Character.isLetter(cp)) { + puncCount++; + } + i += Character.charCount(cp); + } + if (cpCount > 0 && (double) puncCount / cpCount > maxPuncFrac) { + return null; + } + return text; + } + + // ----------------------------------------------------------------------- + // I/O helpers + // ----------------------------------------------------------------------- + + private static String extractText(String line) { + int tab = line.indexOf('\t'); + String text = (tab >= 0) ? line.substring(tab + 1) : line; + return text.replace("\uFEFF", ""); + } + + private static void writeGzipped(Path path, List<String> lines) throws IOException { + try (BufferedWriter w = new BufferedWriter( + new OutputStreamWriter( + new GZIPOutputStream(Files.newOutputStream(path)), + StandardCharsets.UTF_8))) { + for (String line : lines) { + w.write(line); + w.newLine(); + } + } + } + + private static void printUsage() { + System.err.println("Usage: BuildJunkTrainingData [options]"); + System.err.println(" --data-dir <path> MADLAD data root" + + " (default: ~/datasets/madlad/data)"); + System.err.println(" --output-dir <path> Output directory" + + " (default: ~/datasets/madlad/junkdetect)"); + System.err.println(" --script-sample-lines N Lines per language for script" + + " detection (default: 2000)"); + System.err.println(" --total-budget-bytes N Total UTF-8 bytes across all" + + " scripts (default: 50000000)"); + System.err.println(" --min-bytes N Min UTF-8 bytes per sentence" + + " (default: 50)"); + System.err.println(" --max-punc-frac F Max ASCII punct fraction" + + " (default: 0.30)"); + System.err.println(" --seed N Random seed (default: 42)"); + System.err.println(" --dry-run Detect scripts + show budget," + + " skip file writing"); + } +} diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/EvalJunkDetector.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/EvalJunkDetector.java new file mode 100644 index 0000000000..de2494816a --- /dev/null +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/EvalJunkDetector.java @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.junkdetect.tools; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; +import java.util.zip.GZIPInputStream; + +import org.apache.tika.ml.junkdetect.JunkDetector; +import org.apache.tika.ml.junkdetect.JunkScore; + +/** + * Ablation evaluation for the junk detector. + * + * <p>For each script's dev set, scores clean sentences alongside three corruption + * modes — random-byte injection, codepoint-reversal, and byte-shuffle — at several + * injection rates and string lengths. Computes per-cell Cohen's d (discrimination + * power) and TPR/FPR at a fixed z-score threshold. + * + * <p>Output: two TSV files. + * <ul> + * <li><b>detail.tsv</b> — one row per (script, distortion, rate, length): + * {@code script, distortion, param, length, n_clean, n_corrupt, + * mean_clean_z, mean_corrupt_z, cohens_d, fpr, tpr} + * <li><b>summary.tsv</b> — macro-averaged Cohen's d and FPR/TPR per + * (distortion, rate, length) across all scripts. + * </ul> + * + * <p>Cohen's d = (mean_clean_z − mean_corrupt_z) / pooled_std. + * Higher d = better discrimination. FPR = fraction of clean text falsely flagged; + * TPR = fraction of corrupted text correctly flagged. Both use threshold = −2.0. + * + * <p>To compare two model versions: run eval before and after, then diff the + * summary TSVs. The "macro_d" column in summary.tsv is the single headline metric. + * + * <p>Usage: + * <pre> + * java EvalJunkDetector \ + * --model /path/to/junkdetect.bin (default: classpath) + * --data-dir ~/datasets/madlad/junkdetect + * --output-dir /path/to/results (default: data-dir/eval) + * --split dev|test (default: dev — use test only for final reporting) + * --samples 200 + * --seed 42 + * --lengths 15,30,50,100,200 + * --rates 0.01,0.05,0.10,0.25,0.50,0.90 + * --threshold -2.0 + * </pre> + * + * <p><b>Which split to use:</b> Use {@code --split dev} during iterative development + * (dev data is seen by the calibration step, so numbers are slightly optimistic for + * calibration quality, but still valid for relative comparisons between model versions). + * Use {@code --split test} only when reporting final numbers — the test split is + * completely held out and was never used to make any model or threshold decision. + */ +public class EvalJunkDetector { + + public static void main(String[] args) throws Exception { + // Defaults + Path modelPath = null; + Path dataDir = Paths.get(System.getProperty("user.home"), + "datasets", "madlad", "junkdetect"); + Path outputDir = null; + String split = "dev"; // dev during development; test for final reporting + int samplesPerCell = 200; + long seed = 42L; + int[] lengths = {15, 30, 50, 100, 200}; + double[] rates = {0.01, 0.05, 0.10, 0.25, 0.50, 0.90}; + float threshold = -2.0f; + + for (int i = 0; i < args.length; i++) { + switch (args[i]) { + case "--model": + modelPath = Paths.get(args[++i]); + break; + case "--data-dir": + dataDir = Paths.get(args[++i]); + break; + case "--output-dir": + outputDir = Paths.get(args[++i]); + break; + case "--split": + split = args[++i]; + if (!split.equals("dev") && !split.equals("test")) { + System.err.println("--split must be 'dev' or 'test'"); + System.exit(1); + } + break; + case "--samples": + samplesPerCell = Integer.parseInt(args[++i]); + break; + case "--seed": + seed = Long.parseLong(args[++i]); + break; + case "--lengths": + lengths = Arrays.stream(args[++i].split(",")) + .mapToInt(Integer::parseInt).toArray(); + break; + case "--rates": + rates = Arrays.stream(args[++i].split(",")) + .mapToDouble(Double::parseDouble).toArray(); + break; + case "--threshold": + threshold = Float.parseFloat(args[++i]); + break; + default: + System.err.println("Unknown argument: " + args[i]); + System.exit(1); + } + } + + if (outputDir == null) { + outputDir = dataDir.resolve("eval"); + } + Files.createDirectories(outputDir); + + JunkDetector detector = modelPath != null + ? JunkDetector.loadFromPath(modelPath) + : JunkDetector.loadFromClasspath(); + + System.err.println("=== EvalJunkDetector ==="); + System.err.println(" data-dir: " + dataDir); + System.err.println(" output-dir: " + outputDir); + System.err.println(" split: " + split + + (split.equals("test") ? " [FINAL REPORTING MODE]" : "")); + System.err.println(" scripts in model: " + detector.knownScripts().size()); + System.err.println(" threshold: " + threshold); + + String suffix = "." + split + ".gz"; + List<Path> devFiles; + try (var stream = Files.list(dataDir)) { + devFiles = stream + .filter(p -> p.getFileName().toString().endsWith(suffix)) + .sorted() + .collect(Collectors.toList()); + } + + if (devFiles.isEmpty()) { + System.err.println("ERROR: no *" + suffix + " files found in " + dataDir); + System.exit(1); + } + + Path detailPath = outputDir.resolve("detail.tsv"); + Path summaryPath = outputDir.resolve("summary.tsv"); + + // Accumulate all rows for summary aggregation + List<Row> allRows = new ArrayList<>(); + + try (PrintWriter detail = new PrintWriter( + Files.newBufferedWriter(detailPath, StandardCharsets.UTF_8))) { + + detail.println("script\tdistortion\tparam\tlength" + + "\tn_clean\tn_corrupt" + + "\tmean_clean_z\tmean_corrupt_z" + + "\tcohens_d\tfpr\ttpr"); + + for (Path devFile : devFiles) { + String filename = devFile.getFileName().toString(); + String script = filename + .substring(0, filename.length() - suffix.length()) + .toUpperCase(); + + System.err.printf("%n--- %s ---%n", script); + + List<String> sentences = loadSentences(devFile, samplesPerCell * 20); + if (sentences.size() < 10) { + System.err.printf(" Skipping — only %d sentences%n", sentences.size()); + continue; + } + + Random rng = new Random(seed); + + // Score clean baseline once per (script, length) + // Reuse the same clean scores for all distortion comparisons at that length + for (int len : lengths) { + List<Float> cleanZ = scoreClean(detector, sentences, len, + samplesPerCell, new Random(seed)); + + // --- injection --- + for (double rate : rates) { + List<Float> corruptZ = scoreWithInjection(detector, sentences, len, + rate, samplesPerCell, new Random(seed + 1)); + Row row = new Row(script, "inject", + String.format("%.2f", rate), len, + cleanZ, corruptZ, threshold); + allRows.add(row); + detail.println(row.toTsv()); + } + + // --- codepoint reversal --- + { + List<Float> corruptZ = scoreReversed(detector, sentences, len, + samplesPerCell, new Random(seed + 2)); + Row row = new Row(script, "char-reverse", "-", len, + cleanZ, corruptZ, threshold); + allRows.add(row); + detail.println(row.toTsv()); + } + + // --- byte shuffle --- + { + List<Float> corruptZ = scoreShuffled(detector, sentences, len, + samplesPerCell, new Random(seed + 3)); + Row row = new Row(script, "byte-shuffle", "-", len, + cleanZ, corruptZ, threshold); + allRows.add(row); + detail.println(row.toTsv()); + } + + detail.flush(); + rng = new Random(seed); // reset between lengths for reproducibility + } + } + } + + writeSummary(summaryPath, allRows, lengths, rates, threshold); + + System.err.println("\nWrote " + detailPath); + System.err.println("Wrote " + summaryPath); + System.err.println("Done."); + } + + // ----------------------------------------------------------------------- + // Summary aggregation + // ----------------------------------------------------------------------- + + private static void writeSummary(Path summaryPath, List<Row> rows, + int[] lengths, double[] rates, float threshold) + throws IOException { + try (PrintWriter out = new PrintWriter( + Files.newBufferedWriter(summaryPath, StandardCharsets.UTF_8))) { + + out.println("distortion\tparam\tlength\tn_scripts" + + "\tmacro_cohens_d\tmacro_fpr\tmacro_tpr"); + + // For each unique (distortion, param, length), average across scripts + // Build groups: inject@rate, char-reverse, byte-shuffle + List<String[]> conditions = new ArrayList<>(); + for (double rate : rates) { + conditions.add(new String[]{"inject", String.format("%.2f", rate)}); + } + conditions.add(new String[]{"char-reverse", "-"}); + conditions.add(new String[]{"byte-shuffle", "-"}); + + for (String[] cond : conditions) { + String distortion = cond[0]; + String param = cond[1]; + for (int len : lengths) { + List<Row> matching = rows.stream() + .filter(r -> r.distortion.equals(distortion) + && r.param.equals(param) + && r.length == len) + .collect(Collectors.toList()); + if (matching.isEmpty()) { + continue; + } + double macroCohensD = matching.stream() + .filter(r -> !Double.isNaN(r.cohensD)) + .mapToDouble(r -> r.cohensD) + .average().orElse(Double.NaN); + double macroFpr = matching.stream() + .mapToDouble(r -> r.fpr) + .average().orElse(Double.NaN); + double macroTpr = matching.stream() + .mapToDouble(r -> r.tpr) + .average().orElse(Double.NaN); + + out.printf("%s\t%s\t%d\t%d\t%.3f\t%.3f\t%.3f%n", + distortion, param, len, matching.size(), + macroCohensD, macroFpr, macroTpr); + } + } + + // Overall headline: macro-average Cohen's d across everything + double overallD = rows.stream() + .filter(r -> !Double.isNaN(r.cohensD)) + .mapToDouble(r -> r.cohensD) + .average().orElse(Double.NaN); + double overallFpr = rows.stream() + .mapToDouble(r -> r.fpr) + .average().orElse(Double.NaN); + double overallTpr = rows.stream() + .mapToDouble(r -> r.tpr) + .average().orElse(Double.NaN); + out.println(); + out.printf("# OVERALL macro_cohens_d=%.3f macro_fpr=%.3f macro_tpr=%.3f%n", + overallD, overallFpr, overallTpr); + + System.err.printf("%nOVERALL: macro_cohens_d=%.3f macro_fpr=%.3f macro_tpr=%.3f%n", + overallD, overallFpr, overallTpr); + } + } + + // ----------------------------------------------------------------------- + // Row (one evaluation cell) + // ----------------------------------------------------------------------- + + private static final class Row { + final String script; + final String distortion; + final String param; + final int length; + final int nClean; + final int nCorrupt; + final double meanCleanZ; + final double meanCorruptZ; + final double cohensD; + final double fpr; + final double tpr; + + Row(String script, String distortion, String param, int length, + List<Float> cleanZ, List<Float> corruptZ, float threshold) { + this.script = script; + this.distortion = distortion; + this.param = param; + this.length = length; + this.nClean = cleanZ.size(); + this.nCorrupt = corruptZ.size(); + this.meanCleanZ = mean(cleanZ); + this.meanCorruptZ = mean(corruptZ); + this.cohensD = computeCohensD(cleanZ, corruptZ); + this.fpr = fractionBelow(cleanZ, threshold); + this.tpr = fractionBelow(corruptZ, threshold); + } + + String toTsv() { + return String.format("%s\t%s\t%s\t%d\t%d\t%d\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f", + script, distortion, param, length, + nClean, nCorrupt, + meanCleanZ, meanCorruptZ, + cohensD, fpr, tpr); + } + } + + // ----------------------------------------------------------------------- + // Statistics + // ----------------------------------------------------------------------- + + /** + * Cohen's d = (mean_clean − mean_corrupt) / pooled_std. + * Positive = clean scores higher than corrupt (desirable). + * Higher absolute value = better discrimination. + */ + private static double computeCohensD(List<Float> clean, List<Float> corrupt) { + if (clean.isEmpty() || corrupt.isEmpty()) { + return Double.NaN; + } + double mc = mean(clean); + double mj = mean(corrupt); + double vc = variance(clean, mc); + double vj = variance(corrupt, mj); + double pooledStd = Math.sqrt((vc + vj) / 2.0); + if (pooledStd < 1e-9) { + return Double.NaN; + } + return (mc - mj) / pooledStd; + } + + private static double mean(List<Float> xs) { + return xs.stream().mapToDouble(Float::floatValue).average().orElse(0); + } + + private static double variance(List<Float> xs, double mu) { + return xs.stream().mapToDouble(x -> (x - mu) * (x - mu)).average().orElse(0); + } + + private static double fractionBelow(List<Float> zs, float threshold) { + if (zs.isEmpty()) { + return Double.NaN; + } + long count = zs.stream().filter(z -> z < threshold).count(); + return (double) count / zs.size(); + } + + // ----------------------------------------------------------------------- + // Scoring helpers + // ----------------------------------------------------------------------- + + private static List<Float> scoreClean(JunkDetector detector, List<String> sentences, + int targetLen, int n, Random rng) { + List<Float> results = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + String s = pickSubstring(sentences, targetLen, rng); + JunkScore score = detector.score(s); + if (!score.isUnknown()) { + results.add(score.getZScore()); + } + } + return results; + } + + private static List<Float> scoreWithInjection(JunkDetector detector, + List<String> sentences, int targetLen, + double rate, int n, Random rng) { + List<Float> results = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + String s = pickSubstring(sentences, targetLen, rng); + byte[] bytes = s.getBytes(StandardCharsets.UTF_8); + injectRandomBytes(bytes, rate, rng); + JunkScore score = detector.score(bytes); + if (!score.isUnknown()) { + results.add(score.getZScore()); + } + } + return results; + } + + private static List<Float> scoreReversed(JunkDetector detector, List<String> sentences, + int targetLen, int n, Random rng) { + List<Float> results = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + String s = reverseCodepoints(pickSubstring(sentences, targetLen, rng)); + JunkScore score = detector.score(s); + if (!score.isUnknown()) { + results.add(score.getZScore()); + } + } + return results; + } + + private static List<Float> scoreShuffled(JunkDetector detector, List<String> sentences, + int targetLen, int n, Random rng) { + List<Float> results = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + String s = pickSubstring(sentences, targetLen, rng); + byte[] bytes = s.getBytes(StandardCharsets.UTF_8); + shuffleBytes(bytes, rng); + JunkScore score = detector.score(bytes); + if (!score.isUnknown()) { + results.add(score.getZScore()); + } + } + return results; + } + + // ----------------------------------------------------------------------- + // Distortion primitives + // ----------------------------------------------------------------------- + + static void injectRandomBytes(byte[] bytes, double rate, Random rng) { + for (int i = 0; i < bytes.length; i++) { + if (rng.nextDouble() < rate) { + bytes[i] = (byte) (0x80 | rng.nextInt(128)); + } + } + } + + static void shuffleBytes(byte[] bytes, Random rng) { + for (int i = bytes.length - 1; i > 0; i--) { + int j = rng.nextInt(i + 1); + byte tmp = bytes[i]; + bytes[i] = bytes[j]; + bytes[j] = tmp; + } + } + + static String reverseCodepoints(String s) { + int[] codepoints = s.codePoints().toArray(); + for (int lo = 0, hi = codepoints.length - 1; lo < hi; lo++, hi--) { + int tmp = codepoints[lo]; + codepoints[lo] = codepoints[hi]; + codepoints[hi] = tmp; + } + return new String(codepoints, 0, codepoints.length); + } + + // ----------------------------------------------------------------------- + // Sentence sampling + // ----------------------------------------------------------------------- + + private static String pickSubstring(List<String> sentences, int targetLen, Random rng) { + String s = sentences.get(rng.nextInt(sentences.size())); + byte[] bytes = s.getBytes(StandardCharsets.UTF_8); + if (bytes.length <= targetLen) { + return s; + } + // Pick a random window of targetLen bytes, aligned to a codepoint boundary + int start = rng.nextInt(bytes.length - targetLen); + while (start > 0 && (bytes[start] & 0xC0) == 0x80) { + start--; + } + int end = Math.min(start + targetLen, bytes.length); + while (end < bytes.length && (bytes[end] & 0xC0) == 0x80) { + end++; + } + return new String(bytes, start, end - start, StandardCharsets.UTF_8); + } + + private static List<String> loadSentences(Path devGz, int maxSentences) throws IOException { + List<String> result = new ArrayList<>(); + try (BufferedReader r = new BufferedReader( + new InputStreamReader( + new GZIPInputStream(Files.newInputStream(devGz)), + StandardCharsets.UTF_8))) { + String line; + while ((line = r.readLine()) != null && result.size() < maxSentences) { + String trimmed = line.strip(); + if (!trimmed.isEmpty() + && trimmed.getBytes(StandardCharsets.UTF_8).length >= 15) { + result.add(trimmed); + } + } + } + return result; + } +} diff --git a/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java new file mode 100644 index 0000000000..34ecffb533 --- /dev/null +++ b/tika-ml/tika-ml-junkdetect/src/main/java/org/apache/tika/ml/junkdetect/tools/TrainJunkModel.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.junkdetect.tools; + +import java.io.BufferedReader; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.TreeMap; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; + +/** + * Trains the junk detector model from per-script corpus files produced by + * {@link BuildJunkTrainingData}. + * + * <p>For each script group (identified by a {@code {script}.train.gz} file): + * <ol> + * <li>Accumulates byte-bigram counts from the training sentences.</li> + * <li>Applies add-1 (Laplace) smoothing per row, converts to natural + * log-probabilities.</li> + * <li>Computes calibration statistics (mean and stddev of per-sentence mean + * bigram log-prob) from the dev split ({@code {script}.dev.gz}).</li> + * </ol> + * + * <p>Output: a single gzipped binary model file ({@code junkdetect.bin}) in the + * following format: + * <pre> + * [8 bytes] magic "JUNKDET1" + * [1 byte] version = 1 + * [4 bytes] num_scripts (big-endian int) + * for each script (sorted by name): + * [2 bytes] name length (big-endian ushort) + * [N bytes] script name (UTF-8) + * [4 bytes] mu — mean of mean_bigram_logprob over dev sentences (float) + * [4 bytes] sigma — stddev (float) + * [65536 * 4 bytes] float32 log-prob table, row a*256+b = log P(b|a) + * </pre> + * + * <p>Usage: + * <pre> + * java TrainJunkModel \ + * --data-dir ~/datasets/madlad/junkdetect \ + * --output ~/datasets/madlad/junkdetect/junkdetect.bin + * </pre> + */ +public class TrainJunkModel { + + static final String MAGIC = "JUNKDET1"; + static final byte VERSION = 1; + + public static void main(String[] args) throws IOException { + Path dataDir = Paths.get(System.getProperty("user.home"), + "datasets", "madlad", "junkdetect"); + Path output = dataDir.resolve("junkdetect.bin"); + + for (int i = 0; i < args.length; i++) { + switch (args[i]) { + case "--data-dir": + dataDir = Paths.get(args[++i]); + break; + case "--output": + output = Paths.get(args[++i]); + break; + default: + System.err.println("Unknown argument: " + args[i]); + printUsage(); + System.exit(1); + } + } + + System.out.println("=== TrainJunkModel ==="); + System.out.println(" data-dir: " + dataDir); + System.out.println(" output: " + output); + + if (!Files.isDirectory(dataDir)) { + System.err.println("ERROR: data-dir not found: " + dataDir); + System.exit(1); + } + + // Collect all script names by finding *.train.gz files + TreeMap<String, float[]> tables = new TreeMap<>(); + TreeMap<String, float[]> calibrations = new TreeMap<>(); + + try (var stream = Files.list(dataDir)) { + List<Path> trainFiles = stream + .filter(p -> p.getFileName().toString().endsWith(".train.gz")) + .sorted() + .toList(); + + if (trainFiles.isEmpty()) { + System.err.println("ERROR: no *.train.gz files found in " + dataDir); + System.exit(1); + } + + for (Path trainFile : trainFiles) { + String filename = trainFile.getFileName().toString(); + String script = filename.substring(0, filename.length() - ".train.gz".length()) + .toUpperCase(); + + Path devFile = trainFile.getParent().resolve( + filename.replace(".train.gz", ".dev.gz")); + + System.out.printf("%n--- %s ---%n", script); + + System.out.print(" Training bigram table... "); + long t0 = System.currentTimeMillis(); + float[] table = trainBigramTable(trainFile); + System.out.printf("done (%dms)%n", System.currentTimeMillis() - t0); + + float[] cal = new float[]{0f, 1f}; + if (Files.exists(devFile)) { + System.out.print(" Calibrating on dev set... "); + t0 = System.currentTimeMillis(); + cal = computeCalibration(devFile, table); + System.out.printf("done — mu=%.4f sigma=%.4f (%dms)%n", + cal[0], cal[1], System.currentTimeMillis() - t0); + } else { + System.out.println(" WARNING: no dev file found, using uncalibrated defaults"); + } + + tables.put(script, table); + calibrations.put(script, cal); + } + } + + System.out.printf("%nWriting model (%d scripts) → %s%n", tables.size(), output); + saveModel(tables, calibrations, output); + System.out.printf("Model size: %,d bytes (%.1f MB)%n", + Files.size(output), Files.size(output) / 1_000_000.0); + System.out.println("Done."); + } + + // ----------------------------------------------------------------------- + // Training + // ----------------------------------------------------------------------- + + /** + * Trains a 256×256 byte-bigram log-probability table from a gzipped + * sentence file (one UTF-8 sentence per line). + * + * <p>All 256×256 consecutive byte-pair counts are accumulated, then + * add-1 (Laplace) smoothing is applied per row before converting to + * natural log-probabilities: {@code log P(b|a) = log((C[a][b]+1) / sum_b(C[a][b]+1))}. + * + * @return float[65536] table where index {@code a*256+b} = log P(b|a) + */ + static float[] trainBigramTable(Path trainGz) throws IOException { + long[] counts = new long[65536]; + long totalBigrams = 0; + long sentences = 0; + + try (BufferedReader r = openGzipped(trainGz)) { + String line; + while ((line = r.readLine()) != null) { + byte[] bytes = line.getBytes(StandardCharsets.UTF_8); + for (int i = 0; i + 1 < bytes.length; i++) { + counts[((bytes[i] & 0xFF) << 8) | (bytes[i + 1] & 0xFF)]++; + totalBigrams++; + } + sentences++; + } + } + + System.out.printf(" %,d sentences, %,d bigrams%n", sentences, totalBigrams); + + // Add-1 smoothing per row, then log-prob + float[] table = new float[65536]; + for (int a = 0; a < 256; a++) { + long rowTotal = 256; // add 1 for each of the 256 possible next bytes + for (int b = 0; b < 256; b++) { + rowTotal += counts[a * 256 + b]; + } + for (int b = 0; b < 256; b++) { + table[a * 256 + b] = (float) Math.log((counts[a * 256 + b] + 1.0) / rowTotal); + } + } + return table; + } + + /** + * Computes calibration statistics for a script by scoring each sentence + * in the dev set with the given bigram table. + * + * <p>For each sentence, the per-sentence score is the mean log-probability + * of its byte bigrams. The mean (mu) and stddev (sigma) of those scores + * across all dev sentences are returned. At inference, z-score = + * (score - mu) / sigma. + * + * @return float[2] = {mu, sigma} + */ + static float[] computeCalibration(Path devGz, float[] table) throws IOException { + List<Double> scores = new ArrayList<>(); + + try (BufferedReader r = openGzipped(devGz)) { + String line; + while ((line = r.readLine()) != null) { + byte[] bytes = line.getBytes(StandardCharsets.UTF_8); + if (bytes.length < 2) { + continue; + } + double sum = 0; + for (int i = 0; i + 1 < bytes.length; i++) { + sum += table[((bytes[i] & 0xFF) << 8) | (bytes[i + 1] & 0xFF)]; + } + scores.add(sum / (bytes.length - 1)); + } + } + + System.out.printf(" %,d dev sentences%n", scores.size()); + + if (scores.isEmpty()) { + return new float[]{0f, 1f}; + } + + double mu = scores.stream().mapToDouble(Double::doubleValue).average().orElse(0); + double variance = scores.stream() + .mapToDouble(s -> (s - mu) * (s - mu)) + .average().orElse(1.0); + double sigma = Math.sqrt(variance); + if (sigma < 1e-9) { + sigma = 1.0; + } + return new float[]{(float) mu, (float) sigma}; + } + + // ----------------------------------------------------------------------- + // Model serialisation + // ----------------------------------------------------------------------- + + /** + * Writes the trained model to a gzipped binary file. + * + * <p>Format: {@code [magic:8][version:1][num_scripts:4] + * ([name_len:2][name:N][mu:4][sigma:4][table:65536*4])*} + * All multi-byte integers are big-endian. Floats are IEEE 754 big-endian. + */ + static void saveModel(TreeMap<String, float[]> tables, + TreeMap<String, float[]> calibrations, + Path output) throws IOException { + try (DataOutputStream dos = new DataOutputStream( + new GZIPOutputStream(Files.newOutputStream(output)))) { + + // Magic + version + count + dos.write(MAGIC.getBytes(StandardCharsets.UTF_8)); + dos.writeByte(VERSION); + dos.writeInt(tables.size()); + + for (var entry : tables.entrySet()) { + String script = entry.getKey(); + float[] table = entry.getValue(); + float[] cal = calibrations.getOrDefault(script, new float[]{0f, 1f}); + + byte[] nameBytes = script.getBytes(StandardCharsets.UTF_8); + dos.writeShort(nameBytes.length); + dos.write(nameBytes); + + dos.writeFloat(cal[0]); // mu + dos.writeFloat(cal[1]); // sigma + + // Write 65536 float32 values in big-endian + ByteBuffer buf = ByteBuffer.allocate(65536 * 4).order(ByteOrder.BIG_ENDIAN); + for (float v : table) { + buf.putFloat(v); + } + dos.write(buf.array()); + } + } + } + + // ----------------------------------------------------------------------- + // Helpers + // ----------------------------------------------------------------------- + + static BufferedReader openGzipped(Path path) throws IOException { + return new BufferedReader( + new InputStreamReader( + new GZIPInputStream(Files.newInputStream(path)), + StandardCharsets.UTF_8)); + } + + private static void printUsage() { + System.err.println("Usage: TrainJunkModel [options]"); + System.err.println(" --data-dir <path> Directory with {script}.train.gz / .dev.gz files"); + System.err.println(" (default: ~/datasets/madlad/junkdetect)"); + System.err.println(" --output <path> Output model file"); + System.err.println(" (default: {data-dir}/junkdetect.bin)"); + } + +} diff --git a/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin b/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin new file mode 100644 index 0000000000..623b60df16 Binary files /dev/null and b/tika-ml/tika-ml-junkdetect/src/main/resources/org/apache/tika/ml/junkdetect/junkdetect.bin differ diff --git a/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorSmokeTest.java b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorSmokeTest.java new file mode 100644 index 0000000000..822eae19cf --- /dev/null +++ b/tika-ml/tika-ml-junkdetect/src/test/java/org/apache/tika/ml/junkdetect/JunkDetectorSmokeTest.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tika.ml.junkdetect; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.charset.StandardCharsets; +import java.util.Random; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +/** + * Smoke tests corresponding to Phase 5.1 target cases in the design doc. + * These should pass once the model is trained; failures indicate the model + * needs more data or the feature extraction is wrong. + */ +public class JunkDetectorSmokeTest { + + private static JunkDetector detector; + + @BeforeAll + static void loadModel() throws Exception { + detector = JunkDetector.loadFromClasspath(); + } + + /** + * Clean English should score higher than random high-byte garbage. + * Also serves as the byte-reversal baseline: garbage bytes ~ byte-reversed text. + */ + @Test + void cleanVsGarbage() { + JunkScore clean = detector.score("The quick brown fox jumps over the lazy dog. " + + "Pack my box with five dozen liquor jugs."); + byte[] garbageBytes = new byte[80]; + new Random(42).nextBytes(garbageBytes); + // Force all bytes >= 0x80 so it's clearly invalid UTF-8-looking garbage + for (int i = 0; i < garbageBytes.length; i++) { + garbageBytes[i] = (byte) (0x80 | (garbageBytes[i] & 0x7F)); + } + JunkScore garbage = detector.score(new String(garbageBytes, StandardCharsets.ISO_8859_1) + .getBytes(StandardCharsets.UTF_8)); + + System.out.println("clean: " + clean); + System.out.println("garbage: " + garbage); + + assertTrue(clean.getZScore() > garbage.getZScore(), + "Clean text should score higher than garbage"); + } + + /** + * Forward Arabic should score higher than character-reversed Arabic. + * Character (codepoint) reversal is a realistic distortion: it produces + * valid UTF-8 but wrong reading order — analogous to bidirectional rendering + * failures or incorrectly stored RTL text. + */ + @Test + void forwardVsReversedArabic() { + String arabic = "اللغة العربية جميلة وغنية بالمفردات والتعبيرات"; + byte[] forward = arabic.getBytes(StandardCharsets.UTF_8); + byte[] reversed = reverseString(arabic).getBytes(StandardCharsets.UTF_8); + + JunkScore fwd = detector.score(forward); + JunkScore rev = detector.score(reversed); + + System.out.println("arabic forward: " + fwd); + System.out.println("arabic reversed: " + rev); + + assertTrue(fwd.getZScore() > rev.getZScore(), + "Forward Arabic should score higher than character-reversed Arabic"); + } + + /** + * cp1257 (Baltic) decoding of Lithuanian text should win over cp1252. + * + * <p>This tests the {@link JunkDetector#compare} API: given raw bytes that were + * encoded as cp1257, scoring both decodings should prefer the correct one. + * A low delta is expected because the LATIN model is trained across ~322 languages + * and Baltic-specific bigrams are diluted. + * + * <p>TODO: improve separation by adding a Baltic sub-model or Baltic-weighted retraining. + */ + @Test + void cp1252VsCp1257OnBalticText() throws Exception { + String lithuanian = "Lietuvių kalba yra labai graži ir turtinga"; + byte[] cp1257bytes = lithuanian.getBytes("cp1257"); + + JunkDetector.CompareResult result = detector.compare(cp1257bytes, "cp1252", "cp1257"); + + System.out.println("Baltic comparison: " + result); + + assertEquals("B", result.winner(), + "cp1257 should be identified as the correct encoding for Lithuanian text"); + // Delta is weak (pooled LATIN model dilutes Baltic-specific bigrams). + // Production threshold is delta > 1.0; PoC floor is 0.1. + assertTrue(result.delta() > 0.1, + "Should have some separation: delta=" + result.delta()); + } + + /** + * cp1251 decoding of Russian text should win over cp1252. + * + * <p>This is the canonical Cyrillic mojibake scenario: Windows-1251-encoded + * Russian text misinterpreted as Windows-1252 (Western European). The cp1252 + * decoding produces Latin symbols interspersed with control characters, while + * cp1251 produces proper Cyrillic. The model should strongly prefer cp1251. + * + * <p>Note: character-reversal of LTR Cyrillic is NOT a useful test here — + * byte-bigram statistics are nearly identical forward and backward for LTR scripts, + * so the model cannot distinguish them. Use codec-confusion tests for LTR scripts. + */ + @Test + void cp1252VsCp1251OnRussianText() throws Exception { + String russian = "Русский язык является одним из восточнославянских языков"; + byte[] cp1251bytes = russian.getBytes("cp1251"); + + JunkDetector.CompareResult result = detector.compare(cp1251bytes, "cp1252", "cp1251"); + + System.out.println("Russian Cyrillic comparison: " + result); + + assertEquals("B", result.winner(), + "cp1251 should be identified as the correct encoding for Russian text"); + assertTrue(result.delta() > 1.0, + "Cyrillic codec separation should be strong: delta=" + result.delta()); + } + + /** + * Clean Japanese (CJK) should score higher than shuffled bytes. + */ + @Test + void cleanVsShuffledCjk() { + String japanese = "日本語は美しい言語であり、世界中で約1億3千万人が話している。"; + byte[] clean = japanese.getBytes(StandardCharsets.UTF_8); + byte[] shuffled = shuffled(clean, 42); + + JunkScore cleanScore = detector.score(clean); + JunkScore shuffledScore = detector.score(shuffled); + + System.out.println("Japanese clean: " + cleanScore); + System.out.println("Japanese shuffled: " + shuffledScore); + + assertTrue(cleanScore.getZScore() > shuffledScore.getZScore(), + "Clean Japanese should score higher than shuffled bytes"); + } + + // ----------------------------------------------------------------------- + + /** + * Reverses the string at codepoint granularity (not char granularity), so + * surrogate pairs are kept intact. This produces valid Unicode text in + * reverse reading order — a realistic distortion for RTL-language tests. + */ + static String reverseString(String s) { + int[] codepoints = s.codePoints().toArray(); + for (int i = 0, j = codepoints.length - 1; i < j; i++, j--) { + int tmp = codepoints[i]; + codepoints[i] = codepoints[j]; + codepoints[j] = tmp; + } + return new String(codepoints, 0, codepoints.length); + } + + private static byte[] shuffled(byte[] bytes, long seed) { + byte[] copy = bytes.clone(); + Random rng = new Random(seed); + for (int i = copy.length - 1; i > 0; i--) { + int j = rng.nextInt(i + 1); + byte tmp = copy[i]; + copy[i] = copy[j]; + copy[j] = tmp; + } + return copy; + } +}
