This is an automated email from the ASF dual-hosted git repository.

kriskras99 pushed a commit to branch feat/union_builder
in repository https://gitbox.apache.org/repos/asf/avro-rs.git

commit d4caf2e83f28b011bcbb896adecdb91bf5b23b0a
Author: Kriskras99 <[email protected]>
AuthorDate: Wed Feb 25 23:01:45 2026 +0100

    feat: Add a `UnionSchemaBuilder`
    
    This also fixes a issue with the original `new` implementation where
    it would insert named types in the `variant_index` and then
    `find_schema_with_known_schemata` would use the fast path without
    checking the schema.
    
    `find_schema_with_known_schemata` has also been simplified to use
    `known_schemata` directly instead of rebuilding it with the current
    schema, as this would cause duplicate schema errors after the incorrect
    fast path was removed.
    
    The `UnionSchemaBuilder::variant_ignore_duplicates` and 
`UnionSchemaBuilder::contains`
    are needed for `avro_derive` to implement full support for enums.
---
 avro/src/error.rs        |  14 ++-
 avro/src/schema/name.rs  |   2 +-
 avro/src/schema/union.rs | 320 ++++++++++++++++++++++++++++++++++++++---------
 3 files changed, 276 insertions(+), 60 deletions(-)

diff --git a/avro/src/error.rs b/avro/src/error.rs
index 50a09af..12bee1e 100644
--- a/avro/src/error.rs
+++ b/avro/src/error.rs
@@ -300,8 +300,18 @@ pub enum Details {
     #[error("Unions may not directly contain a union")]
     GetNestedUnion,
 
-    #[error("Unions cannot contain duplicate types")]
-    GetUnionDuplicate,
+    #[error(
+        "Found two different maps while building Union: Schema::Map({0:?}), 
Schema::Map({1:?})"
+    )]
+    GetUnionDuplicateMap(Schema, Schema),
+
+    #[error(
+        "Found two different arrays while building Union: 
Schema::Array({0:?}), Schema::Array({1:?})"
+    )]
+    GetUnionDuplicateArray(Schema, Schema),
+
+    #[error("Unions cannot contain duplicate types, found at least two {0:?}")]
+    GetUnionDuplicate(SchemaKind),
 
     #[error("Unions cannot contain more than one named schema with the same 
name: {0}")]
     GetUnionDuplicateNamedSchemas(String),
diff --git a/avro/src/schema/name.rs b/avro/src/schema/name.rs
index e572d8b..1eeac0d 100644
--- a/avro/src/schema/name.rs
+++ b/avro/src/schema/name.rs
@@ -38,7 +38,7 @@ use crate::{
 ///
 /// More information about schema names can be found in the
 /// [Avro 
specification](https://avro.apache.org/docs/++version++/specification/#names)
-#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
 pub struct Name {
     pub name: String,
     pub namespace: Namespace,
diff --git a/avro/src/schema/union.rs b/avro/src/schema/union.rs
index 7510a13..428f09d 100644
--- a/avro/src/schema/union.rs
+++ b/avro/src/schema/union.rs
@@ -15,24 +15,32 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::AvroResult;
 use crate::error::Details;
-use crate::schema::{Name, Namespace, ResolvedSchema, Schema, SchemaKind};
+use crate::schema::{
+    DecimalSchema, InnerDecimalSchema, Name, Namespace, Schema, SchemaKind, 
UuidSchema,
+};
 use crate::types;
+use crate::{AvroResult, Error};
 use std::borrow::Borrow;
-use std::collections::{BTreeMap, HashMap, HashSet};
+use std::collections::{BTreeMap, HashMap};
 use std::fmt::Debug;
+use strum::IntoDiscriminant;
 
 /// A description of a Union schema
 #[derive(Debug, Clone)]
 pub struct UnionSchema {
     /// The schemas that make up this union
     pub(crate) schemas: Vec<Schema>,
-    // Used to ensure uniqueness of schema inputs, and provide constant time 
finding of the
-    // schema index given a value.
-    // **NOTE** that this approach does not work for named types, and will 
have to be modified
-    // to support that. A simple solution is to also keep a mapping of the 
names used.
+    /// The indexes of unnamed types.
+    ///
+    /// Logical types have been reduced to their inner type.
+    /// Used to provide constant time finding of the
+    /// schema index given an unnamed type. Must only contain unnamed types.
     variant_index: BTreeMap<SchemaKind, usize>,
+    /// The indexes of named types.
+    ///
+    /// The names self aren't saved as they aren't used.
+    named_index: Vec<usize>,
 }
 
 impl UnionSchema {
@@ -42,25 +50,16 @@ impl UnionSchema {
     /// Will return an error if `schemas` has duplicate unnamed schemas or if 
`schemas`
     /// contains a union.
     pub fn new(schemas: Vec<Schema>) -> AvroResult<Self> {
-        let mut named_schemas: HashSet<&Name> = HashSet::default();
-        let mut vindex = BTreeMap::new();
-        for (i, schema) in schemas.iter().enumerate() {
-            if let Schema::Union(_) = schema {
-                return Err(Details::GetNestedUnion.into());
-            } else if !schema.is_named() && 
vindex.insert(SchemaKind::from(schema), i).is_some() {
-                return Err(Details::GetUnionDuplicate.into());
-            } else if schema.is_named() {
-                let name = schema.name().unwrap();
-                if !named_schemas.insert(name) {
-                    return 
Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into());
-                }
-                vindex.insert(SchemaKind::from(schema), i);
-            }
+        let mut builder = Self::builder();
+        for schema in schemas {
+            builder.variant(schema)?;
         }
-        Ok(UnionSchema {
-            schemas,
-            variant_index: vindex,
-        })
+        Ok(builder.build())
+    }
+
+    /// Build a `UnionSchema` piece-by-piece.
+    pub fn builder() -> UnionSchemaBuilder {
+        UnionSchemaBuilder::new()
     }
 
     /// Returns a slice to all variants of this schema.
@@ -70,7 +69,7 @@ impl UnionSchema {
 
     /// Returns true if the any of the variants of this `UnionSchema` is 
`Null`.
     pub fn is_nullable(&self) -> bool {
-        self.schemas.iter().any(|x| matches!(x, Schema::Null))
+        self.variant_index.contains_key(&SchemaKind::Null)
     }
 
     /// Optionally returns a reference to the schema matched by this value, as 
well as its position
@@ -86,39 +85,31 @@ impl UnionSchema {
     ) -> Option<(usize, &Schema)> {
         let schema_kind = SchemaKind::from(value);
         if let Some(&i) = self.variant_index.get(&schema_kind) {
-            // fast path
+            // fast path for unnamed types
             Some((i, &self.schemas[i]))
         } else {
-            // slow path (required for matching logical or named types)
-
-            // first collect what schemas we already know
-            let mut collected_names: HashMap<Name, &Schema> = known_schemata
-                .map(|names| {
-                    names
-                        .iter()
-                        .map(|(name, schema)| (name.clone(), schema.borrow()))
-                        .collect()
+            // slow path required for named types
+            let known_schemata_if_none = HashMap::new();
+            let known_schemata = 
known_schemata.unwrap_or(&known_schemata_if_none);
+
+            self.named_index
+                .iter()
+                .copied()
+                .map(|i| (i, &self.schemas[i]))
+                .filter(|(i, s)| s.discriminant() == schema_kind)
+                .find(|(_i, schema)| {
+                    let namespace = if schema.namespace().is_some() {
+                        &schema.namespace()
+                    } else {
+                        enclosing_namespace
+                    };
+
+                    // TODO: Do this without the clone
+                    value
+                        .clone()
+                        .resolve_internal(schema, known_schemata, namespace, 
&None)
+                        .is_ok()
                 })
-                .unwrap_or_default();
-
-            self.schemas.iter().enumerate().find(|(_, schema)| {
-                let resolved_schema = ResolvedSchema::new_with_known_schemata(
-                    vec![*schema],
-                    enclosing_namespace,
-                    &collected_names,
-                )
-                .expect("Schema didn't successfully parse");
-                let resolved_names = resolved_schema.names_ref;
-
-                // extend known schemas with just resolved names
-                collected_names.extend(resolved_names);
-                let namespace = &schema.namespace().or_else(|| 
enclosing_namespace.clone());
-
-                value
-                    .clone()
-                    .resolve_internal(schema, &collected_names, namespace, 
&None)
-                    .is_ok()
-            })
         }
     }
 }
@@ -130,11 +121,172 @@ impl PartialEq for UnionSchema {
     }
 }
 
+pub struct UnionSchemaBuilder {
+    schemas: Vec<Schema>,
+    names: BTreeMap<Name, usize>,
+    variant_index: BTreeMap<SchemaKind, usize>,
+}
+
+impl UnionSchemaBuilder {
+    /// Create a builder.
+    ///
+    /// See also [`UnionSchema::builder`].
+    pub fn new() -> Self {
+        Self {
+            schemas: Vec::new(),
+            names: BTreeMap::new(),
+            variant_index: BTreeMap::new(),
+        }
+    }
+
+    /// Add a variant to this union, if it already exists ignore it.
+    ///
+    /// # Errors
+    /// Will return a [`Details::GetUnionDuplicateMap`] or 
[`Details::GetUnionDuplicateArray`] if
+    /// duplicate maps or arrays are encountered with different subtypes.
+    pub fn variant_ignore_duplicates(&mut self, schema: Schema) -> Result<&mut 
Self, Error> {
+        if let Some(name) = schema.name() {
+            if let Some(current) = self.names.get(name).copied() {
+                if self.schemas[current] != schema {
+                    return 
Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into());
+                }
+            } else {
+                self.names.insert(name.clone(), self.schemas.len());
+                self.schemas.push(schema);
+            }
+        } else if let Schema::Map(_) = &schema {
+            if let Some(index) = 
self.variant_index.get(&SchemaKind::Map).copied() {
+                if self.schemas[index] != schema {
+                    return Err(
+                        
Details::GetUnionDuplicateMap(self.schemas.remove(index), schema).into(),
+                    );
+                }
+            } else {
+                self.variant_index
+                    .insert(SchemaKind::Map, self.schemas.len());
+                self.schemas.push(schema);
+            }
+        } else if let Schema::Array(_) = &schema {
+            if let Some(index) = 
self.variant_index.get(&SchemaKind::Array).copied() {
+                if self.schemas[index] != schema {
+                    return Err(
+                        
Details::GetUnionDuplicateMap(self.schemas.remove(index), schema).into(),
+                    );
+                }
+            } else {
+                self.variant_index
+                    .insert(SchemaKind::Array, self.schemas.len());
+                self.schemas.push(schema);
+            }
+        } else {
+            let discriminant = Self::schema_kind_without_logical_type(&schema);
+            if discriminant == SchemaKind::Union {
+                return Err(Details::GetNestedUnion.into());
+            }
+            if !self.variant_index.contains_key(&discriminant) {
+                self.variant_index.insert(discriminant, self.schemas.len());
+                self.schemas.push(schema);
+            }
+        }
+        Ok(self)
+    }
+
+    /// Add a variant to this union.
+    ///
+    /// # Errors
+    /// Will return a [`Details::GetUnionDuplicateNamedSchemas`] or 
[`Details::GetUnionDuplicate`] if
+    /// duplicate names or schema kinds are found.
+    pub fn variant(&mut self, schema: Schema) -> Result<&mut Self, Error> {
+        if let Some(name) = schema.name() {
+            if self.names.contains_key(name) {
+                return 
Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into());
+            } else {
+                self.names.insert(name.clone(), self.schemas.len());
+                self.schemas.push(schema);
+            }
+        } else {
+            let discriminant = Self::schema_kind_without_logical_type(&schema);
+            if discriminant == SchemaKind::Union {
+                return Err(Details::GetNestedUnion.into());
+            }
+            if self.variant_index.contains_key(&discriminant) {
+                return Err(Details::GetUnionDuplicate(discriminant).into());
+            } else {
+                self.variant_index.insert(discriminant, self.schemas.len());
+                self.schemas.push(schema);
+            }
+        }
+        Ok(self)
+    }
+
+    /// Check if a schema already exists in this union.
+    pub fn contains(&self, schema: &Schema) -> bool {
+        if let Some(name) = schema.name() {
+            if let Some(current) = self.names.get(name).copied() {
+                &self.schemas[current] == schema
+            } else {
+                false
+            }
+        } else {
+            let discriminant = Self::schema_kind_without_logical_type(schema);
+            if let Some(index) = 
self.variant_index.get(&discriminant).copied() {
+                &self.schemas[index] == schema
+            } else {
+                false
+            }
+        }
+    }
+
+    /// Create the `UnionSchema`.
+    pub fn build(mut self) -> UnionSchema {
+        self.schemas.shrink_to_fit();
+        UnionSchema {
+            variant_index: self.variant_index,
+            named_index: self.names.into_values().collect(),
+            schemas: self.schemas,
+        }
+    }
+
+    /// Get the [`SchemaKind`] of a [`Schema`] converting logical types to 
their inner type.
+    fn schema_kind_without_logical_type(schema: &Schema) -> SchemaKind {
+        let kind = schema.discriminant();
+        match kind {
+            SchemaKind::Date | SchemaKind::TimeMillis => SchemaKind::Int,
+            SchemaKind::TimeMicros
+            | SchemaKind::TimestampMillis
+            | SchemaKind::TimestampMicros
+            | SchemaKind::TimestampNanos
+            | SchemaKind::LocalTimestampMillis
+            | SchemaKind::LocalTimestampMicros
+            | SchemaKind::LocalTimestampNanos => SchemaKind::Long,
+            SchemaKind::Uuid => match schema {
+                Schema::Uuid(UuidSchema::Bytes) => SchemaKind::Bytes,
+                Schema::Uuid(UuidSchema::String) => SchemaKind::String,
+                Schema::Uuid(UuidSchema::Fixed(_)) => SchemaKind::Fixed,
+                _ => unreachable!(),
+            },
+            SchemaKind::Decimal => match schema {
+                Schema::Decimal(DecimalSchema {
+                    inner: InnerDecimalSchema::Bytes,
+                    ..
+                }) => SchemaKind::Bytes,
+                Schema::Decimal(DecimalSchema {
+                    inner: InnerDecimalSchema::Fixed(_),
+                    ..
+                }) => SchemaKind::Fixed,
+                _ => unreachable!(),
+            },
+            SchemaKind::Duration => SchemaKind::Fixed,
+            _ => kind,
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
     use crate::error::{Details, Error};
-    use crate::schema::RecordSchema;
+    use crate::schema::{EnumSchema, FixedSchema, RecordSchema};
     use apache_avro_test_helper::TestResult;
 
     #[test]
@@ -165,4 +317,58 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn avro_rs_xxx_union_schema_builder() -> TestResult {
+        let mut builder = UnionSchema::builder();
+        builder.variant(Schema::Null)?;
+        assert!(builder.variant(Schema::Null).is_err());
+        builder.variant_ignore_duplicates(Schema::Null)?;
+
+        let enum_schema = Schema::Enum(EnumSchema {
+            name: Name::new("ABC")?,
+            aliases: None,
+            doc: None,
+            symbols: vec!["A".into(), "B".into(), "C".into()],
+            default: None,
+            attributes: Default::default(),
+        });
+        let enum_schema2 = Schema::Enum(EnumSchema {
+            name: Name::new("ABC")?,
+            aliases: None,
+            doc: None,
+            symbols: vec!["A".into(), "B".into(), "C".into(), "D".into()],
+            default: None,
+            attributes: Default::default(),
+        });
+        let fixed_schema = Schema::Fixed(FixedSchema {
+            name: Name::new("ABC")?,
+            aliases: None,
+            doc: None,
+            size: 1,
+            attributes: Default::default(),
+        });
+        builder.variant(enum_schema.clone())?;
+        assert!(builder.variant(enum_schema.clone()).is_err());
+        builder.variant_ignore_duplicates(enum_schema.clone())?;
+        // Name is the same but different schemas, so should always fail
+        assert!(builder.variant(fixed_schema.clone()).is_err());
+        assert!(
+            builder
+                .variant_ignore_duplicates(fixed_schema.clone())
+                .is_err()
+        );
+        // Name and schema type are the same but symbols are different
+        assert!(builder.variant(enum_schema2.clone()).is_err());
+        assert!(
+            builder
+                .variant_ignore_duplicates(enum_schema2.clone())
+                .is_err()
+        );
+
+        let union = builder.build();
+        assert_eq!(union.variants(), &[Schema::Null, enum_schema]);
+
+        Ok(())
+    }
 }

Reply via email to