[ 
https://issues.apache.org/jira/browse/FLINK-1992?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=14550335#comment-14550335
 ] 

ASF GitHub Bot commented on FLINK-1992:
---------------------------------------

Github user thvasilo commented on a diff in the pull request:

    https://github.com/apache/flink/pull/692#discussion_r30592697
  
    --- Diff: 
flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
 ---
    @@ -205,6 +190,61 @@ class GradientDescentITSuite extends FlatSpec with 
Matchers with FlinkTestBase {
     
       }
     
    -  // TODO: Need more corner cases
    +  it should "terminate early if the convergence criterion is reached" in {
    +    // TODO(tvas): We need a better way to check the convergence of the 
weights.
    +    // Ideally we want to have a Breeze-like system, where the optimizers 
carry a history and that
    +    // can tell us whether we have converged and at which iteration
    +
    +    val env = ExecutionEnvironment.getExecutionEnvironment
    +
    +    env.setParallelism(2)
    +
    +    val sgdEarlyTerminate = GradientDescent()
    +      .setConvergenceThreshold(1e2)
    +      .setStepsize(1.0)
    +      .setIterations(800)
    +      .setLossFunction(SquaredLoss())
    +      .setRegularizationType(NoRegularization())
    +      .setRegularizationParameter(0.0)
    +
    +    val inputDS = env.fromCollection(data)
    +
    +    val weightDSEarlyTerminate = sgdEarlyTerminate.optimize(inputDS, None)
    +
    +    val weightListEarly: Seq[WeightVector] = 
weightDSEarlyTerminate.collect()
    +
    +    weightListEarly.size should equal(1)
    +
    +    val weightVectorEarly: WeightVector = weightListEarly.head
    +    val weightsEarly = 
weightVectorEarly.weights.asInstanceOf[DenseVector].data
    +    val weight0Early = weightVectorEarly.intercept
    +
    +    val sgdNoConvergence = GradientDescent()
    +      .setStepsize(1.0)
    --- End diff --
    
    Here we get a problem with the return type, we should be returning the 
runtime type, if we call setLossFunction first we get back a Solver, which 
means we can no longer call the IterativeSolver methods.


> Add convergence criterion to SGD optimizer
> ------------------------------------------
>
>                 Key: FLINK-1992
>                 URL: https://issues.apache.org/jira/browse/FLINK-1992
>             Project: Flink
>          Issue Type: Improvement
>          Components: Machine Learning Library
>            Reporter: Till Rohrmann
>            Assignee: Theodore Vasiloudis
>              Labels: ML
>
> Currently, Flink's SGD optimizer runs for a fixed number of iterations. It 
> would be good to support a dynamic convergence criterion, too.



--
This message was sent by Atlassian JIRA
(v6.3.4#6332)

Reply via email to