All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.fm.FieldAwareFactorizationMachineModel Maven / Gradle / Ivy

There is a newer version: 0.6.0-incubating
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.fm;

import hivemall.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.collections.arrays.DoubleArray3D;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.math.MathUtils;

import java.util.Arrays;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.hadoop.hive.ql.metadata.HiveException;

public abstract class FieldAwareFactorizationMachineModel extends FactorizationMachineModel {

    @Nonnull
    protected final FFMHyperParameters _params;
    protected final float _eps;

    protected final boolean _useAdaGrad;
    protected final boolean _useFTRL;

    // FTRL
    private final float _alpha;
    private final float _beta;
    private final float _lambda1;
    private final float _lambda2;

    public FieldAwareFactorizationMachineModel(@Nonnull FFMHyperParameters params) {
        super(params);
        this._params = params;
        this._eps = params.eps;
        this._useAdaGrad = params.useAdaGrad;
        this._useFTRL = params.useFTRL;
        this._alpha = params.alphaFTRL;
        this._beta = params.betaFTRL;
        this._lambda1 = params.lambda1;
        this._lambda2 = params.lambda2;
    }

    public abstract float getV(@Nonnull Feature x, @Nonnull int yField, int f);

    @Deprecated
    protected abstract void setV(@Nonnull Feature x, @Nonnull int yField, int f, float nextVif);

    @Override
    public float getV(Feature x, int f) {
        throw new UnsupportedOperationException();
    }

    @Override
    protected void setV(Feature x, int f, float nextVif) {
        throw new UnsupportedOperationException();
    }

    @Override
    protected final double predict(@Nonnull final Feature[] x) throws HiveException {
        // w0
        double ret = getW0();
        // W
        for (Feature e : x) {
            double xi = e.getValue();
            float wi = getW(e);
            double wx = wi * xi;
            ret += wx;
        }
        // V
        for (int i = 0; i < x.length; i++) {
            final Feature ei = x[i];
            final double xi = ei.getValue();
            final int iField = ei.getField();
            for (int j = i + 1; j < x.length; j++) {
                final Feature ej = x[j];
                final double xj = ej.getValue();
                final int jField = ej.getField();
                for (int f = 0, k = _factor; f < k; f++) {
                    float vijf = getV(ei, jField, f);
                    float vjif = getV(ej, iField, f);
                    ret += vijf * vjif * xi * xj;
                    assert (!Double.isNaN(ret));
                }
            }
        }
        if (!NumberUtils.isFinite(ret)) {
            throw new HiveException(
                "Detected " + ret + " in predict. We recommend to normalize training examples.\n"
                        + "Dumping variables ...\n" + varDump(x));
        }
        return ret;
    }

    void updateWi(final double dloss, @Nonnull final Feature x, final long t) {
        if (_useFTRL) {
            updateWi_FTRL(dloss, x);
            return;
        }

        final double Xi = x.getValue();
        float gradWi = (float) (dloss * Xi);

        final Entry theta = getEntryW(x);
        float wi = theta.getW();

        final float eta = eta(theta, t, gradWi);
        float nextWi = wi - eta * (gradWi + 2.f * _lambdaW * wi);
        if (!NumberUtils.isFinite(nextWi)) {
            throw new IllegalStateException(
                "Got " + nextWi + " for next W[" + x.getFeature() + "]\n" + "Xi=" + Xi + ", gradWi="
                        + gradWi + ", wi=" + wi + ", dloss=" + dloss + ", eta=" + eta + ", t=" + t);
        }
        if (MathUtils.closeToZero(nextWi, 1E-9f)) {
            removeEntry(theta);
            return;
        }
        theta.setW(nextWi);
    }

    /**
     * Update Wi using Follow-the-Regularized-Leader
     */
    private void updateWi_FTRL(final double dloss, @Nonnull final Feature x) {
        final double Xi = x.getValue();
        float gradWi = (float) (dloss * Xi);

        final Entry theta = getEntryW(x);

        final float z = theta.updateZ(gradWi, _alpha);
        final double n = theta.updateN(gradWi);

        if (Math.abs(z) <= _lambda1) {
            removeEntry(theta);
            return;
        }

        final float nextWi = (float) ((MathUtils.sign(z) * _lambda1 - z)
                / ((_beta + Math.sqrt(n)) / _alpha + _lambda2));
        if (!NumberUtils.isFinite(nextWi)) {
            throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature()
                    + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + theta.getW()
                    + ", dloss=" + dloss + ", n=" + n + ", z=" + z);
        }
        if (MathUtils.closeToZero(nextWi, 1E-9f)) {
            removeEntry(theta);
            return;
        }
        theta.setW(nextWi);
    }

    protected abstract void removeEntry(@Nonnull final Entry entry);

    void updateV(final double dloss, @Nonnull final Feature x, @Nonnull final int yField,
            final int f, final double sumViX, long t) {
        if (_useFTRL) {
            updateV_FTRL(dloss, x, yField, f, sumViX);
            return;
        }

        final Entry theta = getEntryV(x, yField);
        if (theta == null) {
            return;
        }

        final double Xi = x.getValue();
        final double h = Xi * sumViX;
        final float gradV = (float) (dloss * h);
        final float lambdaVf = getLambdaV(f);

        final float currentV = theta.getV(f);
        final float eta = eta(theta, f, t, gradV);
        final float nextV = currentV - eta * (gradV + 2.f * lambdaVf * currentV);
        if (!NumberUtils.isFinite(nextV)) {
            throw new IllegalStateException(
                "Got " + nextV + " for next V" + f + '[' + x.getFeatureIndex() + "]\n" + "Xi=" + Xi
                        + ", Vif=" + currentV + ", h=" + h + ", gradV=" + gradV + ", lambdaVf="
                        + lambdaVf + ", dloss=" + dloss + ", sumViX=" + sumViX + ", t=" + t);
        }
        if (MathUtils.closeToZero(nextV, 1E-9f)) {
            theta.setV(f, 0.f);
            if (theta.removable()) { // Whether other factors are zero filled or not? Remove if zero filled
                removeEntry(theta);
            }
            return;
        }
        theta.setV(f, nextV);
    }

    private void updateV_FTRL(final double dloss, @Nonnull final Feature x,
            @Nonnull final int yField, final int f, final double sumViX) {
        final Entry theta = getEntryV(x, yField);
        if (theta == null) {
            return;
        }

        final double Xi = x.getValue();
        final double h = Xi * sumViX;
        final float gradV = (float) (dloss * h);

        float oldV = theta.getV(f);
        final float z = theta.updateZ(f, oldV, gradV, _alpha);
        final double n = theta.updateN(f, gradV);

        if (Math.abs(z) <= _lambda1) {
            theta.setV(f, 0.f);
            if (theta.removable()) { // Whether other factors are zero filled or not? Remove if zero filled
                removeEntry(theta);
            }
            return;
        }

        final float nextV = (float) ((MathUtils.sign(z) * _lambda1 - z)
                / ((_beta + Math.sqrt(n)) / _alpha + _lambda2));
        if (!NumberUtils.isFinite(nextV)) {
            throw new IllegalStateException(
                "Got " + nextV + " for next V" + f + '[' + x.getFeatureIndex() + "]\n" + "Xi=" + Xi
                        + ", Vif=" + theta.getV(f) + ", h=" + h + ", gradV=" + gradV + ", dloss="
                        + dloss + ", sumViX=" + sumViX + ", n=" + n + ", z=" + z);
        }
        if (MathUtils.closeToZero(nextV, 1E-9f)) {
            theta.setV(f, 0.f);
            if (theta.removable()) { // Whether other factors are zero filled or not? Remove if zero filled
                removeEntry(theta);
            }
            return;
        }
        theta.setV(f, nextV);
    }

    protected final float eta(@Nonnull final Entry theta, final long t, final float grad) {
        return eta(theta, 0, t, grad);
    }

    protected final float eta(@Nonnull final Entry theta, @Nonnegative final int f, final long t,
            final float grad) {
        if (_useAdaGrad) {
            double gg = theta.getSumOfSquaredGradients(f);
            theta.addGradient(f, grad);
            return (float) (_eta.eta(t) / Math.sqrt(_eps + gg));
        } else {
            return _eta.eta(t);
        }
    }

    /**
     * sum{XiViaf} where a is field index of Xi
     */
    @Nonnull
    final DoubleArray3D sumVfX(@Nonnull final Feature[] x, @Nonnull final IntArrayList fieldList,
            @Nullable DoubleArray3D cached) {
        final int xSize = x.length;
        final int fieldSize = fieldList.size();
        final int factors = _factor;

        final DoubleArray3D mdarray;
        if (cached == null) {
            mdarray = new DoubleArray3D();
            mdarray.setSanityCheck(false);
        } else {
            mdarray = cached;
        }
        mdarray.configure(xSize, fieldSize, factors);

        for (int i = 0; i < xSize; i++) {
            for (int fieldIndex = 0; fieldIndex < fieldSize; fieldIndex++) {
                final int yField = fieldList.get(fieldIndex);
                for (int f = 0; f < factors; f++) {
                    double val = sumVfX(x, i, yField, f);
                    mdarray.set(i, fieldIndex, f, val);
                }
            }
        }

        return mdarray;
    }

    private double sumVfX(@Nonnull final Feature[] x, final int i, @Nonnull final int yField,
            final int f) {
        final Feature xi = x[i];
        final int xiFeature = xi.getFeatureIndex();
        final double xiValue = xi.getValue();
        final int xiField = xi.getField();
        double ret = 0.d;
        // find all other features whose field matches field
        for (Feature e : x) {
            if (e.getFeatureIndex() == xiFeature) { // ignore x[i] = e
                continue;
            }
            if (e.getField() == yField) { // multiply x_e and v_d,field(e),f
                float Vjf = getV(e, xiField, f);
                ret += Vjf * xiValue;
            }
        }
        if (!NumberUtils.isFinite(ret)) {
            throw new IllegalStateException("Got " + ret + " for sumV[ " + i + "][ " + f + "]X.\n"
                    + "x = " + Arrays.toString(x));
        }
        return ret;
    }

    @Nonnull
    protected abstract Entry getEntryW(@Nonnull Feature x);

    @Nullable
    protected abstract Entry getEntryV(@Nonnull Feature x, @Nonnull int yField);

    @Override
    protected final String varDump(@Nonnull final Feature[] x) {
        final StringBuilder buf1 = new StringBuilder(1024);
        final StringBuilder buf2 = new StringBuilder(1024);

        // X
        for (int i = 0; i < x.length; i++) {
            Feature e = x[i];
            String j = e.getFeature();
            double xj = e.getValue();
            if (i != 0) {
                buf1.append(", ");
            }
            buf1.append("x[").append(j).append("] = ").append(xj);
        }
        buf1.append("\n");

        // w0
        double ret = getW0();
        buf1.append("predict(x) = w0");
        buf2.append("predict(x) = ").append(ret);

        // W
        for (Feature e : x) {
            String i = e.getFeature();
            double xi = e.getValue();
            float wi = getW(e);

            buf1.append(" + (w[").append(i).append("] * x[").append(i).append("])");
            buf2.append(" + (").append(wi).append(" * ").append(xi).append(')');

            double wx = wi * xi;
            ret += wx;

            if (!NumberUtils.isFinite(ret)) {
                return buf1.append(" + ... = ")
                           .append(ret)
                           .append('\n')
                           .append(buf2)
                           .append(" + ... = ")
                           .append(ret)
                           .toString();
            }
        }

        // V
        for (int i = 0; i < x.length; i++) {
            final Feature ei = x[i];
            final String fi = ei.getFeature();
            final double xi = ei.getValue();
            final int iField = ei.getField();
            for (int j = i + 1; j < x.length; j++) {
                final Feature ej = x[j];
                final String fj = ej.getFeature();
                final double xj = ej.getValue();
                final int jField = ej.getField();
                for (int f = 0, k = _factor; f < k; f++) {
                    float vijf = getV(ei, jField, f);
                    float vjif = getV(ej, iField, f);

                    buf1.append(" + (v[i")
                        .append(fi)
                        .append("-j")
                        .append(jField)
                        .append("-f")
                        .append(f)
                        .append("] * v[j")
                        .append(fj)
                        .append("-i")
                        .append(iField)
                        .append("-f")
                        .append(f)
                        .append("] * x[")
                        .append(fi)
                        .append("] * x[")
                        .append(fj)
                        .append("])");

                    buf2.append(" + (")
                        .append(vijf)
                        .append(" * ")
                        .append(vjif)
                        .append(" * ")
                        .append(xi)
                        .append(" * ")
                        .append(xj)
                        .append(')');

                    ret += vijf * vjif * xi * xj;

                    if (!NumberUtils.isFinite(ret)) {
                        return buf1.append(" + ... = ")
                                   .append(ret)
                                   .append('\n')
                                   .append(buf2)
                                   .append(" + ... = ")
                                   .append(ret)
                                   .toString();
                    }
                }
            }
        }

        return buf1.append(" = ")
                   .append(ret)
                   .append('\n')
                   .append(buf2)
                   .append(" = ")
                   .append(ret)
                   .toString();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy