All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.xgboost.XGBoostBatchPredictUDTF Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.xgboost;

import hivemall.UDTFWithOptions;
import hivemall.utils.collections.lists.FloatArrayList;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Primitives;
import hivemall.xgboost.utils.NativeLibLoader;
import hivemall.xgboost.utils.XGBoostUtils;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;

//@formatter:off
@Description(name = "xgboost_batch_predict",
        value = "_FUNC_(PRIMITIVE rowid, array features, string model_id, array pred_model [, string options]) "
                + "- Returns a prediction result as (string rowid, array predicted)",
        extended = "select\n" + 
                "  rowid, \n" + 
                "  array_avg(predicted) as predicted,\n" + 
                "  avg(predicted[0]) as predicted0\n" + 
                "from (\n" + 
                "  select\n" + 
                "    xgboost_batch_predict(rowid, features, model_id, model) as (rowid, predicted)\n" + 
                "  from\n" + 
                "    xgb_model l\n" + 
                "    LEFT OUTER JOIN xgb_input r\n" + 
                ") t\n" + 
                "group by rowid;")
//@formatter:on
public final class XGBoostBatchPredictUDTF extends UDTFWithOptions {

    // For input parameters
    private PrimitiveObjectInspector rowIdOI;
    private ListObjectInspector featureListOI;
    private boolean denseFeatures;
    @Nullable
    private PrimitiveObjectInspector featureElemOI;
    private StringObjectInspector modelIdOI;
    private StringObjectInspector modelOI;

    // For input buffer
    private transient Map mapToModel;
    private transient Map> rowBuffer;

    private int _batchSize;

    @Nonnull
    protected transient final Object[] _forwardObj;

    // Settings for the XGBoost native library
    static {
        NativeLibLoader.initXGBoost();
    }

    public XGBoostBatchPredictUDTF() {
        super();
        this._forwardObj = new Object[2];
    }

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("batch_size", true, "Number of rows to predict together [default: 128]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        int batchSize = 128;
        CommandLine cl = null;
        if (argOIs.length >= 5) {
            String rawArgs = HiveUtils.getConstString(argOIs, 4);
            cl = parseOptions(rawArgs);
            batchSize = Primitives.parseInt(cl.getOptionValue("batch_size"), batchSize);
            if (batchSize < 1) {
                throw new UDFArgumentException("batch_size must be greater than 0: " + batchSize);
            }
        }
        this._batchSize = batchSize;
        return cl;
    }

    @Override
    public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
            throws UDFArgumentException {
        if (argOIs.length != 4 && argOIs.length != 5) {
            showHelp("Invalid argment length=" + argOIs.length);
        }
        processOptions(argOIs);

        this.rowIdOI = HiveUtils.asPrimitiveObjectInspector(argOIs, 0);

        this.featureListOI = HiveUtils.asListOI(argOIs, 1);
        ObjectInspector elemOI = featureListOI.getListElementObjectInspector();
        if (HiveUtils.isNumberOI(elemOI)) {
            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
            this.denseFeatures = true;
        } else if (HiveUtils.isStringOI(elemOI)) {
            this.denseFeatures = false;
        } else {
            throw new UDFArgumentException(
                "Expected array for the 2nd argment but got an unexpected features type: "
                        + featureListOI.getTypeName());
        }
        this.modelIdOI = HiveUtils.asStringOI(argOIs, 2);
        this.modelOI = HiveUtils.asStringOI(argOIs, 3);

        return getReturnOI(rowIdOI);
    }

    /** Override this to output predicted results depending on a task type */
    /** Return (string rowid, array predicted) as a result */
    @Nonnull
    protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector rowIdOI) {
        List fieldNames = new ArrayList<>(2);
        List fieldOIs = new ArrayList<>(2);
        fieldNames.add("rowid");
        fieldOIs.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(
            rowIdOI.getPrimitiveCategory()));
        fieldNames.add("predicted");
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Override
    public void process(Object[] args) throws HiveException {
        if (mapToModel == null) {
            this.mapToModel = new HashMap();
            this.rowBuffer = new HashMap>();
        }
        if (args[1] == null) {
            return;
        }

        String modelId =
                PrimitiveObjectInspectorUtils.getString(nonNullArgument(args, 2), modelIdOI);
        Booster model = mapToModel.get(modelId);
        if (model == null) {
            Text arg3 = modelOI.getPrimitiveWritableObject(nonNullArgument(args, 3));
            model = XGBoostUtils.deserializeBooster(arg3);
            mapToModel.put(modelId, model);
        }

        List rowBatch = rowBuffer.get(modelId);
        if (rowBatch == null) {
            rowBatch = new ArrayList(_batchSize);
            rowBuffer.put(modelId, rowBatch);
        }
        LabeledPointWithRowId row = parseRow(args);
        rowBatch.add(row);
        if (rowBatch.size() >= _batchSize) {
            predictAndFlush(model, rowBatch);
        }
    }

    @Nonnull
    private LabeledPointWithRowId parseRow(@Nonnull Object[] args) throws UDFArgumentException {
        final Writable rowId = HiveUtils.copyToWritable(nonNullArgument(args, 0), rowIdOI);

        final Object arg1 = args[1];
        if (denseFeatures) {
            return parseDenseFeatures(rowId, arg1, featureListOI, featureElemOI);
        } else {
            return parseSparseFeatures(rowId, arg1, featureListOI);
        }
    }

    @Nonnull
    private static LabeledPointWithRowId parseDenseFeatures(@Nonnull final Writable rowId,
            @Nonnull final Object argObj, @Nonnull final ListObjectInspector featureListOI,
            @Nonnull final PrimitiveObjectInspector featureElemOI) throws UDFArgumentException {
        final int size = featureListOI.getListLength(argObj);

        final float[] values = new float[size];
        for (int i = 0; i < size; i++) {
            final Object o = featureListOI.getListElement(argObj, i);
            if (o == null) {
                values[i] = Float.NaN;
            } else {
                float v = PrimitiveObjectInspectorUtils.getFloat(o, featureElemOI);
                values[i] = v;
            }
        }

        return new LabeledPointWithRowId(rowId, /* dummy label */ 0.f, null, values);

    }

    @Nonnull
    private static LabeledPointWithRowId parseSparseFeatures(@Nonnull final Writable rowId,
            @Nonnull final Object argObj, @Nonnull final ListObjectInspector featureListOI)
            throws UDFArgumentException {
        final int size = featureListOI.getListLength(argObj);
        final IntArrayList indices = new IntArrayList(size);
        final FloatArrayList values = new FloatArrayList(size);

        for (int i = 0; i < size; i++) {
            Object f = featureListOI.getListElement(argObj, i);
            if (f == null) {
                continue;
            }
            final String str = f.toString();
            final int pos = str.indexOf(':');
            if (pos < 1) {
                throw new UDFArgumentException("Invalid feature format: " + str);
            }
            final int index;
            final float value;
            try {
                index = Integer.parseInt(str.substring(0, pos));
                value = Float.parseFloat(str.substring(pos + 1));
            } catch (NumberFormatException e) {
                throw new UDFArgumentException("Failed to parse a feature value: " + str);
            }
            indices.add(index);
            values.add(value);
        }

        return new LabeledPointWithRowId(rowId, /* dummy label */ 0.f, indices.toArray(),
            values.toArray());
    }


    @Override
    public void close() throws HiveException {
        for (Entry> e : rowBuffer.entrySet()) {
            String modelId = e.getKey();
            List rowBatch = e.getValue();
            if (rowBatch.isEmpty()) {
                continue;
            }
            final Booster model = Objects.requireNonNull(mapToModel.get(modelId));
            try {
                predictAndFlush(model, rowBatch);
            } finally {
                XGBoostUtils.close(model);
            }
        }
        this.rowBuffer = null;
        this.mapToModel = null;
    }

    private void predictAndFlush(@Nonnull final Booster model,
            @Nonnull final List rowBatch) throws HiveException {
        DMatrix testData = null;
        final float[][] predicted;
        try {
            testData = XGBoostUtils.createDMatrix(rowBatch);
            predicted = model.predict(testData);
        } catch (XGBoostError e) {
            throw new HiveException("Exception caused at prediction", e);
        } finally {
            XGBoostUtils.close(testData);
        }
        forwardPredicted(rowBatch, predicted);
        rowBatch.clear();
    }

    private void forwardPredicted(@Nonnull final List rowBatch,
            @Nonnull final float[][] predicted) throws HiveException {
        if (rowBatch.size() != predicted.length) {
            throw new HiveException(String.format("buf.size() = %d but predicted.length = %d",
                rowBatch.size(), predicted.length));
        }
        if (predicted.length == 0) {
            return;
        }

        final int ncols = predicted[0].length;
        final List list = WritableUtils.newFloatList(ncols);

        final Object[] forwardObj = this._forwardObj;
        forwardObj[1] = list;

        for (int i = 0; i < predicted.length; i++) {
            Writable rowId = Objects.requireNonNull(rowBatch.get(i)).getRowId();
            forwardObj[0] = rowId;
            WritableUtils.setValues(predicted[i], list);
            forward(forwardObj);
        }
    }

    public static final class LabeledPointWithRowId extends LabeledPoint {
        private static final long serialVersionUID = -7150841669515184648L;

        @Nonnull
        final Writable rowId;

        LabeledPointWithRowId(@Nonnull Writable rowId, float label, @Nullable int[] indices,
                @Nonnull float[] values) {
            super(label, indices, values);
            this.rowId = rowId;
        }

        @Nonnull
        public Writable getRowId() {
            return rowId;
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy