Hi i have a problem. Using spark 1.2 with Pipeline + GridSearch +
LogisticRegression. I’ve reimplemented LogisticRegression.fit method and
comment out instances.unpersist()
|override def fit(dataset:SchemaRDD,
paramMap:ParamMap):LogisticRegressionModel = {
println(s"Fitting dataset ${dataset.take(1000).toSeq.hashCode()} with ParamMap
$paramMap.")
transformSchema(dataset.schema, paramMap, logging =true)
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr)
.map {
case Row(label:Double, features:Vector) =>
LabeledPoint(label, features)
}
if (instances.getStorageLevel ==StorageLevel.NONE) {
println("Instances not persisted")
instances.persist(StorageLevel.MEMORY_AND_DISK)
}
val lr = (new LogisticRegressionWithLBFGS)
.setValidateData(false)
.setIntercept(true)
lr.optimizer
.setRegParam(map(regParam))
.setNumIterations(map(maxIter))
val lrm = new LogisticRegressionModel(this, map,
lr.run(instances).weights)
//instances.unpersist()
// copy model params
Params.inheritValues(map,this, lrm)
lrm
}
|
CrossValidator feeds the same SchemaRDD for each parameter (same hash
code), but somewhere cache being flushed. The memory is enough. Here’s
the output:
|Fitting dataset 2051470010 with ParamMap {
DRLogisticRegression-f35ae4d3-regParam: 0.1
}.
Instances not persisted
Fitting dataset 2051470010 with ParamMap {
DRLogisticRegression-f35ae4d3-regParam: 0.01
}.
Instances not persisted
Fitting dataset 2051470010 with ParamMap {
DRLogisticRegression-f35ae4d3-regParam: 0.001
}.
Instances not persisted
Fitting dataset 802615223 with ParamMap {
DRLogisticRegression-f35ae4d3-regParam: 0.1
}.
Instances not persisted
Fitting dataset 802615223 with ParamMap {
DRLogisticRegression-f35ae4d3-regParam: 0.01
}.
Instances not persisted
|
I have 3 parameters in GridSearch and 3 folds for CrossValidation:
|
val paramGrid = new ParamGridBuilder()
.addGrid(model.regParam,Array(0.1,0.01,0.001))
.build()
crossval.setEstimatorParamMaps(paramGrid)
crossval.setNumFolds(3)
|
I assume that the data should be read and cached 3 times (1 to
numFolds).combinations(2) and be independent from number of parameters.
But i have 9 times data being read and cached.
Thanks,
Peter Rudenko