org.nd4j.linalg.profiler.data.StringAggregator Maven / Gradle / Ivy
package org.nd4j.linalg.profiler.data;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.profiler.data.primitives.ComparableAtomicLong;
import org.nd4j.linalg.profiler.data.primitives.TimeSet;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
/**
* @author [email protected]
*/
public class StringAggregator {
private Map times = new ConcurrentHashMap<>();
private Map longCalls = new ConcurrentHashMap<>();
private static final long THRESHOLD = 100000;
public StringAggregator() {
}
public void reset() {
for (String key: times.keySet()) {
// times.remove(key);
times.put(key, new TimeSet());
}
for (String key: longCalls.keySet()) {
// longCalls.remove(key);
longCalls.put(key, new ComparableAtomicLong(0));
}
}
public void putTime(String key, Op op, long timeSpent) {
if (!times.containsKey(key))
times.put(key, new TimeSet());
times.get(key).addTime(timeSpent);
if (timeSpent > THRESHOLD) {
String keyExt = key + " " + op.name() + " (" + op.opNum() + ")";
if (!longCalls.containsKey(keyExt))
longCalls.put(keyExt, new ComparableAtomicLong(0));
longCalls.get(keyExt).incrementAndGet();
}
}
public void putTime(String key, long timeSpent) {
if (!times.containsKey(key))
times.put(key, new TimeSet());
times.get(key).addTime(timeSpent);
}
protected long getMedian(String key) {
return times.get(key).getMedian();
}
protected long getAverage(String key) {
return times.get(key).getAverage();
}
protected long getMaximum(String key) {
return times.get(key).getMaximum();
}
protected long getMinimum(String key) {
return times.get(key).getMinimum();
}
protected long getSum(String key) {
return times.get(key).getSum();
}
public String asPercentageString() {
StringBuilder builder = new StringBuilder();
Map sortedTimes = ArrayUtil.sortMapByValue(times);
AtomicLong sum = new AtomicLong(0);
for (String key: sortedTimes.keySet()) {
sum.addAndGet(getSum(key));
}
builder.append("Total time spent: ").append(sum.get() / 1000000).append(" ms.").append("\n");
for (String key: sortedTimes.keySet()) {
long currentSum = getSum(key);
float perc = currentSum * 100 / sum.get();
long sumMs = currentSum / 1000000;
builder.append(key).append(" >>> ")
.append(" perc: ").append(perc).append(" ")
.append("Time spent: ").append(sumMs).append(" ms");
builder.append("\n");
}
return builder.toString();
}
public String asString() {
StringBuilder builder = new StringBuilder();
Map sortedTimes = ArrayUtil.sortMapByValue(times);
for (String key: sortedTimes.keySet()) {
long currentMax = getMaximum(key);
long currentMin = getMinimum(key);
long currentAvg = getAverage(key);
long currentMed = getMedian(key);
builder.append(key).append(" >>> ");
if (longCalls.size() == 0)
builder.append(" ").append(sortedTimes.get(key).size()).append(" calls; ");
builder.append("Min: ").append(currentMin).append(" ns; ")
.append("Max: ").append(currentMax).append(" ns; ")
.append("Average: ").append(currentAvg).append(" ns; ")
.append("Median: ").append(currentMed).append(" ns; ");
builder.append("\n");
}
builder.append("\n");
Map sortedCalls = ArrayUtil.sortMapByValue(longCalls);
for (String key: sortedCalls.keySet()) {
long numCalls = sortedCalls.get(key).get();
builder.append(key).append(" >>> ")
.append(numCalls);
builder.append("\n");
}
builder.append("\n");
return builder.toString();
}
}