All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.campagnelab.dl.somatic.learning.calibrate.CalibratingModel Maven / Gradle / Ivy

The newest version!
package org.campagnelab.dl.somatic.learning.calibrate;

import it.unimi.dsi.fastutil.floats.FloatArrayList;
import org.campagnelab.dl.framework.mappers.FeatureMapper;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/**
 * Created by fac2003 on 7/17/16.
 */
public class CalibratingModel {
    private final int numPredictiveInputFeatures;
    private final int numCalibratingInputFeatures;
    private final FeatureMapper featureMapper;
    MultiLayerNetwork predictiveModel;
    MultiLayerNetwork calibratingModel;
    private int[] indices = {0, 0};

    public CalibratingModel(MultiLayerNetwork predictiveModel, FeatureMapper featureMapper,
                            MultiLayerNetwork calibratingModel) {
        this.predictiveModel = predictiveModel;
        this.featureMapper = featureMapper;
        this.calibratingModel = calibratingModel;
        this.numPredictiveInputFeatures = featureMapper.numberOfFeatures();
        this.numCalibratingInputFeatures = getModelActivationNumber(predictiveModel, featureMapper);
    }

    public float estimateCalibratedP(INDArray testFeatures) {
        assert testFeatures.data().length() == featureMapper.numberOfFeatures() : "number of features does not match predictive model input length.";
        INDArray inputs = Nd4j.zeros(1, numCalibratingInputFeatures);
        //  INDArray labels = Nd4j.zeros(1, 1);
        // DataSet dataset = new DataSet(inputs,labels);
        int indexOfNewRecordInMinibatch = 0;
        // calculate the model output and use as features for the calibration model:
        FloatArrayList floats = getModelInternalActivations(testFeatures);
        assert floats.size() == numCalibratingInputFeatures : "number of features does not match calibrating model input length.";

        indices[0] = indexOfNewRecordInMinibatch;
        for (int i = 0; i < numCalibratingInputFeatures; i++) {
            indices[1] = i;
            //     dataset.getFeatures().putScalar(indices, floats.getFloat(i));
            inputs.putScalar(indices, floats.getFloat(i));
        }
        calibratingModel.init();
       // System.out.println("size:" + inputs);
        float[] predicted = calibratingModel.output(inputs, false).getRow(0).data().asFloat();
        return predicted[0];
    }

    private FloatArrayList getModelInternalActivations(INDArray testFeatures) {


        FloatArrayList floats = new FloatArrayList();
        predictiveModel.feedForward(testFeatures).stream().forEach(indArray -> floats.addAll(FloatArrayList.wrap(indArray.data().asFloat())));
        return floats;
    }

    private int getModelActivationNumber(MultiLayerNetwork model, FeatureMapper modelFeatureMapper) {
        int numActivations = 0;
        Layer[] layers = model.getLayers();
        int totalNumParams = 0;

        INDArray inputFeatures = Nd4j.zeros(1, modelFeatureMapper.numberOfFeatures());

        int sum = model.feedForward(inputFeatures, false).stream().mapToInt(indArray ->
                indArray.data().asFloat().length).sum();
        System.out.println("Number of activations: " + sum);
        return sum;
    }

}






© 2015 - 2024 Weber Informatics LLC | Privacy Policy