Hello 👋 At my company we have a Flink SQL job that necessitates the `array_agg` function that was introduced in version 1.20 (this is the only documentation I could find: <https://nightlies.apache.org/flink/flink- docs-release-1.20/docs/dev/table/functions/systemfunctions/#aggregate- functions>). Due to constraints with some connectors, we cannot upgrade to 1.20 at. We have been looking to backport this function to 1.18.
Do you have any general advice about this? What we've done is we have essentially copied the code here into a Jar: <https://github.com/apache/flink/blob/2d17f6148796c30890d38c830f772f8f38bfd495/flink- table/flink-table- runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java> There are some slight differences in our version (AFAIU we cannot use the same constructor). I've attached our edited function and the diff with upstream. The problem that we face is an exception thrown upon receiving input (the job is submitted successfully): java.lang.ClassCastException: class org.apache.flink.types.Row cannot be cast to class org.apache.flink.table.data.RowData (org.apache.flink.types.Row and org.apache.flink.table.data.RowData are in unnamed module of loader 'app') Can anybody help us backport this function? Any input is greatly appreciated 🙏 Best, Daniele
2,17c2,3 < * Licensed to the Apache Software Foundation (ASF) under one < * or more contributor license agreements. See the NOTICE file < * distributed with this work for additional information < * regarding copyright ownership. The ASF licenses this file < * to you under the Apache License, Version 2.0 (the < * "License"); you may not use this file except in compliance < * with the License. You may obtain a copy of the License at < * < * http://www.apache.org/licenses/LICENSE-2.0 < * < * Unless required by applicable law or agreed to in writing, software < * distributed under the License is distributed on an "AS IS" BASIS, < * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. < * See the License for the specific language governing permissions and < * limitations under the License. < */ --- > This is Flink's `array_agg` built-in function from version 1.20. > We backported it to Flink 1.18. 19c5,19 < package org.apache.flink.table.runtime.functions.aggregate; --- > https://nightlies.apache.org/flink/flink-docs-release-1.20/docs/dev/table/functions/systemfunctions/#collection-functions:~:text=2%203%204-,ARRAY_AGG,-(%5B%20ALL%20%7C%20DISTINCT%20%5D%20expression > https://github.com/apache/flink/blob/d2c241a3f84a32c98b3848825b999c1b24860455/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java#L41 > > There are some slight differences from the 1.20 version: > - type inference was changed (we can't rely on the planner directly, like > `array_agg` does) > - if no items are accumulated, this returns an empty array instead of `null` > - this is the signature accepted by array_agg: > `ARRAY_AGG([ ALL | DISTINCT ] expression [ RESPECT NULLS | IGNORE NULLS ])` > we accept only a subset: > `ARRAY_AGG(expression)` > - `null`s in the input are ignored > */ > > > package eu.spaziodati.cp.flink.funcs; 23a24 > import org.apache.flink.table.catalog.DataTypeFactory; 26a28 > import org.apache.flink.table.runtime.typeutils.ExternalSerializer; 28a31,32 > import org.apache.flink.table.types.inference.InputTypeStrategies; > import org.apache.flink.table.types.inference.TypeInference; 30a35,36 > import > org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction; > import org.apache.flink.table.types.logical.RowType; 35a42 > import java.util.Optional; 40d46 < @Internal 46,47d51 < private final transient DataType elementDataType; < 49a54,57 > public ArrayAggFunction() { > this.ignoreNulls = true; > } > 51d58 < this.elementDataType = toInternalDataType(elementType); 60,82c67,88 < public List<DataType> getArgumentDataTypes() { < return Collections.singletonList(elementDataType); < } < < @Override < public DataType getAccumulatorDataType() { < DataType linkedListType = getLinkedListType(); < return DataTypes.STRUCTURED( < ArrayAggAccumulator.class, < DataTypes.FIELD("list", linkedListType), < DataTypes.FIELD("retractList", linkedListType)); < } < < @Override < public DataType getOutputDataType() { < return DataTypes.ARRAY(elementDataType).bridgedTo(ArrayData.class); < } < < @SuppressWarnings({"unchecked", "rawtypes"}) < private DataType getLinkedListType() { < TypeSerializer<T> serializer = InternalSerializers.create(elementDataType.getLogicalType()); < return DataTypes.RAW( < LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer)); --- > public TypeInference getTypeInference(DataTypeFactory typeFactory) { > return TypeInference.newBuilder() > > .inputTypeStrategy(InputTypeStrategies.sequence(InputTypeStrategies.ANY)) > // let the accumulator data type depend on the first input > argument > .accumulatorTypeStrategy( > callContext -> { > final DataType argDataType = > callContext.getArgumentDataTypes().get(0); > TypeSerializer<T> serializer = > InternalSerializers.create(toInternalDataType(argDataType).getLogicalType()); > DataType linkedListType = DataTypes.RAW( > LinkedList.class, (TypeSerializer) new > LinkedListSerializer<>(serializer)); > return Optional.of(DataTypes.STRUCTURED( > ArrayAggAccumulator.class, > DataTypes.FIELD("list", linkedListType), > DataTypes.FIELD("retractList", > linkedListType))); > }) > // let the output data type depend on the first input argument > .outputTypeStrategy( > callContext -> { > final DataType argDataType = > callContext.getArgumentDataTypes().get(0); > return > Optional.of(DataTypes.ARRAY(argDataType).bridgedTo(ArrayData.class)); > }) > .build(); 160c166 < if (accList == null || accList.isEmpty()) { --- > if (accList == null) {
java.lang.ClassCastException: class org.apache.flink.types.Row cannot be cast to class org.apache.flink.table.data.RowData (org.apache.flink.types.Row and org.apache.flink.table.data.RowData are in unnamed module of loader 'app') at org.apache.flink.table.runtime.typeutils.RowDataSerializer.serialize(RowDataSerializer.java:48) at org.apache.flink.table.runtime.typeutils.LinkedListSerializer.serialize(LinkedListSerializer.java:137) at org.apache.flink.table.runtime.typeutils.LinkedListSerializer.serialize(LinkedListSerializer.java:40) at org.apache.flink.util.InstantiationUtil.serializeToByteArray(InstantiationUtil.java:498) at org.apache.flink.table.data.binary.BinaryRawValueData.materialize(BinaryRawValueData.java:113) at org.apache.flink.table.data.binary.LazyBinaryFormat.ensureMaterialized(LazyBinaryFormat.java:126) at org.apache.flink.table.data.writer.AbstractBinaryWriter.writeRawValue(AbstractBinaryWriter.java:135) at org.apache.flink.table.data.writer.BinaryRowWriter.writeRawValue(BinaryRowWriter.java:27) at org.apache.flink.table.data.writer.BinaryWriter.write(BinaryWriter.java:158) at org.apache.flink.table.runtime.typeutils.RowDataSerializer.toBinaryRow(RowDataSerializer.java:204) at org.apache.flink.table.data.writer.AbstractBinaryWriter.writeRow(AbstractBinaryWriter.java:147) at org.apache.flink.table.data.writer.BinaryRowWriter.writeRow(BinaryRowWriter.java:27) at org.apache.flink.table.data.writer.BinaryWriter.write(BinaryWriter.java:155) at org.apache.flink.table.runtime.typeutils.RowDataSerializer.toBinaryRow(RowDataSerializer.java:204) at org.apache.flink.table.runtime.typeutils.RowDataSerializer.serialize(RowDataSerializer.java:103) at org.apache.flink.table.runtime.typeutils.RowDataSerializer.serialize(RowDataSerializer.java:48) at org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer.serialize(StreamElementSerializer.java:165) at org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer.serialize(StreamElementSerializer.java:43) at org.apache.flink.runtime.plugable.SerializationDelegate.write(SerializationDelegate.java:54) at org.apache.flink.runtime.io.network.api.writer.RecordWriter.serializeRecord(RecordWriter.java:141) at org.apache.flink.runtime.io.network.api.writer.RecordWriter.emit(RecordWriter.java:107) at org.apache.flink.runtime.io.network.api.writer.ChannelSelectorRecordWriter.emit(ChannelSelectorRecordWriter.java:55) at org.apache.flink.streaming.runtime.io.RecordWriterOutput.pushToRecordWriter(RecordWriterOutput.java:134) at org.apache.flink.streaming.runtime.io.RecordWriterOutput.collectAndCheckIfChained(RecordWriterOutput.java:114) at org.apache.flink.streaming.runtime.io.RecordWriterOutput.collect(RecordWriterOutput.java:95) at org.apache.flink.streaming.runtime.io.RecordWriterOutput.collect(RecordWriterOutput.java:48) at org.apache.flink.streaming.api.operators.CountingOutput.collect(CountingOutput.java:59) at org.apache.flink.streaming.api.operators.CountingOutput.collect(CountingOutput.java:31) at org.apache.flink.table.runtime.util.StreamRecordCollector.collect(StreamRecordCollector.java:44) at org.apache.flink.table.runtime.operators.aggregate.MiniBatchLocalGroupAggFunction.finishBundle(MiniBatchLocalGroupAggFunction.java:90) at org.apache.flink.table.runtime.operators.bundle.AbstractMapBundleOperator.finishBundle(AbstractMapBundleOperator.java:136) at org.apache.flink.table.runtime.operators.bundle.AbstractMapBundleOperator.processWatermark(AbstractMapBundleOperator.java:144) at org.apache.flink.streaming.runtime.tasks.ChainingOutput.emitWatermark(ChainingOutput.java:121) at org.apache.flink.streaming.api.operators.AbstractStreamOperator.processWatermark(AbstractStreamOperator.java:604) at org.apache.flink.table.runtime.operators.TableStreamOperator.processWatermark(TableStreamOperator.java:57) at org.apache.flink.streaming.runtime.tasks.ChainingOutput.emitWatermark(ChainingOutput.java:121) at org.apache.flink.streaming.api.operators.AbstractStreamOperator.processWatermark(AbstractStreamOperator.java:604) at org.apache.flink.streaming.api.operators.AbstractStreamOperator.processWatermark(AbstractStreamOperator.java:609) at org.apache.flink.streaming.api.operators.AbstractStreamOperator.processWatermark1(AbstractStreamOperator.java:614) at org.apache.flink.streaming.runtime.io.StreamTwoInputProcessorFactory$StreamTaskNetworkOutput.emitWatermark(StreamTwoInputProcessorFactory.java:262) at org.apache.flink.streaming.runtime.watermarkstatus.StatusWatermarkValve.findAndOutputNewMinWatermarkAcrossAlignedChannels(StatusWatermarkValve.java:200) at org.apache.flink.streaming.runtime.watermarkstatus.StatusWatermarkValve.inputWatermark(StatusWatermarkValve.java:115) at org.apache.flink.streaming.runtime.io.AbstractStreamTaskNetworkInput.processElement(AbstractStreamTaskNetworkInput.java:148) at org.apache.flink.streaming.runtime.io.AbstractStreamTaskNetworkInput.emitNext(AbstractStreamTaskNetworkInput.java:110) at org.apache.flink.streaming.runtime.io.StreamOneInputProcessor.processInput(StreamOneInputProcessor.java:65) at org.apache.flink.streaming.runtime.io.StreamMultipleInputProcessor.processInput(StreamMultipleInputProcessor.java:85) at org.apache.flink.streaming.runtime.tasks.StreamTask.processInput(StreamTask.java:562) at org.apache.flink.streaming.runtime.tasks.mailbox.MailboxProcessor.runMailboxLoop(MailboxProcessor.java:231) at org.apache.flink.streaming.runtime.tasks.StreamTask.runMailboxLoop(StreamTask.java:858) at org.apache.flink.streaming.runtime.tasks.StreamTask.invoke(StreamTask.java:807) at org.apache.flink.runtime.taskmanager.Task.runWithSystemExitMonitoring(Task.java:953) at org.apache.flink.runtime.taskmanager.Task.restoreAndInvoke(Task.java:932) at org.apache.flink.runtime.taskmanager.Task.doRun(Task.java:746) at org.apache.flink.runtime.taskmanager.Task.run(Task.java:562) at java.base/java.lang.Thread.run(Unknown Source)
/* This is Flink's `array_agg` built-in function from version 1.20. We backported it to Flink 1.18. https://nightlies.apache.org/flink/flink-docs-release-1.20/docs/dev/table/functions/systemfunctions/#collection-functions:~:text=2%203%204-,ARRAY_AGG,-(%5B%20ALL%20%7C%20DISTINCT%20%5D%20expression https://github.com/apache/flink/blob/d2c241a3f84a32c98b3848825b999c1b24860455/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java#L41 There are some slight differences from the 1.20 version: - type inference was changed (we can't rely on the planner directly, like `array_agg` does) - if no items are accumulated, this returns an empty array instead of `null` - this is the signature accepted by array_agg: `ARRAY_AGG([ ALL | DISTINCT ] expression [ RESPECT NULLS | IGNORE NULLS ])` we accept only a subset: `ARRAY_AGG(expression)` - `null`s in the input are ignored */ package eu.spaziodati.cp.flink.funcs; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.catalog.DataTypeFactory; import org.apache.flink.table.data.ArrayData; import org.apache.flink.table.data.GenericArrayData; import org.apache.flink.table.runtime.typeutils.InternalSerializers; import org.apache.flink.table.runtime.typeutils.ExternalSerializer; import org.apache.flink.table.runtime.typeutils.LinkedListSerializer; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.inference.InputTypeStrategies; import org.apache.flink.table.types.inference.TypeInference; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.util.FlinkRuntimeException; import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction; import org.apache.flink.table.types.logical.RowType; import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Objects; import java.util.Optional; import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType; /** Built-in ARRAY_AGG aggregate function. */ public final class ArrayAggFunction<T> extends BuiltInAggregateFunction<ArrayData, ArrayAggFunction.ArrayAggAccumulator<T>> { private static final long serialVersionUID = -5860934997657147836L; private final boolean ignoreNulls; public ArrayAggFunction() { this.ignoreNulls = true; } public ArrayAggFunction(LogicalType elementType, boolean ignoreNulls) { this.ignoreNulls = ignoreNulls; } // -------------------------------------------------------------------------------------------- // Planning // -------------------------------------------------------------------------------------------- @Override public TypeInference getTypeInference(DataTypeFactory typeFactory) { return TypeInference.newBuilder() .inputTypeStrategy(InputTypeStrategies.sequence(InputTypeStrategies.ANY)) // let the accumulator data type depend on the first input argument .accumulatorTypeStrategy( callContext -> { final DataType argDataType = callContext.getArgumentDataTypes().get(0); TypeSerializer<T> serializer = InternalSerializers.create(toInternalDataType(argDataType).getLogicalType()); DataType linkedListType = DataTypes.RAW( LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer)); return Optional.of(DataTypes.STRUCTURED( ArrayAggAccumulator.class, DataTypes.FIELD("list", linkedListType), DataTypes.FIELD("retractList", linkedListType))); }) // let the output data type depend on the first input argument .outputTypeStrategy( callContext -> { final DataType argDataType = callContext.getArgumentDataTypes().get(0); return Optional.of(DataTypes.ARRAY(argDataType).bridgedTo(ArrayData.class)); }) .build(); } // -------------------------------------------------------------------------------------------- // Runtime // -------------------------------------------------------------------------------------------- /** Accumulator for ARRAY_AGG with retraction. */ public static class ArrayAggAccumulator<T> { public LinkedList<T> list; public LinkedList<T> retractList; @Override public boolean equals(Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } ArrayAggAccumulator<?> that = (ArrayAggAccumulator<?>) o; return Objects.equals(list, that.list) && Objects.equals(retractList, that.retractList); } @Override public int hashCode() { return Objects.hash(list, retractList); } } @Override public ArrayAggAccumulator<T> createAccumulator() { final ArrayAggAccumulator<T> acc = new ArrayAggAccumulator<>(); acc.list = new LinkedList<>(); acc.retractList = new LinkedList<>(); return acc; } public void accumulate(ArrayAggAccumulator<T> acc, T value) throws Exception { if (value != null || !ignoreNulls) { acc.list.add(value); } } public void retract(ArrayAggAccumulator<T> acc, T value) throws Exception { if (value != null || !ignoreNulls) { if (!acc.list.remove(value)) { acc.retractList.add(value); } } } public void merge(ArrayAggAccumulator<T> acc, Iterable<ArrayAggAccumulator<T>> its) throws Exception { List<T> newRetractBuffer = new LinkedList<>(); for (ArrayAggAccumulator<T> otherAcc : its) { if (!otherAcc.list.iterator().hasNext() && !otherAcc.retractList.iterator().hasNext()) { // otherAcc is empty, skip it continue; } acc.list.addAll(otherAcc.list); acc.retractList.addAll(otherAcc.retractList); } for (T element : acc.retractList) { if (!acc.list.remove(element)) { newRetractBuffer.add(element); } } acc.retractList.clear(); acc.retractList.addAll(newRetractBuffer); } @Override public ArrayData getValue(ArrayAggAccumulator<T> acc) { try { List<T> accList = acc.list; if (accList == null) { // array_agg returns null rather than an empty array when there are no input rows. return null; } else { return new GenericArrayData(accList.toArray()); } } catch (Exception e) { throw new FlinkRuntimeException(e); } } public void resetAccumulator(ArrayAggAccumulator<T> acc) { acc.list.clear(); acc.retractList.clear(); } }