This is an automated email from the ASF dual-hosted git repository.

ycai pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-analytics.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 0966b2e3 CASSANALYTICS-49: Support UDTs inside collections (#109)
0966b2e3 is described below

commit 0966b2e3d804d80834d744c5f0649fca37395f83
Author: Shailaja Koppu <s_ko...@apple.com>
AuthorDate: Thu May 22 18:11:42 2025 +0100

    CASSANALYTICS-49: Support UDTs inside collections (#109)
    
    Patch by Shailaja Koppu; Reviewed by Francisco Guerrero, Yifan Cai for 
CASSANALYTICS-49
---
 CHANGES.txt                                        |   1 +
 .../org/apache/cassandra/spark/data/CqlTable.java  |  33 ++
 .../cassandra/spark/bulkwriter/RecordWriter.java   |  84 +++--
 .../cassandra/analytics/BulkWriteUdtTest.java      | 414 ++++++++++++++++++++-
 .../SharedClusterSparkIntegrationTestBase.java     |  57 ++-
 5 files changed, 547 insertions(+), 42 deletions(-)

diff --git a/CHANGES.txt b/CHANGES.txt
index f35c4971..7092f15e 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 1.0.0
+ * Add support for nested UDT in collections for bulk write (CASSANALYTICS-49)
  * Add Sidecar Client (CASSANALYTICS-30)
  * Add support for vnodes (CASSANALYTICS-50)
  * Add CDC Kafka and Avro codecs module to translate CDC mutations into Avro 
format for publication over Kafka (CASSANALYTICS-9)
diff --git 
a/cassandra-analytics-common/src/main/java/org/apache/cassandra/spark/data/CqlTable.java
 
b/cassandra-analytics-common/src/main/java/org/apache/cassandra/spark/data/CqlTable.java
index 18bc60fc..ebb5c2ab 100644
--- 
a/cassandra-analytics-common/src/main/java/org/apache/cassandra/spark/data/CqlTable.java
+++ 
b/cassandra-analytics-common/src/main/java/org/apache/cassandra/spark/data/CqlTable.java
@@ -22,6 +22,7 @@ package org.apache.cassandra.spark.data;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
 import java.util.List;
@@ -50,6 +51,7 @@ public class CqlTable implements Serializable
     private final Set<CqlField.CqlUdt> udts;
 
     private final Map<String, CqlField> fieldsMap;
+    private final Set<String> columnsWithUdts;
     private final List<CqlField> partitionKeys;
     private final List<CqlField> clusteringKeys;
     private final List<CqlField> staticColumns;
@@ -105,6 +107,8 @@ public class CqlTable implements Serializable
         {
             columns.put(column.name(), column);
         }
+
+        this.columnsWithUdts = determineColumnsWithUdts();
     }
 
     public TableIdentifier tableIdentifier()
@@ -242,6 +246,35 @@ public class CqlTable implements Serializable
         return indexCount;
     }
 
+    /**
+     * Check each column of the table for UDT type somewhere nested inside it 
and
+     * create set of columns containing UDT types
+     * @return set of columns containing UDT types
+     */
+    private Set<String> determineColumnsWithUdts()
+    {
+        Set<String> columnsWithUdts = new HashSet<>();
+        for (Map.Entry<String, CqlField> field : fieldsMap.entrySet())
+        {
+            if (!field.getValue().type().udts().isEmpty())
+            {
+                columnsWithUdts.add(field.getKey());
+            }
+        }
+
+        return columnsWithUdts;
+    }
+
+    /**
+     * Determines if a column has UDT type somewhere nested inside it
+     * @param fieldName name of the column
+     * @return true if the column has UDT type , false otherwise
+     */
+    public boolean containsUdt(String fieldName)
+    {
+        return columnsWithUdts.contains(fieldName);
+    }
+
     @Override
     public int hashCode()
     {
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java
index 633e7efa..a33ff642 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java
@@ -23,9 +23,11 @@ import java.io.IOException;
 import java.math.BigInteger;
 import java.nio.file.Files;
 import java.nio.file.Path;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.LinkedHashSet;
 import java.util.List;
@@ -79,7 +81,7 @@ public class RecordWriter
     private final ExecutorService executorService;
     private final Path baseDir;
 
-    private volatile CqlTable cqlTable;
+    private final CqlTable cqlTable;
     private StreamSession<?> streamSession = null;
 
     public RecordWriter(BulkWriterContext writerContext, String[] columnNames)
@@ -107,21 +109,12 @@ public class RecordWriter
                                                                
taskContextSupplier.get());
 
         writerContext.cluster().startupValidate();
-    }
-
-    private CqlTable cqlTable()
-    {
-        if (cqlTable == null)
-        {
-            cqlTable = writerContext.bridge()
-                                    
.buildSchema(writerContext.schema().getTableSchema().createStatement,
-                                                 
writerContext.job().qualifiedTableName().keyspace(),
-                                                 IGNORED_REPLICATION_FACTOR,
-                                                 
writerContext.cluster().getPartitioner(),
-                                                 
writerContext.schema().getUserDefinedTypeStatements());
-        }
-
-        return cqlTable;
+        cqlTable = writerContext.bridge()
+                                
.buildSchema(writerContext.schema().getTableSchema().createStatement,
+                                             
writerContext.job().qualifiedTableName().keyspace(),
+                                             IGNORED_REPLICATION_FACTOR,
+                                             
writerContext.cluster().getPartitioner(),
+                                             
writerContext.schema().getUserDefinedTypeStatements());
     }
 
     /**
@@ -380,35 +373,82 @@ public class RecordWriter
     {
         Preconditions.checkArgument(values.length == columnNames.length,
                                     "Number of values does not match the 
number of columns " + values.length + ", " + columnNames.length);
+
         for (int i = 0; i < columnNames.length; i++)
         {
-            map.put(columnNames[i], maybeConvertUdt(values[i]));
+            if (cqlTable.containsUdt(columnNames[i]))
+            {
+                map.put(columnNames[i], maybeConvertUdt(values[i]));
+            }
+            else
+            {
+                map.put(columnNames[i], values[i]);
+            }
         }
         return map;
     }
 
+    /**
+     * A column can have UDTs somewhere nested inside collections/UDTs. All 
occurrences of BridgeUdtValue need to be
+     * recursively converted to UDTValue to be able to write to CQL.
+     * @param value column value
+     * @return column value after converting all occurrences of BridgeUdtValue 
to UDTValue
+     */
     private Object maybeConvertUdt(Object value)
     {
+        if (value instanceof List && !((List<?>) value).isEmpty())
+        {
+            List<Object> resultList = new ArrayList<>();
+            for (Object entry : (List<?>) value)
+            {
+                resultList.add(maybeConvertUdt(entry));
+            }
+
+            return resultList;
+        }
+
+        if (value instanceof Set && !((Set<?>) value).isEmpty())
+        {
+            Set<Object> resultSet = new HashSet<>();
+            for (Object entry : (Set<?>) value)
+            {
+                resultSet.add(maybeConvertUdt(entry));
+            }
+
+            return resultSet;
+        }
+
+        if (value instanceof Map && !((Map<?, ?>) value).isEmpty())
+        {
+            Map<Object, Object> resultMap = new HashMap<>();
+            for (Map.Entry<?, ?> entry : ((Map<?, ?>) value).entrySet())
+            {
+                resultMap.put(maybeConvertUdt(entry.getKey()), 
maybeConvertUdt(entry.getValue()));
+            }
+
+            return resultMap;
+        }
+
         if (value instanceof BridgeUdtValue)
         {
             BridgeUdtValue udtValue = (BridgeUdtValue) value;
             // Depth-first replacement of BridgeUdtValue instances to their 
appropriate Cql types
             for (Map.Entry<String, Object> entry : udtValue.udtMap.entrySet())
             {
-                if (entry.getValue() instanceof BridgeUdtValue)
-                {
-                    udtValue.udtMap.put(entry.getKey(), 
maybeConvertUdt(entry.getValue()));
-                }
+                // udt can have complex types like nested udt, list, set or 
map with embedded UDTs in them
+                // convert each entry recursively until we see basic datatype
+                udtValue.udtMap.put(entry.getKey(), 
maybeConvertUdt(entry.getValue()));
             }
             return getUdt(udtValue.name).convertForCqlWriter(udtValue.udtMap, 
writerContext.bridge().getVersion(), false);
         }
+
         return value;
     }
 
     private synchronized CqlField.CqlType getUdt(String udtName)
     {
         return udtCache.computeIfAbsent(udtName, name -> {
-            for (CqlField.CqlUdt udt1 : cqlTable().udts())
+            for (CqlField.CqlUdt udt1 : cqlTable.udts())
             {
                 if (udt1.cqlName().equals(name))
                 {
diff --git 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BulkWriteUdtTest.java
 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BulkWriteUdtTest.java
index abd8beb1..82b27446 100644
--- 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BulkWriteUdtTest.java
+++ 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BulkWriteUdtTest.java
@@ -19,8 +19,10 @@
 
 package org.apache.cassandra.analytics;
 
+import java.util.Objects;
 import java.util.function.Predicate;
-
+import com.datastax.driver.core.UDTValue;
+import org.apache.cassandra.distributed.api.ICoordinator;
 import org.junit.jupiter.api.Test;
 
 import org.apache.cassandra.distributed.api.ConsistencyLevel;
@@ -59,6 +61,61 @@ class BulkWriteUdtTest extends 
SharedClusterSparkIntegrationTestBase
                                                      + "           id BIGINT 
PRIMARY KEY,\n"
                                                      + "           nested " + 
NESTED_FIELD_UDT_NAME + ");";
 
+    // UDT with list, set and map in it
+    public static final String UDT_WITH_COLLECTIONS_TYPE_NAME = 
"udt_with_collections";
+    public static final String UDT_WITH_COLLECTIONS_TYPE_CREATE = "CREATE TYPE 
" + TEST_KEYSPACE + "." + UDT_WITH_COLLECTIONS_TYPE_NAME +
+            " (f1 list<text>, f2 set<int>, f3 map<int, text>);";
+
+    // table with list of UDTs, and UDT itself has collections in it
+    public static final QualifiedName LIST_OF_UDT_SOURCE_TABLE = new 
QualifiedName(TEST_KEYSPACE, "list_of_udt_src");
+    public static final QualifiedName LIST_OF_UDT_DEST_TABLE = new 
QualifiedName(TEST_KEYSPACE, "list_of_udt_dest");
+    public static final String LIST_OF_UDT_TABLE_CREATE = "CREATE TABLE %s.%s 
(\n"
+            + "            id BIGINT PRIMARY KEY,\n"
+            + "            udtlist frozen<list<frozen<" + 
UDT_WITH_COLLECTIONS_TYPE_NAME + ">>>)";
+
+    // table with set of UDTs, and UDT itself has collections in it
+    public static final QualifiedName SET_OF_UDT_SOURCE_TABLE = new 
QualifiedName(TEST_KEYSPACE, "set_of_udt_src");
+    public static final QualifiedName SET_OF_UDT_DEST_TABLE = new 
QualifiedName(TEST_KEYSPACE, "set_of_udt_dest");
+    public static final String SET_OF_UDT_TABLE_CREATE = "CREATE TABLE %s.%s 
(\n"
+            + "            id BIGINT PRIMARY KEY,\n"
+            + "            udtset frozen<set<frozen<" + 
UDT_WITH_COLLECTIONS_TYPE_NAME + ">>>)";
+
+    // table with map of UDTs, and UDT itself has collections in it
+    public static final QualifiedName MAP_OF_UDT_SOURCE_TABLE = new 
QualifiedName(TEST_KEYSPACE, "map_of_udt_src");
+    public static final QualifiedName MAP_OF_UDT_DEST_TABLE = new 
QualifiedName(TEST_KEYSPACE, "map_of_udt_dest");
+    public static final String MAP_OF_UDT_TABLE_CREATE = "CREATE TABLE %s.%s 
(\n"
+            + "            id BIGINT PRIMARY KEY,\n"
+            + "            udtmap frozen<map<frozen<" + 
UDT_WITH_COLLECTIONS_TYPE_NAME + ">, frozen<" + UDT_WITH_COLLECTIONS_TYPE_NAME 
+ ">>>)";
+
+    // udt with list of UDTs inside it
+    public static final String UDT_WITH_LIST_OF_UDT_TYPE_NAME = 
"udt_with_list_of_udt_type";
+    public static final String UDT_WITH_LIST_OF_UDT_TYPE_CREATE = "CREATE TYPE 
" + TEST_KEYSPACE + "." + UDT_WITH_LIST_OF_UDT_TYPE_NAME +
+            " (innerudt list<frozen<" + TWO_FIELD_UDT_NAME + ">>);";
+    public static final QualifiedName UDT_WITH_LIST_OF_UDT_SOURCE_TABLE = new 
QualifiedName(TEST_KEYSPACE, "udt_with_list_of_udt_src");
+    public static final QualifiedName UDT_WITH_LIST_OF_UDT_DEST_TABLE = new 
QualifiedName(TEST_KEYSPACE, "udt_with_list_of_udt_dest");
+
+    // udt with set of UDTs inside it
+    public static final String UDT_WITH_SET_OF_UDT_TYPE_NAME = 
"udt_with_set_of_udt_type";
+    public static final String UDT_WITH_SET_OF_UDT_TYPE_CREATE = "CREATE TYPE 
" + TEST_KEYSPACE + "." + UDT_WITH_SET_OF_UDT_TYPE_NAME +
+            " (innerudt set<frozen<" + TWO_FIELD_UDT_NAME + ">>);";
+    public static final QualifiedName UDT_WITH_SET_OF_UDT_SOURCE_TABLE = new 
QualifiedName(TEST_KEYSPACE, "udt_with_set_of_udt_src");
+    public static final QualifiedName UDT_WITH_SET_OF_UDT_DEST_TABLE = new 
QualifiedName(TEST_KEYSPACE, "udt_with_set_of_udt_dest");
+
+    // udt with map of UDTs inside it
+    public static final String UDT_WITH_MAP_OF_UDT_TYPE_NAME = 
"udt_with_map_of_udt_type";
+    public static final String UDT_WITH_MAP_OF_UDT_TYPE_CREATE = "CREATE TYPE 
" + TEST_KEYSPACE + "." + UDT_WITH_MAP_OF_UDT_TYPE_NAME +
+            " (innerudt map<frozen<" + TWO_FIELD_UDT_NAME + ">, frozen<" + 
TWO_FIELD_UDT_NAME + ">>);";
+    public static final QualifiedName UDT_WITH_MAP_OF_UDT_SOURCE_TABLE = new 
QualifiedName(TEST_KEYSPACE, "udt_with_map_of_udt_src");
+    public static final QualifiedName UDT_WITH_MAP_OF_UDT_DEST_TABLE = new 
QualifiedName(TEST_KEYSPACE, "udt_with_map_of_udt_dest");
+
+    // Table with UDT which contains either a list or set or map of UDTs 
inside it
+    public static final String UDT_WITH_COLLECTION_OF_UDT_TABLE_CREATE = 
"CREATE TABLE %s.%s (\n"
+            + "            id BIGINT PRIMARY KEY,\n"
+            + "            outerudt frozen<%s>)";
+
+    private ICoordinator coordinator;
+
+
     @Test
     void testWriteWithUdt()
     {
@@ -68,11 +125,11 @@ class BulkWriteUdtTest extends 
SharedClusterSparkIntegrationTestBase
 
         bulkWriterDataFrameWriter(df, UDT_TABLE_NAME).save();
 
-        SimpleQueryResult result = 
cluster.coordinator(1).executeWithResult("SELECT * FROM " + UDT_TABLE_NAME, 
ConsistencyLevel.ALL);
+        SimpleQueryResult result = coordinator.executeWithResult("SELECT * 
FROM " + UDT_TABLE_NAME, ConsistencyLevel.ALL);
         assertThat(result.hasNext()).isTrue();
         validateWritesWithDriverResultSet(df.collectAsList(),
                                           
queryAllDataWithDriver(UDT_TABLE_NAME),
-                                          
BulkWriteUdtTest::defaultRowFormatter);
+                                          BulkWriteUdtTest::udtRowFormatter);
     }
 
     @Test
@@ -84,21 +141,310 @@ class BulkWriteUdtTest extends 
SharedClusterSparkIntegrationTestBase
 
         bulkWriterDataFrameWriter(df, NESTED_TABLE_NAME).save();
 
-        SimpleQueryResult result = 
cluster.coordinator(1).executeWithResult("SELECT * FROM " + NESTED_TABLE_NAME, 
ConsistencyLevel.ALL);
+        SimpleQueryResult result = coordinator.executeWithResult("SELECT * 
FROM " + NESTED_TABLE_NAME, ConsistencyLevel.ALL);
         assertThat(result.hasNext()).isTrue();
         validateWritesWithDriverResultSet(df.collectAsList(),
                                           
queryAllDataWithDriver(NESTED_TABLE_NAME),
-                                          
BulkWriteUdtTest::defaultRowFormatter);
+                                          BulkWriteUdtTest::udtRowFormatter);
+    }
+
+    @Test
+    void testListOfUdts()
+    {
+        int numRowsInserted = populateListOfUdts();
+
+        // Create a spark frame with the data inserted during the setup
+        Dataset<Row> sourceData = 
bulkReaderDataFrame(LIST_OF_UDT_SOURCE_TABLE).load();
+        assertThat(sourceData.count()).isEqualTo(numRowsInserted);
+
+        // Insert the dataset containing list of UDTs, and UDT itself has 
collections in it
+        bulkWriterDataFrameWriter(sourceData, LIST_OF_UDT_DEST_TABLE).save();
+        validateWritesWithDriverResultSet(sourceData.collectAsList(),
+                queryAllDataWithDriver(LIST_OF_UDT_DEST_TABLE),
+                BulkWriteUdtTest::listOfUdtRowFormatter);
+    }
+
+    private int populateListOfUdts()
+    {
+        // table(id, list<udt(list<>, set<>, map<>)>)
+        // insert list of UDTs, and each UDT has a list, set and map
+        String insertIntoListOfUdts = "INSERT INTO %s (id, udtlist) VALUES 
(%d, [{f1:['value %d'], f2:{%d}, f3:{%d : 'value %d'}}])";
+
+        int i = 0;
+        for (; i < ROW_COUNT; i++)
+        {
+            coordinator.execute(String.format(insertIntoListOfUdts, 
LIST_OF_UDT_SOURCE_TABLE, i, i, i, i, i), ConsistencyLevel.ALL);
+        }
+
+        // test null cases
+        coordinator.execute(String.format("insert into %s (id) values (%d)",
+                                          LIST_OF_UDT_SOURCE_TABLE, i++), 
ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, udtlist) values 
(%d, null)",
+                                          LIST_OF_UDT_SOURCE_TABLE, i++), 
ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, udtlist) values 
(%d, [{f1:null, f2:null, f3:null}])",
+                                          LIST_OF_UDT_SOURCE_TABLE, i++), 
ConsistencyLevel.ALL);
+
+        return i;
+    }
+
+    @Test
+    void testSetOfUdts()
+    {
+        int numRowsInserted = populateSetOfUdts();
+        // Create a spark frame with the data inserted during the setup
+        Dataset<Row> sourceData = 
bulkReaderDataFrame(SET_OF_UDT_SOURCE_TABLE).load();
+        assertThat(sourceData.count()).isEqualTo(numRowsInserted);
+
+        // Insert the dataset containing set of UDTs, and UDT itself has 
collections in it
+        bulkWriterDataFrameWriter(sourceData, SET_OF_UDT_DEST_TABLE).save();
+        validateWritesWithDriverResultSet(sourceData.collectAsList(),
+                queryAllDataWithDriver(SET_OF_UDT_DEST_TABLE),
+                BulkWriteUdtTest::setOfUdtRowFormatter);
+    }
+
+    private int populateSetOfUdts()
+    {
+        // table(id, set<udt(list<>, set<>, map<>)>)
+        // insert set of UDTs, and UDT has a list, set and map inside it
+        String insertIntoSetOfUdts = "INSERT INTO %s (id, udtset) VALUES (%d, 
" +
+                "{{f1:['value %d'], f2:{%d}, f3:{%d : 'value %d'}}})";
+
+        int i = 0;
+        for (; i < ROW_COUNT; i++)
+        {
+            
cluster.schemaChangeIgnoringStoppedInstances(String.format(insertIntoSetOfUdts, 
SET_OF_UDT_SOURCE_TABLE,
+                    i, i, i, i, i));
+        }
+
+        // test null cases
+        coordinator.execute(String.format("insert into %s (id) values (%d)",
+                                          SET_OF_UDT_SOURCE_TABLE, i++), 
ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, udtset) values 
(%d, null)",
+                                          SET_OF_UDT_SOURCE_TABLE, i++), 
ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, udtset) values 
(%d, {{f1:null, f2:null, f3:null}})",
+                                          SET_OF_UDT_SOURCE_TABLE, i++), 
ConsistencyLevel.ALL);
+
+        return i;
+    }
+
+    @Test
+    void testMapOfUdts()
+    {
+        int numRowsInserted = populateMapOfUdts();
+        // Create a spark frame with the data inserted during the setup
+        Dataset<Row> sourceData = 
bulkReaderDataFrame(MAP_OF_UDT_SOURCE_TABLE).load();
+        assertThat(sourceData.count()).isEqualTo(numRowsInserted);
+
+        // Insert the dataset containing map of UDTs, and UDT itself has 
collections in it
+        bulkWriterDataFrameWriter(sourceData, MAP_OF_UDT_DEST_TABLE).save();
+        validateWritesWithDriverResultSet(sourceData.collectAsList(),
+                queryAllDataWithDriver(MAP_OF_UDT_DEST_TABLE),
+                BulkWriteUdtTest::mapOfUdtRowFormatter);
+    }
+
+    private int populateMapOfUdts()
+    {
+        // table(id, map<udt(list<>, set<>, map<>), udt(list<>, set<>, map<>)>)
+        // insert map of UDTs, and UDT has a list, set and map inside it
+        String insertIntoMapOfUdts = "INSERT INTO %s (id, udtmap) VALUES (%d, 
" +
+                "{{f1:['value %d'], f2:{%d}, f3:{%d : 'value %d'}} : 
{f1:['value %d'], f2:{%d}, f3:{%d : 'value %d'}}})";
+
+        int i = 0;
+        for (; i < ROW_COUNT; i++)
+        {
+            
cluster.schemaChangeIgnoringStoppedInstances(String.format(insertIntoMapOfUdts, 
MAP_OF_UDT_SOURCE_TABLE,
+                    i, i, i, i, i, i, i, i, i));
+        }
+
+        coordinator.execute(String.format("insert into %s (id) values (%d)",
+                                          MAP_OF_UDT_SOURCE_TABLE, i++), 
ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, udtmap) values 
(%d, null)",
+                                          MAP_OF_UDT_SOURCE_TABLE, i++), 
ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, udtmap) values 
(%d, {{f1:null, f2:null, f3:null} : {f1:null, f2:null, f3:null}})",
+                                          MAP_OF_UDT_SOURCE_TABLE, i++), 
ConsistencyLevel.ALL);
+
+        return i;
+    }
+
+    @Test
+    void testUdtWithListOfUdts()
+    {
+        int numRowsInserted = populateUdtWithListOfUdts();
+
+        // Create a spark frame with the data inserted during the setup
+        Dataset<Row> sourceData = 
bulkReaderDataFrame(UDT_WITH_LIST_OF_UDT_SOURCE_TABLE).load();
+        assertThat(sourceData.count()).isEqualTo(numRowsInserted);
+
+        // Insert the dataset containing list of UDTs, and UDT itself has 
collections in it
+        bulkWriterDataFrameWriter(sourceData, 
UDT_WITH_LIST_OF_UDT_DEST_TABLE).save();
+        validateWritesWithDriverResultSet(sourceData.collectAsList(),
+                queryAllDataWithDriver(UDT_WITH_LIST_OF_UDT_DEST_TABLE),
+                BulkWriteUdtTest::udtRowFormatter);
+    }
+
+    private int populateUdtWithListOfUdts()
+    {
+        // table(id, udt<list<udt(f1 text, f2 int)>>)
+        String insertIntoUdtWithListOfUdts = "INSERT INTO %s (id, outerudt) 
VALUES (%d, {innerudt:[{f1:'value %d', f2:%d}]})";
+
+        int i = 0;
+        for (; i < ROW_COUNT; i++)
+        {
+            
cluster.schemaChangeIgnoringStoppedInstances(String.format(insertIntoUdtWithListOfUdts,
 UDT_WITH_LIST_OF_UDT_SOURCE_TABLE, i, i, i, i, i));
+        }
+
+        // test null cases
+        coordinator.execute(String.format("insert into %s (id) values (%d)",
+                                          UDT_WITH_LIST_OF_UDT_SOURCE_TABLE, 
i++), ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, outerudt) 
values (%d, null)",
+                                          UDT_WITH_LIST_OF_UDT_SOURCE_TABLE, 
i++), ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, outerudt) 
values (%d, {innerudt:[]})",
+                                          UDT_WITH_LIST_OF_UDT_SOURCE_TABLE, 
i++), ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, outerudt) 
values (%d, {innerudt:[{f1:null, f2:null}]})",
+                                          UDT_WITH_LIST_OF_UDT_SOURCE_TABLE, 
i++), ConsistencyLevel.ALL);
+
+        return i;
+    }
+
+    @Test
+    void testUdtWithSetOfUdts()
+    {
+        int numRowsInserted = populateUdtWithSetOfUdts();
+
+        // Create a spark frame with the data inserted during the setup
+        Dataset<Row> sourceData = 
bulkReaderDataFrame(UDT_WITH_SET_OF_UDT_SOURCE_TABLE).load();
+        assertThat(sourceData.count()).isEqualTo(numRowsInserted);
+
+        // Insert the dataset containing list of UDTs, and UDT itself has 
collections in it
+        bulkWriterDataFrameWriter(sourceData, 
UDT_WITH_SET_OF_UDT_DEST_TABLE).save();
+        validateWritesWithDriverResultSet(sourceData.collectAsList(),
+                queryAllDataWithDriver(UDT_WITH_SET_OF_UDT_DEST_TABLE),
+                BulkWriteUdtTest::udtRowFormatter);
+    }
+
+    private int populateUdtWithSetOfUdts()
+    {
+        // table(id, udt<set<udt(f1 text, f2 int)>>)
+        String insertIntoUdtWithSetOfUdts = "INSERT INTO %s (id, outerudt) 
VALUES (%d, {innerudt:{{f1:'value %d', f2:%d}}})";
+
+        int i = 0;
+        for (; i < ROW_COUNT; i++)
+        {
+            
cluster.schemaChangeIgnoringStoppedInstances(String.format(insertIntoUdtWithSetOfUdts,
 UDT_WITH_SET_OF_UDT_SOURCE_TABLE, i, i, i, i, i));
+        }
+
+        // test null cases
+        coordinator.execute(String.format("insert into %s (id) values (%d)",
+                                          UDT_WITH_SET_OF_UDT_SOURCE_TABLE, 
i++), ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, outerudt) 
values (%d, null)",
+                                          UDT_WITH_SET_OF_UDT_SOURCE_TABLE, 
i++), ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, outerudt) 
values (%d, {innerudt:{}})",
+                                          UDT_WITH_SET_OF_UDT_SOURCE_TABLE, 
i++), ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, outerudt) 
values (%d, {innerudt:{{f1:null, f2:null}}})",
+                                          UDT_WITH_SET_OF_UDT_SOURCE_TABLE, 
i++), ConsistencyLevel.ALL);
+
+        return i;
+    }
+
+    @Test
+    void testUdtWithMapOfUdts()
+    {
+        int numRowsInserted = populateUdtWithMapOfUdts();
+
+        // Create a spark frame with the data inserted during the setup
+        Dataset<Row> sourceData = 
bulkReaderDataFrame(UDT_WITH_MAP_OF_UDT_SOURCE_TABLE).load();
+        assertThat(sourceData.count()).isEqualTo(numRowsInserted);
+
+        // Insert the dataset containing list of UDTs, and UDT itself has 
collections in it
+        bulkWriterDataFrameWriter(sourceData, 
UDT_WITH_MAP_OF_UDT_DEST_TABLE).save();
+        validateWritesWithDriverResultSet(sourceData.collectAsList(),
+                queryAllDataWithDriver(UDT_WITH_MAP_OF_UDT_DEST_TABLE),
+                BulkWriteUdtTest::udtRowFormatter);
+    }
+
+    private int populateUdtWithMapOfUdts()
+    {
+        // table(id, udt<map<udt(f1 text, f2 int), udt(f1 text, f2 int)>>)
+        String insertIntoUdtWithMapOfUdts = "INSERT INTO %s (id, outerudt) 
VALUES (%d, {innerudt:{{f1:'valueA %d', f2:%d}: {f1:'valueB %d', f2:%d}}})";
+
+        int i = 0;
+        for (; i < ROW_COUNT; i++)
+        {
+            
cluster.schemaChangeIgnoringStoppedInstances(String.format(insertIntoUdtWithMapOfUdts,
 UDT_WITH_MAP_OF_UDT_SOURCE_TABLE, i, i, i, i, i));
+        }
+
+        // test null cases
+        coordinator.execute(String.format("insert into %s (id) values (%d)",
+                                          UDT_WITH_MAP_OF_UDT_SOURCE_TABLE, 
i++), ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, outerudt) 
values (%d, null)",
+                                          UDT_WITH_MAP_OF_UDT_SOURCE_TABLE, 
i++), ConsistencyLevel.ALL);
+        coordinator.execute(String.format("insert into %s (id, outerudt) 
values (%d, {innerudt:{{f1:null, f2:null}: {f1:null, f2:null}}})",
+                                          UDT_WITH_MAP_OF_UDT_SOURCE_TABLE, 
i++), ConsistencyLevel.ALL);
+
+        return i;
     }
 
     @NotNull
-    public static String defaultRowFormatter(com.datastax.driver.core.Row row)
+    public static String udtRowFormatter(com.datastax.driver.core.Row row)
+    {
+        UDTValue udt = row.getUDTValue(1);
+        return row.getLong(0) +
+               ":" +
+               Objects.requireNonNullElse(udt, "null").toString()
+                      // driver writes lists as [] and sets as {},
+                      // whereas spark entries have the same type Seq for both 
lists and sets
+                      .replace('[', '{')
+                      .replace(']', '}');
+    }
+
+    @NotNull
+    public static String listOfUdtRowFormatter(com.datastax.driver.core.Row 
row)
+    {
+        return row.getLong(0) +
+               ":" +
+               row.getList(1, UDTValue.class).toString()
+                  // empty collections have different formatting between 
driver and spark
+                  .replace("{}", "null")
+                  .replace("[]", "null")
+                  // driver writes lists as [] and sets as {},
+                  // whereas spark entries have the same type Seq for both 
lists and sets
+                  .replace('[', '{')
+                  .replace(']', '}');
+    }
+
+    @NotNull
+    public static String setOfUdtRowFormatter(com.datastax.driver.core.Row row)
     {
         // Formats as field:value with no whitespaces, and strings quoted
         // Driver Codec writes "NULL" for null value. Spark DF writes "null".
         return row.getLong(0) +
                ":" +
-               row.getUDTValue(1).toString().replace("NULL", "null");
+               row.getSet(1, UDTValue.class).toString()
+                  // empty collections have different formatting between 
driver and spark
+                  .replace("{}", "null")
+                  .replace("[]", "null")
+                  // driver writes lists as [] and sets as {},
+                  // whereas spark entries have the same type Seq for both 
lists and sets
+                  .replace('[', '{')
+                  .replace(']', '}');
+    }
+
+    @NotNull
+    public static String mapOfUdtRowFormatter(com.datastax.driver.core.Row row)
+    {
+        // Formats as field:value with no whitespaces, and strings quoted
+        // Driver Codec writes "NULL" for null value. Spark DF writes "null".
+        return row.getLong(0) +
+               ":" +
+               row.getMap(1, UDTValue.class, UDTValue.class).toString()
+                  // empty collections have different formatting between 
driver and spark
+                  .replace("{}", "null")
+                  .replace("[]", "null")
+                  .replace("=", ":")
+                  // driver writes lists as [] and sets as {},
+                  // whereas spark entries have the same type Seq for both 
lists and sets
+                  .replace('[', '{')
+                  .replace(']', '}');
     }
 
     @Override
