Hi,

I’ve been trying to train the new MultilayerPerceptronClassifier in spark 1.5 
for the MNIST digit recognition task. I’m trying to reproduce the work here:

https://github.com/avulanov/ann-benchmark

The API has changed since this work, so I’m not sure that I’m setting up the 
task correctly.

After I've trained the classifier, it classifies everything as a 1. It even 
does this for the training set. I am doing something wrong with the setup? I’m 
not looking for state of the art performance, just something that looks 
reasonable. This experiment is meant to be a quick sanity test.

Here is the job:

import org.apache.log4j._
//Logger.getRootLogger.setLevel(Level.OFF)
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.PipelineStage
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.sql.SQLContext
import java.io.FileOutputStream
import java.io.ObjectOutputStream

object MNIST {
 def main(args: Array[String]) {
   val conf = new SparkConf().setAppName("MNIST")
   conf.set("spark.driver.extraJavaOptions", "-XX:MaxPermSize=512M")
   val sc = new SparkContext(conf)
   val batchSize = 100
   val numIterations = 5
   val mlp = new MultilayerPerceptronClassifier
   mlp.setLayers(Array[Int](780, 2500, 2000, 1500, 1000, 500, 10))
   mlp.setMaxIter(numIterations)
   mlp.setBlockSize(batchSize)
   val train = MLUtils.loadLibSVMFile(sc, 
"file:///misc/home/rwaite/mt-work/ann-benchmark/mnist.scale")
   train.repartition(200)
   val sqlContext = new SQLContext(sc)
   import sqlContext.implicits._
   val df = train.toDF
   val model = mlp.fit(df)
   val trainPredictions = model.transform(df)
   trainPredictions.show(100)
   val test = MLUtils.loadLibSVMFile(sc, 
"file:///misc/home/rwaite/mt-work/ann-benchmark/mnist.scale.t", 780).toDF
   val result = model.transform(test)
   result.show(100)
   val predictionAndLabels = result.select("prediction", "label")
   val evaluator = new MulticlassClassificationEvaluator()
     .setMetricName("precision")
   println("Precision:" + evaluator.evaluate(predictionAndLabels))
   val fos = new 
FileOutputStream("/home/rwaite/mt-work/ann-benchmark/spark_out/spark_model.obj");
   val oos = new ObjectOutputStream(fos);
   oos.writeObject(model);
   oos.close
 }
}


And here is the output:

