if((score >=0 && label == 1) || (score <0 && label == 0)) { return 1; //correct classiciation } else return 0;
I suspect score is always between 0 and 1 On Sat, Nov 28, 2015 at 10:39 AM, Tarek Elgamal <tarek.elga...@gmail.com> wrote: > Hi, > > I am trying to run the straightforward example of SVm but I am getting low > accuracy (around 50%) when I predict using the same data I used for > training. I am probably doing the prediction in a wrong way. My code is > below. I would appreciate any help. > > > import java.util.List; > > import org.apache.spark.SparkConf; > import org.apache.spark.SparkContext; > import org.apache.spark.api.java.JavaRDD; > import org.apache.spark.api.java.function.Function; > import org.apache.spark.api.java.function.Function2; > import org.apache.spark.mllib.classification.SVMModel; > import org.apache.spark.mllib.classification.SVMWithSGD; > import org.apache.spark.mllib.regression.LabeledPoint; > import org.apache.spark.mllib.util.MLUtils; > > import scala.Tuple2; > import edu.illinois.biglbjava.readers.LabeledPointReader; > > public class SimpleDistSVM { > public static void main(String[] args) { > SparkConf conf = new SparkConf().setAppName("SVM Classifier Example"); > SparkContext sc = new SparkContext(conf); > String inputPath=args[0]; > > // Read training data > JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, > inputPath).toJavaRDD(); > > // Run training algorithm to build the model. > int numIterations = 3; > final SVMModel model = SVMWithSGD.train(data.rdd(), numIterations); > > // Clear the default threshold. > model.clearThreshold(); > > > // Predict points in test set and map to an RDD of 0/1 values where 0 > is misclassication and 1 is correct classification > JavaRDD<Integer> classification = data.map(new Function<LabeledPoint, > Integer>() { > public Integer call(LabeledPoint p) { > int label = (int) p.label(); > Double score = model.predict(p.features()); > if((score >=0 && label == 1) || (score <0 && label == 0)) > { > return 1; //correct classiciation > } > else > return 0; > > } > } > ); > // sum up all values in the rdd to get the number of correctly > classified examples > int sum=classification.reduce(new Function2<Integer, Integer, > Integer>() > { > public Integer call(Integer arg0, Integer arg1) > throws Exception { > return arg0+arg1; > }}); > > //compute accuracy as the percentage of the correctly classified > examples > double accuracy=((double)sum)/((double)classification.count()); > System.out.println("Accuracy = " + accuracy); > > } > } > ); > } > } > -- Best Regards Jeff Zhang