@@ -111,11 +457,65 @@ class BulkWriteUdtTest extends 
SharedClusterSparkIntegrationTestBase
     @Override
     protected void initializeSchemaForTest()
     {
+        coordinator = cluster.getFirstRunningInstance().coordinator();
+
         createTestKeyspace(UDT_TABLE_NAME, DC1_RF3);
 
         cluster.schemaChangeIgnoringStoppedInstances(TWO_FIELD_UDT_DEF);
         cluster.schemaChangeIgnoringStoppedInstances(NESTED_UDT_DEF);
         cluster.schemaChangeIgnoringStoppedInstances(UDT_TABLE_CREATE);
         cluster.schemaChangeIgnoringStoppedInstances(NESTED_TABLE_CREATE);
+        
cluster.schemaChangeIgnoringStoppedInstances(UDT_WITH_COLLECTIONS_TYPE_CREATE);
+        
cluster.schemaChangeIgnoringStoppedInstances(UDT_WITH_LIST_OF_UDT_TYPE_CREATE);
+        
cluster.schemaChangeIgnoringStoppedInstances(UDT_WITH_SET_OF_UDT_TYPE_CREATE);
+        
cluster.schemaChangeIgnoringStoppedInstances(UDT_WITH_MAP_OF_UDT_TYPE_CREATE);
+
+        
cluster.schemaChangeIgnoringStoppedInstances(String.format(LIST_OF_UDT_TABLE_CREATE,
+                LIST_OF_UDT_SOURCE_TABLE.keyspace(),
+                LIST_OF_UDT_SOURCE_TABLE.table()));
+        
cluster.schemaChangeIgnoringStoppedInstances(String.format(LIST_OF_UDT_TABLE_CREATE,
+                LIST_OF_UDT_DEST_TABLE.keyspace(),
+                LIST_OF_UDT_DEST_TABLE.table()));
+
+        
cluster.schemaChangeIgnoringStoppedInstances(String.format(SET_OF_UDT_TABLE_CREATE,
+                SET_OF_UDT_SOURCE_TABLE.keyspace(),
+                SET_OF_UDT_SOURCE_TABLE.table()));
+        
cluster.schemaChangeIgnoringStoppedInstances(String.format(SET_OF_UDT_TABLE_CREATE,
+                SET_OF_UDT_DEST_TABLE.keyspace(),
+                SET_OF_UDT_DEST_TABLE.table()));
+
+        
cluster.schemaChangeIgnoringStoppedInstances(String.format(MAP_OF_UDT_TABLE_CREATE,
+                MAP_OF_UDT_SOURCE_TABLE.keyspace(),
+                MAP_OF_UDT_SOURCE_TABLE.table()));
+        
cluster.schemaChangeIgnoringStoppedInstances(String.format(MAP_OF_UDT_TABLE_CREATE,
+                MAP_OF_UDT_DEST_TABLE.keyspace(),
+                MAP_OF_UDT_DEST_TABLE.table()));
+
+        
cluster.schemaChangeIgnoringStoppedInstances(String.format(UDT_WITH_COLLECTION_OF_UDT_TABLE_CREATE,
+                UDT_WITH_LIST_OF_UDT_SOURCE_TABLE.keyspace(),
+                UDT_WITH_LIST_OF_UDT_SOURCE_TABLE.table(),
+                UDT_WITH_LIST_OF_UDT_TYPE_NAME));
+        
cluster.schemaChangeIgnoringStoppedInstances(String.format(UDT_WITH_COLLECTION_OF_UDT_TABLE_CREATE,
+                UDT_WITH_LIST_OF_UDT_DEST_TABLE.keyspace(),
+                UDT_WITH_LIST_OF_UDT_DEST_TABLE.table(),
+                UDT_WITH_LIST_OF_UDT_TYPE_NAME));
+
+        
cluster.schemaChangeIgnoringStoppedInstances(String.format(UDT_WITH_COLLECTION_OF_UDT_TABLE_CREATE,
+                UDT_WITH_SET_OF_UDT_SOURCE_TABLE.keyspace(),
+                UDT_WITH_SET_OF_UDT_SOURCE_TABLE.table(),
+                UDT_WITH_SET_OF_UDT_TYPE_NAME));
+        
cluster.schemaChangeIgnoringStoppedInstances(String.format(UDT_WITH_COLLECTION_OF_UDT_TABLE_CREATE,
+                UDT_WITH_SET_OF_UDT_DEST_TABLE.keyspace(),
+                UDT_WITH_SET_OF_UDT_DEST_TABLE.table(),
+                UDT_WITH_SET_OF_UDT_TYPE_NAME));
+
+        
cluster.schemaChangeIgnoringStoppedInstances(String.format(UDT_WITH_COLLECTION_OF_UDT_TABLE_CREATE,
+                UDT_WITH_MAP_OF_UDT_SOURCE_TABLE.keyspace(),
+                UDT_WITH_MAP_OF_UDT_SOURCE_TABLE.table(),
+                UDT_WITH_MAP_OF_UDT_TYPE_NAME));
+        
cluster.schemaChangeIgnoringStoppedInstances(String.format(UDT_WITH_COLLECTION_OF_UDT_TABLE_CREATE,
+                UDT_WITH_MAP_OF_UDT_DEST_TABLE.keyspace(),
+                UDT_WITH_MAP_OF_UDT_DEST_TABLE.table(),
+                UDT_WITH_MAP_OF_UDT_TYPE_NAME));
     }
 }
