All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.api.ops.performance.PerformanceTracker Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.linalg.api.ops.performance;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.performance.primitives.AveragingTransactionsHolder;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.memory.MemcpyDirection;

import java.util.HashMap;
import java.util.Map;

/**
 * This class provides routines for performance tracking and holder for corresponding results
 *
 * @author [email protected]
 */
@Slf4j
public class PerformanceTracker {
    private static final PerformanceTracker INSTANCE = new PerformanceTracker();

    private Map bandwidth = new HashMap<>();
    private Map operations = new HashMap<>();

    private PerformanceTracker() {
        // we put in initial holders, one per device
        val nd = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int e = 0; e < nd; e++) {
            bandwidth.put(e, new AveragingTransactionsHolder());
            operations.put(e, new AveragingTransactionsHolder());
        }
    }

    public static PerformanceTracker getInstance() {
        return INSTANCE;
    }

    /**
     * This method stores bandwidth used for given transaction.
     *
     * PLEASE NOTE: Bandwidth is stored in per millisecond value.
     *
     * @param deviceId device used for this transaction
     * @param timeSpent time spent on this transaction in nanoseconds
     * @param numberOfBytes number of bytes
     */
    public long addMemoryTransaction(int deviceId, long timeSpentNanos, long numberOfBytes) {
        // default is H2H transaction
        return addMemoryTransaction(deviceId, timeSpentNanos, numberOfBytes, MemcpyDirection.HOST_TO_HOST);
    }

    /**
     * This method stores bandwidth used for given transaction.
     *
     * PLEASE NOTE: Bandwidth is stored in per millisecond value.
     *
     * @param deviceId device used for this transaction
     * @param timeSpent time spent on this transaction in nanoseconds
     * @param numberOfBytes number of bytes
     * @param direction direction for the given memory transaction
     */
    public long addMemoryTransaction(int deviceId, long timeSpentNanos, long numberOfBytes, @NonNull MemcpyDirection direction) {
        // we calculate bytes per microsecond now
        val bw = (long) (numberOfBytes / (timeSpentNanos / (double) 1000.0));

        // we skip too small values
        if (bw > 0)
            bandwidth.get(deviceId).addValue(direction, bw);

        return bw;
    }

    public void clear() {
        for (val k: bandwidth.keySet())
            bandwidth.get(k).clear();
    }


    public long helperStartTransaction() {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.BANDWIDTH)
            return System.nanoTime();
        else
            return 0L;
    }


    public void helperRegisterTransaction(int deviceId, long timeSpentNanos, long numberOfBytes, @NonNull MemcpyDirection direction) {
        // only do something if profiling is enabled
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.BANDWIDTH) {
            addMemoryTransaction(deviceId, System.nanoTime() - timeSpentNanos, numberOfBytes, direction);
        }
    }

    public Map> getCurrentBandwidth() {
        val result = new HashMap>();
        val keys = bandwidth.keySet();
        for (val d: keys) {

            result.put(d, new HashMap());

            // get average for each MemcpyDirection and store it
            for (val m: MemcpyDirection.values())
                result.get(d).put(m, bandwidth.get(d).getAverageValue(m));

        }

        return result;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy