All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.factorization.fm.FieldAwareFactorizationMachineUDTF Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.factorization.fm;

import hivemall.factorization.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.collections.Fastutil;
import hivemall.utils.collections.arrays.DoubleArray3D;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.math.MathUtils;
import it.unimi.dsi.fastutil.ints.Int2LongMap;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;

/**
 * Field-aware Factorization Machines.
 *
 * @link https://www.csie.ntu.edu.tw/~cjlin/libffm/
 * @since v0.5-rc.1
 */
@Description(name = "train_ffm",
        value = "_FUNC_(array x, double y [, const string options]) - Returns a prediction model")
public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachineUDTF {
    private static final Log LOG = LogFactory.getLog(FieldAwareFactorizationMachineUDTF.class);

    // ----------------------------------------
    // Learning hyper-parameters/options
    private boolean _globalBias;
    private boolean _linearCoeff;

    private int _numFeatures;
    private int _numFields;
    // ----------------------------------------

    protected transient FFMStringFeatureMapModel _ffmModel;

    private transient IntArrayList _fieldList;
    @Nullable
    private transient DoubleArray3D _sumVfX;

    public FieldAwareFactorizationMachineUDTF() {
        super();
    }

    @Override
    protected Options getOptions() {
        Options opts = super.getOptions();
        opts.addOption("w0", "global_bias", false,
            "Whether to include global bias term w0 [default: OFF]");
        opts.addOption("enable_wi", "linear_term", false, "Include linear term [default: OFF]");
        opts.addOption("no_norm", "disable_norm", false, "Disable instance-wise L2 normalization");
        // feature hashing
        opts.addOption("feature_hashing", true,
            "The number of bits for feature hashing in range [18,31] [default: -1]. No feature hashing for -1.");
        opts.addOption("num_fields", true,
            "The number of fields [default: " + Feature.DEFAULT_NUM_FIELDS + "]");
        // optimizer
        opts.addOption("opt", "optimizer", true,
            "Gradient Descent optimizer [default: ftrl, adagrad, sgd]");
        // adagrad
        opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default: 1.0]");
        // FTRL
        opts.addOption("alpha", "alphaFTRL", true,
            "Alpha value (learning rate) of Follow-The-Regularized-Reader [default: 0.5]");
        opts.addOption("beta", "betaFTRL", true,
            "Beta value (a learning smoothing parameter) of Follow-The-Regularized-Reader [default: 1.0]");
        opts.addOption("l1", "lambda1", true,
            "L1 regularization value of Follow-The-Regularized-Reader that controls model Sparseness [default: 0.0002]");
        opts.addOption("l2", "lambda2", true,
            "L2 regularization value of Follow-The-Regularized-Reader [default: 0.0001]");
        return opts;
    }

    @Override
    protected boolean isAdaptiveRegularizationSupported() {
        return false;
    }

    @Override
    protected FFMHyperParameters newHyperParameters() {
        return new FFMHyperParameters();
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = super.processOptions(argOIs);

        FFMHyperParameters params = (FFMHyperParameters) _params;
        this._globalBias = params.globalBias;
        this._linearCoeff = params.linearCoeff;
        this._numFeatures = params.numFeatures;
        this._numFields = params.numFields;

        return cl;
    }

    @Override
    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        StructObjectInspector oi = super.initialize(argOIs);

        this._fieldList = new IntArrayList();
        return oi;
    }

    @Override
    protected StructObjectInspector getOutputOI(@Nonnull FMHyperParameters params) {
        ArrayList fieldNames = new ArrayList();
        ArrayList fieldOIs = new ArrayList();

        fieldNames.add("model_id");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);

        fieldNames.add("i");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);

        fieldNames.add("Wi");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);

        fieldNames.add("Vi");
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.writableFloatObjectInspector));

        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Override
    protected FFMStringFeatureMapModel initModel(@Nonnull FMHyperParameters params)
            throws UDFArgumentException {
        FFMHyperParameters ffmParams = (FFMHyperParameters) params;

        FFMStringFeatureMapModel model = new FFMStringFeatureMapModel(ffmParams);
        this._ffmModel = model;
        return model;
    }

    @Override
    protected Feature[] parseFeatures(@Nonnull final Object arg) throws HiveException {
        Feature[] features = Feature.parseFFMFeatures(arg, _xOI, _probes, _numFeatures, _numFields);
        if (_params.l2norm) {
            Feature.l2normalize(features);
        }
        return features;
    }

    @Override
    protected void processValidationSample(@Nonnull final Feature[] x, final double y)
            throws HiveException {
        if (_earlyStopping) {
            double p = _model.predict(x);
            double loss = _lossFunction.loss(p, y);
            _validationState.incrLoss(loss);
        }
    }

    @Override
    protected void trainTheta(@Nonnull final Feature[] x, final double y) throws HiveException {
        final double p = _ffmModel.predict(x);
        final double lossGrad = _ffmModel.dloss(p, y);

        double loss = _lossFunction.loss(p, y);
        _cvState.incrLoss(loss);

        if (MathUtils.closeToZero(lossGrad, 1E-9d)) {
            return;
        }

        // w0 update
        if (_globalBias) {
            float eta_t = _etaEstimator.eta(_t);
            _ffmModel.updateW0(lossGrad, eta_t);
        }

        // ViFf update
        final IntArrayList fieldList = getFieldList(x);
        // sumVfX[i as in index for x][index for field list][index for factorized dimension]
        final DoubleArray3D sumVfX = _ffmModel.sumVfX(x, fieldList, _sumVfX);
        for (int i = 0; i < x.length; i++) {
            final Feature x_i = x[i];
            if (x_i.value == 0.f) {
                continue;
            }
            if (_linearCoeff) {
                _ffmModel.updateWi(lossGrad, x_i, _t);// wi update
            }
            for (int fieldIndex = 0, size = fieldList.size(); fieldIndex < size; fieldIndex++) {
                final int yField = fieldList.get(fieldIndex);
                for (int f = 0, k = _factors; f < k; f++) {
                    final double sumViX = sumVfX.get(i, fieldIndex, f);
                    if (MathUtils.closeToZero(sumViX)) {// grad will be 0 => skip it
                        continue;
                    }
                    _ffmModel.updateV(lossGrad, x_i, yField, f, sumViX, _t);
                }
            }
        }

        // clean up per training instance caches
        sumVfX.clear();
        this._sumVfX = sumVfX;
        fieldList.clear();
    }

    @Nonnull
    private IntArrayList getFieldList(@Nonnull final Feature[] x) {
        for (Feature e : x) {
            int field = e.getField();
            _fieldList.add(field);
        }
        return _fieldList;
    }

    @Override
    protected IntFeature instantiateFeature(@Nonnull final ByteBuffer input) {
        return new IntFeature(input);
    }

    @Override
    public void close() throws HiveException {
        if (LOG.isInfoEnabled()) {
            LOG.info(_ffmModel.getStatistics());
        }

        _ffmModel.disableInitV(); // trick to avoid re-instantiating removed (zero-filled) entry of V
        super.close();

        if (LOG.isInfoEnabled()) {
            LOG.info(_ffmModel.getStatistics());
        }
        this._ffmModel = null;
    }

    @Override
    protected void forwardModel() throws HiveException {
        this._model = null;
        this._fieldList = null;
        this._sumVfX = null;

        final int factors = _factors;
        final IntWritable idx = new IntWritable();
        final FloatWritable Wi = new FloatWritable(0.f);
        final FloatWritable[] Vi = HiveUtils.newFloatArray(factors, 0.f);
        final List ViObj = Arrays.asList(Vi);

        final Object[] forwardObjs = new Object[4];
        String modelId = HadoopUtils.getUniqueTaskIdString();
        forwardObjs[0] = new Text(modelId);
        forwardObjs[1] = idx;
        forwardObjs[2] = Wi;
        forwardObjs[3] = null; // Vi

        // W0
        idx.set(0);
        Wi.set(_ffmModel.getW0());
        forward(forwardObjs);

        final Entry entryW = new Entry(_ffmModel._buf, 1);
        final Entry entryV = new Entry(_ffmModel._buf, factors);
        final float[] Vf = new float[factors];

        for (Int2LongMap.Entry e : Fastutil.fastIterable(_ffmModel._map)) {
            // set i
            final int i = e.getIntKey();
            idx.set(i);

            final long offset = e.getLongValue();
            if (Entry.isEntryW(i)) {// set Wi
                entryW.setOffset(offset);
                float w = entryW.getW();
                if (w == 0.f) {
                    continue; // skip w_i=0
                }
                Wi.set(w);
                forwardObjs[2] = Wi;
                forwardObjs[3] = null;
            } else {// set Vif
                entryV.setOffset(offset);
                entryV.getV(Vf);
                for (int f = 0; f < factors; f++) {
                    Vi[f].set(Vf[f]);
                }
                forwardObjs[2] = null;
                forwardObjs[3] = ViObj;
            }

            forward(forwardObjs);
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy