if((score >=0 && label == 1) || (score <0 && label == 0))
{
return 1; //correct classiciation
}
else
return 0;
I suspect score is always between 0 and 1
On Sat, Nov 28, 2015 at 10:39 AM, Tarek Elgamal <[email protected]>
wrote:
> Hi,
>
> I am trying to run the straightforward example of SVm but I am getting low
> accuracy (around 50%) when I predict using the same data I used for
> training. I am probably doing the prediction in a wrong way. My code is
> below. I would appreciate any help.
>
>
> import java.util.List;
>
> import org.apache.spark.SparkConf;
> import org.apache.spark.SparkContext;
> import org.apache.spark.api.java.JavaRDD;
> import org.apache.spark.api.java.function.Function;
> import org.apache.spark.api.java.function.Function2;
> import org.apache.spark.mllib.classification.SVMModel;
> import org.apache.spark.mllib.classification.SVMWithSGD;
> import org.apache.spark.mllib.regression.LabeledPoint;
> import org.apache.spark.mllib.util.MLUtils;
>
> import scala.Tuple2;
> import edu.illinois.biglbjava.readers.LabeledPointReader;
>
> public class SimpleDistSVM {
> public static void main(String[] args) {
> SparkConf conf = new SparkConf().setAppName("SVM Classifier Example");
> SparkContext sc = new SparkContext(conf);
> String inputPath=args[0];
>
> // Read training data
> JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc,
> inputPath).toJavaRDD();
>
> // Run training algorithm to build the model.
> int numIterations = 3;
> final SVMModel model = SVMWithSGD.train(data.rdd(), numIterations);
>
> // Clear the default threshold.
> model.clearThreshold();
>
>
> // Predict points in test set and map to an RDD of 0/1 values where 0
> is misclassication and 1 is correct classification
> JavaRDD<Integer> classification = data.map(new Function<LabeledPoint,
> Integer>() {
> public Integer call(LabeledPoint p) {
> int label = (int) p.label();
> Double score = model.predict(p.features());
> if((score >=0 && label == 1) || (score <0 && label == 0))
> {
> return 1; //correct classiciation
> }
> else
> return 0;
>
> }
> }
> );
> // sum up all values in the rdd to get the number of correctly
> classified examples
> int sum=classification.reduce(new Function2<Integer, Integer,
> Integer>()
> {
> public Integer call(Integer arg0, Integer arg1)
> throws Exception {
> return arg0+arg1;
> }});
>
> //compute accuracy as the percentage of the correctly classified
> examples
> double accuracy=((double)sum)/((double)classification.count());
> System.out.println("Accuracy = " + accuracy);
>
> }
> }
> );
> }
> }
>
--
Best Regards
Jeff Zhang