lindong28 commented on code in PR #237: URL: https://github.com/apache/flink-ml/pull/237#discussion_r1194570809
########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java: ########## @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.updater; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A model updater that could be used to handle push/pull request from workers. + * + * <p>Note that model updater should also ensure that model data is robust to failures. + */ +public interface ModelUpdater extends Serializable { + + /** Initialize the model data. */ + void open(long startFeatureIndex, long endFeatureIndex); + + /** Applies the push to update the model data, e.g., using gradient to update model. */ + void handlePush(long[] keys, double[] values); + + /** Applies the pull and return the retrieved model data. */ + double[] handlePull(long[] keys); + + /** Returns model pieces with the format of (startFeatureIdx, endFeatureIdx, modelValues). */ + Iterator<Tuple3<Long, Long, double[]>> getModelPieces(); + + /** Recover the model data from state. */ Review Comment: It would be useful to make the comment style consistent. E.g. Recover -> Recovers. Same for other comments. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java: ########## @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.updater; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A model updater that could be used to handle push/pull request from workers. + * + * <p>Note that model updater should also ensure that model data is robust to failures. + */ +public interface ModelUpdater extends Serializable { + + /** Initialize the model data. */ + void open(long startFeatureIndex, long endFeatureIndex); + + /** Applies the push to update the model data, e.g., using gradient to update model. */ + void handlePush(long[] keys, double[] values); + + /** Applies the pull and return the retrieved model data. */ + double[] handlePull(long[] keys); + + /** Returns model pieces with the format of (startFeatureIdx, endFeatureIdx, modelValues). */ + Iterator<Tuple3<Long, Long, double[]>> getModelPieces(); Review Comment: It would be useful to know what is the expected output of this API w.r.t. the invocation of other APIs (e.g. handlePush). ########## flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java: ########## @@ -81,11 +82,49 @@ public DataFrame transform(DataFrame input) { public LogisticRegressionModelServable setModelData(InputStream... modelDataInputs) throws IOException { Preconditions.checkArgument(modelDataInputs.length == 1); + List<LogisticRegressionModelData> modelPieces = new ArrayList<>(); + while (true) { + try { + LogisticRegressionModelData piece = + LogisticRegressionModelData.decode(modelDataInputs[0]); + modelPieces.add(piece); + } catch (IOException e) { + // Reached the end of model stream. + break; + } + } - modelData = LogisticRegressionModelData.decode(modelDataInputs[0]); + modelData = mergePieces(modelPieces); return this; } + @VisibleForTesting + public static LogisticRegressionModelData mergePieces( Review Comment: Would it be more intuitive to put this method in `LogisticRegressionModelData`? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/FTRL.java: ########## @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.updater; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** The FTRL model updater. */ Review Comment: Would it be useful to provide doc or reference link to explain what is FTRL? Maybe something like https://github.com/Angel-ML/angel/blob/master/docs/algo/ftrl_lr_spark.md. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java: ########## @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +/** Message Type between workers and servers. */ +public enum MessageType { + ZEROS_TO_PUSH((char) 0), Review Comment: How about using the following names: PUSH_ZERO, PUSH_KV, PULL_INDICE, PULL_VALUE I am not sure what is the meaning of `zero` in `PUSH_ZERO`. Should we rename it something like `INITIALIZE_MODEL`? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java: ########## @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.common.ps.message.ValuesPulledM; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; + +/** + * Merges the message from different servers for one pull request. + * + * <p>Note that for each single-thread worker, there are at exactly #numServers pieces for each pull + * request in the feedback edge. + */ +public class MirrorWorkerOperator extends AbstractStreamOperator<byte[]> Review Comment: It is not clear what is the meaning of `mirror` here. Maybe we can discuss offline. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java: ########## @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.updater; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A model updater that could be used to handle push/pull request from workers. + * + * <p>Note that model updater should also ensure that model data is robust to failures. + */ +public interface ModelUpdater extends Serializable { Review Comment: Given that `ModelUpdater` is used only by classes in the package `org.apache.flink.ml.common.ps`, would it be better to move it to that package? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java: ########## @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.updater; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A model updater that could be used to handle push/pull request from workers. + * + * <p>Note that model updater should also ensure that model data is robust to failures. + */ +public interface ModelUpdater extends Serializable { + + /** Initialize the model data. */ + void open(long startFeatureIndex, long endFeatureIndex); + + /** Applies the push to update the model data, e.g., using gradient to update model. */ + void handlePush(long[] keys, double[] values); + + /** Applies the pull and return the retrieved model data. */ + double[] handlePull(long[] keys); Review Comment: What is the relationship between this method and `handlePush`? For example, does this only handle `keys` that has been updated with `handlePush()`? If it works like a map, maybe re-use the API of map so that it is more intuitive. ########## flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java: ########## @@ -81,11 +82,49 @@ public DataFrame transform(DataFrame input) { public LogisticRegressionModelServable setModelData(InputStream... modelDataInputs) throws IOException { Preconditions.checkArgument(modelDataInputs.length == 1); + List<LogisticRegressionModelData> modelPieces = new ArrayList<>(); + while (true) { + try { + LogisticRegressionModelData piece = + LogisticRegressionModelData.decode(modelDataInputs[0]); Review Comment: Other `XXXModelData#decode` methods will finish reading the given input stream and return a self-contained model data instance. We will break this convention by having `LogisticRegressionModelData.decode` return a segment of the full model data. Would it be simpler to have `LogisticRegressionModelData` maintain a list of `LogisticRegressionModelDataSegment` internally? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java: ########## @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight; +import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss; +import org.apache.flink.ml.common.lossfunc.LossFunc; +import org.apache.flink.ml.common.ps.training.IterationStageList; +import org.apache.flink.ml.common.ps.training.ProcessStage; +import org.apache.flink.ml.common.ps.training.PullStage; +import org.apache.flink.ml.common.ps.training.PushStage; +import org.apache.flink.ml.common.ps.training.SerializableConsumer; +import org.apache.flink.ml.common.ps.training.TrainingContext; +import org.apache.flink.ml.common.ps.training.TrainingUtils; +import org.apache.flink.ml.common.updater.FTRL; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableFunction; +import org.apache.flink.util.function.SerializableSupplier; + +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * An Estimator which implements the large scale logistic regression algorithm using FTRL optimizer. + * + * <p>See https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegressionWithFtrl Review Comment: Since we keep both `LogisticRegressionWithFtrl` and `LogisticRegression` and both classes implement the same algorithm, I suppose these two algorithms have different pros/cons that address different use-cases. Can you provide information to help users decide which algorithm to use? ########## flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java: ########## @@ -81,11 +82,49 @@ public DataFrame transform(DataFrame input) { public LogisticRegressionModelServable setModelData(InputStream... modelDataInputs) throws IOException { Preconditions.checkArgument(modelDataInputs.length == 1); + List<LogisticRegressionModelData> modelPieces = new ArrayList<>(); + while (true) { + try { + LogisticRegressionModelData piece = + LogisticRegressionModelData.decode(modelDataInputs[0]); + modelPieces.add(piece); Review Comment: It is probably more common and intuitive to use `segment` instead of `piece`. We can find a lot of class in Flink with `segment` in the class name. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java: ########## @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.ps.message.IndicesToPullM; +import org.apache.flink.ml.common.ps.message.KVsToPushM; +import org.apache.flink.ml.common.ps.message.MessageType; +import org.apache.flink.ml.common.ps.message.MessageUtils; +import org.apache.flink.ml.common.ps.message.ValuesPulledM; +import org.apache.flink.ml.common.ps.message.ZerosToPushM; +import org.apache.flink.ml.common.updater.ModelUpdater; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.SerializableObject; + +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +/** + * The server operator maintains the shared parameters. It receives push/pull requests from {@link + * WorkerOperator} and sends the answer request to {@link MirrorWorkerOperator}. It works closely + * with {@link ModelUpdater} in the following way: + * + * <ul> + * <li>The server operator deals with the message from workers and decide when to process the + * received message. (i.e., synchronous vs. asynchronous). + * <li>The server operator calls {@link ModelUpdater#handlePush(long[], double[])} and {@link + * ModelUpdater#handlePull(long[])} to process the messages in detail. + * <li>The server operator ensures that {@link ModelUpdater} is robust to failures. Review Comment: Instead of using `robust to failures`, it might be simpler and more explicit to say something like this: The server operator triggers checkpoint for {@link ModelUpdater}. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java: ########## @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight; +import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss; +import org.apache.flink.ml.common.lossfunc.LossFunc; +import org.apache.flink.ml.common.ps.training.IterationStageList; +import org.apache.flink.ml.common.ps.training.ProcessStage; +import org.apache.flink.ml.common.ps.training.PullStage; +import org.apache.flink.ml.common.ps.training.PushStage; +import org.apache.flink.ml.common.ps.training.SerializableConsumer; +import org.apache.flink.ml.common.ps.training.TrainingContext; +import org.apache.flink.ml.common.ps.training.TrainingUtils; +import org.apache.flink.ml.common.updater.FTRL; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableFunction; +import org.apache.flink.util.function.SerializableSupplier; + +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * An Estimator which implements the large scale logistic regression algorithm using FTRL optimizer. + * + * <p>See https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegressionWithFtrl + implements Estimator<LogisticRegressionWithFtrl, LogisticRegressionModel>, + LogisticRegressionWithFtrlParams<LogisticRegressionWithFtrl> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public LogisticRegressionWithFtrl() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public LogisticRegressionModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + String classificationType = getMultiClass(); + Preconditions.checkArgument( + "auto".equals(classificationType) || "binomial".equals(classificationType), + "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + DataStream<LabeledLargePointWithWeight> trainData = + tEnv.toDataStream(inputs[0]) + .map( + (MapFunction<Row, LabeledLargePointWithWeight>) + dataPoint -> { + double weight = + getWeightCol() == null + ? 1.0 + : ((Number) + dataPoint.getField( + getWeightCol())) + .doubleValue(); + double label = + ((Number) dataPoint.getField(getLabelCol())) + .doubleValue(); + boolean isBinomial = + Double.compare(0., label) == 0 + || Double.compare(1., label) == 0; + if (!isBinomial) { + throw new RuntimeException( + "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); + } + Tuple2<long[], double[]> features = + dataPoint.getFieldAs(getFeaturesCol()); + return new LabeledLargePointWithWeight( + features, label, weight); + }); + + DataStream<Long> modelDim; + if (getModelDim() > 0) { + modelDim = trainData.getExecutionEnvironment().fromElements(getModelDim()); + } else { + modelDim = + DataStreamUtils.reduce( + trainData.map(x -> x.features.f0[x.features.f0.length - 1]), + (ReduceFunction<Long>) Math::max) + .map((MapFunction<Long, Long>) value -> value + 1); + } + + LogisticRegressionWithFtrlTrainingContext trainingContext = + new LogisticRegressionWithFtrlTrainingContext(getParamMap()); + + IterationStageList<LogisticRegressionWithFtrlTrainingContext> iterationStages = + new IterationStageList<>(trainingContext); + iterationStages + .addTrainingStage(new ComputeIndices()) + .addTrainingStage( + new PullStage( + (SerializableSupplier<long[]>) () -> trainingContext.pullIndices, + (SerializableConsumer<double[]>) + x -> trainingContext.pulledValues = x)) + .addTrainingStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE)) + .addTrainingStage( + new PushStage( + (SerializableSupplier<long[]>) () -> trainingContext.pushIndices, + (SerializableSupplier<double[]>) () -> trainingContext.pushValues)) + .setTerminationCriteria( + (SerializableFunction<LogisticRegressionWithFtrlTrainingContext, Boolean>) + o -> o.iterationId >= getMaxIter()); + FTRL ftrl = new FTRL(getAlpha(), getBeta(), getReg(), getElasticNet()); + + DataStream<Tuple3<Long, Long, double[]>> rawModelData = + TrainingUtils.train( + modelDim, + trainData, + ftrl, + iterationStages, + getNumServers(), + getNumServerCores()); + + final long modelVersion = 0L; + + DataStream<LogisticRegressionModelData> modelData = + rawModelData.map( + tuple3 -> + new LogisticRegressionModelData( + Vectors.dense(tuple3.f2), + tuple3.f0, + tuple3.f1, + modelVersion)); + + LogisticRegressionModel model = + new LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData)); + ParamUtils.updateExistingParams(model, paramMap); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static LogisticRegressionWithFtrl load(StreamTableEnvironment tEnv, String path) + throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } +} + +/** + * An iteration stage that samples a batch of training data and computes the indices needed to + * compute gradients. + */ +class ComputeIndices extends ProcessStage<LogisticRegressionWithFtrlTrainingContext> { + + @Override + public void process(LogisticRegressionWithFtrlTrainingContext context) throws Exception { + context.readInNextBatchData(); + context.pullIndices = computeIndices(context.batchData); + } + + public static long[] computeIndices(List<LabeledLargePointWithWeight> dataPoints) { + LongOpenHashSet indices = new LongOpenHashSet(); + for (LabeledLargePointWithWeight dataPoint : dataPoints) { + long[] notZeros = dataPoint.features.f0; + for (long index : notZeros) { + indices.add(index); + } + } + + long[] sortedIndices = new long[indices.size()]; + Iterator<Long> iterator = indices.iterator(); + int i = 0; + while (iterator.hasNext()) { + sortedIndices[i++] = iterator.next(); + } + Arrays.sort(sortedIndices); + return sortedIndices; + } +} + +/** + * An iteration stage that uses the pulled model values and sampled batch data to compute the + * gradients. + */ +class ComputeGradients extends ProcessStage<LogisticRegressionWithFtrlTrainingContext> { Review Comment: Since APIs of this class may be invoked directly outside `LogisticRegressionWithFtrl`, it seems more conventional and readable to move this class outside `LogisticRegressionWithFtrl`. ########## flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java: ########## @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.feature; + +import org.apache.flink.api.java.tuple.Tuple2; + +/** A data point to represent values that use long as index and double as values. */ +public class LabeledLargePointWithWeight { + public Tuple2<long[], double[]> features; Review Comment: Can you explain why we can't re-use `LabeledPointWithWeight`? If the features presented here encodes a sparse vector, then we should be able to re-use `LabeledPointWithWeight` because `LabeledPointWithWeight#features` can be a SparseVector. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingContext.java: ########## @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; + +import java.io.Serializable; + +/** + * Stores the context information that is alive during the training process. Note that the context + * information will be updated by each {@link IterationStage}. + * + * <p>Note that subclasses should take care of the snapshot of object stored in {@link + * TrainingContext} if the object satisfies that: the write-process is followed by an {@link + * PullStage}, which is later again read by other stages. + */ +public interface TrainingContext extends Serializable { Review Comment: `context` is typically used for APIs that get states rather than writing states. Would it be more intuitive to name it `IterationStageListener`? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java: ########## @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.util.function.SerializableFunction; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +/** + * A list of iteration stages to express the logic of an iterative machine learning training + * process. + */ +public class IterationStageList<T extends TrainingContext> implements Serializable { + public final T context; + public Function<T, Boolean> shouldTerminate; + public List<IterationStage> stageList; + + public IterationStageList(T context) { + this.stageList = new ArrayList<>(); + this.context = context; + } + + /** Sets the criteria of termination. */ + public void setTerminationCriteria(SerializableFunction<T, Boolean> shouldTerminate) { + this.shouldTerminate = shouldTerminate; + } + + /** Adds an iteration stage into the stage list. */ + public IterationStageList<T> addTrainingStage(IterationStage stage) { Review Comment: Given that the class name is `IterationStageList`, would it be simpler to name the method `add(...)` or `addStage(...)`? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java: ########## @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight; +import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss; +import org.apache.flink.ml.common.lossfunc.LossFunc; +import org.apache.flink.ml.common.ps.training.IterationStageList; +import org.apache.flink.ml.common.ps.training.ProcessStage; +import org.apache.flink.ml.common.ps.training.PullStage; +import org.apache.flink.ml.common.ps.training.PushStage; +import org.apache.flink.ml.common.ps.training.SerializableConsumer; +import org.apache.flink.ml.common.ps.training.TrainingContext; +import org.apache.flink.ml.common.ps.training.TrainingUtils; +import org.apache.flink.ml.common.updater.FTRL; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableFunction; +import org.apache.flink.util.function.SerializableSupplier; + +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * An Estimator which implements the large scale logistic regression algorithm using FTRL optimizer. + * + * <p>See https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegressionWithFtrl + implements Estimator<LogisticRegressionWithFtrl, LogisticRegressionModel>, + LogisticRegressionWithFtrlParams<LogisticRegressionWithFtrl> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public LogisticRegressionWithFtrl() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public LogisticRegressionModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + String classificationType = getMultiClass(); + Preconditions.checkArgument( + "auto".equals(classificationType) || "binomial".equals(classificationType), + "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + DataStream<LabeledLargePointWithWeight> trainData = + tEnv.toDataStream(inputs[0]) + .map( + (MapFunction<Row, LabeledLargePointWithWeight>) + dataPoint -> { + double weight = + getWeightCol() == null + ? 1.0 + : ((Number) + dataPoint.getField( + getWeightCol())) + .doubleValue(); + double label = + ((Number) dataPoint.getField(getLabelCol())) + .doubleValue(); + boolean isBinomial = + Double.compare(0., label) == 0 + || Double.compare(1., label) == 0; + if (!isBinomial) { + throw new RuntimeException( + "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); + } + Tuple2<long[], double[]> features = + dataPoint.getFieldAs(getFeaturesCol()); + return new LabeledLargePointWithWeight( + features, label, weight); + }); + + DataStream<Long> modelDim; + if (getModelDim() > 0) { + modelDim = trainData.getExecutionEnvironment().fromElements(getModelDim()); + } else { + modelDim = + DataStreamUtils.reduce( + trainData.map(x -> x.features.f0[x.features.f0.length - 1]), + (ReduceFunction<Long>) Math::max) + .map((MapFunction<Long, Long>) value -> value + 1); + } + + LogisticRegressionWithFtrlTrainingContext trainingContext = + new LogisticRegressionWithFtrlTrainingContext(getParamMap()); + + IterationStageList<LogisticRegressionWithFtrlTrainingContext> iterationStages = + new IterationStageList<>(trainingContext); + iterationStages + .addTrainingStage(new ComputeIndices()) + .addTrainingStage( + new PullStage( + (SerializableSupplier<long[]>) () -> trainingContext.pullIndices, + (SerializableConsumer<double[]>) + x -> trainingContext.pulledValues = x)) + .addTrainingStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE)) + .addTrainingStage( + new PushStage( + (SerializableSupplier<long[]>) () -> trainingContext.pushIndices, + (SerializableSupplier<double[]>) () -> trainingContext.pushValues)) + .setTerminationCriteria( + (SerializableFunction<LogisticRegressionWithFtrlTrainingContext, Boolean>) + o -> o.iterationId >= getMaxIter()); + FTRL ftrl = new FTRL(getAlpha(), getBeta(), getReg(), getElasticNet()); + + DataStream<Tuple3<Long, Long, double[]>> rawModelData = + TrainingUtils.train( + modelDim, + trainData, + ftrl, + iterationStages, + getNumServers(), + getNumServerCores()); + + final long modelVersion = 0L; + + DataStream<LogisticRegressionModelData> modelData = + rawModelData.map( + tuple3 -> + new LogisticRegressionModelData( + Vectors.dense(tuple3.f2), + tuple3.f0, + tuple3.f1, + modelVersion)); + + LogisticRegressionModel model = + new LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData)); + ParamUtils.updateExistingParams(model, paramMap); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static LogisticRegressionWithFtrl load(StreamTableEnvironment tEnv, String path) + throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } +} + +/** + * An iteration stage that samples a batch of training data and computes the indices needed to + * compute gradients. + */ +class ComputeIndices extends ProcessStage<LogisticRegressionWithFtrlTrainingContext> { + + @Override + public void process(LogisticRegressionWithFtrlTrainingContext context) throws Exception { + context.readInNextBatchData(); + context.pullIndices = computeIndices(context.batchData); + } + + public static long[] computeIndices(List<LabeledLargePointWithWeight> dataPoints) { + LongOpenHashSet indices = new LongOpenHashSet(); + for (LabeledLargePointWithWeight dataPoint : dataPoints) { + long[] notZeros = dataPoint.features.f0; + for (long index : notZeros) { + indices.add(index); + } + } + + long[] sortedIndices = new long[indices.size()]; + Iterator<Long> iterator = indices.iterator(); + int i = 0; + while (iterator.hasNext()) { + sortedIndices[i++] = iterator.next(); + } + Arrays.sort(sortedIndices); + return sortedIndices; + } +} + +/** + * An iteration stage that uses the pulled model values and sampled batch data to compute the + * gradients. + */ +class ComputeGradients extends ProcessStage<LogisticRegressionWithFtrlTrainingContext> { + private final LossFunc lossFunc; + + public ComputeGradients(LossFunc lossFunc) { + this.lossFunc = lossFunc; + } + + @Override + public void process(LogisticRegressionWithFtrlTrainingContext context) throws IOException { + long[] indices = ComputeIndices.computeIndices(context.batchData); + double[] pulledModelValues = context.pulledValues; + double[] gradients = computeGradient(context.batchData, indices, pulledModelValues); + + context.pushIndices = indices; + context.pushValues = gradients; + } + + private double[] computeGradient( + List<LabeledLargePointWithWeight> batchData, + long[] sortedBatchIndices, + double[] pulledModelValues) { + Long2DoubleOpenHashMap coefficient = new Long2DoubleOpenHashMap(sortedBatchIndices.length); + for (int i = 0; i < sortedBatchIndices.length; i++) { + coefficient.put(sortedBatchIndices[i], pulledModelValues[i]); + } + Long2DoubleOpenHashMap cumGradients = new Long2DoubleOpenHashMap(sortedBatchIndices.length); + + for (LabeledLargePointWithWeight dataPoint : batchData) { + double dot = dot(dataPoint.features, coefficient); + double multiplier = lossFunc.computeGradient(dataPoint.label, dot) * dataPoint.weight; + + long[] featureIndices = dataPoint.features.f0; + double[] featureValues = dataPoint.features.f1; + double z; + for (int i = 0; i < featureIndices.length; i++) { + long currentIndex = featureIndices[i]; + z = featureValues[i] * multiplier + cumGradients.getOrDefault(currentIndex, 0.); + cumGradients.put(currentIndex, z); + } + } + double[] cumGradientValues = new double[sortedBatchIndices.length]; + for (int i = 0; i < sortedBatchIndices.length; i++) { + cumGradientValues[i] = cumGradients.get(sortedBatchIndices[i]); + } + return cumGradientValues; + } + + private static double dot( + Tuple2<long[], double[]> features, Long2DoubleOpenHashMap coefficient) { + double dot = 0; + for (int i = 0; i < features.f0.length; i++) { + dot += features.f1[i] * coefficient.get(features.f0[i]); + } + return dot; + } +} + +/** The context information of local computing process. */ +class LogisticRegressionWithFtrlTrainingContext Review Comment: Would it be more intuitive to name it something like `FtrlIterationStageState`? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ZerosToPushM.java: ########## @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +import org.apache.flink.ml.util.Bits; +import org.apache.flink.util.Preconditions; + +/** + * Message sent by worker to server that initializes the model as a dense array with defined range. + */ +public class ZerosToPushM implements Message { + public final int workerId; + public final int serverId; + public final long startIndex; + public final long endIndex; + + public static final MessageType MESSAGE_TYPE = MessageType.ZEROS_TO_PUSH; Review Comment: Can we make this field `private` or even remove this field for simplicity? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java: ########## @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.util.Bits; + +/** Utility functions for processing messages. */ +public class MessageUtils { + + /** Retrieves the message type from the byte array. */ + public static MessageType getMessageType(byte[] bytesData) { + char type = Bits.getChar(bytesData, 0); + return MessageType.valueOf(type); + } + + /** Reads a long array from the byte array starting from the given offset. */ + public static long[] readLongArray(byte[] bytesData, int offset) { + int size = Bits.getInt(bytesData, offset); + offset += Integer.BYTES; + long[] result = new long[size]; + for (int i = 0; i < size; i++) { + result[i] = Bits.getLong(bytesData, offset); + offset += Long.BYTES; + } + return result; + } + + /** + * Writes a long array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int writeLongArray(long[] array, byte[] bytesData, int offset) { + Bits.putInt(bytesData, offset, array.length); + offset += Integer.BYTES; + for (int i = 0; i < array.length; i++) { + Bits.putLong(bytesData, offset, array[i]); + offset += Long.BYTES; + } + return offset; + } + + /** Returns the size of a long array in bytes. */ + public static int getLongArraySizeInBytes(long[] array) { + return Integer.BYTES + array.length * Long.BYTES; + } + + /** Reads a double array from the byte array starting from the given offset. */ + public static double[] readDoubleArray(byte[] bytesData, int offset) { + int size = Bits.getInt(bytesData, offset); + offset += Integer.BYTES; + double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = Bits.getDouble(bytesData, offset); + offset += Long.BYTES; + } + return result; + } + + /** + * Writes a double array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int writeDoubleArray(double[] array, byte[] bytesData, int offset) { Review Comment: Would it be simpler to move these methods to `Bits.java` and make the method and parameter names consistent with the existing methods in `Bits`? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/IndicesToPullM.java: ########## @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +import org.apache.flink.ml.util.Bits; +import org.apache.flink.util.Preconditions; + +/** The indices one worker needs to pull from servers. */ +public class IndicesToPullM implements Message { + public final int serverId; + public final int workerId; + public final long[] indicesToPull; + + public static final MessageType MESSAGE_TYPE = MessageType.INDICES_TO_PULL; + + public IndicesToPullM(int serverId, int workerId, long[] indicesToPull) { + this.serverId = serverId; + this.workerId = workerId; + this.indicesToPull = indicesToPull; + } + + public static IndicesToPullM fromBytes(byte[] bytesData) { Review Comment: It seems simpler to rename `bytesData` as `bytes`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org