This is an automated email from the ASF dual-hosted git repository.

ahmedabualsaud pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 359c9ebec95 Introduce Schema Registry Functionality to Managed KafkaIO 
Write. (#35644)
359c9ebec95 is described below

commit 359c9ebec955db02eea8170178816dbfb4255caf
Author: fozzie15 <[email protected]>
AuthorDate: Tue Sep 9 13:23:00 2025 -0400

    Introduce Schema Registry Functionality to Managed KafkaIO Write. (#35644)
    
    * Add testing for Managed Schema Registry support
    
    * Add testing that runs on Dataflow
    
    * Clean up test and use apache beam testing resources
    
    * Spotless
    
    * Trigger GitHub Actions. No Code Changes
    
    * Push Write changes for testing
    
    * add test to read and write with Managed KafkaIO using SR
    
    * Testing Write Transform
    
    * Add @ignore for faster testing. WILL REMOVE BEFORE MERGE.
    
    * Finish the Write schema transform provider and add tests
    
    * Refactor write class to use a generic method for the conversion function
    
    * Add extra logging and clean up variable names to address comments.
---
 sdks/java/io/kafka/build.gradle                    |   3 +
 sdks/java/io/kafka/kafka-integration-test.gradle   |   1 +
 .../kafka/KafkaWriteSchemaTransformProvider.java   | 132 +++++++++++++++------
 .../KafkaWriteSchemaTransformProviderTest.java     |  44 ++++++-
 4 files changed, 146 insertions(+), 34 deletions(-)

diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle
index 6e9b5aec093..ba25078b64e 100644
--- a/sdks/java/io/kafka/build.gradle
+++ b/sdks/java/io/kafka/build.gradle
@@ -74,6 +74,9 @@ dependencies {
   implementation (group: 'com.google.cloud.hosted.kafka', name: 
'managed-kafka-auth-login-handler', version: '1.0.5') {
     // "kafka-clients" has to be provided since user can use its own version.
     exclude group: 'org.apache.kafka', module: 'kafka-clients'
+    // "kafka-schema-registry-client must be excluded per the Google Cloud 
documentation:
+    // 
https://cloud.google.com/managed-service-for-apache-kafka/docs/quickstart-avro#configure_and_run_the_producer
+    exclude group: "io.confluent", module: "kafka-schema-registry-client"
   }
   implementation ("io.confluent:kafka-avro-serializer:${confluentVersion}") {
     // zookeeper depends on "spotbugs-annotations:3.1.9" which clashes with 
current
diff --git a/sdks/java/io/kafka/kafka-integration-test.gradle 
b/sdks/java/io/kafka/kafka-integration-test.gradle
index 3bbab72ff77..14d90349ded 100644
--- a/sdks/java/io/kafka/kafka-integration-test.gradle
+++ b/sdks/java/io/kafka/kafka-integration-test.gradle
@@ -33,6 +33,7 @@ dependencies {
     // instead, rely on io/kafka/build.gradle's custom configurations with 
forced kafka-client resolutionStrategy
     testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1'
     testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
+    testImplementation library.java.avro
 }
 
 configurations.create("kafkaVersion$undelimited")
diff --git 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
index d6f46b11cb7..e2a4f394ccd 100644
--- 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
+++ 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
@@ -21,6 +21,7 @@ import static 
org.apache.beam.sdk.util.construction.BeamUrns.getUrn;
 
 import com.google.auto.service.AutoService;
 import com.google.auto.value.AutoValue;
+import io.confluent.kafka.serializers.KafkaAvroSerializer;
 import java.io.Serializable;
 import java.util.Collections;
 import java.util.HashMap;
@@ -28,7 +29,11 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import javax.annotation.Nullable;
+import org.apache.avro.generic.GenericRecord;
 import org.apache.beam.model.pipeline.v1.ExternalTransforms;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.extensions.avro.coders.AvroCoder;
 import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils;
 import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
 import org.apache.beam.sdk.metrics.Counter;
@@ -74,6 +79,8 @@ public class KafkaWriteSchemaTransformProvider
   public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {};
   public static final TupleTag<KV<byte[], byte[]>> OUTPUT_TAG =
       new TupleTag<KV<byte[], byte[]>>() {};
+  public static final TupleTag<KV<byte[], GenericRecord>> RECORD_OUTPUT_TAG =
+      new TupleTag<KV<byte[], GenericRecord>>() {};
   private static final Logger LOG =
       LoggerFactory.getLogger(KafkaWriteSchemaTransformProvider.class);
 
@@ -118,29 +125,32 @@ public class KafkaWriteSchemaTransformProvider
       }
     }
 
-    public static class ErrorCounterFn extends DoFn<Row, KV<byte[], byte[]>> {
-      private final SerializableFunction<Row, byte[]> toBytesFn;
+    public abstract static class BaseKafkaWriterFn<T> extends DoFn<Row, 
KV<byte[], T>> {
+      private final SerializableFunction<Row, T> conversionFn;
       private final Counter errorCounter;
       private Long errorsInBundle = 0L;
       private final boolean handleErrors;
       private final Schema errorSchema;
+      private final TupleTag<KV<byte[], T>> successTag;
 
-      public ErrorCounterFn(
+      public BaseKafkaWriterFn(
           String name,
-          SerializableFunction<Row, byte[]> toBytesFn,
+          SerializableFunction<Row, T> conversionFn,
           Schema errorSchema,
-          boolean handleErrors) {
-        this.toBytesFn = toBytesFn;
+          boolean handleErrors,
+          TupleTag<KV<byte[], T>> successTag) {
+        this.conversionFn = conversionFn;
         this.errorCounter = 
Metrics.counter(KafkaWriteSchemaTransformProvider.class, name);
         this.handleErrors = handleErrors;
         this.errorSchema = errorSchema;
+        this.successTag = successTag;
       }
 
       @ProcessElement
       public void process(@DoFn.Element Row row, MultiOutputReceiver receiver) 
{
-        KV<byte[], byte[]> output = null;
+        KV<byte[], T> output = null;
         try {
-          output = KV.of(new byte[1], toBytesFn.apply(row));
+          output = KV.of(new byte[1], conversionFn.apply(row));
         } catch (Exception e) {
           if (!handleErrors) {
             throw new RuntimeException(e);
@@ -150,7 +160,7 @@ public class KafkaWriteSchemaTransformProvider
           
receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(errorSchema, row, e));
         }
         if (output != null) {
-          receiver.get(OUTPUT_TAG).output(output);
+          receiver.get(successTag).output(output);
         }
       }
 
@@ -161,13 +171,35 @@ public class KafkaWriteSchemaTransformProvider
       }
     }
 
+    public static class ErrorCounterFn extends BaseKafkaWriterFn<byte[]> {
+      public ErrorCounterFn(
+          String name,
+          SerializableFunction<Row, byte[]> toBytesFn,
+          Schema errorSchema,
+          boolean handleErrors) {
+        super(name, toBytesFn, errorSchema, handleErrors, OUTPUT_TAG);
+      }
+    }
+
+    public static class GenericRecordErrorCounterFn extends 
BaseKafkaWriterFn<GenericRecord> {
+      public GenericRecordErrorCounterFn(
+          String name,
+          SerializableFunction<Row, GenericRecord> toGenericRecordsFn,
+          Schema errorSchema,
+          boolean handleErrors) {
+        super(name, toGenericRecordsFn, errorSchema, handleErrors, 
RECORD_OUTPUT_TAG);
+      }
+    }
+
     @SuppressWarnings({
       "nullness" // TODO(https://github.com/apache/beam/issues/20497)
     })
     @Override
     public PCollectionRowTuple expand(PCollectionRowTuple input) {
       Schema inputSchema = input.get("input").getSchema();
+      org.apache.avro.Schema avroSchema = AvroUtils.toAvroSchema(inputSchema);
       final SerializableFunction<Row, byte[]> toBytesFn;
+      SerializableFunction<Row, GenericRecord> toGenericRecordsFn = null;
       if (configuration.getFormat().equals("RAW")) {
         int numFields = inputSchema.getFields().size();
         if (numFields != 1) {
@@ -198,36 +230,70 @@ public class KafkaWriteSchemaTransformProvider
           throw new IllegalArgumentException(
               "At least a descriptorPath or a proto Schema is required.");
         }
-
       } else {
-        toBytesFn = AvroUtils.getRowToAvroBytesFunction(inputSchema);
+        if (configuration.getProducerConfigUpdates() != null
+            && 
configuration.getProducerConfigUpdates().containsKey("schema.registry.url")) {
+          toGenericRecordsFn = 
AvroUtils.getRowToGenericRecordFunction(avroSchema);
+          toBytesFn = null;
+        } else {
+          toBytesFn = AvroUtils.getRowToAvroBytesFunction(inputSchema);
+        }
       }
 
       boolean handleErrors = 
ErrorHandling.hasOutput(configuration.getErrorHandling());
       final Map<String, String> configOverrides = 
configuration.getProducerConfigUpdates();
       Schema errorSchema = ErrorHandling.errorSchema(inputSchema);
-      PCollectionTuple outputTuple =
-          input
-              .get("input")
-              .apply(
-                  "Map rows to Kafka messages",
-                  ParDo.of(
-                          new ErrorCounterFn(
-                              "Kafka-write-error-counter", toBytesFn, 
errorSchema, handleErrors))
-                      .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));
-
-      outputTuple
-          .get(OUTPUT_TAG)
-          .apply(
-              KafkaIO.<byte[], byte[]>write()
-                  .withTopic(configuration.getTopic())
-                  .withBootstrapServers(configuration.getBootstrapServers())
-                  .withProducerConfigUpdates(
-                      configOverrides == null
-                          ? new HashMap<>()
-                          : new HashMap<String, Object>(configOverrides))
-                  .withKeySerializer(ByteArraySerializer.class)
-                  .withValueSerializer(ByteArraySerializer.class));
+      PCollectionTuple outputTuple;
+      if (toGenericRecordsFn != null) {
+        LOG.info("Convert to GenericRecord with schema {}", avroSchema);
+        outputTuple =
+            input
+                .get("input")
+                .apply(
+                    "Map rows to Kafka messages",
+                    ParDo.of(
+                            new GenericRecordErrorCounterFn(
+                                "Kafka-write-error-counter",
+                                toGenericRecordsFn,
+                                errorSchema,
+                                handleErrors))
+                        .withOutputTags(RECORD_OUTPUT_TAG, 
TupleTagList.of(ERROR_TAG)));
+        HashMap<String, Object> producerConfig = new 
HashMap<>(configOverrides);
+        outputTuple
+            .get(RECORD_OUTPUT_TAG)
+            .setCoder(KvCoder.of(ByteArrayCoder.of(), 
AvroCoder.of(avroSchema)))
+            .apply(
+                "Map Rows to GenericRecords",
+                KafkaIO.<byte[], GenericRecord>write()
+                    .withTopic(configuration.getTopic())
+                    .withBootstrapServers(configuration.getBootstrapServers())
+                    .withProducerConfigUpdates(producerConfig)
+                    .withKeySerializer(ByteArraySerializer.class)
+                    .withValueSerializer((Class) KafkaAvroSerializer.class));
+      } else {
+        outputTuple =
+            input
+                .get("input")
+                .apply(
+                    "Map rows to Kafka messages",
+                    ParDo.of(
+                            new ErrorCounterFn(
+                                "Kafka-write-error-counter", toBytesFn, 
errorSchema, handleErrors))
+                        .withOutputTags(OUTPUT_TAG, 
TupleTagList.of(ERROR_TAG)));
+
+        outputTuple
+            .get(OUTPUT_TAG)
+            .apply(
+                KafkaIO.<byte[], byte[]>write()
+                    .withTopic(configuration.getTopic())
+                    .withBootstrapServers(configuration.getBootstrapServers())
+                    .withProducerConfigUpdates(
+                        configOverrides == null
+                            ? new HashMap<>()
+                            : new HashMap<String, Object>(configOverrides))
+                    .withKeySerializer(ByteArraySerializer.class)
+                    .withValueSerializer(ByteArraySerializer.class));
+      }
 
       // TODO: include output from KafkaIO Write once updated from PDone
       PCollection<Row> errorOutput =
diff --git 
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java
 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java
index dffa6ece9d1..b63a9334239 100644
--- 
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java
+++ 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java
@@ -24,9 +24,16 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
+import org.apache.avro.generic.GenericData;
+import org.apache.avro.generic.GenericRecord;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.extensions.avro.coders.AvroCoder;
+import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils;
 import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
 import 
org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform.ErrorCounterFn;
+import 
org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform.GenericRecordErrorCounterFn;
 import org.apache.beam.sdk.managed.Managed;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
@@ -53,6 +60,8 @@ public class KafkaWriteSchemaTransformProviderTest {
 
   private static final TupleTag<KV<byte[], byte[]>> OUTPUT_TAG =
       KafkaWriteSchemaTransformProvider.OUTPUT_TAG;
+  private static final TupleTag<KV<byte[], GenericRecord>> RECORD_OUTPUT_TAG =
+      KafkaWriteSchemaTransformProvider.RECORD_OUTPUT_TAG;
   private static final TupleTag<Row> ERROR_TAG = 
KafkaWriteSchemaTransformProvider.ERROR_TAG;
 
   private static final Schema BEAMSCHEMA =
@@ -126,7 +135,8 @@ public class KafkaWriteSchemaTransformProviderTest {
                   
getClass().getResource("/proto_byte/file_descriptor/proto_byte_utils.pb"))
               .getPath(),
           "MyMessage");
-
+  final SerializableFunction<Row, GenericRecord> recordValueMapper =
+      
AvroUtils.getRowToGenericRecordFunction(AvroUtils.toAvroSchema(BEAMSCHEMA));
   @Rule public transient TestPipeline p = TestPipeline.create();
 
   @Test
@@ -198,6 +208,38 @@ public class KafkaWriteSchemaTransformProviderTest {
           + "  bool active = 3;\n"
           + "}";
 
+  @Test
+  public void testKafkaRecordErrorFnSuccess() throws Exception {
+    org.apache.avro.Schema avroSchema = AvroUtils.toAvroSchema(BEAMSCHEMA);
+
+    GenericRecord record1 = new GenericData.Record(avroSchema);
+    GenericRecord record2 = new GenericData.Record(avroSchema);
+    GenericRecord record3 = new GenericData.Record(avroSchema);
+    record1.put("name", "a");
+    record2.put("name", "b");
+    record3.put("name", "c");
+
+    List<KV<byte[], GenericRecord>> msg =
+        Arrays.asList(
+            KV.of(new byte[1], record1), KV.of(new byte[1], record2), 
KV.of(new byte[1], record3));
+
+    PCollection<Row> input = p.apply(Create.of(ROWS));
+    Schema errorSchema = ErrorHandling.errorSchema(BEAMSCHEMA);
+    PCollectionTuple output =
+        input.apply(
+            ParDo.of(
+                    new GenericRecordErrorCounterFn(
+                        "Kafka-write-error-counter", recordValueMapper, 
errorSchema, true))
+                .withOutputTags(RECORD_OUTPUT_TAG, 
TupleTagList.of(ERROR_TAG)));
+
+    output.get(ERROR_TAG).setRowSchema(errorSchema);
+    output
+        .get(RECORD_OUTPUT_TAG)
+        .setCoder(KvCoder.of(ByteArrayCoder.of(), AvroCoder.of(avroSchema)));
+    PAssert.that(output.get(RECORD_OUTPUT_TAG)).containsInAnyOrder(msg);
+    p.run().waitUntilFinish();
+  }
+
   @Test
   public void testBuildTransformWithManaged() {
     List<String> configs =

Reply via email to