[ https://issues.apache.org/jira/browse/FLINK-4964?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=15633603#comment-15633603 ]
ASF GitHub Bot commented on FLINK-4964: --------------------------------------- Github user tfournier314 commented on a diff in the pull request: https://github.com/apache/flink/pull/2740#discussion_r86402282 --- Diff: flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/StringIndexer.scala --- @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.preprocessing + +import org.apache.flink.api.scala._ +import org.apache.flink.ml.common.{Parameter, ParameterMap} +import org.apache.flink.ml.pipeline.{FitOperation, TransformDataSetOperation, Transformer} +import org.apache.flink.ml.preprocessing.StringIndexer.HandleInvalid + +import scala.collection.immutable.Seq + +/** + * StringIndexer maps a DataSet[String] to a DataSet[(String,Int)] where each label is + * associated with an index.The indices are in [0,numLabels) and are ordered by label + * frequencies. The most frequent label has index 0. + * + * @example + * {{{ + * val trainingDS: DataSet[String] = env.fromCollection(data) + * val transformer = StringIndexer().setHandleInvalid("skip") + * + * transformer.fit(trainingDS) + * val transformedDS = transformer.transform(trainingDS) + * }}} + * + * + * You can manage unseen labels using HandleInvalid parameter. If HandleInvalid is + * set to "skip" (see example),then each line containing an unseen label is skipped. + * Otherwise an exception is raised. + * + * =Parameters= + * + * -[[HandleInvalid]]: Define how to handle unseen labels: by default is "skip" + * + * + */ +class StringIndexer extends Transformer[StringIndexer] { + + private[preprocessing] var metricsOption: Option[Map[String, Int]] = None + + + /** + * Set the value to handle unseen labels + * @param value set to "skip" if you want to filter line with unseen labels + * @return StringIndexer instance with HandleInvalid value + */ + def setHandleInvalid(value: String): this.type ={ + parameters.add( HandleInvalid, value ) + this + } + +} + +object StringIndexer { + + case object HandleInvalid extends Parameter[String] { + val defaultValue: Option[String] = Some( "skip" ) + } + + // ==================================== Factory methods ======================================== + + def apply(): StringIndexer ={ + new StringIndexer( ) + } + + // ====================================== Operations =========================================== + + /** + * Trains [[StringIndexer]] by learning the count of each labels in the input DataSet. + * + * @return [[FitOperation]] training the [[StringIndexer]] on string labels + */ + implicit def fitStringIndexer ={ + new FitOperation[StringIndexer, String] { + def fit(instance: StringIndexer, fitParameters: ParameterMap, + input: DataSet[String]): Unit = { + val metrics = extractIndices( input ) + instance.metricsOption = Some( metrics ) + } + } + } + + /** + * Count the frequency of each label, sort them in a decreasing order and assign an index + * + * @param input input Dataset containing labels + * @return a map that returns for each label (key) its index (value) + */ + private def extractIndices(input: DataSet[String]): Map[String, Int] ={ + + implicit val resultTypeInformation = createTypeInformation[(String, Int)] + + val mapper = input + .map( s => (s, 1) ) + .groupBy( 0 ) + .reduce( (a, b) => (a._1, a._2 + b._2) ) + .collect( ) + .sortBy( r => (r._2, r._1) ) + .zipWithIndex + .map { case ((s, c), ind) => (s, ind) } + .toMap --- End diff -- Exactly a DataSet would be better > FlinkML - Add StringIndexer > --------------------------- > > Key: FLINK-4964 > URL: https://issues.apache.org/jira/browse/FLINK-4964 > Project: Flink > Issue Type: New Feature > Reporter: Thomas FOURNIER > Priority: Minor > > Add StringIndexer as described here: > http://spark.apache.org/docs/latest/ml-features.html#stringindexer > This will be added in package preprocessing of FlinkML -- This message was sent by Atlassian JIRA (v6.3.4#6332)