This is an automated email from the ASF dual-hosted git repository.
leerho pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datasketches-rust.git
The following commit(s) were added to refs/heads/main by this push:
new f5143c3 feat: add a generic error type (#47)
f5143c3 is described below
commit f5143c311ae169b9c0bead1e47bd58881526f83e
Author: tison <[email protected]>
AuthorDate: Tue Dec 30 09:53:15 2025 +0800
feat: add a generic error type (#47)
* feat: add a generic error type
Signed-off-by: tison <[email protected]>
* apply
Signed-off-by: tison <[email protected]>
* convenience constructors
Signed-off-by: tison <[email protected]>
* more
Signed-off-by: tison <[email protected]>
* rename error kind
Signed-off-by: tison <[email protected]>
* no need source for now
Signed-off-by: tison <[email protected]>
---------
Signed-off-by: tison <[email protected]>
---
Cargo.lock | 7 ++
Cargo.toml | 1 +
datasketches/Cargo.toml | 1 +
datasketches/src/countmin/sketch.rs | 58 +++++-----
datasketches/src/error.rs | 167 ++++++++++++++++++++++++----
datasketches/src/hll/array4.rs | 10 +-
datasketches/src/hll/array6.rs | 8 +-
datasketches/src/hll/array8.rs | 8 +-
datasketches/src/hll/hash_set.rs | 10 +-
datasketches/src/hll/list.rs | 8 +-
datasketches/src/hll/serialization.rs | 2 +-
datasketches/src/hll/sketch.rs | 42 +++-----
datasketches/src/tdigest/sketch.rs | 197 ++++++++++++++++------------------
13 files changed, 309 insertions(+), 210 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index ba741cd..66212a0 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -61,6 +61,12 @@ dependencies = [
"windows-sys",
]
+[[package]]
+name = "anyhow"
+version = "1.0.100"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
+
[[package]]
name = "autocfg"
version = "1.5.0"
@@ -129,6 +135,7 @@ checksum =
"b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
name = "datasketches"
version = "0.1.0"
dependencies = [
+ "anyhow",
"byteorder",
"googletest",
]
diff --git a/Cargo.toml b/Cargo.toml
index e991913..4c93ee2 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -32,6 +32,7 @@ rust-version = "1.85.0"
datasketches = { path = "datasketches" }
# Crates.io dependencies
+anyhow = { version = "1.0.100" }
byteorder = { version = "1.5.0" }
clap = { version = "4.5.20", features = ["derive"] }
googletest = { version = "0.14.2" }
diff --git a/datasketches/Cargo.toml b/datasketches/Cargo.toml
index bddcb7e..8a150ff 100644
--- a/datasketches/Cargo.toml
+++ b/datasketches/Cargo.toml
@@ -35,6 +35,7 @@ all-features = true
rustdoc-args = ["--cfg", "docsrs"]
[dependencies]
+anyhow = { workspace = true }
byteorder = { workspace = true }
[dev-dependencies]
diff --git a/datasketches/src/countmin/sketch.rs
b/datasketches/src/countmin/sketch.rs
index 3fabfb4..e98e96e 100644
--- a/datasketches/src/countmin/sketch.rs
+++ b/datasketches/src/countmin/sketch.rs
@@ -29,7 +29,7 @@ use crate::countmin::serialization::LONG_SIZE_BYTES;
use crate::countmin::serialization::PREAMBLE_LONGS_SHORT;
use crate::countmin::serialization::SERIAL_VERSION;
use crate::countmin::serialization::compute_seed_hash;
-use crate::error::SerdeError;
+use crate::error::Error;
use crate::hash::MurmurHash3X64128;
const MAX_TABLE_ENTRIES: usize = 1 << 30;
@@ -231,14 +231,14 @@ impl CountMinSketch {
}
/// Deserializes a sketch from bytes using the default seed.
- pub fn deserialize(bytes: &[u8]) -> Result<Self, SerdeError> {
+ pub fn deserialize(bytes: &[u8]) -> Result<Self, Error> {
Self::deserialize_with_seed(bytes, DEFAULT_SEED)
}
/// Deserializes a sketch from bytes using the provided seed.
- pub fn deserialize_with_seed(bytes: &[u8], seed: u64) -> Result<Self,
SerdeError> {
- fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) ->
SerdeError {
- move |_| SerdeError::InsufficientData(tag.to_string())
+ pub fn deserialize_with_seed(bytes: &[u8], seed: u64) -> Result<Self,
Error> {
+ fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) ->
Error {
+ move |_| Error::insufficient_data(tag)
}
let mut cursor = Cursor::new(bytes);
@@ -249,21 +249,23 @@ impl CountMinSketch {
cursor.read_u32::<LE>().map_err(make_error("unused32"))?;
if family_id != COUNTMIN_FAMILY_ID {
- return Err(SerdeError::InvalidFamily(format!(
- "expected {} (CountMinSketch), got {}",
- COUNTMIN_FAMILY_ID, family_id
- )));
+ return Err(Error::invalid_family(
+ COUNTMIN_FAMILY_ID,
+ family_id,
+ "CountMinSketch",
+ ));
}
if serial_version != SERIAL_VERSION {
- return Err(SerdeError::UnsupportedVersion(format!(
- "expected {}, got {}",
- SERIAL_VERSION, serial_version
- )));
+ return Err(Error::unsupported_serial_version(
+ SERIAL_VERSION,
+ serial_version,
+ ));
}
if preamble_longs != PREAMBLE_LONGS_SHORT {
- return Err(SerdeError::MalformedData(format!(
- "unsupported preamble_longs {preamble_longs}"
- )));
+ return Err(Error::invalid_preamble_longs(
+ PREAMBLE_LONGS_SHORT,
+ preamble_longs,
+ ));
}
let num_buckets =
cursor.read_u32::<LE>().map_err(make_error("num_buckets"))?;
@@ -273,9 +275,8 @@ impl CountMinSketch {
let expected_seed_hash = compute_seed_hash(seed);
if seed_hash != expected_seed_hash {
- return Err(SerdeError::InvalidParameter(format!(
- "incompatible seed hash: expected {}, got {}",
- expected_seed_hash, seed_hash
+ return Err(Error::deserial(format!(
+ "incompatible seed hash: expected {expected_seed_hash}, got
{seed_hash}",
)));
}
@@ -329,26 +330,19 @@ fn entries_for_config(num_hashes: u8, num_buckets: u32)
-> usize {
entries
}
-fn entries_for_config_checked(num_hashes: u8, num_buckets: u32) ->
Result<usize, SerdeError> {
+fn entries_for_config_checked(num_hashes: u8, num_buckets: u32) ->
Result<usize, Error> {
if num_hashes == 0 {
- return Err(SerdeError::InvalidParameter(
- "num_hashes must be at least 1".to_string(),
- ));
+ return Err(Error::deserial("num_hashes must be at least 1"));
}
if num_buckets < 3 {
- return Err(SerdeError::InvalidParameter(
- "num_buckets must be at least 3".to_string(),
- ));
+ return Err(Error::deserial("num_buckets must be at least 3"));
}
let entries = (num_hashes as usize)
.checked_mul(num_buckets as usize)
- .ok_or_else(|| {
- SerdeError::InvalidParameter("num_hashes * num_buckets overflows
usize".to_string())
- })?;
+ .ok_or_else(|| Error::deserial("num_hashes * num_buckets overflows
usize"))?;
if entries >= MAX_TABLE_ENTRIES {
- return Err(SerdeError::InvalidParameter(format!(
- "num_hashes * num_buckets must be < {}",
- MAX_TABLE_ENTRIES
+ return Err(Error::deserial(format!(
+ "num_hashes * num_buckets must be < {MAX_TABLE_ENTRIES}",
)));
}
Ok(entries)
diff --git a/datasketches/src/error.rs b/datasketches/src/error.rs
index 88e71d9..624ee0a 100644
--- a/datasketches/src/error.rs
+++ b/datasketches/src/error.rs
@@ -19,31 +19,152 @@
use std::fmt;
-/// Errors that can occur during sketch serialization or deserialization
-#[derive(Debug, Clone)]
-pub enum SerdeError {
- /// Insufficient data in buffer
- InsufficientData(String),
- /// Invalid sketch family identifier
- InvalidFamily(String),
- /// Unsupported serialization version
- UnsupportedVersion(String),
- /// Invalid parameter value
- InvalidParameter(String),
- /// Malformed or corrupt sketch data
- MalformedData(String),
-}
-
-impl fmt::Display for SerdeError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+/// ErrorKind is all kinds of Error of datasketches.
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[non_exhaustive]
+pub enum ErrorKind {
+ /// The argument provided is invalid.
+ InvalidArgument,
+ /// The sketch data deserializing is malformed.
+ InvalidData,
+}
+
+impl ErrorKind {
+ /// Convert this error kind instance into static str.
+ pub const fn into_static(self) -> &'static str {
match self {
- SerdeError::InsufficientData(msg) => write!(f, "insufficient data:
{}", msg),
- SerdeError::InvalidFamily(msg) => write!(f, "invalid family: {}",
msg),
- SerdeError::UnsupportedVersion(msg) => write!(f, "unsupported
version: {}", msg),
- SerdeError::InvalidParameter(msg) => write!(f, "invalid parameter:
{}", msg),
- SerdeError::MalformedData(msg) => write!(f, "malformed data: {}",
msg),
+ ErrorKind::InvalidArgument => "InvalidArgument",
+ ErrorKind::InvalidData => "InvalidData",
+ }
+ }
+}
+
+impl fmt::Display for ErrorKind {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.into_static())
+ }
+}
+
+/// Error is the error struct returned by all datasketches functions.
+pub struct Error {
+ kind: ErrorKind,
+ message: String,
+ context: Vec<(&'static str, String)>,
+}
+
+impl Error {
+ /// Create a new Error with error kind and message.
+ pub fn new(kind: ErrorKind, message: impl Into<String>) -> Self {
+ Self {
+ kind,
+ message: message.into(),
+ context: vec![],
}
}
+
+ /// Add more context in error.
+ pub fn with_context(mut self, key: &'static str, value: impl ToString) ->
Self {
+ self.context.push((key, value.to_string()));
+ self
+ }
+
+ /// Return error's kind.
+ pub fn kind(&self) -> ErrorKind {
+ self.kind
+ }
+
+ /// Return error's message.
+ pub fn message(&self) -> &str {
+ self.message.as_str()
+ }
+}
+
+// Convenience constructors for deserialization errors
+impl Error {
+ pub(crate) fn deserial(msg: impl Into<String>) -> Self {
+ Self::new(ErrorKind::InvalidData, msg)
+ }
+
+ pub(crate) fn insufficient_data(msg: impl fmt::Display) -> Self {
+ Self::deserial(format!("insufficient data: {msg}"))
+ }
+
+ pub(crate) fn insufficient_data_of(context: &'static str, msg: impl
fmt::Display) -> Self {
+ Self::deserial(format!("insufficient data ({context}): {msg}"))
+ }
+
+ pub(crate) fn invalid_family(expected: u8, actual: u8, name: &'static str)
-> Self {
+ Self::deserial(format!(
+ "invalid family: expected {expected} ({name}), got {actual}"
+ ))
+ }
+
+ pub(crate) fn unsupported_serial_version(expected: u8, actual: u8) -> Self
{
+ Self::deserial(format!(
+ "unsupported serial version: expected {expected}, got {actual}"
+ ))
+ }
+
+ pub(crate) fn invalid_preamble_longs(expected: u8, actual: u8) -> Self {
+ Self::deserial(format!(
+ "invalid preamble longs: expected {expected}, got {actual}"
+ ))
+ }
+}
+
+impl fmt::Debug for Error {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ // If alternate has been specified, we will print like Debug.
+ if f.alternate() {
+ let mut de = f.debug_struct("Error");
+ de.field("kind", &self.kind);
+ de.field("message", &self.message);
+ de.field("context", &self.context);
+ return de.finish();
+ }
+
+ write!(f, "{}", self.kind)?;
+ if !self.message.is_empty() {
+ write!(f, " => {}", self.message)?;
+ }
+ writeln!(f)?;
+
+ if !self.context.is_empty() {
+ writeln!(f)?;
+ writeln!(f, "Context:")?;
+ for (k, v) in self.context.iter() {
+ writeln!(f, " {k}: {v}")?;
+ }
+ }
+
+ Ok(())
+ }
+}
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.kind)?;
+
+ if !self.context.is_empty() {
+ write!(f, ", context: {{ ")?;
+ write!(
+ f,
+ "{}",
+ self.context
+ .iter()
+ .map(|(k, v)| format!("{k}: {v}"))
+ .collect::<Vec<_>>()
+ .join(", ")
+ )?;
+ write!(f, " }}")?;
+ }
+
+ if !self.message.is_empty() {
+ write!(f, " => {}", self.message)?;
+ }
+
+ Ok(())
+ }
}
-impl std::error::Error for SerdeError {}
+impl std::error::Error for Error {}
diff --git a/datasketches/src/hll/array4.rs b/datasketches/src/hll/array4.rs
index 44707b6..fbef5e4 100644
--- a/datasketches/src/hll/array4.rs
+++ b/datasketches/src/hll/array4.rs
@@ -21,7 +21,7 @@
//! When values exceed 4 bits after cur_min offset, they're stored in an
auxiliary hash map.
use super::aux_map::AuxMap;
-use crate::error::SerdeError;
+use crate::error::Error;
use crate::hll::NumStdDev;
use crate::hll::estimator::HipEstimator;
use crate::hll::get_slot;
@@ -289,13 +289,13 @@ impl Array4 {
lg_config_k: u8,
compact: bool,
ooo: bool,
- ) -> Result<Self, SerdeError> {
+ ) -> Result<Self, Error> {
use crate::hll::get_slot;
use crate::hll::get_value;
use crate::hll::serialization::*;
if bytes.len() < HLL_PREAMBLE_SIZE {
- return Err(SerdeError::InsufficientData(format!(
+ return Err(Error::insufficient_data(format!(
"expected at least {}, got {}",
HLL_PREAMBLE_SIZE,
bytes.len()
@@ -324,7 +324,7 @@ impl Array4 {
};
if bytes.len() < expected_len {
- return Err(SerdeError::InsufficientData(format!(
+ return Err(Error::insufficient_data(format!(
"expected {}, got {}",
expected_len,
bytes.len()
@@ -392,7 +392,7 @@ impl Array4 {
// Write standard header
bytes[PREAMBLE_INTS_BYTE] = HLL_PREINTS;
- bytes[SER_VER_BYTE] = SER_VER;
+ bytes[SER_VER_BYTE] = SERIAL_VER;
bytes[FAMILY_BYTE] = HLL_FAMILY_ID;
bytes[LG_K_BYTE] = lg_config_k;
bytes[LG_ARR_BYTE] = 0; // Not used for HLL mode
diff --git a/datasketches/src/hll/array6.rs b/datasketches/src/hll/array6.rs
index 8c7138b..5d36cb5 100644
--- a/datasketches/src/hll/array6.rs
+++ b/datasketches/src/hll/array6.rs
@@ -21,7 +21,7 @@
//! This is sufficient for most HLL use cases without needing exception
handling or
//! cur_min optimization like Array4.
-use crate::error::SerdeError;
+use crate::error::Error;
use crate::hll::NumStdDev;
use crate::hll::estimator::HipEstimator;
use crate::hll::get_slot;
@@ -173,7 +173,7 @@ impl Array6 {
lg_config_k: u8,
compact: bool,
ooo: bool,
- ) -> Result<Self, SerdeError> {
+ ) -> Result<Self, Error> {
use crate::hll::serialization::*;
let k = 1 << lg_config_k;
@@ -185,7 +185,7 @@ impl Array6 {
};
if bytes.len() < expected_len {
- return Err(SerdeError::InsufficientData(format!(
+ return Err(Error::insufficient_data(format!(
"expected {}, got {}",
expected_len,
bytes.len()
@@ -234,7 +234,7 @@ impl Array6 {
// Write standard header
bytes[PREAMBLE_INTS_BYTE] = HLL_PREINTS;
- bytes[SER_VER_BYTE] = SER_VER;
+ bytes[SER_VER_BYTE] = SERIAL_VER;
bytes[FAMILY_BYTE] = HLL_FAMILY_ID;
bytes[LG_K_BYTE] = lg_config_k;
bytes[LG_ARR_BYTE] = 0; // Not used for HLL mode
diff --git a/datasketches/src/hll/array8.rs b/datasketches/src/hll/array8.rs
index 33e2cca..dd3b556 100644
--- a/datasketches/src/hll/array8.rs
+++ b/datasketches/src/hll/array8.rs
@@ -20,7 +20,7 @@
//! Array8 is the simplest HLL array implementation, storing one byte per slot.
//! This provides the maximum value range (0-255) with no bit-packing
complexity.
-use crate::error::SerdeError;
+use crate::error::Error;
use crate::hll::NumStdDev;
use crate::hll::estimator::HipEstimator;
use crate::hll::get_slot;
@@ -247,7 +247,7 @@ impl Array8 {
lg_config_k: u8,
compact: bool,
ooo: bool,
- ) -> Result<Self, SerdeError> {
+ ) -> Result<Self, Error> {
use crate::hll::serialization::*;
let k = 1 << lg_config_k;
@@ -258,7 +258,7 @@ impl Array8 {
};
if bytes.len() < expected_len {
- return Err(SerdeError::InsufficientData(format!(
+ return Err(Error::insufficient_data(format!(
"expected {}, got {}",
expected_len,
bytes.len()
@@ -306,7 +306,7 @@ impl Array8 {
// Write standard header
bytes[PREAMBLE_INTS_BYTE] = HLL_PREINTS;
- bytes[SER_VER_BYTE] = SER_VER;
+ bytes[SER_VER_BYTE] = SERIAL_VER;
bytes[FAMILY_BYTE] = HLL_FAMILY_ID;
bytes[LG_K_BYTE] = lg_config_k;
bytes[LG_ARR_BYTE] = 0; // Not used for HLL mode
diff --git a/datasketches/src/hll/hash_set.rs b/datasketches/src/hll/hash_set.rs
index 05f5ad2..f6bff05 100644
--- a/datasketches/src/hll/hash_set.rs
+++ b/datasketches/src/hll/hash_set.rs
@@ -20,7 +20,7 @@
//! Uses open addressing with a custom stride function to handle collisions.
//! Provides better performance than List when many coupons are stored.
-use crate::error::SerdeError;
+use crate::error::Error;
use crate::hll::HllType;
use crate::hll::KEY_MASK_26;
use crate::hll::container::COUPON_EMPTY;
@@ -84,7 +84,7 @@ impl HashSet {
}
/// Deserialize a HashSet from bytes
- pub fn deserialize(bytes: &[u8], compact: bool) -> Result<Self,
SerdeError> {
+ pub fn deserialize(bytes: &[u8], compact: bool) -> Result<Self, Error> {
// Read coupon count from bytes 8-11
let coupon_count = read_u32_le(bytes, HASH_SET_COUNT_INT) as usize;
@@ -95,7 +95,7 @@ impl HashSet {
// Compact mode: only couponCount coupons are stored
let expected_len = HASH_SET_INT_ARR_START + (coupon_count * 4);
if bytes.len() < expected_len {
- return Err(SerdeError::InsufficientData(format!(
+ return Err(Error::insufficient_data(format!(
"expected {}, got {}",
expected_len,
bytes.len()
@@ -115,7 +115,7 @@ impl HashSet {
let array_size = 1 << lg_arr;
let expected_len = HASH_SET_INT_ARR_START + (array_size * 4);
if bytes.len() < expected_len {
- return Err(SerdeError::InsufficientData(format!(
+ return Err(Error::insufficient_data(format!(
"expected {}, got {}",
expected_len,
bytes.len()
@@ -153,7 +153,7 @@ impl HashSet {
// Write preamble
bytes[PREAMBLE_INTS_BYTE] = HASH_SET_PREINTS;
- bytes[SER_VER_BYTE] = SER_VER;
+ bytes[SER_VER_BYTE] = SERIAL_VER;
bytes[FAMILY_BYTE] = HLL_FAMILY_ID;
bytes[LG_K_BYTE] = lg_config_k;
bytes[LG_ARR_BYTE] = lg_arr as u8;
diff --git a/datasketches/src/hll/list.rs b/datasketches/src/hll/list.rs
index d01a9b7..c705383 100644
--- a/datasketches/src/hll/list.rs
+++ b/datasketches/src/hll/list.rs
@@ -20,7 +20,7 @@
//! Provides sequential storage with linear search for duplicates.
//! Efficient for small numbers of coupons before transitioning to HashSet.
-use crate::error::SerdeError;
+use crate::error::Error;
use crate::hll::HllType;
use crate::hll::container::COUPON_EMPTY;
use crate::hll::container::Container;
@@ -66,7 +66,7 @@ impl List {
}
/// Deserialize a List from bytes
- pub fn deserialize(bytes: &[u8], empty: bool, compact: bool) ->
Result<Self, SerdeError> {
+ pub fn deserialize(bytes: &[u8], empty: bool, compact: bool) ->
Result<Self, Error> {
// Read coupon count from byte 6
let coupon_count = bytes[LIST_COUNT_BYTE] as usize;
@@ -77,7 +77,7 @@ impl List {
// Validate length
let expected_len = LIST_INT_ARR_START + (array_size * 4);
if bytes.len() < expected_len {
- return Err(SerdeError::InsufficientData(format!(
+ return Err(Error::insufficient_data(format!(
"expected {}, got {}",
expected_len,
bytes.len()
@@ -113,7 +113,7 @@ impl List {
// Write preamble
bytes[PREAMBLE_INTS_BYTE] = LIST_PREINTS;
- bytes[SER_VER_BYTE] = SER_VER;
+ bytes[SER_VER_BYTE] = SERIAL_VER;
bytes[FAMILY_BYTE] = HLL_FAMILY_ID;
bytes[LG_K_BYTE] = lg_config_k;
bytes[LG_ARR_BYTE] = lg_arr as u8;
diff --git a/datasketches/src/hll/serialization.rs
b/datasketches/src/hll/serialization.rs
index e111393..b99a262 100644
--- a/datasketches/src/hll/serialization.rs
+++ b/datasketches/src/hll/serialization.rs
@@ -24,7 +24,7 @@
pub const HLL_FAMILY_ID: u8 = 7;
/// Current serialization version
-pub const SER_VER: u8 = 1;
+pub const SERIAL_VER: u8 = 1;
/// Flag indicating sketch is empty (no values inserted)
pub const EMPTY_FLAG_MASK: u8 = 4;
diff --git a/datasketches/src/hll/sketch.rs b/datasketches/src/hll/sketch.rs
index 6887338..a5f0aac 100644
--- a/datasketches/src/hll/sketch.rs
+++ b/datasketches/src/hll/sketch.rs
@@ -22,7 +22,7 @@
use std::hash::Hash;
-use crate::error::SerdeError;
+use crate::error::Error;
use crate::hll::HllType;
use crate::hll::NumStdDev;
use crate::hll::RESIZE_DENOMINATOR;
@@ -212,16 +212,16 @@ impl HllSketch {
}
/// Deserializes an HLL sketch from bytes
- pub fn deserialize(bytes: &[u8]) -> Result<HllSketch, SerdeError> {
+ pub fn deserialize(bytes: &[u8]) -> Result<HllSketch, Error> {
if bytes.len() < 8 {
- return Err(SerdeError::InsufficientData(
- "sketch data too short (< 8 bytes)".to_string(),
+ return Err(Error::insufficient_data(
+ "sketch data too short (< 8 bytes)",
));
}
// Read and validate preamble
let preamble_ints = bytes[PREAMBLE_INTS_BYTE];
- let ser_ver = bytes[SER_VER_BYTE];
+ let serial_ver = bytes[SER_VER_BYTE];
let family_id = bytes[FAMILY_BYTE];
let lg_config_k = bytes[LG_K_BYTE];
let flags = bytes[FLAGS_BYTE];
@@ -229,25 +229,18 @@ impl HllSketch {
// Verify family ID
if family_id != HLL_FAMILY_ID {
- return Err(SerdeError::InvalidFamily(format!(
- "expected {} (HLL), got {}",
- HLL_FAMILY_ID, family_id
- )));
+ return Err(Error::invalid_family(HLL_FAMILY_ID, family_id, "HLL"));
}
// Verify serialization version
- if ser_ver != SER_VER {
- return Err(SerdeError::UnsupportedVersion(format!(
- "expected {}, got {}",
- SER_VER, ser_ver
- )));
+ if serial_ver != SERIAL_VER {
+ return Err(Error::unsupported_serial_version(SERIAL_VER,
serial_ver));
}
// Verify lg_k range (4-21 are valid)
if !(4..=21).contains(&lg_config_k) {
- return Err(SerdeError::InvalidParameter(format!(
- "lg_k must be in [4; 21], got {}",
- lg_config_k
+ return Err(Error::deserial(format!(
+ "lg_k must be in [4; 21], got {lg_config_k}",
)));
}
@@ -256,10 +249,7 @@ impl HllSketch {
TGT_HLL6 => HllType::Hll6,
TGT_HLL8 => HllType::Hll8,
hll_type => {
- return Err(SerdeError::MalformedData(format!(
- "invalid HLL type: {}",
- hll_type
- )));
+ return Err(Error::deserial(format!("invalid HLL type:
{hll_type}")));
}
};
@@ -272,9 +262,9 @@ impl HllSketch {
match extract_cur_mode(mode_byte) {
CUR_MODE_LIST => {
if preamble_ints != LIST_PREINTS {
- return Err(SerdeError::MalformedData(format!(
+ return Err(Error::deserial(format!(
"LIST mode preamble: expected {}, got {}",
- LIST_PREINTS, preamble_ints
+ LIST_PREINTS, preamble_ints,
)));
}
@@ -283,7 +273,7 @@ impl HllSketch {
}
CUR_MODE_SET => {
if preamble_ints != HASH_SET_PREINTS {
- return Err(SerdeError::MalformedData(format!(
+ return Err(Error::deserial(format!(
"SET mode preamble: expected {}, got {}",
HASH_SET_PREINTS, preamble_ints
)));
@@ -294,7 +284,7 @@ impl HllSketch {
}
CUR_MODE_HLL => {
if preamble_ints != HLL_PREINTS {
- return Err(SerdeError::MalformedData(format!(
+ return Err(Error::deserial(format!(
"HLL mode preamble: expected {}, got {}",
HLL_PREINTS, preamble_ints
)));
@@ -309,7 +299,7 @@ impl HllSketch {
.map(Mode::Array8)?,
}
}
- mode => return Err(SerdeError::MalformedData(format!("invalid
mode: {}", mode))),
+ mode => return Err(Error::deserial(format!("invalid mode:
{mode}"))),
};
Ok(HllSketch { lg_config_k, mode })
diff --git a/datasketches/src/tdigest/sketch.rs
b/datasketches/src/tdigest/sketch.rs
index 13aa6ca..a0f3883 100644
--- a/datasketches/src/tdigest/sketch.rs
+++ b/datasketches/src/tdigest/sketch.rs
@@ -24,7 +24,8 @@ use byteorder::BE;
use byteorder::LE;
use byteorder::ReadBytesExt;
-use crate::error::SerdeError;
+use crate::error::Error;
+use crate::error::ErrorKind;
use crate::tdigest::serialization::*;
/// The default value of K if one is not specified.
@@ -60,9 +61,11 @@ impl Default for TDigestMut {
impl TDigestMut {
/// Creates a tdigest instance with the given value of k.
///
+ /// The fallible version of this method is [`TDigestMut::try_new`].
+ ///
/// # Panics
///
- /// If k is less than 10
+ /// Panics if k is less than 10
pub fn new(k: u16) -> Self {
Self::make(
k,
@@ -75,6 +78,32 @@ impl TDigestMut {
)
}
+ /// Creates a tdigest instance with the given value of k.
+ ///
+ /// The panicking version of this method is [`TDigestMut::new`].
+ ///
+ /// # Errors
+ ///
+ /// If k is less than 10, returns [`ErrorKind::InvalidArgument`].
+ pub fn try_new(k: u16) -> Result<Self, Error> {
+ if k < 10 {
+ return Err(Error::new(
+ ErrorKind::InvalidArgument,
+ format!("k must be at least 10, got {k}"),
+ ));
+ }
+
+ Ok(Self::make(
+ k,
+ false,
+ f64::INFINITY,
+ f64::NEG_INFINITY,
+ vec![],
+ 0,
+ vec![],
+ ))
+ }
+
// for deserialization
fn make(
k: u16,
@@ -205,27 +234,7 @@ impl TDigestMut {
}
}
- /// Returns an approximation to the Cumulative Distribution Function
(CDF), which is the
- /// cumulative analog of the PMF, of the input stream given a set of split
points.
- ///
- /// # Arguments
- ///
- /// * `split_points`: An array of _m_ unique, monotonically increasing
values that divide the
- /// input domain into _m+1_ consecutive disjoint intervals.
- ///
- /// # Returns
- ///
- /// An array of m+1 doubles, which are a consecutive approximation to the
CDF of the input
- /// stream given the split points. The value at array position j of the
returned CDF array
- /// is the sum of the returned values in positions 0 through j of the
returned PMF array.
- /// This can be viewed as array of ranks of the given split points plus
one more value that
- /// is always 1.
- ///
- /// Returns `None` if TDigest is empty.
- ///
- /// # Panics
- ///
- /// If `split_points` is not unique, not monotonically increasing, or
contains `NaN` values.
+ /// See [`TDigest::cdf`].
pub fn cdf(&mut self, split_points: &[f64]) -> Option<Vec<f64>> {
check_split_points(split_points);
@@ -236,24 +245,7 @@ impl TDigestMut {
self.view().cdf(split_points)
}
- /// Returns an approximation to the Probability Mass Function (PMF) of the
input stream
- /// given a set of split points.
- ///
- /// # Arguments
- ///
- /// * `split_points`: An array of _m_ unique, monotonically increasing
values that divide the
- /// input domain into _m+1_ consecutive disjoint intervals (bins).
- ///
- /// # Returns
- ///
- /// An array of m+1 doubles each of which is an approximation to the
fraction of the input
- /// stream values (the mass) that fall into one of those intervals.
- ///
- /// Returns `None` if TDigest is empty.
- ///
- /// # Panics
- ///
- /// If `split_points` is not unique, not monotonically increasing, or
contains `NaN` values.
+ /// See [`TDigest::pmf`].
pub fn pmf(&mut self, split_points: &[f64]) -> Option<Vec<f64>> {
check_split_points(split_points);
@@ -264,13 +256,7 @@ impl TDigestMut {
self.view().pmf(split_points)
}
- /// Compute approximate normalized rank (from 0 to 1 inclusive) of the
given value.
- ///
- /// Returns `None` if TDigest is empty.
- ///
- /// # Panics
- ///
- /// If the value is `NaN`.
+ /// See [`TDigest::rank`].
pub fn rank(&mut self, value: f64) -> Option<f64> {
assert!(!value.is_nan(), "value must not be NaN");
@@ -291,13 +277,7 @@ impl TDigestMut {
self.view().rank(value)
}
- /// Compute approximate quantile value corresponding to the given
normalized rank.
- ///
- /// Returns `None` if TDigest is empty.
- ///
- /// # Panics
- ///
- /// If rank is not in [0.0, 1.0].
+ /// See [`TDigest::quantile`].
pub fn quantile(&mut self, rank: f64) -> Option<f64> {
assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0, 1.0]");
@@ -390,9 +370,9 @@ impl TDigestMut {
///
/// [^1]: This is to support reading the `tdigest<float>` format from the
C++ implementation.
/// [^2]: <https://github.com/tdunning/t-digest>
- pub fn deserialize(bytes: &[u8], is_f32: bool) -> Result<Self, SerdeError>
{
- fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) ->
SerdeError {
- move |_| SerdeError::InsufficientData(tag.to_string())
+ pub fn deserialize(bytes: &[u8], is_f32: bool) -> Result<Self, Error> {
+ fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) ->
Error {
+ move |_| Error::insufficient_data(tag)
}
let mut cursor = Cursor::new(bytes);
@@ -401,25 +381,25 @@ impl TDigestMut {
let serial_version =
cursor.read_u8().map_err(make_error("serial_version"))?;
let family_id = cursor.read_u8().map_err(make_error("family_id"))?;
if family_id != TDIGEST_FAMILY_ID {
- if preamble_longs == 0 && serial_version == 0 && family_id == 0 {
- return Self::deserialize_compat(bytes);
- }
- return Err(SerdeError::InvalidFamily(format!(
- "expected {} (TDigest), got {}",
- TDIGEST_FAMILY_ID, family_id
- )));
+ return if preamble_longs == 0 && serial_version == 0 && family_id
== 0 {
+ Self::deserialize_compat(bytes)
+ } else {
+ Err(Error::invalid_family(
+ TDIGEST_FAMILY_ID,
+ family_id,
+ "TDigest",
+ ))
+ };
}
if serial_version != SERIAL_VERSION {
- return Err(SerdeError::UnsupportedVersion(format!(
- "expected {}, got {}",
- SERIAL_VERSION, serial_version
- )));
+ return Err(Error::unsupported_serial_version(
+ SERIAL_VERSION,
+ serial_version,
+ ));
}
let k = cursor.read_u16::<LE>().map_err(make_error("k"))?;
if k < 10 {
- return Err(SerdeError::InvalidParameter(format!(
- "k must be at least 10, got {k}"
- )));
+ return Err(Error::deserial(format!("k must be at least 10, got
{k}")));
}
let flags = cursor.read_u8().map_err(make_error("flags"))?;
let is_empty = (flags & FLAGS_IS_EMPTY) != 0;
@@ -430,10 +410,10 @@ impl TDigestMut {
PREAMBLE_LONGS_MULTIPLE
};
if preamble_longs != expected_preamble_longs {
- return Err(SerdeError::MalformedData(format!(
- "expected preamble_longs to be {}, got {}",
- expected_preamble_longs, preamble_longs
- )));
+ return Err(Error::invalid_preamble_longs(
+ expected_preamble_longs,
+ preamble_longs,
+ ));
}
cursor.read_u16::<LE>().map_err(make_error("<unused>"))?; // unused
if is_empty {
@@ -452,7 +432,7 @@ impl TDigestMut {
.map_err(make_error("single_value"))?
};
check_non_nan(value, "single_value")?;
- check_non_infinite(value, "single_value")?;
+ check_finite(value, "single_value")?;
return Ok(TDigestMut::make(
k,
reverse_merge,
@@ -500,7 +480,7 @@ impl TDigestMut {
)
};
check_non_nan(mean, "centroid mean")?;
- check_non_infinite(mean, "centroid")?;
+ check_finite(mean, "centroid")?;
let weight = check_nonzero(weight, "centroid weight")?;
centroids_weight += weight.get();
centroids.push(Centroid { mean, weight });
@@ -517,7 +497,7 @@ impl TDigestMut {
.map_err(make_error("buffered_value"))?
};
check_non_nan(value, "buffered_value mean")?;
- check_non_infinite(value, "buffered_value mean")?;
+ check_finite(value, "buffered_value mean")?;
buffer.push(value);
}
Ok(TDigestMut::make(
@@ -533,9 +513,9 @@ impl TDigestMut {
// compatibility with the format of the reference implementation
// default byte order of ByteBuffer is used there, which is big endian
- fn deserialize_compat(bytes: &[u8]) -> Result<Self, SerdeError> {
- fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) ->
SerdeError {
- move |_| SerdeError::InsufficientData(format!("{tag} in compat
format"))
+ fn deserialize_compat(bytes: &[u8]) -> Result<Self, Error> {
+ fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) ->
Error {
+ move |_| Error::insufficient_data_of("compat format", tag)
}
let mut cursor = Cursor::new(bytes);
@@ -543,8 +523,8 @@ impl TDigestMut {
let ty = cursor.read_u32::<BE>().map_err(make_error("type"))?;
match ty {
COMPAT_DOUBLE => {
- fn make_error(tag: &'static str) -> impl
FnOnce(std::io::Error) -> SerdeError {
- move |_| SerdeError::InsufficientData(format!("{tag} in
compat double format"))
+ fn make_error(tag: &'static str) -> impl
FnOnce(std::io::Error) -> Error {
+ move |_| Error::insufficient_data_of("compat double
format", tag)
}
// compatibility with asBytes()
let min = cursor.read_f64::<BE>().map_err(make_error("min"))?;
@@ -553,8 +533,8 @@ impl TDigestMut {
check_non_nan(max, "max in compat double format")?;
let k = cursor.read_f64::<BE>().map_err(make_error("k"))? as
u16;
if k < 10 {
- return Err(SerdeError::InvalidParameter(format!(
- "k must be at least 10, got {k} in compat double
format"
+ return Err(Error::deserial(format!(
+ "k must be at least 10 in compat double format, got
{k}"
)));
}
let num_centroids = cursor
@@ -568,7 +548,7 @@ impl TDigestMut {
let mean =
cursor.read_f64::<BE>().map_err(make_error("mean"))?;
let weight = check_nonzero(weight, "centroid weight in
compat double format")?;
check_non_nan(mean, "centroid mean in compat double
format")?;
- check_non_infinite(mean, "centroid mean in compat double
format")?;
+ check_finite(mean, "centroid mean in compat double
format")?;
total_weight += weight.get();
centroids.push(Centroid { mean, weight });
}
@@ -583,8 +563,8 @@ impl TDigestMut {
))
}
COMPAT_FLOAT => {
- fn make_error(tag: &'static str) -> impl
FnOnce(std::io::Error) -> SerdeError {
- move |_| SerdeError::InsufficientData(format!("{tag} in
compat float format"))
+ fn make_error(tag: &'static str) -> impl
FnOnce(std::io::Error) -> Error {
+ move |_| Error::insufficient_data_of("compat float
format", tag)
}
// COMPAT_FLOAT: compatibility with asSmallBytes()
// reference implementation uses doubles for min and max
@@ -594,8 +574,8 @@ impl TDigestMut {
check_non_nan(max, "max in compat float format")?;
let k = cursor.read_f32::<BE>().map_err(make_error("k"))? as
u16;
if k < 10 {
- return Err(SerdeError::InvalidParameter(format!(
- "k must be at least 10, got {k} in compat float format"
+ return Err(Error::deserial(format!(
+ "k must be at least 10 in compat float format, got {k}"
)));
}
// reference implementation stores capacities of the array of
centroids and the
@@ -612,7 +592,7 @@ impl TDigestMut {
let mean =
cursor.read_f32::<BE>().map_err(make_error("mean"))? as f64;
let weight = check_nonzero(weight, "centroid weight in
compat float format")?;
check_non_nan(mean, "centroid mean in compat float
format")?;
- check_non_infinite(mean, "centroid mean in compat float
format")?;
+ check_finite(mean, "centroid mean in compat float
format")?;
total_weight += weight.get();
centroids.push(Centroid { mean, weight });
}
@@ -626,9 +606,7 @@ impl TDigestMut {
vec![],
))
}
- ty => Err(SerdeError::InvalidParameter(format!(
- "unknown TDigest compat type {ty}",
- ))),
+ ty => Err(Error::deserial(format!("unknown TDigest compat type
{ty}"))),
}
}
@@ -786,7 +764,8 @@ impl TDigest {
///
/// # Panics
///
- /// If `split_points` is not unique, not monotonically increasing, or
contains `NaN` values.
+ /// Panics if `split_points` is not unique, not monotonically increasing,
or contains `NaN`
+ /// values.
pub fn cdf(&self, split_points: &[f64]) -> Option<Vec<f64>> {
self.view().cdf(split_points)
}
@@ -808,7 +787,8 @@ impl TDigest {
///
/// # Panics
///
- /// If `split_points` is not unique, not monotonically increasing, or
contains `NaN` values.
+ /// Panics if `split_points` is not unique, not monotonically increasing,
or contains `NaN`
+ /// values.
pub fn pmf(&self, split_points: &[f64]) -> Option<Vec<f64>> {
self.view().pmf(split_points)
}
@@ -819,7 +799,7 @@ impl TDigest {
///
/// # Panics
///
- /// If the value is `NaN`.
+ /// Panics if the value is `NaN`.
pub fn rank(&self, value: f64) -> Option<f64> {
assert!(!value.is_nan(), "value must not be NaN");
self.view().rank(value)
@@ -831,7 +811,7 @@ impl TDigest {
///
/// # Panics
///
- /// If rank is not in [0.0, 1.0].
+ /// Panics if rank is not in [0.0, 1.0].
pub fn quantile(&self, rank: f64) -> Option<f64> {
assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0, 1.0]");
self.view().quantile(rank)
@@ -1129,24 +1109,29 @@ impl Centroid {
}
}
-fn check_non_nan(value: f64, tag: &'static str) -> Result<(), SerdeError> {
+fn check_non_nan(value: f64, tag: &'static str) -> Result<(), Error> {
if value.is_nan() {
- return Err(SerdeError::MalformedData(format!("{tag} cannot be NaN")));
+ return Err(Error::deserial(format!(
+ "malformed data: {tag} cannot be NaN"
+ )));
}
+
Ok(())
}
-fn check_non_infinite(value: f64, tag: &'static str) -> Result<(), SerdeError>
{
+fn check_finite(value: f64, tag: &'static str) -> Result<(), Error> {
if value.is_infinite() {
- return Err(SerdeError::MalformedData(format!(
- "{tag} cannot be is_infinite"
+ return Err(Error::deserial(format!(
+ "malformed data: {tag} cannot be infinite"
)));
}
+
Ok(())
}
-fn check_nonzero(value: u64, tag: &'static str) -> Result<NonZeroU64,
SerdeError> {
- NonZeroU64::new(value).ok_or_else(||
SerdeError::MalformedData(format!("{tag} cannot be zero")))
+fn check_nonzero(value: u64, tag: &'static str) -> Result<NonZeroU64, Error> {
+ NonZeroU64::new(value)
+ .ok_or_else(|| Error::deserial(format!("malformed data: {tag} cannot
be zero")))
}
/// Generates cluster sizes proportional to `q*(1-q)`.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]