All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.schedule.MapSchedule Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.linalg.schedule;

import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

@Data
@EqualsAndHashCode
@JsonIgnoreProperties({"allKeysSorted"})
public class MapSchedule implements ISchedule {

    private ScheduleType scheduleType;
    private Map values;

    private int[] allKeysSorted;

    public MapSchedule(@JsonProperty("scheduleType") @NonNull ScheduleType scheduleType,
                       @JsonProperty("values") @NonNull Map values) {
        if (!values.containsKey(0)) {
            throw new IllegalArgumentException("Invalid set of values: must contain initial value (position 0)");
        }
        this.scheduleType = scheduleType;
        this.values = values;

        this.allKeysSorted = new int[values.size()];
        int pos = 0;
        for (Integer i : values.keySet()) {
            allKeysSorted[pos++] = i;
        }
        Arrays.sort(allKeysSorted);
    }

    @Override
    public double valueAt(int iteration, int epoch) {
        int i = (scheduleType == ScheduleType.ITERATION ? iteration : epoch);

        if (values.containsKey(i)) {
            return values.get(i);
        } else {
            //Key doesn't exist - find nearest key...
            if (i >= allKeysSorted[allKeysSorted.length - 1]) {
                return values.get(allKeysSorted[allKeysSorted.length - 1]);
            } else {
                /*
                Returned:
                index of the search key, if it is contained in the array; otherwise, (-(insertion point) - 1). The
                 insertion point is defined as the point at which the key would be inserted into the array: the index
                  of the first element greater than the key
                 */
                int pt = Arrays.binarySearch(allKeysSorted, i);
                int iPt = -(pt + 1);
                double d = values.get(allKeysSorted[iPt-1]);
                return d;
            }
        }
    }

    @Override
    public ISchedule clone() {
        return new MapSchedule(scheduleType, values);
    }

    /**
     * DynamicCustomOpsBuilder for conveniently constructing map schedules
     */
    public static class Builder {

        private ScheduleType scheduleType;
        private Map values = new HashMap<>();

        /**
         * @param scheduleType Schedule opType to use
         */
        public Builder(ScheduleType scheduleType) {
            this.scheduleType = scheduleType;
        }

        /**
         * Add a single point to the map schedule. Indexes start at 0
         *
         * @param position Position to add (iteration or epoch index, depending on setting)
         * @param value    Value for that iteraiton/epoch
         */
        public Builder add(int position, double value) {
            values.put(position, value);
            return this;
        }

        public MapSchedule build() {
            return new MapSchedule(scheduleType, values);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy