Hi, I’ve been trying to train the new MultilayerPerceptronClassifier in spark 1.5 for the MNIST digit recognition task. I’m trying to reproduce the work here:
https://github.com/avulanov/ann-benchmark The API has changed since this work, so I’m not sure that I’m setting up the task correctly. After I've trained the classifier, it classifies everything as a 1. It even does this for the training set. I am doing something wrong with the setup? I’m not looking for state of the art performance, just something that looks reasonable. This experiment is meant to be a quick sanity test. Here is the job: import org.apache.log4j._ //Logger.getRootLogger.setLevel(Level.OFF) import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.ml.classification.MultilayerPerceptronClassifier import org.apache.spark.ml.Pipeline import org.apache.spark.ml.PipelineStage import org.apache.spark.mllib.util.MLUtils import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.SparkConf import org.apache.spark.sql.SQLContext import java.io.FileOutputStream import java.io.ObjectOutputStream object MNIST { def main(args: Array[String]) { val conf = new SparkConf().setAppName("MNIST") conf.set("spark.driver.extraJavaOptions", "-XX:MaxPermSize=512M") val sc = new SparkContext(conf) val batchSize = 100 val numIterations = 5 val mlp = new MultilayerPerceptronClassifier mlp.setLayers(Array[Int](780, 2500, 2000, 1500, 1000, 500, 10)) mlp.setMaxIter(numIterations) mlp.setBlockSize(batchSize) val train = MLUtils.loadLibSVMFile(sc, "file:///misc/home/rwaite/mt-work/ann-benchmark/mnist.scale") train.repartition(200) val sqlContext = new SQLContext(sc) import sqlContext.implicits._ val df = train.toDF val model = mlp.fit(df) val trainPredictions = model.transform(df) trainPredictions.show(100) val test = MLUtils.loadLibSVMFile(sc, "file:///misc/home/rwaite/mt-work/ann-benchmark/mnist.scale.t", 780).toDF val result = model.transform(test) result.show(100) val predictionAndLabels = result.select("prediction", "label") val evaluator = new MulticlassClassificationEvaluator() .setMetricName("precision") println("Precision:" + evaluator.evaluate(predictionAndLabels)) val fos = new FileOutputStream("/home/rwaite/mt-work/ann-benchmark/spark_out/spark_model.obj"); val oos = new ObjectOutputStream(fos); oos.writeObject(model); oos.close } } And here is the output: +-----+--------------------+----------+ |label| features|prediction| +-----+--------------------+----------+ | 5.0|(780,[152,153,154...| 1.0| | 0.0|(780,[127,128,129...| 1.0| | 4.0|(780,[160,161,162...| 1.0| | 1.0|(780,[158,159,160...| 1.0| | 9.0|(780,[208,209,210...| 1.0| | 2.0|(780,[155,156,157...| 1.0| | 1.0|(780,[124,125,126...| 1.0| | 3.0|(780,[151,152,153...| 1.0| | 1.0|(780,[152,153,154...| 1.0| | 4.0|(780,[134,135,161...| 1.0| | 3.0|(780,[123,124,125...| 1.0| | 5.0|(780,[216,217,218...| 1.0| | 3.0|(780,[143,144,145...| 1.0| | 6.0|(780,[72,73,74,99...| 1.0| | 1.0|(780,[151,152,153...| 1.0| | 7.0|(780,[211,212,213...| 1.0| | 2.0|(780,[151,152,153...| 1.0| | 8.0|(780,[159,160,161...| 1.0| | 6.0|(780,[100,101,102...| 1.0| | 9.0|(780,[209,210,211...| 1.0| | 4.0|(780,[129,130,131...| 1.0| | 0.0|(780,[129,130,131...| 1.0| | 9.0|(780,[183,184,185...| 1.0| | 1.0|(780,[158,159,160...| 1.0| | 1.0|(780,[99,100,101,...| 1.0| | 2.0|(780,[124,125,126...| 1.0| | 4.0|(780,[185,186,187...| 1.0| | 3.0|(780,[150,151,152...| 1.0| | 2.0|(780,[145,146,147...| 1.0| | 7.0|(780,[240,241,242...| 1.0| | 3.0|(780,[201,202,203...| 1.0| | 8.0|(780,[153,154,155...| 1.0| | 6.0|(780,[71,72,73,74...| 1.0| | 9.0|(780,[210,211,212...| 1.0| | 0.0|(780,[154,155,156...| 1.0| | 5.0|(780,[188,189,190...| 1.0| | 6.0|(780,[98,99,100,1...| 1.0| | 0.0|(780,[127,128,129...| 1.0| | 7.0|(780,[201,202,203...| 1.0| | 6.0|(780,[125,126,127...| 1.0| | 1.0|(780,[154,155,156...| 1.0| | 8.0|(780,[131,132,133...| 1.0| | 7.0|(780,[209,210,211...| 1.0| | 9.0|(780,[181,182,183...| 1.0| | 3.0|(780,[174,175,176...| 1.0| | 9.0|(780,[208,209,210...| 1.0| | 8.0|(780,[152,153,154...| 1.0| | 5.0|(780,[186,187,188...| 1.0| | 9.0|(780,[150,151,152...| 1.0| | 3.0|(780,[152,153,154...| 1.0| | 3.0|(780,[122,123,124...| 1.0| | 0.0|(780,[153,154,155...| 1.0| | 7.0|(780,[203,204,205...| 1.0| | 4.0|(780,[212,213,214...| 1.0| | 9.0|(780,[205,206,207...| 1.0| | 8.0|(780,[181,182,183...| 1.0| | 0.0|(780,[151,152,153...| 1.0| | 9.0|(780,[210,211,212...| 1.0| | 4.0|(780,[156,157,158...| 1.0| | 1.0|(780,[129,130,131...| 1.0| | 4.0|(780,[149,159,160...| 1.0| | 4.0|(780,[187,188,189...| 1.0| | 6.0|(780,[127,128,129...| 1.0| | 0.0|(780,[154,155,156...| 1.0| | 4.0|(780,[152,153,154...| 1.0| | 5.0|(780,[219,220,221...| 1.0| | 6.0|(780,[74,75,101,1...| 1.0| | 1.0|(780,[150,151,152...| 1.0| | 0.0|(780,[124,125,126...| 1.0| | 0.0|(780,[152,153,154...| 1.0| | 1.0|(780,[97,98,99,12...| 1.0| | 7.0|(780,[237,238,239...| 1.0| | 1.0|(780,[124,125,126...| 1.0| | 6.0|(780,[70,71,72,73...| 1.0| | 3.0|(780,[149,150,151...| 1.0| | 0.0|(780,[154,155,156...| 1.0| | 2.0|(780,[124,125,126...| 1.0| | 1.0|(780,[156,157,158...| 1.0| | 1.0|(780,[127,128,129...| 1.0| | 7.0|(780,[213,214,215...| 1.0| | 9.0|(780,[123,124,125...| 1.0| | 0.0|(780,[153,154,155...| 1.0| | 2.0|(780,[94,95,96,97...| 1.0| | 6.0|(780,[72,73,99,10...| 1.0| | 7.0|(780,[199,200,201...| 1.0| | 8.0|(780,[152,153,154...| 1.0| | 3.0|(780,[171,172,173...| 1.0| | 9.0|(780,[208,209,210...| 1.0| | 0.0|(780,[122,123,124...| 1.0| | 4.0|(780,[189,190,191...| 1.0| | 6.0|(780,[73,74,75,76...| 1.0| | 7.0|(780,[238,239,240...| 1.0| | 4.0|(780,[158,159,177...| 1.0| | 6.0|(780,[99,100,101,...| 1.0| | 8.0|(780,[154,155,156...| 1.0| | 0.0|(780,[126,127,128...| 1.0| | 7.0|(780,[209,210,211...| 1.0| | 8.0|(780,[152,153,154...| 1.0| | 3.0|(780,[150,151,152...| 1.0| | 1.0|(780,[156,157,158...| 1.0| +-----+--------------------+----------+ only showing top 100 rows +-----+--------------------+----------+ |label| features|prediction| +-----+--------------------+----------+ | 7.0|(780,[202,203,204...| 1.0| | 2.0|(780,[94,95,96,97...| 1.0| | 1.0|(780,[128,129,130...| 1.0| | 0.0|(780,[124,125,126...| 1.0| | 4.0|(780,[150,151,159...| 1.0| | 1.0|(780,[156,157,158...| 1.0| | 4.0|(780,[149,150,151...| 1.0| | 9.0|(780,[179,180,181...| 1.0| | 5.0|(780,[129,130,131...| 1.0| | 9.0|(780,[209,210,211...| 1.0| | 0.0|(780,[123,124,125...| 1.0| | 6.0|(780,[94,95,96,97...| 1.0| | 9.0|(780,[208,209,210...| 1.0| | 0.0|(780,[152,153,154...| 1.0| | 1.0|(780,[125,126,127...| 1.0| | 5.0|(780,[124,125,126...| 1.0| | 9.0|(780,[179,180,181...| 1.0| | 7.0|(780,[200,201,202...| 1.0| | 3.0|(780,[118,119,120...| 1.0| | 4.0|(780,[158,159,185...| 1.0| | 9.0|(780,[183,184,185...| 1.0| | 6.0|(780,[96,97,98,99...| 1.0| | 6.0|(780,[93,94,95,12...| 1.0| | 5.0|(780,[156,157,158...| 1.0| | 4.0|(780,[151,152,178...| 1.0| | 0.0|(780,[125,126,127...| 1.0| | 7.0|(780,[230,234,235...| 1.0| | 4.0|(780,[152,153,179...| 1.0| | 0.0|(780,[149,150,151...| 1.0| | 1.0|(780,[123,124,125...| 1.0| | 3.0|(780,[175,176,177...| 1.0| | 1.0|(780,[152,153,154...| 1.0| | 3.0|(780,[148,149,150...| 1.0| | 4.0|(780,[122,123,150...| 1.0| | 7.0|(780,[175,176,177...| 1.0| | 2.0|(780,[124,125,126...| 1.0| | 7.0|(780,[202,203,204...| 1.0| | 1.0|(780,[151,152,153...| 1.0| | 2.0|(780,[125,126,127...| 1.0| | 1.0|(780,[126,127,128...| 1.0| | 1.0|(780,[125,126,153...| 1.0| | 7.0|(780,[207,208,209...| 1.0| | 4.0|(780,[176,177,178...| 1.0| | 2.0|(780,[126,127,128...| 1.0| | 3.0|(780,[121,122,123...| 1.0| | 5.0|(780,[152,153,154...| 1.0| | 1.0|(780,[122,123,124...| 1.0| | 2.0|(780,[65,66,67,68...| 1.0| | 4.0|(780,[177,178,179...| 1.0| | 4.0|(780,[147,148,157...| 1.0| | 6.0|(780,[100,101,102...| 1.0| | 3.0|(780,[172,173,174...| 1.0| | 5.0|(780,[163,164,165...| 1.0| | 5.0|(780,[126,127,128...| 1.0| | 6.0|(780,[93,94,95,12...| 1.0| | 0.0|(780,[151,152,153...| 1.0| | 4.0|(780,[148,149,150...| 1.0| | 1.0|(780,[155,156,157...| 1.0| | 9.0|(780,[209,210,211...| 1.0| | 5.0|(780,[190,191,192...| 1.0| | 7.0|(780,[198,199,200...| 1.0| | 8.0|(780,[153,154,155...| 1.0| | 9.0|(780,[178,179,180...| 1.0| | 3.0|(780,[95,96,97,98...| 1.0| | 7.0|(780,[200,201,202...| 1.0| | 4.0|(780,[156,157,184...| 1.0| | 6.0|(780,[67,68,69,95...| 1.0| | 4.0|(780,[160,161,162...| 1.0| | 3.0|(780,[148,149,150...| 1.0| | 0.0|(780,[152,153,179...| 1.0| | 7.0|(780,[206,207,208...| 1.0| | 0.0|(780,[123,124,125...| 1.0| | 2.0|(780,[119,120,121...| 1.0| | 9.0|(780,[180,181,182...| 1.0| | 1.0|(780,[152,153,154...| 1.0| | 7.0|(780,[213,214,215...| 1.0| | 3.0|(780,[124,125,126...| 1.0| | 2.0|(780,[205,206,207...| 1.0| | 9.0|(780,[183,184,185...| 1.0| | 7.0|(780,[209,210,211...| 1.0| | 7.0|(780,[205,206,207...| 1.0| | 6.0|(780,[99,100,101,...| 1.0| | 2.0|(780,[96,97,98,99...| 1.0| | 7.0|(780,[204,205,206...| 1.0| | 8.0|(780,[156,157,159...| 1.0| | 4.0|(780,[147,148,158...| 1.0| | 7.0|(780,[203,204,205...| 1.0| | 3.0|(780,[146,147,148...| 1.0| | 6.0|(780,[67,68,69,70...| 1.0| | 1.0|(780,[128,129,130...| 1.0| | 3.0|(780,[152,153,154...| 1.0| | 6.0|(780,[71,72,73,74...| 1.0| | 9.0|(780,[182,183,184...| 1.0| | 3.0|(780,[149,150,151...| 1.0| | 1.0|(780,[123,124,125...| 1.0| | 4.0|(780,[158,159,160...| 1.0| | 1.0|(780,[149,150,151...| 1.0| | 7.0|(780,[175,176,177...| 1.0| | 6.0|(780,[99,100,101,...| 1.0| | 9.0|(780,[177,178,179...| 1.0| +-----+--------------------+----------+ only showing top 100 rows Precision:0.1135 [http://dr0muzwhcp26z.cloudfront.net/static/corporate/SDL-logo-2014.png] <www.sdl.com/> www.sdl.com SDL PLC confidential, all rights reserved. If you are not the intended recipient of this mail SDL requests and requires that you delete it without acting upon or copying any of its contents, and we further request that you advise us. SDL PLC is a public limited company registered in England and Wales. Registered number: 02675207. Registered address: Globe House, Clivemont Road, Maidenhead, Berkshire SL6 7DY, UK. This message has been scanned for malware by Websense. www.websense.com