+-----+--------------------+----------+
|label|            features|prediction|
+-----+--------------------+----------+
|  5.0|(780,[152,153,154...|       1.0|
|  0.0|(780,[127,128,129...|       1.0|
|  4.0|(780,[160,161,162...|       1.0|
|  1.0|(780,[158,159,160...|       1.0|
|  9.0|(780,[208,209,210...|       1.0|
|  2.0|(780,[155,156,157...|       1.0|
|  1.0|(780,[124,125,126...|       1.0|
|  3.0|(780,[151,152,153...|       1.0|
|  1.0|(780,[152,153,154...|       1.0|
|  4.0|(780,[134,135,161...|       1.0|
|  3.0|(780,[123,124,125...|       1.0|
|  5.0|(780,[216,217,218...|       1.0|
|  3.0|(780,[143,144,145...|       1.0|
|  6.0|(780,[72,73,74,99...|       1.0|
|  1.0|(780,[151,152,153...|       1.0|
|  7.0|(780,[211,212,213...|       1.0|
|  2.0|(780,[151,152,153...|       1.0|
|  8.0|(780,[159,160,161...|       1.0|
|  6.0|(780,[100,101,102...|       1.0|
|  9.0|(780,[209,210,211...|       1.0|
|  4.0|(780,[129,130,131...|       1.0|
|  0.0|(780,[129,130,131...|       1.0|
|  9.0|(780,[183,184,185...|       1.0|
|  1.0|(780,[158,159,160...|       1.0|
|  1.0|(780,[99,100,101,...|       1.0|
|  2.0|(780,[124,125,126...|       1.0|
|  4.0|(780,[185,186,187...|       1.0|
|  3.0|(780,[150,151,152...|       1.0|
|  2.0|(780,[145,146,147...|       1.0|
|  7.0|(780,[240,241,242...|       1.0|
|  3.0|(780,[201,202,203...|       1.0|
|  8.0|(780,[153,154,155...|       1.0|
|  6.0|(780,[71,72,73,74...|       1.0|
|  9.0|(780,[210,211,212...|       1.0|
|  0.0|(780,[154,155,156...|       1.0|
|  5.0|(780,[188,189,190...|       1.0|
|  6.0|(780,[98,99,100,1...|       1.0|
|  0.0|(780,[127,128,129...|       1.0|
|  7.0|(780,[201,202,203...|       1.0|
|  6.0|(780,[125,126,127...|       1.0|
|  1.0|(780,[154,155,156...|       1.0|
|  8.0|(780,[131,132,133...|       1.0|
|  7.0|(780,[209,210,211...|       1.0|
|  9.0|(780,[181,182,183...|       1.0|
|  3.0|(780,[174,175,176...|       1.0|
|  9.0|(780,[208,209,210...|       1.0|
|  8.0|(780,[152,153,154...|       1.0|
|  5.0|(780,[186,187,188...|       1.0|
|  9.0|(780,[150,151,152...|       1.0|
|  3.0|(780,[152,153,154...|       1.0|
|  3.0|(780,[122,123,124...|       1.0|
|  0.0|(780,[153,154,155...|       1.0|
|  7.0|(780,[203,204,205...|       1.0|
|  4.0|(780,[212,213,214...|       1.0|
|  9.0|(780,[205,206,207...|       1.0|
|  8.0|(780,[181,182,183...|       1.0|
|  0.0|(780,[151,152,153...|       1.0|
|  9.0|(780,[210,211,212...|       1.0|
|  4.0|(780,[156,157,158...|       1.0|
|  1.0|(780,[129,130,131...|       1.0|
|  4.0|(780,[149,159,160...|       1.0|
|  4.0|(780,[187,188,189...|       1.0|
|  6.0|(780,[127,128,129...|       1.0|
|  0.0|(780,[154,155,156...|       1.0|
|  4.0|(780,[152,153,154...|       1.0|
|  5.0|(780,[219,220,221...|       1.0|
|  6.0|(780,[74,75,101,1...|       1.0|
|  1.0|(780,[150,151,152...|       1.0|
|  0.0|(780,[124,125,126...|       1.0|
|  0.0|(780,[152,153,154...|       1.0|
|  1.0|(780,[97,98,99,12...|       1.0|
|  7.0|(780,[237,238,239...|       1.0|
|  1.0|(780,[124,125,126...|       1.0|
|  6.0|(780,[70,71,72,73...|       1.0|
|  3.0|(780,[149,150,151...|       1.0|
|  0.0|(780,[154,155,156...|       1.0|
|  2.0|(780,[124,125,126...|       1.0|
|  1.0|(780,[156,157,158...|       1.0|
|  1.0|(780,[127,128,129...|       1.0|
|  7.0|(780,[213,214,215...|       1.0|
|  9.0|(780,[123,124,125...|       1.0|
|  0.0|(780,[153,154,155...|       1.0|
|  2.0|(780,[94,95,96,97...|       1.0|
|  6.0|(780,[72,73,99,10...|       1.0|
|  7.0|(780,[199,200,201...|       1.0|
|  8.0|(780,[152,153,154...|       1.0|
|  3.0|(780,[171,172,173...|       1.0|
|  9.0|(780,[208,209,210...|       1.0|
|  0.0|(780,[122,123,124...|       1.0|
|  4.0|(780,[189,190,191...|       1.0|
|  6.0|(780,[73,74,75,76...|       1.0|
|  7.0|(780,[238,239,240...|       1.0|
|  4.0|(780,[158,159,177...|       1.0|
|  6.0|(780,[99,100,101,...|       1.0|
|  8.0|(780,[154,155,156...|       1.0|
|  0.0|(780,[126,127,128...|       1.0|
|  7.0|(780,[209,210,211...|       1.0|
|  8.0|(780,[152,153,154...|       1.0|
|  3.0|(780,[150,151,152...|       1.0|
|  1.0|(780,[156,157,158...|       1.0|
+-----+--------------------+----------+
only showing top 100 rows

+-----+--------------------+----------+
|label|            features|prediction|
+-----+--------------------+----------+
|  7.0|(780,[202,203,204...|       1.0|
|  2.0|(780,[94,95,96,97...|       1.0|
|  1.0|(780,[128,129,130...|       1.0|
|  0.0|(780,[124,125,126...|       1.0|
|  4.0|(780,[150,151,159...|       1.0|
|  1.0|(780,[156,157,158...|       1.0|
|  4.0|(780,[149,150,151...|       1.0|
|  9.0|(780,[179,180,181...|       1.0|
|  5.0|(780,[129,130,131...|       1.0|
|  9.0|(780,[209,210,211...|       1.0|
|  0.0|(780,[123,124,125...|       1.0|
|  6.0|(780,[94,95,96,97...|       1.0|
|  9.0|(780,[208,209,210...|       1.0|
|  0.0|(780,[152,153,154...|       1.0|
|  1.0|(780,[125,126,127...|       1.0|
|  5.0|(780,[124,125,126...|       1.0|
|  9.0|(780,[179,180,181...|       1.0|
|  7.0|(780,[200,201,202...|       1.0|
|  3.0|(780,[118,119,120...|       1.0|
|  4.0|(780,[158,159,185...|       1.0|
|  9.0|(780,[183,184,185...|       1.0|
|  6.0|(780,[96,97,98,99...|       1.0|
|  6.0|(780,[93,94,95,12...|       1.0|
|  5.0|(780,[156,157,158...|       1.0|
|  4.0|(780,[151,152,178...|       1.0|
|  0.0|(780,[125,126,127...|       1.0|
|  7.0|(780,[230,234,235...|       1.0|
|  4.0|(780,[152,153,179...|       1.0|
|  0.0|(780,[149,150,151...|       1.0|
|  1.0|(780,[123,124,125...|       1.0|
|  3.0|(780,[175,176,177...|       1.0|
|  1.0|(780,[152,153,154...|       1.0|
|  3.0|(780,[148,149,150...|       1.0|
|  4.0|(780,[122,123,150...|       1.0|
|  7.0|(780,[175,176,177...|       1.0|
|  2.0|(780,[124,125,126...|       1.0|
|  7.0|(780,[202,203,204...|       1.0|
|  1.0|(780,[151,152,153...|       1.0|
|  2.0|(780,[125,126,127...|       1.0|
|  1.0|(780,[126,127,128...|       1.0|
|  1.0|(780,[125,126,153...|       1.0|
|  7.0|(780,[207,208,209...|       1.0|
|  4.0|(780,[176,177,178...|       1.0|
|  2.0|(780,[126,127,128...|       1.0|
|  3.0|(780,[121,122,123...|       1.0|
|  5.0|(780,[152,153,154...|       1.0|
|  1.0|(780,[122,123,124...|       1.0|
|  2.0|(780,[65,66,67,68...|       1.0|
|  4.0|(780,[177,178,179...|       1.0|
|  4.0|(780,[147,148,157...|       1.0|
|  6.0|(780,[100,101,102...|       1.0|
|  3.0|(780,[172,173,174...|       1.0|
|  5.0|(780,[163,164,165...|       1.0|
|  5.0|(780,[126,127,128...|       1.0|
|  6.0|(780,[93,94,95,12...|       1.0|
|  0.0|(780,[151,152,153...|       1.0|
|  4.0|(780,[148,149,150...|       1.0|
|  1.0|(780,[155,156,157...|       1.0|
|  9.0|(780,[209,210,211...|       1.0|
|  5.0|(780,[190,191,192...|       1.0|
|  7.0|(780,[198,199,200...|       1.0|
|  8.0|(780,[153,154,155...|       1.0|
|  9.0|(780,[178,179,180...|       1.0|
|  3.0|(780,[95,96,97,98...|       1.0|
|  7.0|(780,[200,201,202...|       1.0|
|  4.0|(780,[156,157,184...|       1.0|
|  6.0|(780,[67,68,69,95...|       1.0|
|  4.0|(780,[160,161,162...|       1.0|
|  3.0|(780,[148,149,150...|       1.0|
|  0.0|(780,[152,153,179...|       1.0|
|  7.0|(780,[206,207,208...|       1.0|
|  0.0|(780,[123,124,125...|       1.0|
|  2.0|(780,[119,120,121...|       1.0|
|  9.0|(780,[180,181,182...|       1.0|
|  1.0|(780,[152,153,154...|       1.0|
|  7.0|(780,[213,214,215...|       1.0|
|  3.0|(780,[124,125,126...|       1.0|
|  2.0|(780,[205,206,207...|       1.0|
|  9.0|(780,[183,184,185...|       1.0|
|  7.0|(780,[209,210,211...|       1.0|
|  7.0|(780,[205,206,207...|       1.0|
|  6.0|(780,[99,100,101,...|       1.0|
|  2.0|(780,[96,97,98,99...|       1.0|
|  7.0|(780,[204,205,206...|       1.0|
|  8.0|(780,[156,157,159...|       1.0|
|  4.0|(780,[147,148,158...|       1.0|
|  7.0|(780,[203,204,205...|       1.0|
|  3.0|(780,[146,147,148...|       1.0|
|  6.0|(780,[67,68,69,70...|       1.0|
|  1.0|(780,[128,129,130...|       1.0|
|  3.0|(780,[152,153,154...|       1.0|
|  6.0|(780,[71,72,73,74...|       1.0|
|  9.0|(780,[182,183,184...|       1.0|
|  3.0|(780,[149,150,151...|       1.0|
|  1.0|(780,[123,124,125...|       1.0|
|  4.0|(780,[158,159,160...|       1.0|
|  1.0|(780,[149,150,151...|       1.0|
|  7.0|(780,[175,176,177...|       1.0|
|  6.0|(780,[99,100,101,...|       1.0|
|  9.0|(780,[177,178,179...|       1.0|
+-----+--------------------+----------+
only showing top 100 rows

Precision:0.1135




 [http://dr0muzwhcp26z.cloudfront.net/static/corporate/SDL-logo-2014.png] 
<www.sdl.com/>
www.sdl.com


SDL PLC confidential, all rights reserved. If you are not the intended 
recipient of this mail SDL requests and requires that you delete it without 
acting upon or copying any of its contents, and we further request that you 
advise us.

SDL PLC is a public limited company registered in England and Wales. Registered 
number: 02675207.
Registered address: Globe House, Clivemont Road, Maidenhead, Berkshire SL6 7DY, 
UK.



This message has been scanned for malware by Websense. www.websense.com

Reply via email to