diff --git 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java
 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java
index daf74693..465cbc9a 100644
--- 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java
+++ 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java
@@ -45,6 +45,8 @@ import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.StructField;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -171,28 +173,31 @@ public abstract class 
SharedClusterSparkIntegrationTestBase extends SharedCluste
         }
     }
 
-    public void validateWritesWithDriverResultSet(List<Row> sourceData, 
ResultSet queriedData,
-                                                  
Function<com.datastax.driver.core.Row, String> rowFormatter)
+    public void validateWritesWithDriverResultSet(List<Row> sparkData, 
ResultSet driverData,
+                                                  
Function<com.datastax.driver.core.Row, String> driverRowFormatter)
     {
-        Set<String> actualEntries = new HashSet<>();
-        queriedData.forEach(row -> actualEntries.add(rowFormatter.apply(row)));
+        Set<String> driverEntries = new HashSet<>();
+        driverData.forEach(row -> driverEntries.add(driverRowFormatter
+                .apply(row)
+                // Driver Codec writes "NULL" for null value. Spark DF writes 
"null".
+                .replace("NULL", "null")));
 
         // Number of entries in Cassandra must match the original datasource
-        assertThat(actualEntries.size()).isEqualTo(sourceData.size());
+        assertThat(driverEntries.size()).isEqualTo(sparkData.size());
 
         // remove from actual entries to make sure that the data read is the 
same as the data written
-        Set<String> sourceEntries = 
sourceData.stream().map(this::formattedSourceEntry)
-                                              .collect(Collectors.toSet());
-        assertThat(actualEntries).as("All entries are expected to be read from 
database")
-                                 
.containsExactlyInAnyOrderElementsOf(sourceEntries);
+        Set<String> sparkEntries = 
sparkData.stream().map(this::formattedSparkRow)
+                .collect(Collectors.toSet());
+        assertThat(driverEntries).as("All entries are expected to be read from 
database")
+                .containsExactlyInAnyOrderElementsOf(sparkEntries);
     }
 
-    private String formattedSourceEntry(Row row)
+    private String formattedSparkRow(Row row)
     {
         StringBuilder sb = new StringBuilder();
         for (int i = 0; i < row.size(); i++)
         {
-            maybeFormatUdt(sb, row.get(i));
+            maybeFormatSparkCompositeType(sb, row.get(i));
             if (i != (row.size() - 1))
             {
                 sb.append(":");
@@ -203,7 +208,7 @@ public abstract class SharedClusterSparkIntegrationTestBase 
extends SharedCluste
 
     // Format a Spark row to look like what the toString on a UDT looks like
     // Unfortunately not _quite_ json, so we need to do this manually.
-    protected void maybeFormatUdt(StringBuilder sb, Object o)
+    protected void maybeFormatSparkCompositeType(StringBuilder sb, Object o)
     {
         if (o instanceof Row)
         {
@@ -214,7 +219,7 @@ public abstract class SharedClusterSparkIntegrationTestBase 
extends SharedCluste
             {
                 sb.append(maybeQuoteFieldName(fields[i]));
                 sb.append(":");
-                maybeFormatUdt(sb, r.get(i));
+                maybeFormatSparkCompositeType(sb, r.get(i));
                 if (i != r.size() - 1)
                 {
                     sb.append(',');
@@ -222,6 +227,32 @@ public abstract class 
SharedClusterSparkIntegrationTestBase extends SharedCluste
             }
             sb.append("}");
         }
+        else if (o instanceof Seq) // can't differentiate between scala list 
and set, both come here as Seq
+        {
+            List<?> entries = JavaConverters.seqAsJavaList((Seq<?>) o);
+            sb.append("{");
+            for (int i = 0; i < entries.size(); i++)
+            {
+                maybeFormatSparkCompositeType(sb, entries.get(i));
+                if (i != (entries.size() - 1))
+                {
+                    sb.append(',');
+                }
+            }
+            sb.append("}");
+        }
+        else if (o instanceof scala.collection.Map)
+        {
+            Map<?, ?> map = 
JavaConverters.mapAsJavaMap(((scala.collection.Map<?, ?>) o));
+            for (Map.Entry<?, ?> entry : map.entrySet())
+            {
+                sb.append("{");
+                maybeFormatSparkCompositeType(sb, entry.getKey());
+                sb.append(":");
+                maybeFormatSparkCompositeType(sb, entry.getValue());
+                sb.append("}");
+            }
+        }
         else if (o instanceof String)
         {
             sb.append(String.format("'%s'", o));


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org
For additional commands, e-mail: commits-h...@cassandra.apache.org


Reply via email to