This is an automated email from the ASF dual-hosted git repository. placave pushed a commit to branch fix-build-cpc-cms in repository https://gitbox.apache.org/repos/asf/datasketches-java.git
commit c0912144c578d0e6b2ff5608f6d74410c106bd9e Author: Pierre Lacave <[email protected]> AuthorDate: Thu Jul 24 12:04:16 2025 +0200 Fix build following cms/cpc recent PR --- .../apache/datasketches/count/CountMinSketch.java | 71 +++++++++++++++------- .../cpc/CpcSketchCrossLanguageTest.java | 3 +- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/src/main/java/org/apache/datasketches/count/CountMinSketch.java b/src/main/java/org/apache/datasketches/count/CountMinSketch.java index 36bea38cf..4ef6a8ec5 100644 --- a/src/main/java/org/apache/datasketches/count/CountMinSketch.java +++ b/src/main/java/org/apache/datasketches/count/CountMinSketch.java @@ -22,8 +22,8 @@ package org.apache.datasketches.count; import org.apache.datasketches.common.Family; import org.apache.datasketches.common.SketchesArgumentException; import org.apache.datasketches.common.SketchesException; +import org.apache.datasketches.common.Util; import org.apache.datasketches.hash.MurmurHash3; -import org.apache.datasketches.tuple.Util; import java.io.ByteArrayOutputStream; import java.nio.ByteBuffer; @@ -39,6 +39,9 @@ public class CountMinSketch { private final long[] sketchArray_; private long totalWeight_; + // Thread-local ByteBuffer to avoid allocations in hot paths + private static final ThreadLocal<ByteBuffer> LONG_BUFFER = + ThreadLocal.withInitial(() -> ByteBuffer.allocate(8)); private enum Flag { IS_EMPTY; @@ -57,30 +60,59 @@ public class CountMinSketch { * @param seed The base hash seed */ CountMinSketch(final byte numHashes, final int numBuckets, final long seed) { - numHashes_ = numHashes; - numBuckets_ = numBuckets; - seed_ = seed; - hashSeeds_ = new long[numHashes]; - sketchArray_ = new long[numHashes * numBuckets]; - totalWeight_ = 0; + // Validate numHashes + if (numHashes <= 0) { + throw new SketchesArgumentException("Number of hash functions must be positive, got: " + numHashes); + } + // Validate numBuckets with clear mathematical justification + if (numBuckets <= 0) { + throw new SketchesArgumentException("Number of buckets must be positive, got: " + numBuckets); + } if (numBuckets < 3) { - throw new SketchesArgumentException("Using fewer than 3 buckets incurs relative error greater than 1."); + throw new SketchesArgumentException("Number of buckets must be at least 3 to ensure relative error ≤ 1.0. " + + "With " + numBuckets + " buckets, relative error would be " + String.format("%.3f", Math.exp(1.0) / numBuckets)); + } + + // Check for potential overflow in array size calculation + // Use long arithmetic to detect overflow before casting + final long totalSize = (long) numHashes * (long) numBuckets; + if (totalSize > Integer.MAX_VALUE) { + throw new SketchesArgumentException("Sketch array size would overflow: " + numHashes + " * " + numBuckets + + " = " + totalSize + " > " + Integer.MAX_VALUE); } // This check is to ensure later compatibility with a Java implementation whose maximum size can only // be 2^31-1. We check only against 2^30 for simplicity. - if (numBuckets * numHashes >= 1 << 30) { - throw new SketchesArgumentException("These parameters generate a sketch that exceeds 2^30 elements. \n" + - "Try reducing either the number of buckets or the number of hash functions."); + if (totalSize >= (1L << 30)) { + throw new SketchesArgumentException("Sketch would require excessive memory: " + numHashes + " * " + numBuckets + + " = " + totalSize + " elements (~" + String.format("%.1f", totalSize * 8.0 / (1024 * 1024 * 1024)) + " GB). " + + "Consider reducing numHashes or numBuckets."); } + numHashes_ = numHashes; + numBuckets_ = numBuckets; + seed_ = seed; + hashSeeds_ = new long[numHashes]; + sketchArray_ = new long[(int) totalSize]; + totalWeight_ = 0; + Random rand = new Random(seed); for (int i = 0; i < numHashes; i++) { hashSeeds_[i] = rand.nextLong(); } } + /** + * Efficiently converts a long to byte array using thread-local buffer to avoid allocations. + */ + private static byte[] longToBytes(final long value) { + final ByteBuffer buffer = LONG_BUFFER.get(); + buffer.clear(); + buffer.putLong(value); + return buffer.array(); + } + private long[] getHashes(byte[] item) { long[] updateLocations = new long[numHashes_]; @@ -171,8 +203,7 @@ public class CountMinSketch { * @param weight The weight of the item. */ public void update(final long item, final long weight) { - byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); - update(longByte, weight); + update(longToBytes(item), weight); } /** @@ -211,8 +242,7 @@ public class CountMinSketch { * @return Estimated frequency. */ public long getEstimate(final long item) { - byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); - return getEstimate(longByte); + return getEstimate(longToBytes(item)); } /** @@ -241,8 +271,9 @@ public class CountMinSketch { long[] hashLocations = getHashes(item); long res = sketchArray_[(int) hashLocations[0]]; - for (long h : hashLocations) { - res = Math.min(res, sketchArray_[(int) h]); + // Start from index 1 to avoid processing first element twice + for (int i = 1; i < hashLocations.length; i++) { + res = Math.min(res, sketchArray_[(int) hashLocations[i]]); } return res; @@ -254,8 +285,7 @@ public class CountMinSketch { * @return Upper bound of estimated frequency. */ public long getUpperBound(final long item) { - byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); - return getUpperBound(longByte); + return getUpperBound(longToBytes(item)); } /** @@ -291,8 +321,7 @@ public class CountMinSketch { * @return Lower bound of estimated frequency. */ public long getLowerBound(final long item) { - byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); - return getLowerBound(longByte); + return getLowerBound(longToBytes(item)); } /** diff --git a/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java b/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java index 16cc55db9..2346ec918 100644 --- a/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java +++ b/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java @@ -31,6 +31,7 @@ import java.lang.foreign.MemorySegment; import java.io.IOException; import java.nio.file.Files; +import org.apache.datasketches.memory.Memory; import org.testng.annotations.Test; /** @@ -89,7 +90,7 @@ public class CpcSketchCrossLanguageTest { int flavorIdx = 0; for (int n: nArr) { final byte[] bytes = Files.readAllBytes(goPath.resolve("cpc_n" + n + "_go.sk")); - final CpcSketch sketch = CpcSketch.heapify(Memory.wrap(bytes)); + final CpcSketch sketch = CpcSketch.heapify(MemorySegment.ofArray(bytes)); assertEquals(sketch.getFlavor(), flavorArr[flavorIdx++]); assertEquals(sketch.getEstimate(), n, n * 0.02); } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
