Bago,

The code I wrote is not generating the issue. In our case, we build a ML
pipeline from a UI and is done in a particular fashion so that a user can
create a pipeline behind the scene using drag and drop. I am yet to dig
deeper to recreate the same as a standalone code. Meanwhile I am sharing a
similar which I wrote here. Hope to find time next week to get the correct
one.

import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.ml.tuning.TrainValidationSplit;
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.streaming.StreamingQuery;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class StreamingIssueCountVectorizerSplit {

  public static void main(String[] args) throws Exception{
    SparkSession sparkSession =
SparkSession.builder().appName("StreamingIssueCountVectorizer")
        .master("local[2]")
        .getOrCreate();

    List<Row> _trainData = Arrays.asList(
        RowFactory.create("sunny fantastic day", "Positive"),
        RowFactory.create("fantastic morning match", "Positive"),
        RowFactory.create("good morning", "Positive"),
        RowFactory.create("boring evening", "Negative"),
        RowFactory.create("tragic evening event", "Negative"),
        RowFactory.create("today is bad ", "Negative")
    );
    List<Row> _testData = Arrays.asList(
        RowFactory.create("sunny morning"),
        RowFactory.create("bad evening")
    );
    StructType schema = new StructType(new StructField[]{
        new StructField("tweet", DataTypes.StringType, false,
Metadata.empty()),
        new StructField("sentiment", DataTypes.StringType, true,
Metadata.empty())
    });
    StructType testSchema = new StructType(new StructField[]{
        new StructField("tweet", DataTypes.StringType, false,
Metadata.empty())
    });

    Dataset<Row> trainData = sparkSession.createDataFrame(_trainData,
schema);
    Dataset<Row> testData = sparkSession.createDataFrame(_testData,
testSchema);
    StringIndexerModel labelIndexerModel = new StringIndexer()
        .setInputCol("sentiment")
        .setOutputCol("label")
        .setHandleInvalid("skip")
        .fit(trainData);
    Tokenizer tokenizer = new Tokenizer()
        .setInputCol("tweet")
        .setOutputCol("words");
    CountVectorizer countVectorizer = new CountVectorizer()
        .setInputCol(tokenizer.getOutputCol())
        .setOutputCol("features")
        .setVocabSize(3)
        .setMinDF(2)
        .setMinTF(2).setBinary(true);
    Dataset<Row> words = tokenizer.transform(trainData);
    CountVectorizerModel countVectorizerModel = countVectorizer.fit(words);

    LogisticRegression lr = new LogisticRegression()
        .setMaxIter(10)
        .setRegParam(0.001);
    IndexToString labelConverter = new IndexToString()
        .setInputCol("prediction")
        .setOutputCol("predicted")
        .setLabels(labelIndexerModel.labels());

    countVectorizerModel.setMinTF(1);
    Pipeline pipeline = new Pipeline()
        .setStages(
            new PipelineStage[]{labelIndexerModel, tokenizer,
countVectorizerModel, lr, labelConverter});
    ParamMap[] paramGrid = new ParamGridBuilder()
        .addGrid(lr.regParam(), new double[]{0.1, 0.01})
        .addGrid(lr.fitIntercept())
        .addGrid(lr.elasticNetParam(), new double[]{0.0, 0.5, 1.0})
        .build();

    MulticlassClassificationEvaluator evaluator = new
MulticlassClassificationEvaluator();
    evaluator.setLabelCol("label");
    evaluator.setPredictionCol("prediction");

    TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
        .setEstimator(pipeline)
        .setEvaluator(evaluator)
        .setEstimatorParamMaps(paramGrid)
        .setTrainRatio(0.7);
    

    // Fit the pipeline to training documents.
    TrainValidationSplitModel trainValidationSplitModel =
trainValidationSplit.fit(trainData);

   
trainValidationSplitModel.write().overwrite().save("/tmp/CountSplit.model");

    TrainValidationSplitModel _loadedModel =
TrainValidationSplitModel.load("/tmp/CountSplit.model");
    PipelineModel loadedModel = (PipelineModel) ( _loadedModel).bestModel();

    //Test on non-streaming data
    Dataset<Row> predicted = loadedModel.transform(testData);
    List<Row> rows = predicted.select("tweet", "predicted").collectAsList();
    for (Row r : rows) {
      System.out.println("[" + r.get(0) + "], prediction=" + r.get(1));
    }

    //Test on streaming data
    Dataset<Row> lines = sparkSession
        .readStream()
        .format("socket")
        .option("host", "localhost")
        .option("port", 9999)
        .load();
    lines = lines.withColumnRenamed("value", "tweet");


    StreamingQuery query = loadedModel.transform(lines).writeStream()
        .outputMode("append")
        .format("console")
        .start();

    query.awaitTermination();

  }
}





--
Sent from: http://apache-spark-developers-list.1001551.n3.nabble.com/

---------------------------------------------------------------------
To unsubscribe e-mail: dev-unsubscr...@spark.apache.org

Reply via email to