All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.alibaba.alink.operator.common.classification.ann.AnnObjFunc Maven / Gradle / Ivy

package com.alibaba.alink.operator.common.classification.ann;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;

/**
 * The OptimObjFunc for multilayer perceptron.
 */
public class AnnObjFunc extends OptimObjFunc {

    private Topology topology;
    private Stacker stacker;
    private transient TopologyModel topologyModel = null;

    public AnnObjFunc(Topology topology,
                      int inputSize, int outputSize, boolean oneHotLabel,
                      Params params) {
        super(params);
        this.topology = topology;
        this.stacker = new Stacker(inputSize, outputSize, oneHotLabel);
    }

    @Override
    protected double calcLoss(Tuple3 labledVector, DenseVector coefVector) {
        if (topologyModel == null) {
            topologyModel = topology.getModel(coefVector);
        } else {
            topologyModel.resetModel(coefVector);
        }
        Tuple2 unstacked = stacker.unstack(labledVector);
        return topologyModel.computeGradient(unstacked.f0, unstacked.f1, null);
    }

    @Override
    protected void updateGradient(Tuple3 labledVector, DenseVector coefVector,
                                  DenseVector updateGrad) {
        if (topologyModel == null) {
            topologyModel = topology.getModel(coefVector);
        } else {
            topologyModel.resetModel(coefVector);
        }
        Tuple2 unstacked = stacker.unstack(labledVector);
        topologyModel.computeGradient(unstacked.f0, unstacked.f1, updateGrad);
    }

    @Override
    protected void updateHessian(Tuple3 labledVector, DenseVector coefVector,
                                 DenseMatrix updateHessian) {
        throw new RuntimeException("not supported.");
    }

    @Override
    public boolean hasSecondDerivative() {
        return false;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy