This is an automated email from the ASF dual-hosted git repository.
ahmedabualsaud pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 359c9ebec95 Introduce Schema Registry Functionality to Managed KafkaIO
Write. (#35644)
359c9ebec95 is described below
commit 359c9ebec955db02eea8170178816dbfb4255caf
Author: fozzie15 <[email protected]>
AuthorDate: Tue Sep 9 13:23:00 2025 -0400
Introduce Schema Registry Functionality to Managed KafkaIO Write. (#35644)
* Add testing for Managed Schema Registry support
* Add testing that runs on Dataflow
* Clean up test and use apache beam testing resources
* Spotless
* Trigger GitHub Actions. No Code Changes
* Push Write changes for testing
* add test to read and write with Managed KafkaIO using SR
* Testing Write Transform
* Add @ignore for faster testing. WILL REMOVE BEFORE MERGE.
* Finish the Write schema transform provider and add tests
* Refactor write class to use a generic method for the conversion function
* Add extra logging and clean up variable names to address comments.
---
sdks/java/io/kafka/build.gradle | 3 +
sdks/java/io/kafka/kafka-integration-test.gradle | 1 +
.../kafka/KafkaWriteSchemaTransformProvider.java | 132 +++++++++++++++------
.../KafkaWriteSchemaTransformProviderTest.java | 44 ++++++-
4 files changed, 146 insertions(+), 34 deletions(-)
diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle
index 6e9b5aec093..ba25078b64e 100644
--- a/sdks/java/io/kafka/build.gradle
+++ b/sdks/java/io/kafka/build.gradle
@@ -74,6 +74,9 @@ dependencies {
implementation (group: 'com.google.cloud.hosted.kafka', name:
'managed-kafka-auth-login-handler', version: '1.0.5') {
// "kafka-clients" has to be provided since user can use its own version.
exclude group: 'org.apache.kafka', module: 'kafka-clients'
+ // "kafka-schema-registry-client must be excluded per the Google Cloud
documentation:
+ //
https://cloud.google.com/managed-service-for-apache-kafka/docs/quickstart-avro#configure_and_run_the_producer
+ exclude group: "io.confluent", module: "kafka-schema-registry-client"
}
implementation ("io.confluent:kafka-avro-serializer:${confluentVersion}") {
// zookeeper depends on "spotbugs-annotations:3.1.9" which clashes with
current
diff --git a/sdks/java/io/kafka/kafka-integration-test.gradle
b/sdks/java/io/kafka/kafka-integration-test.gradle
index 3bbab72ff77..14d90349ded 100644
--- a/sdks/java/io/kafka/kafka-integration-test.gradle
+++ b/sdks/java/io/kafka/kafka-integration-test.gradle
@@ -33,6 +33,7 @@ dependencies {
// instead, rely on io/kafka/build.gradle's custom configurations with
forced kafka-client resolutionStrategy
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1'
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
+ testImplementation library.java.avro
}
configurations.create("kafkaVersion$undelimited")
diff --git
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
index d6f46b11cb7..e2a4f394ccd 100644
---
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
+++
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
@@ -21,6 +21,7 @@ import static
org.apache.beam.sdk.util.construction.BeamUrns.getUrn;
import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
+import io.confluent.kafka.serializers.KafkaAvroSerializer;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
@@ -28,7 +29,11 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
+import org.apache.avro.generic.GenericRecord;
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.extensions.avro.coders.AvroCoder;
import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils;
import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
import org.apache.beam.sdk.metrics.Counter;
@@ -74,6 +79,8 @@ public class KafkaWriteSchemaTransformProvider
public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {};
public static final TupleTag<KV<byte[], byte[]>> OUTPUT_TAG =
new TupleTag<KV<byte[], byte[]>>() {};
+ public static final TupleTag<KV<byte[], GenericRecord>> RECORD_OUTPUT_TAG =
+ new TupleTag<KV<byte[], GenericRecord>>() {};
private static final Logger LOG =
LoggerFactory.getLogger(KafkaWriteSchemaTransformProvider.class);
@@ -118,29 +125,32 @@ public class KafkaWriteSchemaTransformProvider
}
}
- public static class ErrorCounterFn extends DoFn<Row, KV<byte[], byte[]>> {
- private final SerializableFunction<Row, byte[]> toBytesFn;
+ public abstract static class BaseKafkaWriterFn<T> extends DoFn<Row,
KV<byte[], T>> {
+ private final SerializableFunction<Row, T> conversionFn;
private final Counter errorCounter;
private Long errorsInBundle = 0L;
private final boolean handleErrors;
private final Schema errorSchema;
+ private final TupleTag<KV<byte[], T>> successTag;
- public ErrorCounterFn(
+ public BaseKafkaWriterFn(
String name,
- SerializableFunction<Row, byte[]> toBytesFn,
+ SerializableFunction<Row, T> conversionFn,
Schema errorSchema,
- boolean handleErrors) {
- this.toBytesFn = toBytesFn;
+ boolean handleErrors,
+ TupleTag<KV<byte[], T>> successTag) {
+ this.conversionFn = conversionFn;
this.errorCounter =
Metrics.counter(KafkaWriteSchemaTransformProvider.class, name);
this.handleErrors = handleErrors;
this.errorSchema = errorSchema;
+ this.successTag = successTag;
}
@ProcessElement
public void process(@DoFn.Element Row row, MultiOutputReceiver receiver)
{
- KV<byte[], byte[]> output = null;
+ KV<byte[], T> output = null;
try {
- output = KV.of(new byte[1], toBytesFn.apply(row));
+ output = KV.of(new byte[1], conversionFn.apply(row));
} catch (Exception e) {
if (!handleErrors) {
throw new RuntimeException(e);
@@ -150,7 +160,7 @@ public class KafkaWriteSchemaTransformProvider
receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(errorSchema, row, e));
}
if (output != null) {
- receiver.get(OUTPUT_TAG).output(output);
+ receiver.get(successTag).output(output);
}
}
@@ -161,13 +171,35 @@ public class KafkaWriteSchemaTransformProvider
}
}
+ public static class ErrorCounterFn extends BaseKafkaWriterFn<byte[]> {
+ public ErrorCounterFn(
+ String name,
+ SerializableFunction<Row, byte[]> toBytesFn,
+ Schema errorSchema,
+ boolean handleErrors) {
+ super(name, toBytesFn, errorSchema, handleErrors, OUTPUT_TAG);
+ }
+ }
+
+ public static class GenericRecordErrorCounterFn extends
BaseKafkaWriterFn<GenericRecord> {
+ public GenericRecordErrorCounterFn(
+ String name,
+ SerializableFunction<Row, GenericRecord> toGenericRecordsFn,
+ Schema errorSchema,
+ boolean handleErrors) {
+ super(name, toGenericRecordsFn, errorSchema, handleErrors,
RECORD_OUTPUT_TAG);
+ }
+ }
+
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
Schema inputSchema = input.get("input").getSchema();
+ org.apache.avro.Schema avroSchema = AvroUtils.toAvroSchema(inputSchema);
final SerializableFunction<Row, byte[]> toBytesFn;
+ SerializableFunction<Row, GenericRecord> toGenericRecordsFn = null;
if (configuration.getFormat().equals("RAW")) {
int numFields = inputSchema.getFields().size();
if (numFields != 1) {
@@ -198,36 +230,70 @@ public class KafkaWriteSchemaTransformProvider
throw new IllegalArgumentException(
"At least a descriptorPath or a proto Schema is required.");
}
-
} else {
- toBytesFn = AvroUtils.getRowToAvroBytesFunction(inputSchema);
+ if (configuration.getProducerConfigUpdates() != null
+ &&
configuration.getProducerConfigUpdates().containsKey("schema.registry.url")) {
+ toGenericRecordsFn =
AvroUtils.getRowToGenericRecordFunction(avroSchema);
+ toBytesFn = null;
+ } else {
+ toBytesFn = AvroUtils.getRowToAvroBytesFunction(inputSchema);
+ }
}
boolean handleErrors =
ErrorHandling.hasOutput(configuration.getErrorHandling());
final Map<String, String> configOverrides =
configuration.getProducerConfigUpdates();
Schema errorSchema = ErrorHandling.errorSchema(inputSchema);
- PCollectionTuple outputTuple =
- input
- .get("input")
- .apply(
- "Map rows to Kafka messages",
- ParDo.of(
- new ErrorCounterFn(
- "Kafka-write-error-counter", toBytesFn,
errorSchema, handleErrors))
- .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));
-
- outputTuple
- .get(OUTPUT_TAG)
- .apply(
- KafkaIO.<byte[], byte[]>write()
- .withTopic(configuration.getTopic())
- .withBootstrapServers(configuration.getBootstrapServers())
- .withProducerConfigUpdates(
- configOverrides == null
- ? new HashMap<>()
- : new HashMap<String, Object>(configOverrides))
- .withKeySerializer(ByteArraySerializer.class)
- .withValueSerializer(ByteArraySerializer.class));
+ PCollectionTuple outputTuple;
+ if (toGenericRecordsFn != null) {
+ LOG.info("Convert to GenericRecord with schema {}", avroSchema);
+ outputTuple =
+ input
+ .get("input")
+ .apply(
+ "Map rows to Kafka messages",
+ ParDo.of(
+ new GenericRecordErrorCounterFn(
+ "Kafka-write-error-counter",
+ toGenericRecordsFn,
+ errorSchema,
+ handleErrors))
+ .withOutputTags(RECORD_OUTPUT_TAG,
TupleTagList.of(ERROR_TAG)));
+ HashMap<String, Object> producerConfig = new
HashMap<>(configOverrides);
+ outputTuple
+ .get(RECORD_OUTPUT_TAG)
+ .setCoder(KvCoder.of(ByteArrayCoder.of(),
AvroCoder.of(avroSchema)))
+ .apply(
+ "Map Rows to GenericRecords",
+ KafkaIO.<byte[], GenericRecord>write()
+ .withTopic(configuration.getTopic())
+ .withBootstrapServers(configuration.getBootstrapServers())
+ .withProducerConfigUpdates(producerConfig)
+ .withKeySerializer(ByteArraySerializer.class)
+ .withValueSerializer((Class) KafkaAvroSerializer.class));
+ } else {
+ outputTuple =
+ input
+ .get("input")
+ .apply(
+ "Map rows to Kafka messages",
+ ParDo.of(
+ new ErrorCounterFn(
+ "Kafka-write-error-counter", toBytesFn,
errorSchema, handleErrors))
+ .withOutputTags(OUTPUT_TAG,
TupleTagList.of(ERROR_TAG)));
+
+ outputTuple
+ .get(OUTPUT_TAG)
+ .apply(
+ KafkaIO.<byte[], byte[]>write()
+ .withTopic(configuration.getTopic())
+ .withBootstrapServers(configuration.getBootstrapServers())
+ .withProducerConfigUpdates(
+ configOverrides == null
+ ? new HashMap<>()
+ : new HashMap<String, Object>(configOverrides))
+ .withKeySerializer(ByteArraySerializer.class)
+ .withValueSerializer(ByteArraySerializer.class));
+ }
// TODO: include output from KafkaIO Write once updated from PDone
PCollection<Row> errorOutput =
diff --git
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java
index dffa6ece9d1..b63a9334239 100644
---
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java
+++
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java
@@ -24,9 +24,16 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
+import org.apache.avro.generic.GenericData;
+import org.apache.avro.generic.GenericRecord;
import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.extensions.avro.coders.AvroCoder;
+import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils;
import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
import
org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform.ErrorCounterFn;
+import
org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform.GenericRecordErrorCounterFn;
import org.apache.beam.sdk.managed.Managed;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
@@ -53,6 +60,8 @@ public class KafkaWriteSchemaTransformProviderTest {
private static final TupleTag<KV<byte[], byte[]>> OUTPUT_TAG =
KafkaWriteSchemaTransformProvider.OUTPUT_TAG;
+ private static final TupleTag<KV<byte[], GenericRecord>> RECORD_OUTPUT_TAG =
+ KafkaWriteSchemaTransformProvider.RECORD_OUTPUT_TAG;
private static final TupleTag<Row> ERROR_TAG =
KafkaWriteSchemaTransformProvider.ERROR_TAG;
private static final Schema BEAMSCHEMA =
@@ -126,7 +135,8 @@ public class KafkaWriteSchemaTransformProviderTest {
getClass().getResource("/proto_byte/file_descriptor/proto_byte_utils.pb"))
.getPath(),
"MyMessage");
-
+ final SerializableFunction<Row, GenericRecord> recordValueMapper =
+
AvroUtils.getRowToGenericRecordFunction(AvroUtils.toAvroSchema(BEAMSCHEMA));
@Rule public transient TestPipeline p = TestPipeline.create();
@Test
@@ -198,6 +208,38 @@ public class KafkaWriteSchemaTransformProviderTest {
+ " bool active = 3;\n"
+ "}";
+ @Test
+ public void testKafkaRecordErrorFnSuccess() throws Exception {
+ org.apache.avro.Schema avroSchema = AvroUtils.toAvroSchema(BEAMSCHEMA);
+
+ GenericRecord record1 = new GenericData.Record(avroSchema);
+ GenericRecord record2 = new GenericData.Record(avroSchema);
+ GenericRecord record3 = new GenericData.Record(avroSchema);
+ record1.put("name", "a");
+ record2.put("name", "b");
+ record3.put("name", "c");
+
+ List<KV<byte[], GenericRecord>> msg =
+ Arrays.asList(
+ KV.of(new byte[1], record1), KV.of(new byte[1], record2),
KV.of(new byte[1], record3));
+
+ PCollection<Row> input = p.apply(Create.of(ROWS));
+ Schema errorSchema = ErrorHandling.errorSchema(BEAMSCHEMA);
+ PCollectionTuple output =
+ input.apply(
+ ParDo.of(
+ new GenericRecordErrorCounterFn(
+ "Kafka-write-error-counter", recordValueMapper,
errorSchema, true))
+ .withOutputTags(RECORD_OUTPUT_TAG,
TupleTagList.of(ERROR_TAG)));
+
+ output.get(ERROR_TAG).setRowSchema(errorSchema);
+ output
+ .get(RECORD_OUTPUT_TAG)
+ .setCoder(KvCoder.of(ByteArrayCoder.of(), AvroCoder.of(avroSchema)));
+ PAssert.that(output.get(RECORD_OUTPUT_TAG)).containsInAnyOrder(msg);
+ p.run().waitUntilFinish();
+ }
+
@Test
public void testBuildTransformWithManaged() {
List<String> configs =