All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.training.ParameterStore Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */

package ai.djl.training;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Parameter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * The {@code ParameterStore} contains a map from a parameter to the mirrors of it on other devices.
 */
public class ParameterStore {

    private NDManager manager;
    private Map parameterMap;
    private Map deviceMap;
    private boolean copy;
    private ParameterServer parameterServer;

    /**
     * Constructs an empty {@code ParameterStore}.
     *
     * @param manager the manager to attach mirrored parameters to
     * @param copy whether to always copy even for the same device as the original parameter
     */
    public ParameterStore(NDManager manager, boolean copy) {
        this.manager = manager;
        this.copy = copy;
        parameterMap = new ConcurrentHashMap<>();
        deviceMap = new ConcurrentHashMap<>();
        deviceMap.put(manager.getDevice(), 0);
    }

    /**
     * Sets the parameterServer used to apply updates to the parameters.
     *
     * @param parameterServer the parameterServer
     * @param devices the devices to create mirrored parameters on
     */
    public void setParameterServer(ParameterServer parameterServer, Device[] devices) {
        this.parameterServer = parameterServer;
        deviceMap.clear();
        for (int i = 0; i < devices.length; ++i) {
            if (deviceMap.put(devices[i], i) != null) {
                throw new IllegalArgumentException("Duplicated devices are not allowed.");
            }
        }
    }

    /** Updates all the mirrored parameters. */
    public void updateAllParameters() {
        int priority = 0;
        for (Map.Entry entry : parameterMap.entrySet()) {
            String parameterId = entry.getKey();
            ParameterData data = entry.getValue();
            if (data.requireGradient()) {
                NDArray[] grads =
                        data.getNDArrays()
                                .stream()
                                .map(NDArray::getGradient)
                                .toArray(NDArray[]::new);
                parameterServer.push(parameterId, grads, -priority);
                ++priority;
            }
        }

        priority = 0;
        for (Map.Entry entry : parameterMap.entrySet()) {
            String parameterId = entry.getKey();
            ParameterData data = entry.getValue();
            if (data.requireGradient()) {
                NDArray[] values = data.toArray();
                parameterServer.pull(parameterId, values, -priority);
                ++priority;
            }
        }
    }

    /**
     * Returns the value of a mirrored parameter on a device.
     *
     * @param parameter the parameter to get the value for
     * @param device the device to get the mirror from
     * @return the value of the mirrored parameter on the device
     */
    public NDArray getValue(Parameter parameter, Device device) {
        String parameterId = parameter.getId();
        int index = deviceMap.get(device);
        ParameterData data =
                parameterMap.computeIfAbsent(parameterId, k -> new ParameterData(parameter));

        if (data.isEmpty()) {
            NDArray array = parameter.getArray();

            if (parameterServer != null) {
                // initialize on parameter store for first time
                parameterServer.init(parameterId, new NDArray[] {array});
                NDArray[] arrays = new NDArray[deviceMap.size()];
                for (Map.Entry entry : deviceMap.entrySet()) {
                    Device dev = entry.getKey();
                    int i = entry.getValue();
                    if (i == index && array.getDevice().equals(dev)) {
                        arrays[i] = array;
                    } else {
                        arrays[i] = array.toDevice(dev, true);
                        arrays[i].attach(manager);
                        arrays[i].attachGradient();
                    }
                    data.add(arrays[i]);
                }
            } else {
                if (copy || !array.getDevice().equals(device)) {
                    array = array.toDevice(device, true);
                    array.attach(manager);
                    array.attachGradient();
                }
                data.add(array);
            }
        }

        return data.get(index);
    }

    /** Synchronizes the values on all mirrors with the main parameter. */
    public void sync() {
        for (ParameterData data : parameterMap.values()) {
            data.sync();
        }
    }

    /** A helper for {@link ParameterStore} that stores data for a single parameter. */
    private final class ParameterData {

        private Parameter parameter;
        private List list;

        private ParameterData(Parameter parameter) {
            this.parameter = parameter;
            list = Collections.synchronizedList(new ArrayList<>());
        }

        private List getNDArrays() {
            return list;
        }

        private boolean isEmpty() {
            return list.isEmpty();
        }

        private void add(NDArray array) {
            list.add(array);
        }

        private NDArray get(int index) {
            return list.get(index);
        }

        private NDArray[] toArray() {
            return list.toArray(new NDArray[0]);
        }

        private boolean requireGradient() {
            return parameter.requireGradient();
        }

        private void sync() {
            NDArray array = parameter.getArray();
            Device device = array.getDevice();
            if (!deviceMap.containsKey(device)) {
                // model's parameters maybe loaded on different device than any of training devices.
                list.get(0).copyTo(array);
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy