
org.deeplearning4j.arbiter.conf.updater.schedule.PolyScheduleSpace Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.conf.updater.schedule;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.schedule.InverseSchedule;
import org.nd4j.linalg.schedule.PolySchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@NoArgsConstructor //JSON
@Data
public class PolyScheduleSpace implements ParameterSpace {
private ScheduleType scheduleType;
private ParameterSpace initialValue;
private ParameterSpace power;
private ParameterSpace maxIter;
public PolyScheduleSpace(@NonNull ScheduleType scheduleType, @NonNull ParameterSpace initialValue,
double power, int maxIter){
this(scheduleType, initialValue, new FixedValue<>(power), new FixedValue<>(maxIter));
}
public PolyScheduleSpace(@NonNull @JsonProperty("scheduleType") ScheduleType scheduleType,
@NonNull @JsonProperty("initialValue") ParameterSpace initialValue,
@NonNull @JsonProperty("power") ParameterSpace power,
@NonNull @JsonProperty("maxIter") ParameterSpace maxIter){
this.scheduleType = scheduleType;
this.initialValue = initialValue;
this.power = power;
this.maxIter = maxIter;
}
@Override
public ISchedule getValue(double[] parameterValues) {
return new PolySchedule(scheduleType, initialValue.getValue(parameterValues),
power.getValue(parameterValues), maxIter.getValue(parameterValues));
}
@Override
public int numParameters() {
return initialValue.numParameters() + power.numParameters() + maxIter.numParameters();
}
@Override
public List collectLeaves() {
return Arrays.asList(initialValue, power, maxIter);
}
@Override
public Map getNestedSpaces() {
Map out = new LinkedHashMap<>();
out.put("initialValue", initialValue);
out.put("power", power);
out.put("maxIter", maxIter);
return out;
}
@Override
public boolean isLeaf() {
return false;
}
@Override
public void setIndices(int... indices) {
if(initialValue.numParameters() > 0){
int[] sub = Arrays.copyOfRange(indices, 0, initialValue.numParameters());
initialValue.setIndices(sub);
}
if(power.numParameters() > 0){
int np = initialValue.numParameters();
int[] sub = Arrays.copyOfRange(indices, np, np + power.numParameters());
power.setIndices(sub);
}
if(maxIter.numParameters() > 0){
int np = initialValue.numParameters() + power.numParameters();
int[] sub = Arrays.copyOfRange(indices, np, np + maxIter.numParameters());
maxIter.setIndices(sub);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy