Hi, It seems to me that there is an overhead in "runMiniBatchSGD" function of MLlib's "GradientDescent". In particular, "sample" and "treeAggregate" might take time that is order of magnitude greater than the actual gradient computation. In particular, for mnist dataset of 60K instances, minibatch size = 0.001 (i.e. 60 samples) it take 0.15 s to sample and 0.3 to aggregate in local mode with 1 data partition on Core i5 processor. The actual gradient computation takes 0.002 s. I searched through Spark Jira and found that there was recently an update for more efficient sampling (SPARK-3250) that is already included in Spark codebase. Is there a way to reduce the sampling time and local treeRedeuce by order of magnitude?
Best regards, Alexander