org.deeplearning4j.nn.simple.multiclass.RankClassificationResult Maven / Gradle / Ivy
package org.deeplearning4j.nn.simple.multiclass;
import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
/**
* A {@link RankClassificationResult}
* is an abstraction over an activation matrix
* for ranking classes.
*
* @author Adam Gibson
*/
@Data
public class RankClassificationResult implements Serializable {
private int[][] rankedIndices;
private float[][] probabilities;
private List labels;
private List maxLabels;
/**
* Takes in just a classification matrix
* and initializes the labels to just be indices
* @param outcome the outcome matrix (usually from a softmax
* or sigmoid output)
*/
public RankClassificationResult(INDArray outcome) {
this(outcome, null);
}
/**
* Takes in a classification matrix
* and the labels for each column
* @param outcome the outcome
* @param labels the labels for the outcomes
*/
public RankClassificationResult(INDArray outcome, List labels) {
if (outcome.rank() > 2) {
throw new ND4JIllegalStateException("Only works with vectors and matrices right now");
}
INDArray[] maxWithIndices = Nd4j.sortWithIndices(outcome, -1, false);
INDArray indexes = maxWithIndices[0];
//default to integers for labels
if (labels == null) {
this.labels = new ArrayList<>(outcome.columns());
for (int i = 0; i < outcome.columns(); i++) {
this.labels.add(String.valueOf(i));
}
} else {
this.labels = new ArrayList<>(labels);
}
rankedIndices = new int[indexes.rows()][indexes.columns()];
probabilities = new float[outcome.rows()][outcome.columns()];
for (int i = 0; i < indexes.rows(); i++) {
for (int j = 0; j < indexes.columns(); j++) {
rankedIndices[i][j] = indexes.getInt(i, j);
probabilities[i][j] = outcome.getFloat(new int[] {i, j});
}
}
//initialize max outcomes
maxOutcomes();
}
/**
* Get the max index for the given row
* @param r the row to get the max index for
* @return the label for the given
* element
*/
public String maxOutcomeForRow(int r) {
return labels.get((rankedIndices[r][0]));
}
public List maxOutcomes() {
if (maxLabels == null) {
maxLabels = new ArrayList<>(rankedIndices.length);
for (int i = 0; i < rankedIndices.length; i++) {
maxLabels.add(maxOutcomeForRow(i));
}
return maxLabels;
}
else
return maxLabels;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy