lindong28 commented on a change in pull request #4: URL: https://github.com/apache/flink-ml/pull/4#discussion_r715591990
########## File path: flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java ########## @@ -19,241 +19,101 @@ package org.apache.flink.ml.api.core; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.ml.api.pipeline.PipelineModel; import org.apache.flink.table.api.Table; -import org.apache.flink.table.api.TableEnvironment; -import org.apache.flink.util.InstantiationUtil; - -import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException; -import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; /** - * A pipeline is a linear workflow which chains {@link Estimator}s and {@link Transformer}s to - * execute an algorithm. - * - * <p>A pipeline itself can either act as an Estimator or a Transformer, depending on the stages it - * includes. More specifically: - * - * <ul> - * <li>If a Pipeline has an {@link Estimator}, one needs to call {@link - * Pipeline#fit(TableEnvironment, Table)} before use the pipeline as a {@link Transformer} . - * In this case the Pipeline is an {@link Estimator} and can produce a Pipeline as a {@link - * Model}. - * <li>If a Pipeline has no {@link Estimator}, it is a {@link Transformer} and can be applied to a - * Table directly. In this case, {@link Pipeline#fit(TableEnvironment, Table)} will simply - * return the pipeline itself. - * </ul> - * - * <p>In addition, a pipeline can also be used as a {@link PipelineStage} in another pipeline, just - * like an ordinary {@link Estimator} or {@link Transformer} as describe above. + * A Pipeline acts as an Estimator. It consists of an ordered list of stages, each of which could be + * an Estimator, Model, Transformer or AlgoOperator. */ @PublicEvolving -public final class Pipeline - implements Estimator<Pipeline, Pipeline>, Transformer<Pipeline>, Model<Pipeline> { - private static final long serialVersionUID = 1L; - private final List<PipelineStage> stages = new ArrayList<>(); +public final class Pipeline implements Estimator<Pipeline, PipelineModel> { + private static final long serialVersionUID = 6384850154817512318L; + private final List<Stage<?>> stages; private final Params params = new Params(); - private int lastEstimatorIndex = -1; - - public Pipeline() {} - - public Pipeline(String pipelineJson) { - this.loadJson(pipelineJson); - } - - public Pipeline(List<PipelineStage> stages) { - for (PipelineStage s : stages) { - appendStage(s); - } - } - - // is the stage a simple Estimator or pipeline with Estimator - private static boolean isStageNeedFit(PipelineStage stage) { - return (stage instanceof Pipeline && ((Pipeline) stage).needFit()) - || (!(stage instanceof Pipeline) && stage instanceof Estimator); + public Pipeline(List<Stage<?>> stages) { + this.stages = stages; } /** - * Appends a PipelineStage to the tail of this pipeline. Pipeline is editable only via this - * method. The PipelineStage must be Estimator, Transformer, Model or Pipeline. + * Trains the pipeline to fit on the given tables. * - * @param stage the stage to be appended - */ - public Pipeline appendStage(PipelineStage stage) { - if (isStageNeedFit(stage)) { - lastEstimatorIndex = stages.size(); - } else if (!(stage instanceof Transformer)) { - throw new RuntimeException( - "All PipelineStages should be Estimator or Transformer, got:" - + stage.getClass().getSimpleName()); - } - stages.add(stage); - return this; - } - - /** - * Returns a list of all stages in this pipeline in order, the list is immutable. - * - * @return an immutable list of all stages in this pipeline in order. - */ - public List<PipelineStage> getStages() { - return Collections.unmodifiableList(stages); - } - - /** - * Check whether the pipeline acts as an {@link Estimator} or not. When the return value is - * true, that means this pipeline contains an {@link Estimator} and thus users must invoke - * {@link #fit(TableEnvironment, Table)} before they can use this pipeline as a {@link - * Transformer}. Otherwise, the pipeline can be used as a {@link Transformer} directly. - * - * @return {@code true} if this pipeline has an Estimator, {@code false} otherwise - */ - public boolean needFit() { - return this.getIndexOfLastEstimator() >= 0; - } - - public Params getParams() { - return params; - } - - // find the last Estimator or Pipeline that needs fit in stages, -1 stand for no Estimator in - // Pipeline - private int getIndexOfLastEstimator() { - return lastEstimatorIndex; - } - - /** - * Train the pipeline to fit on the records in the given {@link Table}. - * - * <p>This method go through all the {@link PipelineStage}s in order and does the following on - * each stage until the last {@link Estimator}(inclusive). + * <p>This method goes through all stages of this pipeline in order and does the following on + * each stage until the last Estimator (inclusive). * * <ul> - * <li>If a stage is an {@link Estimator}, invoke {@link Estimator#fit(TableEnvironment, - * Table)} with the input table to generate a {@link Model}, transform the the input table - * with the generated {@link Model} to get a result table, then pass the result table to - * the next stage as input. - * <li>If a stage is a {@link Transformer}, invoke {@link - * Transformer#transform(TableEnvironment, Table)} on the input table to get a result - * table, and pass the result table to the next stage as input. + * <li>If a stage is an Estimator, invoke {@link Estimator#fit(Table...)} with the input + * tables to generate a Model, transform the the input tables with the generated Model to + * get result tables, then pass the result tables to the next stage as inputs. + * <li>If a stage is an AlgoOperator, invoke {@link AlgoOperator#transform(Table...)} on the + * input tables to get result tables, and pass the result tables to the next stage as + * inputs. * </ul> * - * <p>After all the {@link Estimator}s are trained to fit their input tables, a new pipeline - * will be created with the same stages in this pipeline, except that all the Estimators in the - * new pipeline are replaced with their corresponding Models generated in the above process. - * - * <p>If there is no {@link Estimator} in the pipeline, the method returns a copy of this - * pipeline. + * <p>After all the Estimators are trained to fit their input tables, a new PipelineModel will + * be created with the same stages in this pipeline, except that all the Estimators in the + * PipelineModel are replaced with the models generated in the above process. * - * @param tEnv the table environment to which the input table is bound. - * @param input the table with records to train the Pipeline. - * @return a pipeline with same stages as this Pipeline except all Estimators replaced with - * their corresponding Models. + * @param inputs a list of tables + * @return a PipelineModel */ @Override - public Pipeline fit(TableEnvironment tEnv, Table input) { - List<PipelineStage> transformStages = new ArrayList<>(stages.size()); - int lastEstimatorIdx = getIndexOfLastEstimator(); + public PipelineModel fit(Table... inputs) { + List<Stage<?>> modelStages = new ArrayList<>(stages.size()); + int lastEstimatorIdx = -1; + for (int i = 0; i < stages.size(); i++) { + if (stages.get(i) instanceof Estimator) { + lastEstimatorIdx = i; + } + } + for (int i = 0; i < stages.size(); i++) { - PipelineStage s = stages.get(i); + Stage<?> s = stages.get(i); if (i <= lastEstimatorIdx) { - Transformer t; - boolean needFit = isStageNeedFit(s); - if (needFit) { - t = ((Estimator) s).fit(tEnv, input); + AlgoOperator<?> t; + if (s instanceof Estimator<?, ?>) { + t = ((Estimator<?, ?>) s).fit(inputs); } else { - // stage is Transformer, guaranteed in appendStage() method - t = (Transformer) s; + t = (AlgoOperator<?>) s; } - transformStages.add(t); - input = t.transform(tEnv, input); + modelStages.add(t); + inputs = t.transform(inputs); Review comment: Thanks for catching this. Both the code and the Java doc has been updated accordingly. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org