Bago, The code I wrote is not generating the issue. In our case, we build a ML pipeline from a UI and is done in a particular fashion so that a user can create a pipeline behind the scene using drag and drop. I am yet to dig deeper to recreate the same as a standalone code. Meanwhile I am sharing a similar which I wrote here. Hope to find time next week to get the correct one.
import java.util.Arrays; import java.util.List; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.ml.feature.CountVectorizer; import org.apache.spark.ml.feature.CountVectorizerModel; import org.apache.spark.ml.feature.IndexToString; import org.apache.spark.ml.feature.StringIndexer; import org.apache.spark.ml.feature.StringIndexerModel; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.tuning.ParamGridBuilder; import org.apache.spark.ml.tuning.TrainValidationSplit; import org.apache.spark.ml.tuning.TrainValidationSplitModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.streaming.StreamingQuery; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class StreamingIssueCountVectorizerSplit { public static void main(String[] args) throws Exception{ SparkSession sparkSession = SparkSession.builder().appName("StreamingIssueCountVectorizer") .master("local[2]") .getOrCreate(); List<Row> _trainData = Arrays.asList( RowFactory.create("sunny fantastic day", "Positive"), RowFactory.create("fantastic morning match", "Positive"), RowFactory.create("good morning", "Positive"), RowFactory.create("boring evening", "Negative"), RowFactory.create("tragic evening event", "Negative"), RowFactory.create("today is bad ", "Negative") ); List<Row> _testData = Arrays.asList( RowFactory.create("sunny morning"), RowFactory.create("bad evening") ); StructType schema = new StructType(new StructField[]{ new StructField("tweet", DataTypes.StringType, false, Metadata.empty()), new StructField("sentiment", DataTypes.StringType, true, Metadata.empty()) }); StructType testSchema = new StructType(new StructField[]{ new StructField("tweet", DataTypes.StringType, false, Metadata.empty()) }); Dataset<Row> trainData = sparkSession.createDataFrame(_trainData, schema); Dataset<Row> testData = sparkSession.createDataFrame(_testData, testSchema); StringIndexerModel labelIndexerModel = new StringIndexer() .setInputCol("sentiment") .setOutputCol("label") .setHandleInvalid("skip") .fit(trainData); Tokenizer tokenizer = new Tokenizer() .setInputCol("tweet") .setOutputCol("words"); CountVectorizer countVectorizer = new CountVectorizer() .setInputCol(tokenizer.getOutputCol()) .setOutputCol("features") .setVocabSize(3) .setMinDF(2) .setMinTF(2).setBinary(true); Dataset<Row> words = tokenizer.transform(trainData); CountVectorizerModel countVectorizerModel = countVectorizer.fit(words); LogisticRegression lr = new LogisticRegression() .setMaxIter(10) .setRegParam(0.001); IndexToString labelConverter = new IndexToString() .setInputCol("prediction") .setOutputCol("predicted") .setLabels(labelIndexerModel.labels()); countVectorizerModel.setMinTF(1); Pipeline pipeline = new Pipeline() .setStages( new PipelineStage[]{labelIndexerModel, tokenizer, countVectorizerModel, lr, labelConverter}); ParamMap[] paramGrid = new ParamGridBuilder() .addGrid(lr.regParam(), new double[]{0.1, 0.01}) .addGrid(lr.fitIntercept()) .addGrid(lr.elasticNetParam(), new double[]{0.0, 0.5, 1.0}) .build(); MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator(); evaluator.setLabelCol("label"); evaluator.setPredictionCol("prediction"); TrainValidationSplit trainValidationSplit = new TrainValidationSplit() .setEstimator(pipeline) .setEvaluator(evaluator) .setEstimatorParamMaps(paramGrid) .setTrainRatio(0.7); // Fit the pipeline to training documents. TrainValidationSplitModel trainValidationSplitModel = trainValidationSplit.fit(trainData); trainValidationSplitModel.write().overwrite().save("/tmp/CountSplit.model"); TrainValidationSplitModel _loadedModel = TrainValidationSplitModel.load("/tmp/CountSplit.model"); PipelineModel loadedModel = (PipelineModel) ( _loadedModel).bestModel(); //Test on non-streaming data Dataset<Row> predicted = loadedModel.transform(testData); List<Row> rows = predicted.select("tweet", "predicted").collectAsList(); for (Row r : rows) { System.out.println("[" + r.get(0) + "], prediction=" + r.get(1)); } //Test on streaming data Dataset<Row> lines = sparkSession .readStream() .format("socket") .option("host", "localhost") .option("port", 9999) .load(); lines = lines.withColumnRenamed("value", "tweet"); StreamingQuery query = loadedModel.transform(lines).writeStream() .outputMode("append") .format("console") .start(); query.awaitTermination(); } } -- Sent from: http://apache-spark-developers-list.1001551.n3.nabble.com/ --------------------------------------------------------------------- To unsubscribe e-mail: dev-unsubscr...@spark.apache.org