the-other-tim-brown commented on code in PR #6135:
URL: https://github.com/apache/hudi/pull/6135#discussion_r957524622


##########
hudi-utilities/src/main/java/org/apache/hudi/utilities/sources/helpers/ProtoConversionUtil.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hudi.utilities.sources.helpers;
+
+import org.apache.hudi.common.util.collection.Pair;
+import org.apache.hudi.exception.HoodieException;
+
+import com.google.protobuf.BoolValue;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
+import com.google.protobuf.DescriptorProtos;
+import com.google.protobuf.Descriptors;
+import com.google.protobuf.DoubleValue;
+import com.google.protobuf.FloatValue;
+import com.google.protobuf.Int32Value;
+import com.google.protobuf.Int64Value;
+import com.google.protobuf.Message;
+import com.google.protobuf.StringValue;
+import com.google.protobuf.UInt32Value;
+import com.google.protobuf.UInt64Value;
+import org.apache.avro.Schema;
+import org.apache.avro.generic.GenericData;
+import org.apache.avro.generic.GenericFixed;
+import org.apache.avro.generic.GenericRecord;
+import org.apache.avro.util.Utf8;
+
+import java.io.File;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Function;
+
+/**
+ * A utility class to help translate from Proto to Avro.
+ */
+public class ProtoConversionUtil {
+
+  /**
+   * Creates an Avro {@link Schema} for the provided class. Assumes that the 
class is a protobuf {@link Message}.
+   * @param clazz The protobuf class
+   * @param flattenWrappedPrimitives set to true to treat wrapped primitives 
like nullable fields instead of nested messages.
+   * @return An Avro schema
+   */
+  public static Schema getAvroSchemaForMessageClass(Class clazz, boolean 
flattenWrappedPrimitives) {
+    return AvroSupport.get().getSchema(clazz, flattenWrappedPrimitives);
+  }
+
+  /**
+   * Converts the provided {@link Message} into an avro {@link GenericRecord} 
with the provided schema.
+   * @param schema target schema to convert into
+   * @param message the source message to convert
+   * @return an Avro GenericRecord
+   */
+  public static GenericRecord convertToAvro(Schema schema, Message message) {
+    return AvroSupport.get().convert(schema, message);
+  }
+
+  /**
+   * This class provides support for generating schemas and converting from 
proto to avro. We don't directly use Avro's ProtobufData class so we can:
+   * 1. Customize how schemas are generated for protobufs. We treat Enums as 
strings and provide an option to treat wrapped primitives like {@link 
Int32Value} and {@link StringValue} as messages
+   * (default behavior) or as nullable versions of those primitives.
+   * 2. Convert directly from a protobuf {@link Message} to a {@link 
GenericRecord} while properly handling enums and wrapped primitives mentioned 
above.
+   */
+  private static class AvroSupport {
+    private static final AvroSupport INSTANCE = new AvroSupport();
+    private static final Map<Pair<Class, Boolean>, Schema> SCHEMA_CACHE = new 
ConcurrentHashMap<>();
+    private static final Map<Pair<Schema, Descriptors.Descriptor>, 
Descriptors.FieldDescriptor[]> FIELD_CACHE = new ConcurrentHashMap<>();
+
+
+    private static final Schema STRINGS = Schema.create(Schema.Type.STRING);
+
+    private static final Schema NULL = Schema.create(Schema.Type.NULL);
+    private static final Map<Descriptors.Descriptor, Schema.Type> 
WRAPPER_DESCRIPTORS_TO_TYPE = getWrapperDescriptorsToType();
+
+    private static Map<Descriptors.Descriptor, Schema.Type> 
getWrapperDescriptorsToType() {
+      Map<Descriptors.Descriptor, Schema.Type> wrapperDescriptorsToType = new 
HashMap<>();
+      wrapperDescriptorsToType.put(StringValue.getDescriptor(), 
Schema.Type.STRING);
+      wrapperDescriptorsToType.put(Int32Value.getDescriptor(), 
Schema.Type.INT);
+      wrapperDescriptorsToType.put(UInt32Value.getDescriptor(), 
Schema.Type.INT);
+      wrapperDescriptorsToType.put(Int64Value.getDescriptor(), 
Schema.Type.LONG);
+      wrapperDescriptorsToType.put(UInt64Value.getDescriptor(), 
Schema.Type.LONG);
+      wrapperDescriptorsToType.put(BoolValue.getDescriptor(), 
Schema.Type.BOOLEAN);
+      wrapperDescriptorsToType.put(BytesValue.getDescriptor(), 
Schema.Type.BYTES);
+      wrapperDescriptorsToType.put(DoubleValue.getDescriptor(), 
Schema.Type.DOUBLE);
+      wrapperDescriptorsToType.put(FloatValue.getDescriptor(), 
Schema.Type.FLOAT);
+      return wrapperDescriptorsToType;
+    }
+
+    private AvroSupport() {
+    }
+
+    public static AvroSupport get() {
+      return INSTANCE;
+    }
+
+    public GenericRecord convert(Schema schema, Message message) {
+      return (GenericRecord) convertObject(schema, message);
+    }
+
+    public Schema getSchema(Class c, boolean flattenWrappedPrimitives) {
+      return SCHEMA_CACHE.computeIfAbsent(Pair.of(c, 
flattenWrappedPrimitives), key -> {
+        try {
+          Object descriptor = c.getMethod("getDescriptor").invoke(null);
+          if (c.isEnum()) {
+            return getEnumSchema((Descriptors.EnumDescriptor) descriptor);
+          } else {
+            return getMessageSchema((Descriptors.Descriptor) descriptor, new 
HashMap<>(), flattenWrappedPrimitives);
+          }
+        } catch (Exception e) {
+          throw new RuntimeException(e);
+        }
+      });
+    }
+
+    private Schema getEnumSchema(Descriptors.EnumDescriptor d) {
+      List<String> symbols = new ArrayList<>(d.getValues().size());
+      for (Descriptors.EnumValueDescriptor e : d.getValues()) {
+        symbols.add(e.getName());
+      }
+      return Schema.createEnum(d.getName(), null, getNamespace(d.getFile(), 
d.getContainingType()), symbols);
+    }
+
+    private Schema getMessageSchema(Descriptors.Descriptor descriptor, 
Map<Descriptors.Descriptor, Schema> seen, boolean flattenWrappedPrimitives) {
+      if (seen.containsKey(descriptor)) {
+        return seen.get(descriptor);
+      }
+      Schema result = Schema.createRecord(descriptor.getName(), null,
+          getNamespace(descriptor.getFile(), descriptor.getContainingType()), 
false);
+
+      seen.put(descriptor, result);
+
+      List<Schema.Field> fields = new 
ArrayList<>(descriptor.getFields().size());
+      for (Descriptors.FieldDescriptor f : descriptor.getFields()) {
+        fields.add(new Schema.Field(f.getName(), getFieldSchema(f, seen, 
flattenWrappedPrimitives), null, getDefault(f)));
+      }
+      result.setFields(fields);
+      return result;
+    }
+
+    private Schema getFieldSchema(Descriptors.FieldDescriptor f, 
Map<Descriptors.Descriptor, Schema> seen, boolean flattenWrappedPrimitives) {
+      Function<Schema, Schema> schemaFinalizer =  f.isRepeated() ? 
Schema::createArray : Function.identity();
+      switch (f.getType()) {
+        case BOOL:
+          return schemaFinalizer.apply(Schema.create(Schema.Type.BOOLEAN));
+        case FLOAT:
+          return schemaFinalizer.apply(Schema.create(Schema.Type.FLOAT));
+        case DOUBLE:
+          return schemaFinalizer.apply(Schema.create(Schema.Type.DOUBLE));
+        case ENUM:
+          return schemaFinalizer.apply(getEnumSchema(f.getEnumType()));
+        case STRING:
+          Schema s = Schema.create(Schema.Type.STRING);
+          GenericData.setStringType(s, GenericData.StringType.String);
+          return schemaFinalizer.apply(s);
+        case BYTES:
+          return schemaFinalizer.apply(Schema.create(Schema.Type.BYTES));
+        case INT32:
+        case SINT32:
+        case FIXED32:
+        case SFIXED32:
+          return schemaFinalizer.apply(Schema.create(Schema.Type.INT));
+        case UINT32:
+        case INT64:
+        case UINT64:
+        case SINT64:
+        case FIXED64:
+        case SFIXED64:
+          return schemaFinalizer.apply(Schema.create(Schema.Type.LONG));
+        case MESSAGE:
+          if (flattenWrappedPrimitives && 
WRAPPER_DESCRIPTORS_TO_TYPE.containsKey(f.getMessageType())) {
+            // all wrapper types have a single field, so we can get the first 
field in the message's schema
+            return 
schemaFinalizer.apply(Schema.createUnion(Arrays.asList(NULL, 
getFieldSchema(f.getMessageType().getFields().get(0), seen, 
flattenWrappedPrimitives))));
+          }
+          // if message field is repeated (like a list), elements are non-null
+          if (f.isRepeated()) {
+            return schemaFinalizer.apply(getMessageSchema(f.getMessageType(), 
seen, flattenWrappedPrimitives));
+          }
+          // otherwise we create a nullable field schema
+          return schemaFinalizer.apply(Schema.createUnion(Arrays.asList(NULL, 
getMessageSchema(f.getMessageType(), seen, flattenWrappedPrimitives))));
+        case GROUP: // groups are deprecated
+        default:
+          throw new RuntimeException("Unexpected type: " + f.getType());
+      }
+    }
+
+    private Object getDefault(Descriptors.FieldDescriptor f) {
+      if (f.isRepeated()) { // empty array as repeated fields' default value
+        return Collections.emptyList();
+      }
+
+      switch (f.getType()) { // generate default for type
+        case BOOL:
+          return false;
+        case FLOAT:
+          return 0.0F;
+        case DOUBLE:
+          return 0.0D;
+        case INT32:
+        case UINT32:
+        case SINT32:
+        case FIXED32:
+        case SFIXED32:
+        case INT64:
+        case UINT64:
+        case SINT64:
+        case FIXED64:
+        case SFIXED64:
+          return 0;
+        case STRING:
+        case BYTES:
+          return "";
+        case ENUM:
+          return f.getEnumType().getValues().get(0).getName();
+        case MESSAGE:
+          return Schema.Field.NULL_VALUE;
+        case GROUP: // groups are deprecated
+        default:
+          throw new RuntimeException("Unexpected type: " + f.getType());
+      }
+    }
+
+    private Descriptors.FieldDescriptor[] getOrderedFields(Schema schema, 
Message message) {
+      Descriptors.Descriptor descriptor = message.getDescriptorForType();
+      return FIELD_CACHE.computeIfAbsent(Pair.of(schema, descriptor), key -> {
+        Descriptors.FieldDescriptor[] fields = new 
Descriptors.FieldDescriptor[key.getLeft().getFields().size()];
+        for (Schema.Field f : key.getLeft().getFields()) {
+          fields[f.pos()] = key.getRight().findFieldByName(f.name());
+        }
+        return fields;
+      });
+    }
+
+    private Object convertObject(Schema schema, Object value) {
+      if (value == null) {
+        return null;
+      }
+
+      switch (schema.getType()) {
+        case ARRAY:
+          List<Object> arrayValue = (List<Object>) value;
+          List<Object> arrayCopy = new GenericData.Array<>(arrayValue.size(), 
schema);
+          for (Object obj : arrayValue) {
+            arrayCopy.add(convertObject(schema.getElementType(), obj));
+          }
+          return arrayCopy;
+        case BYTES:
+          ByteBuffer byteBufferValue;
+          if (value instanceof ByteString) {
+            byteBufferValue = ((ByteString) value).asReadOnlyByteBuffer();
+          } else if (value instanceof Message) {
+            byteBufferValue = ((ByteString) 
getWrappedValue(value)).asReadOnlyByteBuffer();
+          } else {
+            byteBufferValue = (ByteBuffer) value;
+          }
+          int start = byteBufferValue.position();
+          int length = byteBufferValue.limit() - start;
+          byte[] bytesCopy = new byte[length];
+          byteBufferValue.get(bytesCopy, 0, length);
+          byteBufferValue.position(start);
+          return ByteBuffer.wrap(bytesCopy, 0, length);
+        case ENUM:
+          return GenericData.get().createEnum(value.toString(), schema);
+        case FIXED:
+          return GenericData.get().createFixed(null, ((GenericFixed) 
value).bytes(), schema);
+        case BOOLEAN:
+        case DOUBLE:
+        case FLOAT:
+        case INT:
+          if (value instanceof Message) {
+            return getWrappedValue(value);
+          }
+          return value; // immutable
+        case LONG:
+          Object tmpValue = value;
+          if (value instanceof Message) {
+            tmpValue = getWrappedValue(value);
+          }
+          // unsigned ints need to be casted to long
+          if (tmpValue instanceof Integer) {
+            tmpValue = new Long((Integer) tmpValue);
+          }
+          return tmpValue;
+        case MAP:
+          Map<Object, Object> mapValue = (Map) value;
+          Map<Object, Object> mapCopy = new HashMap<>(mapValue.size());
+          for (Map.Entry<Object, Object> entry : mapValue.entrySet()) {
+            mapCopy.put(convertObject(STRINGS, entry.getKey()), 
convertObject(schema.getValueType(), entry.getValue()));
+          }
+          return mapCopy;
+        case NULL:
+          return null;
+        case RECORD:
+          GenericData.Record newRecord = new GenericData.Record(schema);
+          Message messageValue = (Message) value;
+          for (Schema.Field f : schema.getFields()) {
+            int position = f.pos();
+            Descriptors.FieldDescriptor fieldDescriptor = 
getOrderedFields(schema, messageValue)[position];
+            Object convertedValue;
+            if (fieldDescriptor.getType() == 
Descriptors.FieldDescriptor.Type.MESSAGE && !fieldDescriptor.isRepeated() && 
!messageValue.hasField(fieldDescriptor)) {
+              convertedValue = null;
+            } else {
+              convertedValue = convertObject(f.schema(), 
messageValue.getField(fieldDescriptor));
+            }
+            newRecord.put(position, convertedValue);
+          }
+          return newRecord;
+        case STRING:
+          if (value instanceof String) {
+            return value;
+          } else if (value instanceof StringValue) {
+            return ((StringValue) value).getValue();
+          } else {
+            return new Utf8(value.toString());
+          }
+        case UNION:
+          // Unions only occur for nullable fields when working with proto + 
avro and null is the first schema in the union
+          return convertObject(schema.getTypes().get(1), value);
+        default:
+          throw new HoodieException("Proto to Avro conversion failed for 
schema \"" + schema + "\" and value \"" + value + "\"");
+      }
+    }
+
+    /**
+     * Returns the wrapped field, assumes all wrapped fields have a single 
value
+     * @param value wrapper message like {@link Int32Value} or {@link 
StringValue}
+     * @return the wrapped object
+     */
+    private Object getWrappedValue(Object value) {
+      Message valueAsMessage = (Message) value;
+      return 
valueAsMessage.getField(valueAsMessage.getDescriptorForType().getFields().get(0));
+    }
+
+    private String getNamespace(Descriptors.FileDescriptor fd, 
Descriptors.Descriptor containing) {
+      DescriptorProtos.FileOptions filedOptions = fd.getOptions();
+      String classPackage = filedOptions.hasJavaPackage() ? 
filedOptions.getJavaPackage() : fd.getPackage();
+      String outer = "";
+      if (!filedOptions.getJavaMultipleFiles()) {
+        if (filedOptions.hasJavaOuterClassname()) {
+          outer = filedOptions.getJavaOuterClassname();
+        } else {
+          outer = new File(fd.getName()).getName();
+          outer = outer.substring(0, outer.lastIndexOf('.'));
+          outer = toCamelCase(outer);

Review Comment:
   Now that you mention it, I don't think this is really necessary since we're 
not persisting these schemas anywhere. It was mainly to make it match java 
naming conventions if the provided proto wasn't already following those.
   
   I think I could simplify all of this by just using the `classPackage` var 
above. What do you think? It's mainly so we have some namespacing so schemas 
don't collide, unless I'm missing something



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to