This is an automated email from the ASF dual-hosted git repository. rzo1 pushed a commit to branch OPENNLP-1220 in repository https://gitbox.apache.org/repos/asf/opennlp.git
commit d056dd27604502bd24bea9c1c1fd4d5050bcb77e Author: Richard Zowalla <[email protected]> AuthorDate: Sun Mar 22 21:11:32 2026 +0100 OPENNLP-1220 - Add support for Byte Pair Encoding (BPE) --- .../main/java/opennlp/tools/tokenize/BPEModel.java | 148 ++++++++++ .../java/opennlp/tools/tokenize/BPETokenizer.java | 263 +++++++++++++++++ .../tools/tokenize/BPETokenizerFactory.java | 178 +++++++++++ .../tools/tokenize/BPETokenizerTrainer.java | 201 +++++++++++++ .../java/opennlp/tools/tokenize/BPEModelTest.java | 159 ++++++++++ .../tools/tokenize/BPETokenizerFactoryTest.java | 158 ++++++++++ .../tools/tokenize/BPETokenizerRealisticTest.java | 325 +++++++++++++++++++++ .../opennlp/tools/tokenize/BPETokenizerTest.java | 230 +++++++++++++++ .../tools/tokenize/BPETokenizerTrainerTest.java | 190 ++++++++++++ 9 files changed, 1852 insertions(+) diff --git a/opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/tokenize/BPEModel.java b/opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/tokenize/BPEModel.java new file mode 100644 index 00000000..8f7e8f4f --- /dev/null +++ b/opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/tokenize/BPEModel.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.tokenize; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +import opennlp.tools.tokenize.BPETokenizer.SymbolPair; +import opennlp.tools.util.BaseToolFactory; +import opennlp.tools.util.InvalidFormatException; +import opennlp.tools.util.model.BaseModel; + +/** + * The {@link BPEModel} stores learned BPE merge operations and can be + * serialized and deserialized for reuse. + * <p> + * A model is created by the {@link BPETokenizerTrainer} and contains an ordered + * list of {@link BPETokenizer.SymbolPair} merge operations that define the BPE + * vocabulary. The model is persisted as a standard OpenNLP ZIP package with a + * {@code bpe.merges} artifact containing the merge rules. + * <p> + * <b>Usage:</b> + * <pre>{@code + * // Create via training + * BPETokenizerTrainer trainer = new BPETokenizerTrainer(); + * BPEModel model = trainer.train(corpus, 10000, "en"); + * + * // Save to disk + * model.serialize(Path.of("bpe-en.bin")); + * + * // Load from disk + * BPEModel loaded = new BPEModel(Path.of("bpe-en.bin")); + * + * // Use for tokenization + * BPETokenizer tokenizer = new BPETokenizer(loaded); + * }</pre> + * + * @see BPETokenizer + * @see BPETokenizerTrainer + * @see BPETokenizerFactory + */ +public final class BPEModel extends BaseModel { + + private static final long serialVersionUID = 1L; + private static final String COMPONENT_NAME = "BPETokenizer"; + + /** + * Creates a {@link BPEModel} from trained merge rules. + * + * @param merges The ordered list of BPE merge operations. + * @param manifestInfoEntries Additional information kept in the manifest. + * @param factory The {@link BPETokenizerFactory} to use. + */ + public BPEModel(List<SymbolPair> merges, Map<String, String> manifestInfoEntries, + BPETokenizerFactory factory) { + super(COMPONENT_NAME, factory.getLanguageCode(), manifestInfoEntries, factory); + checkArtifactMap(); + } + + /** + * Initializes a {@link BPEModel} from an {@link InputStream}. + * + * @param in The {@link InputStream} used for loading the model. + * @throws IOException Thrown if IO errors occurred during initialization. + */ + public BPEModel(InputStream in) throws IOException { + super(COMPONENT_NAME, in); + } + + /** + * Initializes a {@link BPEModel} from a {@link File}. + * + * @param modelFile The {@link File} used for loading the model. + * @throws IOException Thrown if IO errors occurred during initialization. + */ + public BPEModel(File modelFile) throws IOException { + super(COMPONENT_NAME, modelFile); + } + + /** + * Initializes a {@link BPEModel} from a {@link Path}. + * + * @param modelPath The {@link Path} used for loading the model. + * @throws IOException Thrown if IO errors occurred during initialization. + */ + public BPEModel(Path modelPath) throws IOException { + super(COMPONENT_NAME, modelPath); + } + + /** + * Initializes a {@link BPEModel} from a {@link URL}. + * + * @param modelURL The {@link URL} used for loading the model. + * @throws IOException Thrown if IO errors occurred during initialization. + */ + public BPEModel(URL modelURL) throws IOException { + super(COMPONENT_NAME, modelURL); + } + + @Override + protected void validateArtifactMap() throws InvalidFormatException { + super.validateArtifactMap(); + + Object mergesArtifact = artifactMap.get(BPETokenizerFactory.MERGES_ENTRY_NAME); + if (!(mergesArtifact instanceof List<?>)) { + throw new InvalidFormatException("BPE model is incomplete: missing merge rules!"); + } + } + + @Override + protected Class<? extends BaseToolFactory> getDefaultFactory() { + return BPETokenizerFactory.class; + } + + /** + * @return The active {@link BPETokenizerFactory}. + */ + public BPETokenizerFactory getFactory() { + return (BPETokenizerFactory) this.toolFactory; + } + + /** + * @return The ordered list of BPE merge operations stored in this model. + */ + public List<SymbolPair> getMerges() { + return getFactory().getMerges(); + } +} diff --git a/opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/tokenize/BPETokenizer.java b/opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/tokenize/BPETokenizer.java new file mode 100644 index 00000000..272d5c2f --- /dev/null +++ b/opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/tokenize/BPETokenizer.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.tokenize; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Objects; + +import opennlp.tools.util.Span; + +/** + * A {@link Tokenizer} implementation that performs subword tokenization + * using Byte Pair Encoding (BPE). + * <p> + * BPE iteratively merges the most frequent pair of adjacent symbols, + * starting from a character-level representation of each word. This allows + * the tokenizer to handle out-of-vocabulary words by decomposing them into + * known subword units. + * <p> + * <b>Usage:</b> + * <pre>{@code + * // Train a BPE model from a corpus + * BPETokenizerTrainer trainer = new BPETokenizerTrainer(); + * BPEModel model = trainer.train(corpus, 10000, "en"); + * + * // Save the model for later reuse + * model.serialize(Path.of("bpe-en.bin")); + * + * // Load and tokenize + * BPEModel loaded = new BPEModel(Path.of("bpe-en.bin")); + * BPETokenizer tokenizer = new BPETokenizer(loaded); + * String[] tokens = tokenizer.tokenize("unseen words are split into subwords"); + * }</pre> + * <p> + * The tokenizer first splits text on whitespace, then applies learned merge + * operations to each word independently. Words are decomposed into characters + * with an {@link #END_OF_WORD} marker on the final character, then merges are + * applied in priority order (as learned during training) until no more merges + * are applicable. The resulting subword units are returned as tokens. + * <p> + * For reference see: + * <ul> + * <li>Sennrich, R., Haddow, B., & Birch, A. (2016). + * Neural Machine Translation of Rare Words with Subword Units. + * <a href="https://arxiv.org/abs/1508.07909">https://arxiv.org/abs/1508.07909</a> + * </li> + * </ul> + * + * @see BPEModel + * @see BPETokenizerTrainer + * @see WordpieceTokenizer + */ +public class BPETokenizer implements Tokenizer { + + /** + * Suffix appended to the last symbol of each word during BPE encoding + * to distinguish word-final characters from word-internal ones. + * <p> + * Users constructing {@link SymbolPair} merge rules manually must use this + * constant to mark word-final symbols + * (e.g., {@code new SymbolPair("a", "b" + END_OF_WORD)}). + */ + public static final String END_OF_WORD = "</w>"; + + private final LinkedHashMap<SymbolPair, Integer> mergeRanks; + + /** + * Initializes a {@link BPETokenizer} from a trained {@link BPEModel}. + * + * @param model The trained BPE model containing merge rules. Must not be {@code null}. + * @throws NullPointerException if {@code model} is {@code null}. + */ + public BPETokenizer(BPEModel model) { + Objects.requireNonNull(model, "model must not be null"); + final List<SymbolPair> merges = model.getMerges(); + this.mergeRanks = new LinkedHashMap<>(); + for (int i = 0; i < merges.size(); i++) { + mergeRanks.put(merges.get(i), i); + } + } + + /** + * {@inheritDoc} + * <p> + * Splits the input text on whitespace, then applies BPE merge operations + * to each word to produce subword tokens. Words not fully covered by + * learned merges are decomposed into individual characters. + */ + @Override + public String[] tokenize(final String text) { + if (text == null || text.isEmpty()) { + return new String[0]; + } + + final String[] words = WhitespaceTokenizer.INSTANCE.tokenize(text); + final List<String> allTokens = new ArrayList<>(); + + for (final String word : words) { + allTokens.addAll(encodeToBPE(word)); + } + + return allTokens.toArray(new String[0]); + } + + /** + * {@inheritDoc} + * <p> + * Returns {@link Span} offsets into the original text for each subword token. + * Each span maps back to the exact character range in the input string. + */ + @Override + public Span[] tokenizePos(final String text) { + if (text == null || text.isEmpty()) { + return new Span[0]; + } + + final Span[] wordSpans = WhitespaceTokenizer.INSTANCE.tokenizePos(text); + final List<Span> allSpans = new ArrayList<>(); + + for (final Span wordSpan : wordSpans) { + final String word = wordSpan.getCoveredText(text).toString(); + final List<String> symbols = splitToSymbols(word); + final List<String> merged = applyMerges(symbols); + + int offset = wordSpan.getStart(); + for (final String token : merged) { + String clean = token.endsWith(END_OF_WORD) + ? token.substring(0, token.length() - END_OF_WORD.length()) + : token; + int len = clean.length(); + allSpans.add(new Span(offset, offset + len)); + offset += len; + } + } + + return allSpans.toArray(new Span[0]); + } + + /** + * Splits a word into its initial character-level BPE symbol sequence. + * Each character becomes its own symbol, with {@link #END_OF_WORD} appended + * to the final character. + * + * @param word The word to split. Must not be {@code null} or empty. + * @return A mutable list of character symbols. + */ + private List<String> splitToSymbols(final String word) { + final List<String> symbols = new ArrayList<>(word.length()); + for (int i = 0; i < word.length(); i++) { + if (i == word.length() - 1) { + symbols.add(word.charAt(i) + END_OF_WORD); + } else { + symbols.add(String.valueOf(word.charAt(i))); + } + } + return symbols; + } + + private List<String> encodeToBPE(final String word) { + if (word.isEmpty()) { + return List.of(); + } + + final List<String> symbols = splitToSymbols(word); + final List<String> merged = applyMerges(symbols); + + // Strip end-of-word markers and collect final tokens + final List<String> result = new ArrayList<>(); + for (final String token : merged) { + if (token.endsWith(END_OF_WORD)) { + result.add(token.substring(0, token.length() - END_OF_WORD.length())); + } else { + result.add(token); + } + } + + return result; + } + + private List<String> applyMerges(final List<String> symbols) { + if (symbols.size() <= 1) { + return symbols; + } + + List<String> current = new ArrayList<>(symbols); + + while (current.size() > 1) { + int bestRank = Integer.MAX_VALUE; + SymbolPair bestPair = null; + + for (int i = 0; i < current.size() - 1; i++) { + final SymbolPair pair = new SymbolPair(current.get(i), current.get(i + 1)); + final Integer rank = mergeRanks.get(pair); + if (rank != null && rank < bestRank) { + bestRank = rank; + bestPair = pair; + } + } + + if (bestPair == null) { + break; + } + + final List<String> next = new ArrayList<>(); + int i = 0; + while (i < current.size()) { + if (i < current.size() - 1 + && current.get(i).equals(bestPair.left()) + && current.get(i + 1).equals(bestPair.right())) { + next.add(bestPair.left() + bestPair.right()); + i += 2; + } else { + next.add(current.get(i)); + i++; + } + } + current = next; + } + + return current; + } + + /** + * Represents a pair of adjacent symbols in BPE. + * + * @param left The left symbol. + * @param right The right symbol. + */ + public record SymbolPair(String left, String right) { + + /** + * Creates a new {@link SymbolPair}. + * + * @param left The left symbol. Must not be {@code null}. + * @param right The right symbol. Must not be {@code null}. + */ + public SymbolPair { + Objects.requireNonNull(left, "left must not be null"); + Objects.requireNonNull(right, "right must not be null"); + } + + @Override + public String toString() { + return left + " " + right; + } + } +} diff --git a/opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/tokenize/BPETokenizerFactory.java b/opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/tokenize/BPETokenizerFactory.java new file mode 100644 index 00000000..43182dfb --- /dev/null +++ b/opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/tokenize/BPETokenizerFactory.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.tokenize; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import opennlp.tools.tokenize.BPETokenizer.SymbolPair; +import opennlp.tools.util.BaseToolFactory; +import opennlp.tools.util.InvalidFormatException; +import opennlp.tools.util.model.ArtifactSerializer; + +/** + * A {@link BaseToolFactory} for BPE tokenization that manages the BPE merge rules artifact + * and its serialization within a {@link BPEModel}. + * <p> + * This factory is responsible for: + * <ul> + * <li>Providing the {@link BPEMergesSerializer} that reads and writes BPE merge rules + * as a text-based artifact ({@code bpe.merges}) inside the model ZIP package.</li> + * <li>Supplying the merge rules to the {@link BPEModel} via {@link #createArtifactMap()}.</li> + * <li>Validating that a loaded model contains valid merge rules.</li> + * </ul> + * <p> + * This class is typically not used directly. It is instantiated internally by + * {@link BPETokenizerTrainer} during training and by {@link BPEModel} during + * model loading. + * + * @see BPEModel + * @see BPETokenizer + * @see BPETokenizerTrainer + */ +public class BPETokenizerFactory extends BaseToolFactory { + + static final String MERGES_ENTRY_NAME = "bpe.merges"; + + private String languageCode; + private List<SymbolPair> merges; + + /** + * Creates a {@link BPETokenizerFactory}. Required empty constructor for model loading. + */ + public BPETokenizerFactory() { + } + + /** + * Creates a {@link BPETokenizerFactory} with the specified parameters. + * + * @param languageCode The ISO language code. Must not be {@code null}. + * @param merges The ordered list of BPE merge operations. Must not be {@code null}. + */ + public BPETokenizerFactory(String languageCode, List<SymbolPair> merges) { + this.languageCode = Objects.requireNonNull(languageCode, "languageCode must not be null"); + this.merges = Objects.requireNonNull(merges, "merges must not be null"); + } + + /** {@inheritDoc} */ + @Override + public Map<String, ArtifactSerializer<?>> createArtifactSerializersMap() { + Map<String, ArtifactSerializer<?>> serializers = super.createArtifactSerializersMap(); + serializers.put("merges", new BPEMergesSerializer()); + return serializers; + } + + /** {@inheritDoc} */ + @Override + public Map<String, Object> createArtifactMap() { + Map<String, Object> artifactMap = super.createArtifactMap(); + if (merges != null) { + artifactMap.put(MERGES_ENTRY_NAME, new ArrayList<>(merges)); + } + return artifactMap; + } + + /** {@inheritDoc} */ + @Override + public Map<String, String> createManifestEntries() { + Map<String, String> entries = super.createManifestEntries(); + return entries; + } + + /** {@inheritDoc} */ + @Override + public void validateArtifactMap() throws InvalidFormatException { + Object mergesArtifact = this.artifactProvider.getArtifact(MERGES_ENTRY_NAME); + if (!(mergesArtifact instanceof List<?>)) { + throw new InvalidFormatException("Missing or invalid BPE merges artifact!"); + } + } + + /** + * @return The ISO language code for this factory. + */ + public String getLanguageCode() { + return languageCode; + } + + /** + * Retrieves the BPE merge rules from the loaded model artifacts. + * + * @return The ordered list of {@link SymbolPair} merge operations. + */ + @SuppressWarnings("unchecked") + public List<SymbolPair> getMerges() { + if (merges != null) { + return merges; + } + return (List<SymbolPair>) this.artifactProvider.getArtifact(MERGES_ENTRY_NAME); + } + + /** + * An {@link ArtifactSerializer} for BPE merge rules. + * <p> + * Serializes merge rules as a text file with one merge pair per line, + * in the format: {@code left right}. + */ + static class BPEMergesSerializer implements ArtifactSerializer<List<SymbolPair>> { + + @Override + public List<SymbolPair> create(InputStream in) throws IOException { + final List<SymbolPair> merges = new ArrayList<>(); + final BufferedReader reader = new BufferedReader( + new InputStreamReader(in, StandardCharsets.UTF_8)); + String line; + while ((line = reader.readLine()) != null) { + line = line.trim(); + if (line.isEmpty()) { + continue; + } + final int space = line.indexOf(' '); + if (space < 0) { + throw new IOException("Invalid BPE merge line (expected 'left right'): " + line); + } + merges.add(new SymbolPair(line.substring(0, space), line.substring(space + 1))); + } + return merges; + } + + @Override + public void serialize(List<SymbolPair> artifact, OutputStream out) throws IOException { + final BufferedWriter writer = new BufferedWriter( + new OutputStreamWriter(out, StandardCharsets.UTF_8)); + for (final SymbolPair merge : artifact) { + writer.write(merge.left()); + writer.write(' '); + writer.write(merge.right()); + writer.newLine(); + } + writer.flush(); + } + } +} diff --git a/opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/tokenize/BPETokenizerTrainer.java b/opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/tokenize/BPETokenizerTrainer.java new file mode 100644 index 00000000..a39bb4b7 --- /dev/null +++ b/opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/tokenize/BPETokenizerTrainer.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.tokenize; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import opennlp.tools.tokenize.BPETokenizer.SymbolPair; + +/** + * Learns BPE merge operations from a training corpus and produces a {@link BPEModel}. + * <p> + * Implements the BPE learning algorithm from Sennrich et al. (2016): + * <ol> + * <li>Build a vocabulary of character-level symbol sequences from the corpus, + * where each word is split into individual characters with an end-of-word marker.</li> + * <li>Count all adjacent symbol pairs across the vocabulary, weighted by word frequency.</li> + * <li>Merge the most frequent pair into a single new symbol.</li> + * <li>Repeat until the desired number of merges ({@code numMerges}) is reached.</li> + * </ol> + * <p> + * The number of merges controls the granularity of the resulting vocabulary: + * fewer merges produce finer-grained (more character-level) tokens, while more + * merges produce coarser (more word-level) tokens. A typical value ranges from + * a few thousand to tens of thousands, depending on the corpus size and language. + * <p> + * <b>Usage:</b> + * <pre>{@code + * List<String> corpus = List.of( + * "the cat sat on the mat", + * "the dog sat on the log" + * ); + * + * BPETokenizerTrainer trainer = new BPETokenizerTrainer(); + * BPEModel model = trainer.train(corpus, 10000, "en"); + * + * // Persist the model + * model.serialize(Path.of("bpe-en.bin")); + * + * // Use it for tokenization + * BPETokenizer tokenizer = new BPETokenizer(model); + * String[] tokens = tokenizer.tokenize("the cat"); + * }</pre> + * <p> + * For reference see: + * <ul> + * <li>Sennrich, R., Haddow, B., & Birch, A. (2016). + * Neural Machine Translation of Rare Words with Subword Units. + * <a href="https://arxiv.org/abs/1508.07909">https://arxiv.org/abs/1508.07909</a> + * </li> + * </ul> + * + * @see BPETokenizer + * @see BPEModel + */ +public final class BPETokenizerTrainer { + + /** + * Creates a new {@link BPETokenizerTrainer}. + */ + public BPETokenizerTrainer() { + } + + /** + * Learns BPE merge operations from a training corpus and returns a {@link BPEModel}. + * + * @param corpus An iterable of text strings (e.g., sentences or documents). + * Must not be {@code null}. + * @param numMerges The number of merge operations to learn. Must be positive. + * @param languageCode The ISO language code (e.g., "en", "de"). Must not be {@code null}. + * @return A trained {@link BPEModel} containing the learned merge operations. + * @throws IllegalArgumentException if {@code numMerges} is not positive. + * @throws NullPointerException if {@code corpus} or {@code languageCode} is {@code null}. + */ + public BPEModel train(final Iterable<String> corpus, + final int numMerges, + final String languageCode) { + Objects.requireNonNull(corpus, "corpus must not be null"); + Objects.requireNonNull(languageCode, "languageCode must not be null"); + if (numMerges <= 0) { + throw new IllegalArgumentException("numMerges must be positive, got: " + numMerges); + } + + final List<SymbolPair> merges = learnMerges(corpus, numMerges); + final BPETokenizerFactory factory = new BPETokenizerFactory(languageCode, merges); + + return new BPEModel(merges, new HashMap<>(), factory); + } + + private List<SymbolPair> learnMerges(final Iterable<String> corpus, final int numMerges) { + // Step 1: Build word frequency map from corpus + final Map<String, Integer> wordFreqs = new HashMap<>(); + for (final String line : corpus) { + final String[] words = WhitespaceTokenizer.INSTANCE.tokenize(line); + for (final String word : words) { + wordFreqs.merge(word, 1, Integer::sum); + } + } + + // Step 2: Convert to symbol sequences with frequencies + final Map<List<String>, Integer> vocab = new HashMap<>(); + for (final Map.Entry<String, Integer> entry : wordFreqs.entrySet()) { + final List<String> symbols = splitToSymbols(entry.getKey()); + vocab.put(symbols, entry.getValue()); + } + + // Step 3: Iteratively learn merges + final List<SymbolPair> merges = new ArrayList<>(); + + for (int step = 0; step < numMerges; step++) { + // Count all adjacent pairs + final Map<SymbolPair, Integer> pairCounts = new HashMap<>(); + for (final Map.Entry<List<String>, Integer> entry : vocab.entrySet()) { + final List<String> symbols = entry.getKey(); + final int freq = entry.getValue(); + for (int i = 0; i < symbols.size() - 1; i++) { + final SymbolPair pair = new SymbolPair(symbols.get(i), symbols.get(i + 1)); + pairCounts.merge(pair, freq, Integer::sum); + } + } + + if (pairCounts.isEmpty()) { + break; + } + + // Find most frequent pair + SymbolPair bestPair = null; + int bestCount = 0; + for (final Map.Entry<SymbolPair, Integer> entry : pairCounts.entrySet()) { + if (entry.getValue() > bestCount) { + bestCount = entry.getValue(); + bestPair = entry.getKey(); + } + } + + if (bestPair == null || bestCount < 1) { + break; + } + + merges.add(bestPair); + + // Apply merge to vocabulary + final Map<List<String>, Integer> newVocab = new HashMap<>(); + for (final Map.Entry<List<String>, Integer> entry : vocab.entrySet()) { + final List<String> merged = applyMerge(entry.getKey(), bestPair); + newVocab.merge(merged, entry.getValue(), Integer::sum); + } + vocab.clear(); + vocab.putAll(newVocab); + } + + return merges; + } + + private List<String> splitToSymbols(final String word) { + final List<String> symbols = new ArrayList<>(word.length()); + for (int i = 0; i < word.length(); i++) { + if (i == word.length() - 1) { + symbols.add(word.charAt(i) + BPETokenizer.END_OF_WORD); + } else { + symbols.add(String.valueOf(word.charAt(i))); + } + } + return symbols; + } + + private List<String> applyMerge(final List<String> symbols, final SymbolPair pair) { + final List<String> result = new ArrayList<>(); + int i = 0; + while (i < symbols.size()) { + if (i < symbols.size() - 1 + && symbols.get(i).equals(pair.left()) + && symbols.get(i + 1).equals(pair.right())) { + result.add(pair.left() + pair.right()); + i += 2; + } else { + result.add(symbols.get(i)); + i++; + } + } + return result; + } +} diff --git a/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPEModelTest.java b/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPEModelTest.java new file mode 100644 index 00000000..f9aafc6d --- /dev/null +++ b/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPEModelTest.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.tokenize; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.List; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import opennlp.tools.tokenize.BPETokenizer.SymbolPair; + +/** + * Tests for the {@link BPEModel} class. + * + * @see BPEModel + */ +public class BPEModelTest { + + private static final List<String> CORPUS = List.of( + "low low low low low", + "lower lower lower", + "newest newest newest newest", + "widest widest widest" + ); + + private BPEModel trainModel(int numMerges) { + return new BPETokenizerTrainer().train(CORPUS, numMerges, "en"); + } + + /** + * Tests that a model can be serialized and deserialized without data loss. + */ + @Test + void testBPEModelSerialization() throws IOException { + final BPEModel model = trainModel(10); + Assertions.assertFalse(model.isLoadedFromSerialized()); + + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + model.serialize(out); + out.close(); + + final BPEModel restored = new BPEModel(new ByteArrayInputStream(out.toByteArray())); + Assertions.assertNotNull(restored); + Assertions.assertTrue(restored.isLoadedFromSerialized()); + } + + /** + * Tests that merge rules are preserved after serialization roundtrip. + */ + @Test + void testMergesPreservedAfterSerialization() throws IOException { + final BPEModel original = trainModel(10); + + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + original.serialize(out); + + final BPEModel restored = new BPEModel(new ByteArrayInputStream(out.toByteArray())); + + final List<SymbolPair> originalMerges = original.getMerges(); + final List<SymbolPair> restoredMerges = restored.getMerges(); + + Assertions.assertEquals(originalMerges.size(), restoredMerges.size()); + for (int i = 0; i < originalMerges.size(); i++) { + Assertions.assertEquals(originalMerges.get(i), restoredMerges.get(i)); + } + } + + /** + * Tests that merge order is preserved — order determines priority. + */ + @Test + void testMergeOrderPreserved() throws IOException { + final BPEModel model = trainModel(5); + + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + model.serialize(out); + + final BPEModel restored = new BPEModel(new ByteArrayInputStream(out.toByteArray())); + + // Verify exact order matches + Assertions.assertEquals(model.getMerges(), restored.getMerges()); + } + + /** + * Tests that a deserialized model can be used to tokenize text. + */ + @Test + void testDeserializedModelCanTokenize() throws IOException { + final BPEModel original = trainModel(10); + + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + original.serialize(out); + + final BPEModel loaded = new BPEModel(new ByteArrayInputStream(out.toByteArray())); + final BPETokenizer tokenizer = new BPETokenizer(loaded); + + final String[] tokens = tokenizer.tokenize("low"); + Assertions.assertTrue(tokens.length >= 1); + Assertions.assertEquals("low", String.join("", tokens)); + } + + /** + * Tests that the language code is preserved in the model. + */ + @Test + void testLanguagePreserved() throws IOException { + final BPEModel model = new BPETokenizerTrainer().train(CORPUS, 5, "de"); + + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + model.serialize(out); + + final BPEModel restored = new BPEModel(new ByteArrayInputStream(out.toByteArray())); + Assertions.assertEquals("de", restored.getLanguage()); + } + + /** + * Tests that the factory is accessible from a deserialized model. + */ + @Test + void testFactoryAccessibleAfterDeserialization() throws IOException { + final BPEModel original = trainModel(5); + + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + original.serialize(out); + + final BPEModel restored = new BPEModel(new ByteArrayInputStream(out.toByteArray())); + Assertions.assertNotNull(restored.getFactory()); + Assertions.assertInstanceOf(BPETokenizerFactory.class, restored.getFactory()); + } + + /** + * Tests that getMerges() returns the same result whether accessed via + * the model or the factory. + */ + @Test + void testGetMergesConsistency() { + final BPEModel model = trainModel(5); + + Assertions.assertEquals(model.getMerges(), model.getFactory().getMerges()); + } +} diff --git a/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPETokenizerFactoryTest.java b/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPETokenizerFactoryTest.java new file mode 100644 index 00000000..86e9b580 --- /dev/null +++ b/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPETokenizerFactoryTest.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.tokenize; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.List; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import opennlp.tools.tokenize.BPETokenizer.SymbolPair; + +/** + * Tests for the {@link BPETokenizerFactory} class. + * <p> + * Verifies that the factory correctly manages BPE merge rules artifacts, + * serializers, and that properties survive model serialization roundtrips. + * + * @see BPETokenizerFactory + * @see BPEModel + */ +public class BPETokenizerFactoryTest { + + private static final List<String> CORPUS = List.of( + "low low low low low", + "lower lower lower", + "newest newest newest newest" + ); + + /** + * Tests that the factory provides merge rules after training. + */ + @Test + void testFactoryProvidesMerges() { + final BPEModel model = new BPETokenizerTrainer().train(CORPUS, 10, "en"); + final BPETokenizerFactory factory = model.getFactory(); + + Assertions.assertNotNull(factory); + Assertions.assertNotNull(factory.getMerges()); + Assertions.assertFalse(factory.getMerges().isEmpty()); + } + + /** + * Tests that the factory language code is set correctly. + */ + @Test + void testFactoryLanguageCode() { + final List<SymbolPair> merges = List.of(new SymbolPair("a", "b")); + final BPETokenizerFactory factory = new BPETokenizerFactory("de", merges); + + Assertions.assertEquals("de", factory.getLanguageCode()); + } + + /** + * Tests that merge rules are accessible from the factory after + * model serialization and deserialization. + */ + @Test + void testFactorySurvivesSerialization() throws IOException { + final BPEModel original = new BPETokenizerTrainer().train(CORPUS, 10, "en"); + + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + original.serialize(out); + + final BPEModel restored = new BPEModel(new ByteArrayInputStream(out.toByteArray())); + final BPETokenizerFactory factory = restored.getFactory(); + + Assertions.assertNotNull(factory); + Assertions.assertNotNull(factory.getMerges()); + Assertions.assertEquals(original.getMerges().size(), factory.getMerges().size()); + } + + /** + * Tests that the factory merges are consistent between direct construction + * and deserialized access. + */ + @Test + void testMergesConsistentAfterRoundtrip() throws IOException { + final BPEModel original = new BPETokenizerTrainer().train(CORPUS, 5, "en"); + final List<SymbolPair> originalMerges = original.getFactory().getMerges(); + + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + original.serialize(out); + + final BPEModel restored = new BPEModel(new ByteArrayInputStream(out.toByteArray())); + final List<SymbolPair> restoredMerges = restored.getFactory().getMerges(); + + Assertions.assertEquals(originalMerges, restoredMerges); + } + + /** + * Tests that the factory creates the correct artifact serializer map. + */ + @Test + void testArtifactSerializersMapContainsMergesSerializer() { + final BPETokenizerFactory factory = new BPETokenizerFactory("en", List.of()); + + Assertions.assertTrue(factory.createArtifactSerializersMap().containsKey("merges")); + } + + /** + * Tests that the factory creates an artifact map containing the merges entry. + */ + @Test + void testArtifactMapContainsMergesEntry() { + final List<SymbolPair> merges = List.of( + new SymbolPair("a", "b"), + new SymbolPair("ab", "c" + BPETokenizer.END_OF_WORD) + ); + final BPETokenizerFactory factory = new BPETokenizerFactory("en", merges); + + Assertions.assertTrue(factory.createArtifactMap().containsKey(BPETokenizerFactory.MERGES_ENTRY_NAME)); + } + + /** + * Tests that the empty constructor creates a valid factory (for model loading). + */ + @Test + void testEmptyConstructor() { + final BPETokenizerFactory factory = new BPETokenizerFactory(); + + // Empty factory should not throw + Assertions.assertNotNull(factory); + Assertions.assertNotNull(factory.createArtifactSerializersMap()); + } + + /** + * Tests null parameter validation. + */ + @Test + void testNullLanguageCodeThrows() { + Assertions.assertThrows(NullPointerException.class, + () -> new BPETokenizerFactory(null, List.of())); + } + + @Test + void testNullMergesThrows() { + Assertions.assertThrows(NullPointerException.class, + () -> new BPETokenizerFactory("en", null)); + } +} diff --git a/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPETokenizerRealisticTest.java b/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPETokenizerRealisticTest.java new file mode 100644 index 00000000..b5770bee --- /dev/null +++ b/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPETokenizerRealisticTest.java @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.tokenize; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.List; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import opennlp.tools.util.Span; + +/** + * Integration tests for the BPE tokenization pipeline. + * <p> + * This test trains a BPE tokenizer from a realistic English corpus, + * serializes and deserializes the model, then verifies tokenization + * behavior end-to-end. Mirrors the structure of {@link TokenizerMETest}. + * + * @see BPETokenizer + * @see BPETokenizerTrainer + * @see BPEModel + */ +public class BPETokenizerRealisticTest { + + /** + * A small but realistic training corpus for BPE. + */ + private static final List<String> TRAINING_CORPUS = List.of( + "Last September I tried to find out the address of an old school friend", + "whom I had not seen for 15 years", + "I just knew his name Alan McKennedy and I had heard the rumour", + "that he had moved to Scotland the country of his ancestors", + "So I called Julie a friend who is still in contact with him", + "She told me that he lived in Edinburgh Worcesterstreet 12", + "I wrote him a letter right away and he answered soon", + "sounding very happy and delighted", + "Last year I wanted to write a letter to my grandaunt", + "Her 86th birthday was on October 6 and I no longer wanted", + "to be hesitant to get in touch with her", + "I did not know her face to face and so it was not easy", + "for me to find out her address", + "As she had two apartments in different countries", + "I decided to write to both", + "The first was in Paris in Rue de Grandes Illusions 5", + "But Marie Clara as my aunt is called preferred her apartment in Berlin", + "She lived there in beautiful Kaiserstrasse 13 particularly in summer", + "Hi my name is Michael Graf how much is a taxi", + "from Ostbahnhof to Hauptbahnhof", + "About 10 Euro I reckon", + "That sounds good", + "So please call a driver to Leonardstrasse 112 near the Ostbahnhof", + "I would like to be at Silberhornstrasse 12 as soon as possible", + "Thank you very much" + ); + + private static BPEModel trainedModel; + + @BeforeAll + static void setUpClass() { + trainedModel = new BPETokenizerTrainer().train(TRAINING_CORPUS, 100, "en"); + } + + /** + * Tests basic tokenization of a simple sentence with the trained model. + * All words appear in the training corpus and should be fully merged. + */ + @Test + void testTokenizerSimpleModel() { + final BPETokenizer tokenizer = new BPETokenizer(trainedModel); + final String text = "I wrote a letter"; + + final String[] tokens = tokenizer.tokenize(text); + final Span[] spans = tokenizer.tokenizePos(text); + + // All four words are common in training corpus — assert exact reconstruction + final String[] words = reconstructWords(tokens, spans, text); + Assertions.assertArrayEquals(new String[] {"I", "wrote", "a", "letter"}, words); + } + + /** + * Tests tokenization of a sentence with words seen during training. + * Frequent words should be tokenized into fewer subword pieces. + */ + @Test + void testFrequentWordsTokenizeEfficiently() { + final BPETokenizer tokenizer = new BPETokenizer(trainedModel); + + // "the" and "in" appear very frequently in the training corpus + final String[] theTokens = tokenizer.tokenize("the"); + final String[] inTokens = tokenizer.tokenize("in"); + + // With 100 merges on this corpus, these common words should be single tokens + Assertions.assertEquals(1, theTokens.length, "Expected 'the' as single token"); + Assertions.assertEquals("the", theTokens[0]); + Assertions.assertEquals(1, inTokens.length, "Expected 'in' as single token"); + Assertions.assertEquals("in", inTokens[0]); + } + + /** + * Tests tokenization of unseen words — they should be split into subword pieces + * but concatenation must still reconstruct the original. + */ + @Test + void testUnseenWordsTokenization() { + final BPETokenizer tokenizer = new BPETokenizer(trainedModel); + + final String[] tokens = tokenizer.tokenize("unbelievable"); + + // An unseen word will be split into multiple subword pieces + Assertions.assertTrue(tokens.length > 1, + "Unseen word 'unbelievable' should be split into multiple subword tokens"); + Assertions.assertEquals("unbelievable", String.join("", tokens), + "Concatenation of subword tokens must reconstruct the original word"); + } + + /** + * Tests that tokenizePos spans cover the full input text without gaps or overlaps + * and that reconstructed words match the original sentence. + */ + @Test + void testTokenizePosSpanCoverage() { + final BPETokenizer tokenizer = new BPETokenizer(trainedModel); + final String text = "She lived in Edinburgh"; + final String[] tokens = tokenizer.tokenize(text); + final Span[] spans = tokenizer.tokenizePos(text); + + // Verify all spans extract non-empty substrings + for (final Span span : spans) { + final CharSequence covered = span.getCoveredText(text); + Assertions.assertNotNull(covered); + Assertions.assertFalse(covered.toString().isEmpty()); + } + + // Verify that spans + whitespace fully reconstruct the original text + final StringBuilder sb = new StringBuilder(); + int lastEnd = 0; + for (final Span span : spans) { + if (span.getStart() > lastEnd) { + sb.append(text, lastEnd, span.getStart()); + } + sb.append(span.getCoveredText(text)); + lastEnd = span.getEnd(); + } + Assertions.assertEquals(text, sb.toString()); + + // Verify reconstructed words match expected + final String[] words = reconstructWords(tokens, spans, text); + Assertions.assertArrayEquals(new String[] {"She", "lived", "in", "Edinburgh"}, words); + } + + /** + * Tests that the BPE tokenizer handles multi-word input correctly, + * similar to {@link TokenizerMETest#testTokenizer()}. + */ + @Test + void testTokenizer() { + final BPETokenizer tokenizer = new BPETokenizer(trainedModel); + final String sentence = "I had not seen him for years"; + final String[] tokens = tokenizer.tokenize(sentence); + + // Each word produces at least one token + final String[] words = sentence.split(" "); + Assertions.assertTrue(tokens.length >= words.length); + + // Reconstruct each word from its subword tokens via spans + final Span[] spans = tokenizer.tokenizePos(sentence); + final String[] reconstructed = reconstructWords(tokens, spans, sentence); + Assertions.assertArrayEquals(words, reconstructed); + } + + /** + * Tests the full pipeline: train, serialize, deserialize, tokenize. + * Similar to {@link TokenizerModelTest#testTokenizerModelSerialization()}. + */ + @Test + void testTrainSerializeDeserializeTokenize() throws IOException { + // Serialize + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + trainedModel.serialize(out); + out.close(); + + // Deserialize + final BPEModel loaded = new BPEModel(new ByteArrayInputStream(out.toByteArray())); + + // Tokenize with both original and deserialized model — results should match + final BPETokenizer original = new BPETokenizer(trainedModel); + final BPETokenizer restored = new BPETokenizer(loaded); + + final String sentence = "I wrote him a letter right away"; + Assertions.assertArrayEquals( + original.tokenize(sentence), + restored.tokenize(sentence)); + } + + /** + * Tests that the BPE tokenizer fulfills the {@link Tokenizer} contract: + * tokenize() and tokenizePos() must be consistent. + */ + @Test + void testTokenizeAndTokenizePosConsistency() { + final BPETokenizer tokenizer = new BPETokenizer(trainedModel); + final String text = "She told me that he lived in Edinburgh"; + + final String[] tokens = tokenizer.tokenize(text); + final Span[] spans = tokenizer.tokenizePos(text); + + Assertions.assertEquals(tokens.length, spans.length); + + for (int i = 0; i < tokens.length; i++) { + Assertions.assertEquals(tokens[i], spans[i].getCoveredText(text).toString(), + "Token at index " + i + " should match span-covered text"); + } + } + + /** + * Tests tokenization with a model trained on a German corpus. + * Frequent German words should be fully merged and reconstructed correctly. + */ + @Test + void testTrainWithDifferentLanguage() { + final List<String> germanCorpus = List.of( + "Ich wähle den auf Seite 183 mitgeteilten Traum", + "von der botanischen Monographie", + "Der Traum von der botanischen Monographie", + "Ich wähle den Traum von der botanischen Monographie" + ); + + final BPEModel model = new BPETokenizerTrainer().train(germanCorpus, 50, "de"); + Assertions.assertEquals("de", model.getLanguage()); + + final BPETokenizer tokenizer = new BPETokenizer(model); + final String text = "der botanischen Monographie"; + final String[] tokens = tokenizer.tokenize(text); + final Span[] spans = tokenizer.tokenizePos(text); + + // Assert words are reconstructed correctly + final String[] words = reconstructWords(tokens, spans, text); + Assertions.assertArrayEquals(new String[] {"der", "botanischen", "Monographie"}, words); + } + + /** + * Tests that the BPE tokenizer handles punctuation mixed with words. + * BPE treats punctuation as characters — they stay attached to the word + * since BPE splits on whitespace first. + */ + @Test + void testPunctuationHandling() { + final BPETokenizer tokenizer = new BPETokenizer(trainedModel); + final String text = "Hello, world!"; + + final String[] tokens = tokenizer.tokenize(text); + final Span[] spans = tokenizer.tokenizePos(text); + + // "Hello," and "world!" are separate whitespace tokens, each may be subword-split + final String[] words = reconstructWords(tokens, spans, text); + Assertions.assertEquals(2, words.length); + Assertions.assertEquals("Hello,", words[0]); + Assertions.assertEquals("world!", words[1]); + } + + /** + * Tests that training with a larger number of merges produces + * coarser tokenization (fewer subword tokens per word). + */ + @Test + void testMoreMergesProducesCoarserTokens() { + final BPEModel fewMerges = new BPETokenizerTrainer().train(TRAINING_CORPUS, 5, "en"); + final BPEModel manyMerges = new BPETokenizerTrainer().train(TRAINING_CORPUS, 100, "en"); + + final BPETokenizer fewTokenizer = new BPETokenizer(fewMerges); + final BPETokenizer manyTokenizer = new BPETokenizer(manyMerges); + + // With more merges, the same text should produce fewer or equal tokens + final String text = "I wanted to write a letter to my grandaunt"; + final int fewCount = fewTokenizer.tokenize(text).length; + final int manyCount = manyTokenizer.tokenize(text).length; + + Assertions.assertTrue(manyCount <= fewCount, + "More merges (" + manyCount + " tokens) should produce fewer or equal tokens " + + "than fewer merges (" + fewCount + " tokens)"); + } + + /** + * Reconstructs whitespace-separated words from subword tokens using span positions. + */ + private String[] reconstructWords(String[] tokens, Span[] spans, String text) { + final java.util.List<String> words = new java.util.ArrayList<>(); + final StringBuilder currentWord = new StringBuilder(); + int lastWordEnd = -1; + + for (final Span span : spans) { + if (lastWordEnd >= 0 && span.getStart() > lastWordEnd) { + // Gap between spans means a whitespace boundary — new word + words.add(currentWord.toString()); + currentWord.setLength(0); + } + currentWord.append(span.getCoveredText(text)); + lastWordEnd = span.getEnd(); + } + if (!currentWord.isEmpty()) { + words.add(currentWord.toString()); + } + + return words.toArray(new String[0]); + } +} diff --git a/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPETokenizerTest.java b/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPETokenizerTest.java new file mode 100644 index 00000000..b4e72b2a --- /dev/null +++ b/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPETokenizerTest.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.tokenize; + +import java.util.HashMap; +import java.util.List; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import opennlp.tools.tokenize.BPETokenizer.SymbolPair; +import opennlp.tools.util.Span; + +/** + * Tests for the {@link BPETokenizer} class. + * <p> + * Verifies that BPE tokenization correctly splits text into subword tokens + * based on learned merge operations, and that span positions map back to + * the original text. + * + * @see BPETokenizer + */ +public class BPETokenizerTest { + + private static BPEModel createModel(List<SymbolPair> merges) { + final BPETokenizerFactory factory = new BPETokenizerFactory("en", merges); + return new BPEModel(merges, new HashMap<>(), factory); + } + + /** + * Tests that a fully merged word produces a single token. + */ + @Test + void testBasicBPETokenization() { + final List<SymbolPair> merges = List.of( + new SymbolPair("l", "o"), + new SymbolPair("lo", "w" + BPETokenizer.END_OF_WORD), + new SymbolPair("e", "r" + BPETokenizer.END_OF_WORD) + ); + + final BPETokenizer tokenizer = new BPETokenizer(createModel(merges)); + final String[] tokens = tokenizer.tokenize("low"); + + Assertions.assertArrayEquals(new String[]{"low"}, tokens); + } + + /** + * Tests that a word not fully covered by merges is split into subword tokens. + */ + @Test + void testSubwordSplitting() { + final List<SymbolPair> merges = List.of( + new SymbolPair("l", "o"), + new SymbolPair("lo", "w" + BPETokenizer.END_OF_WORD) + ); + + final BPETokenizer tokenizer = new BPETokenizer(createModel(merges)); + final String[] tokens = tokenizer.tokenize("lower"); + + // "lower" cannot fully merge since "w" is not word-final here + Assertions.assertTrue(tokens.length > 1); + Assertions.assertEquals("lower", String.join("", tokens)); + } + + /** + * Tests tokenization of multiple whitespace-separated words. + */ + @Test + void testMultipleWords() { + final List<SymbolPair> merges = List.of( + new SymbolPair("l", "o"), + new SymbolPair("lo", "w" + BPETokenizer.END_OF_WORD) + ); + + final BPETokenizer tokenizer = new BPETokenizer(createModel(merges)); + final String[] tokens = tokenizer.tokenize("low low"); + + Assertions.assertEquals(2, tokens.length); + Assertions.assertEquals("low", tokens[0]); + Assertions.assertEquals("low", tokens[1]); + } + + /** + * Tests that empty and null input produce empty arrays. + */ + @Test + void testEmptyInput() { + final BPETokenizer tokenizer = new BPETokenizer(createModel(List.of())); + + Assertions.assertArrayEquals(new String[0], tokenizer.tokenize("")); + Assertions.assertArrayEquals(new String[0], tokenizer.tokenize(null)); + Assertions.assertArrayEquals(new Span[0], tokenizer.tokenizePos("")); + Assertions.assertArrayEquals(new Span[0], tokenizer.tokenizePos(null)); + } + + /** + * Tests that with no merges, each character becomes a separate token. + */ + @Test + void testNoMergesProducesCharacterTokens() { + final BPETokenizer tokenizer = new BPETokenizer(createModel(List.of())); + final String[] tokens = tokenizer.tokenize("hi"); + + Assertions.assertArrayEquals(new String[]{"h", "i"}, tokens); + } + + /** + * Tests single-character word tokenization. + */ + @Test + void testSingleCharacterWord() { + final BPETokenizer tokenizer = new BPETokenizer(createModel(List.of())); + final String[] tokens = tokenizer.tokenize("a"); + + Assertions.assertArrayEquals(new String[]{"a"}, tokens); + } + + /** + * Tests that {@link BPETokenizer#tokenizePos(String)} returns correct spans + * that map back to the original text. + */ + @Test + void testTokenizePos() { + final List<SymbolPair> merges = List.of( + new SymbolPair("l", "o"), + new SymbolPair("lo", "w" + BPETokenizer.END_OF_WORD) + ); + + final BPETokenizer tokenizer = new BPETokenizer(createModel(merges)); + final String text = "low hi"; + final Span[] spans = tokenizer.tokenizePos(text); + + // "low" -> 1 token, "hi" -> 2 tokens (no merges for h, i) + Assertions.assertEquals(3, spans.length); + Assertions.assertEquals(0, spans[0].getStart()); + Assertions.assertEquals(3, spans[0].getEnd()); + Assertions.assertEquals("low", spans[0].getCoveredText(text)); + // "h" + Assertions.assertEquals(4, spans[1].getStart()); + Assertions.assertEquals(5, spans[1].getEnd()); + Assertions.assertEquals("h", spans[1].getCoveredText(text)); + // "i" + Assertions.assertEquals(5, spans[2].getStart()); + Assertions.assertEquals(6, spans[2].getEnd()); + Assertions.assertEquals("i", spans[2].getCoveredText(text)); + } + + /** + * Tests that span offsets are correct for subword-split words. + */ + @Test + void testTokenizePosWithSubwords() { + final BPETokenizer tokenizer = new BPETokenizer(createModel(List.of())); + final String text = "ab cd"; + final Span[] spans = tokenizer.tokenizePos(text); + + // "ab" -> a, b; "cd" -> c, d + Assertions.assertEquals(4, spans.length); + Assertions.assertEquals("a", spans[0].getCoveredText(text)); + Assertions.assertEquals("b", spans[1].getCoveredText(text)); + Assertions.assertEquals("c", spans[2].getCoveredText(text)); + Assertions.assertEquals("d", spans[3].getCoveredText(text)); + } + + /** + * Tests that concatenating all tokens reconstructs the original word. + */ + @Test + void testTokenConcatenationEqualsOriginal() { + final List<SymbolPair> merges = List.of( + new SymbolPair("l", "o"), + new SymbolPair("lo", "w" + BPETokenizer.END_OF_WORD) + ); + + final BPETokenizer tokenizer = new BPETokenizer(createModel(merges)); + final String[] tokens = tokenizer.tokenize("lower"); + + Assertions.assertEquals("lower", String.join("", tokens)); + } + + /** + * Tests that a null model throws NullPointerException. + */ + @Test + void testNullModelThrows() { + Assertions.assertThrows(NullPointerException.class, () -> new BPETokenizer(null)); + } + + @Test + void testSymbolPairNullLeftThrows() { + Assertions.assertThrows(NullPointerException.class, () -> new SymbolPair(null, "b")); + } + + @Test + void testSymbolPairNullRightThrows() { + Assertions.assertThrows(NullPointerException.class, () -> new SymbolPair("a", null)); + } + + @Test + void testSymbolPairEquality() { + final SymbolPair a = new SymbolPair("lo", "w"); + final SymbolPair b = new SymbolPair("lo", "w"); + final SymbolPair c = new SymbolPair("l", "ow"); + + Assertions.assertEquals(a, b); + Assertions.assertEquals(a.hashCode(), b.hashCode()); + Assertions.assertNotEquals(a, c); + } + + @Test + void testSymbolPairToString() { + final SymbolPair pair = new SymbolPair("lo", "w"); + Assertions.assertEquals("lo w", pair.toString()); + } +} diff --git a/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPETokenizerTrainerTest.java b/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPETokenizerTrainerTest.java new file mode 100644 index 00000000..f849798a --- /dev/null +++ b/opennlp-core/opennlp-runtime/src/test/java/opennlp/tools/tokenize/BPETokenizerTrainerTest.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.tokenize; + +import java.util.Arrays; +import java.util.List; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import opennlp.tools.tokenize.BPETokenizer.SymbolPair; + +/** + * Tests for the {@link BPETokenizerTrainer} class. + * <p> + * Verifies that BPE merge operations are learned correctly from + * a training corpus and that the resulting model can be used for tokenization. + * + * @see BPETokenizerTrainer + * @see BPEModel + */ +public class BPETokenizerTrainerTest { + + private BPETokenizerTrainer trainer; + + @BeforeEach + void setUp() { + trainer = new BPETokenizerTrainer(); + } + + /** + * Tests that training produces a non-null model with merge rules. + */ + @Test + void testTrainProducesModel() { + final List<String> corpus = List.of( + "low low low low low", + "lower lower lower", + "newest newest newest newest", + "widest widest widest" + ); + + final BPEModel model = trainer.train(corpus, 10, "en"); + + Assertions.assertNotNull(model); + Assertions.assertFalse(model.getMerges().isEmpty()); + Assertions.assertTrue(model.getMerges().size() <= 10); + } + + /** + * Tests that the first merge is the most frequent adjacent pair. + * For the corpus "ab ab ab ...", the most frequent pair is ("a", "b</w>"). + */ + @Test + void testFirstMergeIsMostFrequentPair() { + final List<String> corpus = List.of( + "ab ab ab ab ab ab ab ab ab ab" + ); + + final BPEModel model = trainer.train(corpus, 1, "en"); + + Assertions.assertEquals(1, model.getMerges().size()); + Assertions.assertEquals("a", model.getMerges().getFirst().left()); + Assertions.assertEquals("b" + BPETokenizer.END_OF_WORD, model.getMerges().getFirst().right()); + } + + /** + * Tests that requesting more merges than possible stops gracefully. + */ + @Test + void testMoreMergesThanPossible() { + final List<String> corpus = List.of("ab"); + + // "ab" has only one possible pair: ("a", "b</w>") + final BPEModel model = trainer.train(corpus, 100, "en"); + + // Should stop after exhausting all possible merges + Assertions.assertTrue(model.getMerges().size() < 100); + Assertions.assertFalse(model.getMerges().isEmpty()); + } + + /** + * Tests that frequent words get merged into fewer tokens. + */ + @Test + void testFrequentWordsProduceFewerTokens() { + final List<String> corpus = List.of( + "the the the the the the the the the the", + "the the the the the the the the the the", + "xyzzy" + ); + + final BPEModel model = trainer.train(corpus, 20, "en"); + final BPETokenizer tokenizer = new BPETokenizer(model); + + final String[] theTokens = tokenizer.tokenize("the"); + final String[] xyzzyTokens = tokenizer.tokenize("xyzzy"); + + // "the" (very frequent) should have fewer or equal tokens compared to "xyzzy" (rare) + Assertions.assertTrue(theTokens.length <= xyzzyTokens.length, + "Expected 'the' (" + Arrays.toString(theTokens) + ") to have fewer tokens than 'xyzzy' (" + + Arrays.toString(xyzzyTokens) + ")"); + } + + /** + * Tests that the trained model produces a tokenizer that reconstructs + * the original words when tokens are concatenated. + */ + @Test + void testTrainAndTokenizeRoundtrip() { + final List<String> corpus = List.of( + "the cat sat on the mat", + "the cat sat on the mat", + "the cat sat on the mat", + "the dog sat on the log", + "the dog sat on the log" + ); + + final BPEModel model = trainer.train(corpus, 20, "en"); + final BPETokenizer tokenizer = new BPETokenizer(model); + + // Verify token concatenation restores the original word + for (final String word : new String[]{"the", "cat", "sat", "dog"}) { + final String[] tokens = tokenizer.tokenize(word); + Assertions.assertEquals(word, String.join("", tokens), + "Token concatenation should reconstruct '" + word + "'"); + } + } + + /** + * Tests that training with an empty corpus produces a model with no merges. + */ + @Test + void testEmptyCorpus() { + final BPEModel model = trainer.train(List.of(), 10, "en"); + + Assertions.assertNotNull(model); + Assertions.assertTrue(model.getMerges().isEmpty()); + } + + /** + * Tests that the language code is set on the produced model. + */ + @Test + void testLanguageCodePreserved() { + final BPEModel model = trainer.train(List.of("hello world"), 5, "de"); + + Assertions.assertEquals("de", model.getLanguage()); + } + + @Test + void testNullCorpusThrows() { + Assertions.assertThrows(NullPointerException.class, + () -> trainer.train(null, 10, "en")); + } + + @Test + void testNullLanguageThrows() { + Assertions.assertThrows(NullPointerException.class, + () -> trainer.train(List.of("hello"), 10, null)); + } + + @Test + void testZeroMergesThrows() { + Assertions.assertThrows(IllegalArgumentException.class, + () -> trainer.train(List.of("hello"), 0, "en")); + } + + @Test + void testNegativeMergesThrows() { + Assertions.assertThrows(IllegalArgumentException.class, + () -> trainer.train(List.of("hello"), -1, "en")); + } +}
