Github user pnowojski commented on a diff in the pull request: https://github.com/apache/flink/pull/6323#discussion_r202335703 --- Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/factories/TableFactoryService.scala --- @@ -18,143 +18,358 @@ package org.apache.flink.table.factories -import java.util.{ServiceConfigurationError, ServiceLoader} +import java.util.{ServiceConfigurationError, ServiceLoader, Map => JMap} import org.apache.flink.table.api._ import org.apache.flink.table.descriptors.ConnectorDescriptorValidator._ import org.apache.flink.table.descriptors.FormatDescriptorValidator._ import org.apache.flink.table.descriptors.MetadataValidator._ import org.apache.flink.table.descriptors.StatisticsValidator._ -import org.apache.flink.table.descriptors.{DescriptorProperties, TableDescriptor, TableDescriptorValidator} +import org.apache.flink.table.descriptors._ import org.apache.flink.table.util.Logging +import org.apache.flink.util.Preconditions import _root_.scala.collection.JavaConverters._ import _root_.scala.collection.mutable /** - * Unified interface to search for TableFactoryDiscoverable of provided type and properties. + * Unified interface to search for a [[TableFactory]] of provided type and properties. */ object TableFactoryService extends Logging { private lazy val defaultLoader = ServiceLoader.load(classOf[TableFactory]) - def find(clz: Class[_], descriptor: TableDescriptor): TableFactory = { - find(clz, descriptor, null) + /** + * Finds a table factory of the given class and descriptor. + * + * @param factoryClass desired factory class + * @param descriptor descriptor describing the factory configuration + * @tparam T factory class type + * @return the matching factory + */ + def find[T](factoryClass: Class[T], descriptor: Descriptor): T = { + Preconditions.checkNotNull(factoryClass) + Preconditions.checkNotNull(descriptor) + + val descriptorProperties = new DescriptorProperties() + descriptor.addProperties(descriptorProperties) + findInternal(factoryClass, descriptorProperties.asMap, None) } - def find(clz: Class[_], descriptor: TableDescriptor, classLoader: ClassLoader) - : TableFactory = { + /** + * Finds a table factory of the given class, descriptor, and classloader. + * + * @param factoryClass desired factory class + * @param descriptor descriptor describing the factory configuration + * @param classLoader classloader for service loading + * @tparam T factory class type + * @return the matching factory + */ + def find[T](factoryClass: Class[T], descriptor: Descriptor, classLoader: ClassLoader): T = { + Preconditions.checkNotNull(factoryClass) + Preconditions.checkNotNull(descriptor) + Preconditions.checkNotNull(classLoader) - val properties = new DescriptorProperties() - descriptor.addProperties(properties) - find(clz, properties.asMap.asScala.toMap, classLoader) + val descriptorProperties = new DescriptorProperties() + descriptor.addProperties(descriptorProperties) + findInternal(factoryClass, descriptorProperties.asMap, None) } - def find(clz: Class[_], properties: Map[String, String]): TableFactory = { - find(clz: Class[_], properties, null) + /** + * Finds a table factory of the given class and property map. + * + * @param factoryClass desired factory class + * @param propertyMap properties that describe the factory configuration + * @tparam T factory class type + * @return the matching factory + */ + def find[T](factoryClass: Class[T], propertyMap: JMap[String, String]): T = { + Preconditions.checkNotNull(factoryClass) + Preconditions.checkNotNull(propertyMap) + + findInternal(factoryClass, propertyMap, None) } - def find(clz: Class[_], properties: Map[String, String], - classLoader: ClassLoader): TableFactory = { + /** + * Finds a table factory of the given class, property map, and classloader. + * + * @param factoryClass desired factory class + * @param propertyMap properties that describe the factory configuration + * @param classLoader classloader for service loading + * @tparam T factory class type + * @return the matching factory + */ + def find[T]( + factoryClass: Class[T], + propertyMap: JMap[String, String], + classLoader: ClassLoader) + : T = { + Preconditions.checkNotNull(factoryClass) + Preconditions.checkNotNull(propertyMap) + Preconditions.checkNotNull(classLoader) + + findInternal(factoryClass, propertyMap, Some(classLoader)) + } + + /** + * Finds a table factory of the given class, property map, and classloader. + * + * @param factoryClass desired factory class + * @param propertyMap properties that describe the factory configuration + * @param classLoader classloader for service loading + * @tparam T factory class type + * @return the matching factory + */ + private def findInternal[T]( + factoryClass: Class[T], + propertyMap: JMap[String, String], + classLoader: Option[ClassLoader]) + : T = { + + val properties = propertyMap.asScala.toMap + + // discover table factories + val foundFactories = discoverFactories(classLoader) - var matchingFactory: Option[(TableFactory, Seq[String])] = None + // filter by factory class + val classFactories = filterByFactoryClass( + factoryClass, + properties, + foundFactories) + + // find matching context + val contextFactories = filterByContext( + factoryClass, + properties, + foundFactories, + classFactories) + + // filter by supported keys + filterBySupportedProperties( + factoryClass, + properties, + foundFactories, + contextFactories) + } + + /** + * Searches for factories using Java service providers. + * + * @return all factories in the classpath + */ + private def discoverFactories[T](classLoader: Option[ClassLoader]): Seq[TableFactory] = { + val foundFactories = mutable.ArrayBuffer[TableFactory]() try { - val iter = if (classLoader == null) { - defaultLoader.iterator() - } else { - val customLoader = ServiceLoader.load(classOf[TableFactory], classLoader) - customLoader.iterator() + val iterator = classLoader match { + case Some(customClassLoader) => + val customLoader = ServiceLoader.load(classOf[TableFactory], customClassLoader) + customLoader.iterator() + case None => + defaultLoader.iterator() } - while (iter.hasNext) { - val factory = iter.next() - - if (clz.isAssignableFrom(factory.getClass)) { - val requiredContextJava = try { - factory.requiredContext() - } catch { - case t: Throwable => - throw new TableException( - s"Table source factory '${factory.getClass.getCanonicalName}' caused an exception.", - t) - } - - val requiredContext = if (requiredContextJava != null) { - // normalize properties - requiredContextJava.asScala.map(e => (e._1.toLowerCase, e._2)) - } else { - Map[String, String]() - } - - val plainContext = mutable.Map[String, String]() - plainContext ++= requiredContext - // we remove the versions for now until we have the first backwards compatibility case - // with the version we can provide mappings in case the format changes - plainContext.remove(CONNECTOR_PROPERTY_VERSION) - plainContext.remove(FORMAT_PROPERTY_VERSION) - plainContext.remove(METADATA_PROPERTY_VERSION) - plainContext.remove(STATISTICS_PROPERTY_VERSION) - - if (plainContext.forall(e => properties.contains(e._1) && properties(e._1) == e._2)) { - matchingFactory match { - case Some(_) => throw new AmbiguousTableFactoryException(properties) - case None => matchingFactory = - Some((factory.asInstanceOf[TableFactory], requiredContext.keys.toSeq)) - } - } - } + + while (iterator.hasNext) { + val factory = iterator.next() + foundFactories += factory } + + foundFactories } catch { case e: ServiceConfigurationError => LOG.error("Could not load service provider for table factories.", e) throw new TableException("Could not load service provider for table factories.", e) } + } + + /** + * Filters for factories with matching context. + * + * @return all matching factories + */ + private def filterByContext[T]( --- End diff -- nit: move this method below `filterByFactoryClass`? (to match the order of invocations)
---