LuciferYang commented on code in PR #49601:
URL: https://github.com/apache/spark/pull/49601#discussion_r1925237252


##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -448,98 +454,149 @@ private[ml] object MLUtils {
   // Since we're using reflection way to get the attribute, in order not to
   // leave a security hole, we define an allowed attribute list that can be 
accessed.
   // The attributes could be retrieved from the corresponding python class
-  private lazy val ALLOWED_ATTRIBUTES = HashSet(
-    "mean", // StandardScalerModel
-    "std", // StandardScalerModel
-    "maxAbs", // MaxAbsScalerModel
-    "originalMax", // MinMaxScalerModel
-    "originalMin", // MinMaxScalerModel
-    "range", // RobustScalerModel
-    "median", // RobustScalerModel
-    "toString",
-    "toDebugString",
-    "numFeatures",
-    "predict", // PredictionModel
-    "predictLeaf", // Tree models
-    "numClasses",
-    "depth", // DecisionTreeClassificationModel
-    "numNodes", // Tree models
-    "totalNumNodes", // Tree models
-    "javaTreeWeights", // Tree models
-    "treeWeights", // Tree models
-    "featureImportances", // Tree models
-    "predictRaw", // ClassificationModel
-    "predictProbability", // ProbabilisticClassificationModel
-    "scale", // LinearRegressionModel
-    "coefficients",
-    "intercept",
-    "coefficientMatrix",
-    "interceptVector", // LogisticRegressionModel
-    "summary",
-    "hasSummary",
-    "evaluate", // LogisticRegressionModel
-    "evaluateEachIteration", // GBTClassificationModel
-    "predictions",
-    "predictionCol",
-    "labelCol",
-    "weightCol",
-    "labels", // _ClassificationSummary
-    "truePositiveRateByLabel",
-    "falsePositiveRateByLabel", // _ClassificationSummary
-    "precisionByLabel",
-    "recallByLabel",
-    "fMeasureByLabel",
-    "accuracy", // _ClassificationSummary
-    "weightedTruePositiveRate",
-    "weightedFalsePositiveRate", // _ClassificationSummary
-    "weightedRecall",
-    "weightedPrecision",
-    "weightedFMeasure", // _ClassificationSummary
-    "scoreCol",
-    "roc",
-    "areaUnderROC",
-    "pr",
-    "fMeasureByThreshold", // _BinaryClassificationSummary
-    "precisionByThreshold",
-    "recallByThreshold", // _BinaryClassificationSummary
-    "probabilityCol",
-    "featuresCol", // LogisticRegressionSummary
-    "objectiveHistory",
-    "coefficientStandardErrors", // _TrainingSummary
-    "degreesOfFreedom", // LinearRegressionSummary
-    "devianceResiduals", // LinearRegressionSummary
-    "explainedVariance", // LinearRegressionSummary
-    "meanAbsoluteError", // LinearRegressionSummary
-    "meanSquaredError", // LinearRegressionSummary
-    "numInstances", // LinearRegressionSummary
-    "pValues", // LinearRegressionSummary
-    "r2", // LinearRegressionSummary
-    "r2adj", // LinearRegressionSummary
-    "residuals", // LinearRegressionSummary
-    "rootMeanSquaredError", // LinearRegressionSummary
-    "tValues", // LinearRegressionSummary
-    "totalIterations", // LinearRegressionSummary
-    "k", // KMeansSummary
-    "numIter", // KMeansSummary
-    "clusterSizes", // KMeansSummary
-    "trainingCost", // KMeansSummary
-    "cluster", // KMeansSummary
-    "computeCost", // BisectingKMeansModel
-    "rank", // ALSModel
-    "itemFactors", // ALSModel
-    "userFactors", // ALSModel
-    "recommendForAllUsers", // ALSModel
-    "recommendForAllItems", // ALSModel
-    "recommendForUserSubset", // ALSModel
-    "recommendForItemSubset", // ALSModel
-    "associationRules", // FPGrowthModel
-    "freqItemsets" // FPGrowthModel
-  )
+  private lazy val ALLOWED_ATTRIBUTES = Seq(
+    (classOf[Identifiable], Array("toString")),
+
+    // Model Traits
+    (classOf[PredictionModel[_, _]], Array("predict", "numFeatures")),
+    (classOf[ClassificationModel[_, _]], Array("predictRaw", "numClasses")),
+    (classOf[ProbabilisticClassificationModel[_, _]], 
Array("predictProbability")),
+
+    // Summary Traits
+    (classOf[HasTrainingSummary[_]], Array("hasSummary", "summary")),
+    (classOf[TrainingSummary], Array("objectiveHistory", "totalIterations")),
+    (
+      classOf[ClassificationSummary],
+      Array(
+        "predictions",
+        "predictionCol",
+        "labelCol",
+        "weightCol",
+        "labels",
+        "truePositiveRateByLabel",
+        "falsePositiveRateByLabel",
+        "precisionByLabel",
+        "recallByLabel",
+        "fMeasureByLabel",
+        "accuracy",
+        "weightedTruePositiveRate",
+        "weightedFalsePositiveRate",
+        "weightedRecall",
+        "weightedPrecision",
+        "weightedFMeasure",
+        "weightedFMeasure")),
+    (
+      classOf[BinaryClassificationSummary],
+      Array(
+        "scoreCol",
+        "roc",
+        "areaUnderROC",
+        "pr",
+        "fMeasureByThreshold",
+        "precisionByThreshold",
+        "recallByThreshold")),
+    (
+      classOf[ClusteringSummary],
+      Array(
+        "predictions",
+        "predictionCol",
+        "featuresCol",
+        "k",
+        "numIter",
+        "cluster",
+        "clusterSizes")),
+
+    // Tree Models
+    (classOf[DecisionTreeModel], Array("predictLeaf", "numNodes", "depth", 
"toDebugString")),
+    (
+      classOf[TreeEnsembleModel[_]],
+      Array(
+        "predictLeaf",
+        "trees",
+        "treeWeights",
+        "javaTreeWeights",
+        "getNumTrees",
+        "totalNumNodes",
+        "toDebugString")),
+    (classOf[DecisionTreeClassificationModel], Array("featureImportances")),
+    (classOf[RandomForestClassificationModel], Array("featureImportances", 
"evaluate")),
+    (classOf[GBTClassificationModel], Array("featureImportances", 
"evaluateEachIteration")),
+    (classOf[DecisionTreeRegressionModel], Array("featureImportances")),
+    (classOf[RandomForestRegressionModel], Array("featureImportances")),
+    (classOf[GBTRegressionModel], Array("featureImportances", 
"evaluateEachIteration")),
+
+    // Classification Models
+    (
+      classOf[LogisticRegressionModel],
+      Array("intercept", "coefficients", "interceptVector", 
"coefficientMatrix", "evaluate")),
+    (classOf[LogisticRegressionSummary], Array("probabilityCol", 
"featuresCol")),
+    (classOf[BinaryLogisticRegressionSummary], Array("scoreCol")),
+
+    // Regression Models
+    (classOf[LinearRegressionModel], Array("intercept", "coefficients", 
"scale", "evaluate")),
+    (
+      classOf[LinearRegressionSummary],
+      Array(
+        "predictions",
+        "predictionCol",
+        "labelCol",
+        "featuresCol",
+        "explainedVariance",
+        "meanAbsoluteError",
+        "meanSquaredError",
+        "rootMeanSquaredError",
+        "r2",
+        "r2adj",
+        "residuals",
+        "numInstances",
+        "degreesOfFreedom",
+        "devianceResiduals",
+        "coefficientStandardErrors",
+        "tValues",
+        "pValues")),
+    (classOf[LinearRegressionTrainingSummary], Array("objectiveHistory", 
"totalIterations")),
+
+    // Clustering Models
+    (classOf[KMeansModel], Array("predict", "numFeatures", "clusterCenters")),
+    (classOf[KMeansSummary], Array("trainingCost")),
+    (
+      classOf[BisectingKMeansModel],
+      Array("predict", "numFeatures", "clusterCenters", "computeCost")),

Review Comment:
   use `Set`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to