Thanks Feynman, that is useful.

I am interested in comparing the Spark MLP with Caffe. If I understand it 
correctly the changes to the Spark MLP API now restricts the functionality such 
that

-Spark restricts the top layer to be a softmax
-Can only use LBFGS to train the network

I think this benchmark originally used a sigmoid top layer and SGD to optimise 
the network for spark. So the Caffe config used in the benchmark and the Spark 
setup are now not equivalent.

Also this benchmark is designed for speed testing. I just want to do a quick 
sanity test and make sure that Caffe and Spark yield similar accuracies for 
MNIST before I try to test Spark on our own task. I am possibly reproducing 
existing efforts. Is there an example of this kind of sanity test I could 
reproduce?


 [http://dr0muzwhcp26z.cloudfront.net/static/corporate/SDL-logo-2014.png] 
<www.sdl.com/>
www.sdl.com


SDL PLC confidential, all rights reserved. If you are not the intended 
recipient of this mail SDL requests and requires that you delete it without 
acting upon or copying any of its contents, and we further request that you 
advise us.

SDL PLC is a public limited company registered in England and Wales. Registered 
number: 02675207.
Registered address: Globe House, Clivemont Road, Maidenhead, Berkshire SL6 7DY, 
UK.

________________________________
From: Feynman Liang [fli...@databricks.com]
Sent: 11 September 2015 20:34
To: Rory Waite
Cc: user@spark.apache.org
Subject: Re: Training the MultilayerPerceptronClassifier

Rory,

I just sent a PR (https://github.com/avulanov/ann-benchmark/pull/1) to bring 
that benchmark up to date. Hope it helps.

On Fri, Sep 11, 2015 at 6:39 AM, Rory Waite 
<rwa...@sdl.com<mailto:rwa...@sdl.com>> wrote:
Hi,

I’ve been trying to train the new MultilayerPerceptronClassifier in spark 1.5 
for the MNIST digit recognition task. I’m trying to reproduce the work here:

https://github.com/avulanov/ann-benchmark

The API has changed since this work, so I’m not sure that I’m setting up the 
task correctly.

After I've trained the classifier, it classifies everything as a 1. It even 
does this for the training set. I am doing something wrong with the setup? I’m 
not looking for state of the art performance, just something that looks 
reasonable. This experiment is meant to be a quick sanity test.

Here is the job:

import org.apache.log4j._
//Logger.getRootLogger.setLevel(Level.OFF)
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.PipelineStage
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.sql.SQLContext
import java.io.FileOutputStream
import java.io.ObjectOutputStream

object MNIST {
 def main(args: Array[String]) {
   val conf = new SparkConf().setAppName("MNIST")
   conf.set("spark.driver.extraJavaOptions", "-XX:MaxPermSize=512M")
   val sc = new SparkContext(conf)
   val batchSize = 100
   val numIterations = 5
   val mlp = new MultilayerPerceptronClassifier
   mlp.setLayers(Array[Int](780, 2500, 2000, 1500, 1000, 500, 10))
   mlp.setMaxIter(numIterations)
   mlp.setBlockSize(batchSize)
   val train = MLUtils.loadLibSVMFile(sc, 
"file:///misc/home/rwaite/mt-work/ann-benchmark/mnist.scale")
   train.repartition(200)
   val sqlContext = new SQLContext(sc)
   import sqlContext.implicits._
   val df = train.toDF
   val model = mlp.fit(df)
   val trainPredictions = model.transform(df)
   trainPredictions.show(100)
   val test = MLUtils.loadLibSVMFile(sc, 
"file:///misc/home/rwaite/mt-work/ann-benchmark/mnist.scale.t", 780).toDF
   val result = model.transform(test)
   result.show(100)
   val predictionAndLabels = result.select("prediction", "label")
   val evaluator = new MulticlassClassificationEvaluator()
     .setMetricName("precision")
   println("Precision:" + evaluator.evaluate(predictionAndLabels))
   val fos = new 
FileOutputStream("/home/rwaite/mt-work/ann-benchmark/spark_out/spark_model.obj");
   val oos = new ObjectOutputStream(fos);
   oos.writeObject(model);
   oos.close
 }
}


And here is the output:

+-----+--------------------+----------+
|label|            features|prediction|
+-----+--------------------+----------+
|  5.0|(780,[152,153,154...|       1.0|
|  0.0|(780,[127,128,129...|       1.0|
|  4.0|(780,[160,161,162...|       1.0|
|  1.0|(780,[158,159,160...|       1.0|
|  9.0|(780,[208,209,210...|       1.0|
|  2.0|(780,[155,156,157...|       1.0|
|  1.0|(780,[124,125,126...|       1.0|
|  3.0|(780,[151,152,153...|       1.0|
|  1.0|(780,[152,153,154...|       1.0|
|  4.0|(780,[134,135,161...|       1.0|
|  3.0|(780,[123,124,125...|       1.0|
|  5.0|(780,[216,217,218...|       1.0|
|  3.0|(780,[143,144,145...|       1.0|
|  6.0|(780,[72,73,74,99...|       1.0|
|  1.0|(780,[151,152,153...|       1.0|
|  7.0|(780,[211,212,213...|       1.0|
|  2.0|(780,[151,152,153...|       1.0|
|  8.0|(780,[159,160,161...|       1.0|
|  6.0|(780,[100,101,102...|       1.0|
|  9.0|(780,[209,210,211...|       1.0|
|  4.0|(780,[129,130,131...|       1.0|
|  0.0|(780,[129,130,131...|       1.0|
|  9.0|(780,[183,184,185...|       1.0|
|  1.0|(780,[158,159,160...|       1.0|
|  1.0|(780,[99,100,101,...|       1.0|
|  2.0|(780,[124,125,126...|       1.0|
|  4.0|(780,[185,186,187...|       1.0|
|  3.0|(780,[150,151,152...|       1.0|
|  2.0|(780,[145,146,147...|       1.0|
|  7.0|(780,[240,241,242...|       1.0|
|  3.0|(780,[201,202,203...|       1.0|
|  8.0|(780,[153,154,155...|       1.0|
|  6.0|(780,[71,72,73,74...|       1.0|
|  9.0|(780,[210,211,212...|       1.0|
|  0.0|(780,[154,155,156...|       1.0|
|  5.0|(780,[188,189,190...|       1.0|
|  6.0|(780,[98,99,100,1...|       1.0|
|  0.0|(780,[127,128,129...|       1.0|
|  7.0|(780,[201,202,203...|       1.0|
|  6.0|(780,[125,126,127...|       1.0|
|  1.0|(780,[154,155,156...|       1.0|
|  8.0|(780,[131,132,133...|       1.0|
|  7.0|(780,[209,210,211...|       1.0|
|  9.0|(780,[181,182,183...|       1.0|
|  3.0|(780,[174,175,176...|       1.0|
|  9.0|(780,[208,209,210...|       1.0|
|  8.0|(780,[152,153,154...|       1.0|
|  5.0|(780,[186,187,188...|       1.0|
|  9.0|(780,[150,151,152...|       1.0|
|  3.0|(780,[152,153,154...|       1.0|
|  3.0|(780,[122,123,124...|       1.0|
|  0.0|(780,[153,154,155...|       1.0|
|  7.0|(780,[203,204,205...|       1.0|
|  4.0|(780,[212,213,214...|       1.0|
|  9.0|(780,[205,206,207...|       1.0|
|  8.0|(780,[181,182,183...|       1.0|
|  0.0|(780,[151,152,153...|       1.0|
|  9.0|(780,[210,211,212...|       1.0|
|  4.0|(780,[156,157,158...|       1.0|
|  1.0|(780,[129,130,131...|       1.0|
|  4.0|(780,[149,159,160...|       1.0|
|  4.0|(780,[187,188,189...|       1.0|
|  6.0|(780,[127,128,129...|       1.0|
|  0.0|(780,[154,155,156...|       1.0|
|  4.0|(780,[152,153,154...|       1.0|
|  5.0|(780,[219,220,221...|       1.0|
|  6.0|(780,[74,75,101,1...|       1.0|
|  1.0|(780,[150,151,152...|       1.0|
|  0.0|(780,[124,125,126...|       1.0|
|  0.0|(780,[152,153,154...|       1.0|
|  1.0|(780,[97,98,99,12...|       1.0|
|  7.0|(780,[237,238,239...|       1.0|
|  1.0|(780,[124,125,126...|       1.0|
|  6.0|(780,[70,71,72,73...|       1.0|
|  3.0|(780,[149,150,151...|       1.0|
|  0.0|(780,[154,155,156...|       1.0|
|  2.0|(780,[124,125,126...|       1.0|
|  1.0|(780,[156,157,158...|       1.0|
|  1.0|(780,[127,128,129...|       1.0|
|  7.0|(780,[213,214,215...|       1.0|
|  9.0|(780,[123,124,125...|       1.0|
|  0.0|(780,[153,154,155...|       1.0|
|  2.0|(780,[94,95,96,97...|       1.0|
|  6.0|(780,[72,73,99,10...|       1.0|
|  7.0|(780,[199,200,201...|       1.0|
|  8.0|(780,[152,153,154...|       1.0|
|  3.0|(780,[171,172,173...|       1.0|
|  9.0|(780,[208,209,210...|       1.0|
|  0.0|(780,[122,123,124...|       1.0|
|  4.0|(780,[189,190,191...|       1.0|
|  6.0|(780,[73,74,75,76...|       1.0|
|  7.0|(780,[238,239,240...|       1.0|
|  4.0|(780,[158,159,177...|       1.0|
|  6.0|(780,[99,100,101,...|       1.0|
|  8.0|(780,[154,155,156...|       1.0|
|  0.0|(780,[126,127,128...|       1.0|
|  7.0|(780,[209,210,211...|       1.0|
|  8.0|(780,[152,153,154...|       1.0|
|  3.0|(780,[150,151,152...|       1.0|
|  1.0|(780,[156,157,158...|       1.0|
+-----+--------------------+----------+
only showing top 100 rows

+-----+--------------------+----------+
|label|            features|prediction|
+-----+--------------------+----------+
|  7.0|(780,[202,203,204...|       1.0|
|  2.0|(780,[94,95,96,97...|       1.0|
|  1.0|(780,[128,129,130...|       1.0|
|  0.0|(780,[124,125,126...|       1.0|
|  4.0|(780,[150,151,159...|       1.0|
|  1.0|(780,[156,157,158...|       1.0|
|  4.0|(780,[149,150,151...|       1.0|
|  9.0|(780,[179,180,181...|       1.0|
|  5.0|(780,[129,130,131...|       1.0|
|  9.0|(780,[209,210,211...|       1.0|
|  0.0|(780,[123,124,125...|       1.0|
|  6.0|(780,[94,95,96,97...|       1.0|
|  9.0|(780,[208,209,210...|       1.0|
|  0.0|(780,[152,153,154...|       1.0|
|  1.0|(780,[125,126,127...|       1.0|
|  5.0|(780,[124,125,126...|       1.0|
|  9.0|(780,[179,180,181...|       1.0|
|  7.0|(780,[200,201,202...|       1.0|
|  3.0|(780,[118,119,120...|       1.0|
|  4.0|(780,[158,159,185...|       1.0|
|  9.0|(780,[183,184,185...|       1.0|
|  6.0|(780,[96,97,98,99...|       1.0|
|  6.0|(780,[93,94,95,12...|       1.0|
|  5.0|(780,[156,157,158...|       1.0|
|  4.0|(780,[151,152,178...|       1.0|
|  0.0|(780,[125,126,127...|       1.0|
|  7.0|(780,[230,234,235...|       1.0|
|  4.0|(780,[152,153,179...|       1.0|
|  0.0|(780,[149,150,151...|       1.0|
|  1.0|(780,[123,124,125...|       1.0|
|  3.0|(780,[175,176,177...|       1.0|
|  1.0|(780,[152,153,154...|       1.0|
|  3.0|(780,[148,149,150...|       1.0|
|  4.0|(780,[122,123,150...|       1.0|
|  7.0|(780,[175,176,177...|       1.0|
|  2.0|(780,[124,125,126...|       1.0|
|  7.0|(780,[202,203,204...|       1.0|
|  1.0|(780,[151,152,153...|       1.0|
|  2.0|(780,[125,126,127...|       1.0|
|  1.0|(780,[126,127,128...|       1.0|
|  1.0|(780,[125,126,153...|       1.0|
|  7.0|(780,[207,208,209...|       1.0|
|  4.0|(780,[176,177,178...|       1.0|
|  2.0|(780,[126,127,128...|       1.0|
|  3.0|(780,[121,122,123...|       1.0|
|  5.0|(780,[152,153,154...|       1.0|
|  1.0|(780,[122,123,124...|       1.0|
|  2.0|(780,[65,66,67,68...|       1.0|
|  4.0|(780,[177,178,179...|       1.0|
|  4.0|(780,[147,148,157...|       1.0|
|  6.0|(780,[100,101,102...|       1.0|
|  3.0|(780,[172,173,174...|       1.0|
|  5.0|(780,[163,164,165...|       1.0|
|  5.0|(780,[126,127,128...|       1.0|
|  6.0|(780,[93,94,95,12...|       1.0|
|  0.0|(780,[151,152,153...|       1.0|
|  4.0|(780,[148,149,150...|       1.0|
|  1.0|(780,[155,156,157...|       1.0|
|  9.0|(780,[209,210,211...|       1.0|
|  5.0|(780,[190,191,192...|       1.0|
|  7.0|(780,[198,199,200...|       1.0|
|  8.0|(780,[153,154,155...|       1.0|
|  9.0|(780,[178,179,180...|       1.0|
|  3.0|(780,[95,96,97,98...|       1.0|
|  7.0|(780,[200,201,202...|       1.0|
|  4.0|(780,[156,157,184...|       1.0|
|  6.0|(780,[67,68,69,95...|       1.0|
|  4.0|(780,[160,161,162...|       1.0|
|  3.0|(780,[148,149,150...|       1.0|
|  0.0|(780,[152,153,179...|       1.0|
|  7.0|(780,[206,207,208...|       1.0|
|  0.0|(780,[123,124,125...|       1.0|
|  2.0|(780,[119,120,121...|       1.0|
|  9.0|(780,[180,181,182...|       1.0|
|  1.0|(780,[152,153,154...|       1.0|
|  7.0|(780,[213,214,215...|       1.0|
|  3.0|(780,[124,125,126...|       1.0|
|  2.0|(780,[205,206,207...|       1.0|
|  9.0|(780,[183,184,185...|       1.0|
|  7.0|(780,[209,210,211...|       1.0|
|  7.0|(780,[205,206,207...|       1.0|
|  6.0|(780,[99,100,101,...|       1.0|
|  2.0|(780,[96,97,98,99...|       1.0|
|  7.0|(780,[204,205,206...|       1.0|
|  8.0|(780,[156,157,159...|       1.0|
|  4.0|(780,[147,148,158...|       1.0|
|  7.0|(780,[203,204,205...|       1.0|
|  3.0|(780,[146,147,148...|       1.0|
|  6.0|(780,[67,68,69,70...|       1.0|
|  1.0|(780,[128,129,130...|       1.0|
|  3.0|(780,[152,153,154...|       1.0|
|  6.0|(780,[71,72,73,74...|       1.0|
|  9.0|(780,[182,183,184...|       1.0|
|  3.0|(780,[149,150,151...|       1.0|
|  1.0|(780,[123,124,125...|       1.0|
|  4.0|(780,[158,159,160...|       1.0|
|  1.0|(780,[149,150,151...|       1.0|
|  7.0|(780,[175,176,177...|       1.0|
|  6.0|(780,[99,100,101,...|       1.0|
|  9.0|(780,[177,178,179...|       1.0|
+-----+--------------------+----------+
only showing top 100 rows

Precision:0.1135




 [http://dr0muzwhcp26z.cloudfront.net/static/corporate/SDL-logo-2014.png] 
<http://www.sdl.com/>
www.sdl.com<http://www.sdl.com>


SDL PLC confidential, all rights reserved. If you are not the intended 
recipient of this mail SDL requests and requires that you delete it without 
acting upon or copying any of its contents, and we further request that you 
advise us.

SDL PLC is a public limited company registered in England and Wales. Registered 
number: 02675207.
Registered address: Globe House, Clivemont Road, Maidenhead, Berkshire SL6 7DY, 
UK.




This message has been scanned for malware by Websense. 
www.websense.com<http://www.websense.com/>




Click 
here<https://www.mailcontrol.com/sr/QM!YuNVgUMTGX2PQPOmvUu5zZAYN1Mosk6WvDHN7E7Hz8DvhbiWFnYtyBSVgXfInQ9t!hqieNV2OdoKZwoXR0w==>
 to report this email as spam.

Reply via email to