org.nd4j.linalg.schedule.CycleSchedule Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.schedule;
import lombok.Data;
import org.nd4j.shade.jackson.annotation.JsonProperty;
@Data
public class CycleSchedule implements ISchedule {
private final ScheduleType scheduleType;
private final double initialLearningRate;
private final double maxLearningRate;
private final int cycleLength;
private final int annealingLength;
private final int stepSize;
private final double increment;
private double annealingDecay;
public CycleSchedule(@JsonProperty("scheduleType") ScheduleType scheduleType,
@JsonProperty("initialLearningRate") double initialLearningRate,
@JsonProperty("maxLearningRate") double maxLearningRate,
@JsonProperty("cycleLength") int cycleLength,
@JsonProperty("annealingLength") int annealingLength,
@JsonProperty("annealingDecay") double annealingDecay){
this.scheduleType = scheduleType;
this.initialLearningRate = initialLearningRate;
this.maxLearningRate = maxLearningRate;
this.cycleLength = cycleLength;
this.annealingDecay = annealingDecay;
this.annealingLength = annealingLength;
stepSize = ((cycleLength - annealingLength) / 2);
increment = (maxLearningRate - initialLearningRate) / stepSize;
}
public CycleSchedule(ScheduleType scheduleType,
double maxLearningRate,
int cycleLength){
this(scheduleType, maxLearningRate * 0.1, maxLearningRate, cycleLength, (int) Math.round(cycleLength * 0.1), 0.1);
}
@Override
public double valueAt(int iteration, int epoch) {
double learningRate;
final int positionInCycle = (scheduleType == ScheduleType.EPOCH ? epoch : iteration) % cycleLength;
if(positionInCycle < stepSize){
learningRate = initialLearningRate + increment * positionInCycle;
}else if(positionInCycle < 2*stepSize){
learningRate = maxLearningRate - increment * (positionInCycle - stepSize);
}else {
learningRate = initialLearningRate * Math.pow(annealingDecay, annealingLength - (cycleLength - positionInCycle));
}
return learningRate;
}
@Override
public ISchedule clone() {
return new CycleSchedule(scheduleType, initialLearningRate, maxLearningRate, cycleLength, annealingLength, annealingDecay);
}
}