All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.tree.gbdt.histogram.MultiGradPair Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */
package com.tencent.angel.sona.tree.gbdt.histogram;

import com.tencent.angel.sona.tree.gbdt.tree.GBDTParam;
import com.tencent.angel.sona.tree.util.MathUtil;

import java.io.Serializable;
import java.util.Arrays;

public class MultiGradPair implements GradPair, Serializable {
    private double[] grad;
    private double[] hess;

    public MultiGradPair(int numClass, boolean fullHessian) {
        this.grad = new double[numClass];
        if (fullHessian)
            this.hess = new double[(numClass * (numClass + 1)) >> 1];
        else
            this.hess = new double[numClass];
    }

    public MultiGradPair(double[] grad, double[] hess) {
        this.grad = grad;
        this.hess = hess;
    }

    @Override
    public void plusBy(GradPair gradPair) {
        double[] grad = ((MultiGradPair) gradPair).grad;
        double[] hess = ((MultiGradPair) gradPair).hess;
        for (int i = 0; i < this.grad.length; i++)
            this.grad[i] += grad[i];
        for (int i = 0; i < this.hess.length; i++)
            this.hess[i] += hess[i];
    }

    public void plusBy(double[] grad, double[] hess) {
        for (int i = 0; i < this.grad.length; i++)
            this.grad[i] += grad[i];
        for (int i = 0; i < this.hess.length; i++)
            this.hess[i] += hess[i];
    }

    @Override
    public void subtractBy(GradPair gradPair) {
        double[] grad = ((MultiGradPair) gradPair).grad;
        double[] hess = ((MultiGradPair) gradPair).hess;
        for (int i = 0; i < this.grad.length; i++)
            this.grad[i] -= grad[i];
        for (int i = 0; i < this.hess.length; i++)
            this.hess[i] -= hess[i];
    }

    public void subtractBy(double[] grad, double[] hess) {
        for (int i = 0; i < this.grad.length; i++)
            this.grad[i] -= grad[i];
        for (int i = 0; i < this.hess.length; i++)
            this.hess[i] -= hess[i];
    }

    @Override
    public GradPair plus(GradPair gradPair) {
        GradPair res = this.copy();
        res.plusBy(gradPair);
        return res;
    }

    public GradPair plus(double[] grad, double[] hess) {
        MultiGradPair res = this.copy();
        res.plusBy(grad, hess);
        return res;
    }

    @Override
    public GradPair subtract(GradPair gradPair) {
        GradPair res = this.copy();
        res.subtractBy(gradPair);
        return res;
    }

    public GradPair subtract(double[] grad, double[] hess) {
        MultiGradPair res = this.copy();
        res.subtractBy(grad, hess);
        return res;
    }

    @Override
    public void timesBy(double x) {
        for (int i = 0; i < this.grad.length; i++)
            this.grad[i] *= x;
        for (int i = 0; i < this.hess.length; i++)
            this.hess[i] *= x;
    }

    @Override
    public boolean satisfyWeight(GBDTParam param) {
        return param.satisfyWeight(hess);
    }

    @Override
    public float calcGain(GBDTParam param) {
        return (float) param.calcGain(grad, hess);
    }

    @Override
    public float calcWeight(GBDTParam param) {
        throw new RuntimeException(String.format("%s does not support binary-class task",
                this.getClass().getSimpleName()));
    }

    @Override
    public float[] calcWeights(GBDTParam param) {
        return MathUtil.doubleArrayToFloatArray(param.calcWeights(grad, hess));
    }

    @Override
    public MultiGradPair copy() {
        return new MultiGradPair(grad.clone(), hess.clone());
    }

    @Override
    public void clear() {
        Arrays.fill(this.grad, 0.0);
        Arrays.fill(this.hess, 0.0);
    }

    public double[] getGrad() {
        return grad;
    }

    public double[] getHess() {
        return hess;
    }

    public void setGrad(double[] grad) {
        this.grad = grad;
    }

    public void setHess(double[] hess) {
        this.hess = hess;
    }

    public void set(double[] grad, double[] hess) {
        this.grad = grad;
        this.hess = hess;
    }

    public void set(double[] grad, double[] hess, int offset) {
        // numClass is usually small, so we do not use arraycopy here
        for (int i = 0; i < this.grad.length; i++) {
            this.grad[i] = grad[i + offset];
            this.hess[i] = hess[i + offset];
        }
    }

    public void set(double[] grad, int gradOffset, double[] hess, int hessOffset) {
        // numClass is usually small, so we do not use arraycopy here
        for (int i = 0; i < this.grad.length; i++)
            this.grad[i] = grad[i + gradOffset];
        for (int i = 0; i < this.hess.length; i++)
            this.hess[i] = hess[i + hessOffset];
    }

    @Override
    public String toString() {
        String gradStr = Arrays.toString(grad);
        if (grad.length == hess.length) {
            return "(" + gradStr + ", diag{" + Arrays.toString(hess) + "})";
        } else {
            int rowSize = 1, offset = 0;
            StringBuilder hessSB = new StringBuilder("[");
            while (rowSize <= grad.length) {
                hessSB.append("[");
                hessSB.append(hess[offset]);
                for (int i = 1; i < rowSize; i++) {
                    hessSB.append(", ");
                    hessSB.append(hess[offset + i]);
                }
                hessSB.append("]");
                offset += rowSize;
                rowSize++;
            }
            hessSB.append("]");
            return "(" + gradStr + ", " + hessSB.toString() + ")";
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy