org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.earlystopping.termination;
import lombok.Data;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.concurrent.TimeUnit;
/**Terminate training based on max time.
*/
@Data
public class MaxTimeIterationTerminationCondition implements IterationTerminationCondition {
private long maxTimeAmount;
private TimeUnit maxTimeUnit;
private long initializationTime;
private long endTime;
public MaxTimeIterationTerminationCondition(@JsonProperty("maxTimeAmount") long maxTimeAmount, @JsonProperty("maxTimeUnit") TimeUnit maxTimeUnit) {
if (maxTimeAmount <= 0 || maxTimeUnit == null)
throw new IllegalArgumentException(
"Invalid maximum training time: " + "amount = " + maxTimeAmount + " unit = " + maxTimeUnit);
this.maxTimeAmount = maxTimeAmount;
this.maxTimeUnit = maxTimeUnit;
}
@Override
public void initialize() {
initializationTime = System.currentTimeMillis();
endTime = initializationTime + maxTimeUnit.toMillis(maxTimeAmount);
}
@Override
public boolean terminate(double lastMiniBatchScore) {
return System.currentTimeMillis() >= endTime;
}
@Override
public String toString() {
return "MaxTimeIterationTerminationCondition(" + maxTimeAmount + ",unit=" + maxTimeUnit + ")";
}
}