Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/5043#discussion_r152535711 --- Diff: flink-connectors/flink-orc/src/main/java/org/apache/flink/orc/OrcRowInputFormat.java --- @@ -0,0 +1,747 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.orc; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.io.FileInputFormat; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.core.fs.FileInputSplit; +import org.apache.flink.core.fs.Path; +import org.apache.flink.types.Row; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; + +import org.apache.hadoop.hive.ql.io.sarg.PredicateLeaf; +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; +import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.orc.OrcConf; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.orc.RecordReader; +import org.apache.orc.StripeInformation; +import org.apache.orc.TypeDescription; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.apache.flink.orc.OrcUtils.fillRows; + +/** + * InputFormat to read ORC files. + */ +public class OrcRowInputFormat extends FileInputFormat<Row> implements ResultTypeQueryable<Row> { + + private static final Logger LOG = LoggerFactory.getLogger(OrcRowInputFormat.class); + // the number of rows read in a batch + private static final int DEFAULT_BATCH_SIZE = 1000; + + // the number of fields rows to read in a batch + private int batchSize; + // the configuration to read with + private Configuration conf; + // the schema of the ORC files to read + private TypeDescription schema; + + // the fields of the ORC schema that the returned Rows are composed of. + private int[] selectedFields; + // the type information of the Rows returned by this InputFormat. + private transient RowTypeInfo rowType; + + // the ORC reader + private transient RecordReader orcRowsReader; + // the vectorized row data to be read in a batch + private transient VectorizedRowBatch rowBatch; + // the vector of rows that is read in a batch + private transient Row[] rows; + + // the number of rows in the current batch + private transient int rowsInBatch; + // the index of the next row to return + private transient int nextRow; + + private ArrayList<Predicate> conjunctPredicates = new ArrayList<>(); + + /** + * Creates an OrcRowInputFormat. + * + * @param path The path to read ORC files from. + * @param schemaString The schema of the ORC files as String. + * @param orcConfig The configuration to read the ORC files with. + */ + public OrcRowInputFormat(String path, String schemaString, Configuration orcConfig) { + this(path, TypeDescription.fromString(schemaString), orcConfig, DEFAULT_BATCH_SIZE); + } + + /** + * Creates an OrcRowInputFormat. + * + * @param path The path to read ORC files from. + * @param schemaString The schema of the ORC files as String. + * @param orcConfig The configuration to read the ORC files with. + * @param batchSize The number of Row objects to read in a batch. + */ + public OrcRowInputFormat(String path, String schemaString, Configuration orcConfig, int batchSize) { + this(path, TypeDescription.fromString(schemaString), orcConfig, batchSize); + } + + /** + * Creates an OrcRowInputFormat. + * + * @param path The path to read ORC files from. + * @param orcSchema The schema of the ORC files as ORC TypeDescription. + * @param orcConfig The configuration to read the ORC files with. + * @param batchSize The number of Row objects to read in a batch. + */ + public OrcRowInputFormat(String path, TypeDescription orcSchema, Configuration orcConfig, int batchSize) { + super(new Path(path)); + + // configure OrcInputFormat + this.schema = orcSchema; + this.rowType = (RowTypeInfo) OrcUtils.schemaToTypeInfo(schema); + this.conf = orcConfig; + this.batchSize = batchSize; + + // set default selection mask, i.e., all fields. + this.selectedFields = new int[this.schema.getChildren().size()]; + for (int i = 0; i < selectedFields.length; i++) { + this.selectedFields[i] = i; + } + } + + /** + * Adds a filter predicate to reduce the number of rows to be returned by the input format. + * Multiple conjunctive predicates can be added by calling this method multiple times. + * + * <p>Note: Predicates can significantly reduce the amount of data that is read. + * However, the OrcRowInputFormat does not guarantee that all returned rows qualitfy the + * predicates. Moreover, predicates are only applied if the referenced field is among the + * selected fields.</p> + * + * @param predicate The filter predicate. + */ + public void addPredicate(Predicate predicate) { + // validate + validatePredicate(predicate); + // add predicate + this.conjunctPredicates.add(predicate); + } + + private void validatePredicate(Predicate pred) { + if (pred instanceof ColumnPredicate) { + // check column name + String colName = ((ColumnPredicate) pred).columnName; + if (!this.schema.getFieldNames().contains(colName)) { + throw new IllegalArgumentException("Predicate cannot be applied. " + + "Column '" + colName + "' does not exist in ORC schema."); + } + } else if (pred instanceof Not) { + validatePredicate(((Not) pred).child()); + } else if (pred instanceof Or) { + for (Predicate p : ((Or) pred).children()) { + validatePredicate(p); + } + } + } + + /** + * Selects the fields from the ORC schema that are returned by InputFormat. + * + * @param selectedFields The indices of the fields of the ORC schema that are returned by the InputFormat. + */ + public void selectFields(int... selectedFields) { + // set field mapping + this.selectedFields = selectedFields; + // adapt result type + this.rowType = RowTypeInfo.projectFields(this.rowType, selectedFields); + } + + /** + * Computes the ORC projection mask of the fields to include from the selected fields.rowOrcInputFormat.nextRecord(null). + * + * @return The ORC projection mask. + */ + private boolean[] computeProjectionMask() { + // mask with all fields of the schema + boolean[] projectionMask = new boolean[schema.getMaximumId() + 1]; + // for each selected field + for (int inIdx : selectedFields) { + // set all nested fields of a selected field to true + TypeDescription fieldSchema = schema.getChildren().get(inIdx); + for (int i = fieldSchema.getId(); i <= fieldSchema.getMaximumId(); i++) { + projectionMask[i] = true; + } + } + return projectionMask; + } + + @Override + public void openInputFormat() throws IOException { + super.openInputFormat(); + // create and initialize the row batch + this.rows = new Row[batchSize]; + for (int i = 0; i < batchSize; i++) { + rows[i] = new Row(selectedFields.length); + } + } + + @Override + public void open(FileInputSplit fileSplit) throws IOException { + + LOG.debug("Opening ORC file {}", fileSplit.getPath()); + + // open ORC file and create reader + org.apache.hadoop.fs.Path hPath = new org.apache.hadoop.fs.Path(fileSplit.getPath().getPath()); + Reader orcReader = OrcFile.createReader(hPath, OrcFile.readerOptions(conf)); + + // get offset and length for the stripes that start in the split + Tuple2<Long, Long> offsetAndLength = getOffsetAndLengthForSplit(fileSplit, getStripes(orcReader)); + + // create ORC row reader configuration + Reader.Options options = getOptions(orcReader) + .schema(schema) + .range(offsetAndLength.f0, offsetAndLength.f1) + .useZeroCopy(OrcConf.USE_ZEROCOPY.getBoolean(conf)) + .skipCorruptRecords(OrcConf.SKIP_CORRUPT_DATA.getBoolean(conf)) + .tolerateMissingSchema(OrcConf.TOLERATE_MISSING_SCHEMA.getBoolean(conf)); + + // configure filters + if (!conjunctPredicates.isEmpty()) { + SearchArgument.Builder b = SearchArgumentFactory.newBuilder(); + b = b.startAnd(); + for (Predicate predicate : conjunctPredicates) { + predicate.add(b); + } + b = b.end(); + options.searchArgument(b.build(), new String[]{}); + } + + // configure selected fields + options.include(computeProjectionMask()); + + // create ORC row reader + this.orcRowsReader = orcReader.rows(options); + + // assign ids + this.schema.getId(); + // create row batch + this.rowBatch = schema.createRowBatch(batchSize); + rowsInBatch = 0; + nextRow = 0; + } + + @VisibleForTesting + Reader.Options getOptions(Reader orcReader) { + return orcReader.options(); + } + + @VisibleForTesting + List<StripeInformation> getStripes(Reader orcReader) { + return orcReader.getStripes(); + } + + private Tuple2<Long, Long> getOffsetAndLengthForSplit(FileInputSplit split, List<StripeInformation> stripes) { + long splitStart = split.getStart(); + long splitEnd = splitStart + split.getLength(); + + long readStart = Long.MAX_VALUE; + long readEnd = Long.MIN_VALUE; + + for (StripeInformation s : stripes) { + if (splitStart <= s.getOffset() && s.getOffset() < splitEnd) { + // stripe starts in split, so it is included + readStart = Math.min(readStart, s.getOffset()); + readEnd = Math.max(readEnd, s.getOffset() + s.getLength()); + } + } + + if (readStart < Long.MAX_VALUE) { + // at least one split is included + return Tuple2.of(readStart, readEnd - readStart); + } else { + return Tuple2.of(0L, 0L); + } + } + + @Override + public void close() throws IOException { + if (orcRowsReader != null) { + this.orcRowsReader.close(); + } + this.orcRowsReader = null; + } + + @Override + public void closeInputFormat() throws IOException { + this.rows = null; + this.rows = null; + this.schema = null; + this.rowBatch = null; + } + + @Override + public boolean reachedEnd() throws IOException { + return !ensureBatch(); + } + + /** + * Checks if there is at least one row left in the batch to return. + * If no more row are available, it reads another batch of rows. + * + * @return Returns true if there is one more row to return, false otherwise. + * @throws IOException throw if an exception happens while reading a batch. + */ + private boolean ensureBatch() throws IOException { + + if (nextRow >= rowsInBatch) { + // No more rows available in the Rows array. + nextRow = 0; + // Try to read the next batch if rows from the ORC file. + boolean moreRows = orcRowsReader.nextBatch(rowBatch); + + if (moreRows) { + // Load the data into the Rows array. + rowsInBatch = fillRows(rows, schema, rowBatch, selectedFields); + } + return moreRows; + } + // there is at least one Row left in the Rows array. + return true; + } + + @Override + public Row nextRecord(Row reuse) throws IOException { + // return the next row + return rows[this.nextRow++]; + } + + @Override + public TypeInformation<Row> getProducedType() { + return rowType; + } + + // -------------------------------------------------------------------------------------------- + // Custom serialization methods + // -------------------------------------------------------------------------------------------- + + private void writeObject(ObjectOutputStream out) throws IOException { + out.writeInt(batchSize); + this.conf.write(out); + out.writeUTF(schema.toString()); + + out.writeInt(selectedFields.length); + for (int f : selectedFields) { + out.writeInt(f); + } + + out.writeInt(conjunctPredicates.size()); + for (Predicate p : conjunctPredicates) { + out.writeObject(p); + } + } + + @SuppressWarnings("unchecked") + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + batchSize = in.readInt(); + org.apache.hadoop.conf.Configuration configuration = new org.apache.hadoop.conf.Configuration(); + configuration.readFields(in); + + if (this.conf == null) { + this.conf = configuration; + } + this.schema = TypeDescription.fromString(in.readUTF()); + + this.selectedFields = new int[in.readInt()]; + for (int i = 0; i < selectedFields.length; i++) { + this.selectedFields[i] = in.readInt(); + } + + this.conjunctPredicates = new ArrayList<>(); + int numPreds = in.readInt(); + for (int i = 0; i < numPreds; i++) { + conjunctPredicates.add((Predicate) in.readObject()); + } + } + + // -------------------------------------------------------------------------------------------- + // Classes to define predicates + // -------------------------------------------------------------------------------------------- + + /** + * A filter predicate that can be evaluated by the OrcRowInputFormat. + */ + public abstract static class Predicate implements Serializable { + protected abstract SearchArgument.Builder add(SearchArgument.Builder builder); + } + + abstract static class ColumnPredicate extends Predicate { + final String columnName; + final PredicateLeaf.Type literalType; + + ColumnPredicate(String columnName, PredicateLeaf.Type literalType) { + this.columnName = columnName; + this.literalType = literalType; + } + + Object castLiteral(Serializable literal) { + + switch (literalType) { + case LONG: + if (literal instanceof Byte) { + return new Long((Byte) literal); + } else if (literal instanceof Short) { + return new Long((Short) literal); + } else if (literal instanceof Integer) { + return new Long((Integer) literal); + } else if (literal instanceof Long) { + return literal; + } else { + throw new IllegalArgumentException("A predicate on a LONG column requires an integer " + + "literal, i.e., Byte, Short, Integer, or Long."); + } + case FLOAT: + if (literal instanceof Float) { + return new Double((Float) literal); + } else if (literal instanceof Double) { + return literal; + } else if (literal instanceof BigDecimal) { + return ((BigDecimal) literal).doubleValue(); + } else { + throw new IllegalArgumentException("A predicate on a FLOAT column requires a floating " + + "literal, i.e., Float or Double."); + } + case STRING: + if (literal instanceof String) { + return literal; + } else { + throw new IllegalArgumentException("A predicate on a STRING column requires a floating " + + "literal, i.e., Float or Double."); + } + case BOOLEAN: + if (literal instanceof Boolean) { + return literal; + } else { + throw new IllegalArgumentException("A predicate on a BOOLEAN column requires a Boolean literal."); + } + case DATE: + if (literal instanceof Date) { + return literal; + } else { + throw new IllegalArgumentException("A predicate on a DATE column requires a java.sql.Date literal."); + } + case TIMESTAMP: + if (literal instanceof Timestamp) { + return literal; + } else { + throw new IllegalArgumentException("A predicate on a TIMESTAMP column requires a java.sql.Timestamp literal."); + } + case DECIMAL: + if (literal instanceof BigDecimal) { + return new HiveDecimalWritable(HiveDecimal.create((BigDecimal) literal)); + } else { + throw new IllegalArgumentException("A predicate on a DECIMAL column requires a BigDecimal literal."); + } + default: + throw new IllegalArgumentException("Unknown literal type " + literalType); + } + } + } + + abstract static class BinaryPredicate extends ColumnPredicate { + final Serializable literal; + + BinaryPredicate(String columnName, PredicateLeaf.Type literalType, Serializable literal) { + super(columnName, literalType); + this.literal = literal; + } + } + + /** + * An EQUALS predicate that can be evaluated by the OrcRowInputFormat. + */ + public static class Equals extends BinaryPredicate { + /** + * Creates an EQUALS predicate. + * + * @param columnName The column to check. + * @param literalType The type of the literal. + * @param literal The literal value to check the column against. + */ + public Equals(String columnName, PredicateLeaf.Type literalType, Serializable literal) { + super(columnName, literalType, literal); + } + + @Override + protected SearchArgument.Builder add(SearchArgument.Builder builder) { + return builder.equals(columnName, literalType, castLiteral(literal)); + } + + @Override + public String toString() { + return columnName + " = " + literal; + } + } + + /** + * An EQUALS predicate that can be evaluated with Null safety by the OrcRowInputFormat. + */ + public static class NullSafeEquals extends BinaryPredicate { + /** + * Creates a null-safe EQUALS predicate. + * + * @param columnName The column to check. + * @param literalType The type of the literal. + * @param literal The literal value to check the column against. + */ + public NullSafeEquals(String columnName, PredicateLeaf.Type literalType, Serializable literal) { --- End diff -- Yes, but it's an interface class for users that are using the `OrcRowInputFormat` directly.
---