All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.LearnerBaseUDTF Maven / Gradle / Ivy

There is a newer version: 0.6.0-incubating
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall;

import hivemall.mix.MixMessage.MixEventName;
import hivemall.mix.client.MixClient;
import hivemall.model.DenseModel;
import hivemall.model.NewDenseModel;
import hivemall.model.NewSpaceEfficientDenseModel;
import hivemall.model.NewSparseModel;
import hivemall.model.PredictionModel;
import hivemall.model.SpaceEfficientDenseModel;
import hivemall.model.SparseModel;
import hivemall.model.SynchronizedModelWrapper;
import hivemall.optimizer.DenseOptimizerFactory;
import hivemall.optimizer.Optimizer;
import hivemall.optimizer.SparseOptimizerFactory;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.Primitives;

import java.util.Map;

import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

public abstract class LearnerBaseUDTF extends UDTFWithOptions {
    private static final Log logger = LogFactory.getLog(LearnerBaseUDTF.class);
    private static final int DEFAULT_SPARSE_DIMS = 16384;
    private static final int DEFAULT_DENSE_DIMS = 16777216;

    protected final boolean enableNewModel;
    protected boolean dense_model;
    protected int model_dims;
    protected boolean disable_halffloat;
    protected boolean is_mini_batch;
    protected int mini_batch_size;
    protected String mixConnectInfo;
    protected String mixSessionName;
    protected int mixThreshold;
    protected boolean mixCancel;
    protected boolean ssl;

    @Nullable
    protected MixClient mixClient;

    public LearnerBaseUDTF(boolean enableNewModel) {
        this.enableNewModel = enableNewModel;
    }

    protected boolean useCovariance() {
        return false;
    }

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("dense", "densemodel", false, "Use dense model or not");
        opts.addOption("dims", "feature_dimensions", true,
            "The dimension of model [default: 16777216 (2^24)]");
        opts.addOption("disable_halffloat", false,
            "Toggle this option to disable the use of SpaceEfficientDenseModel");
        opts.addOption("mini_batch", "mini_batch_size", true,
            "Mini batch size [default: 1]. Expecting the value in range [1,100] or so.");
        opts.addOption("mix", "mix_servers", true, "Comma separated list of MIX servers");
        opts.addOption("mix_session", "mix_session_name", true,
            "Mix session name [default: ${mapred.job.id}]");
        opts.addOption("mix_threshold", true,
            "Threshold to mix local updates in range (0,127] [default: 3]");
        opts.addOption("mix_cancel", "enable_mix_canceling", false, "Enable mix cancel requests");
        opts.addOption("ssl", false, "Use SSL for the communication with mix servers");
        return opts;
    }

    @Nullable
    @Override
    protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
            throws UDFArgumentException {
        boolean denseModel = false;
        int modelDims = -1;
        boolean disableHalfFloat = false;
        int miniBatchSize = 1;
        String mixConnectInfo = null;
        String mixSessionName = null;
        int mixThreshold = -1;
        boolean mixCancel = false;
        boolean ssl = false;

        CommandLine cl = null;
        if (argOIs.length >= 3) {
            String rawArgs = HiveUtils.getConstString(argOIs[2]);
            cl = parseOptions(rawArgs);

            denseModel = cl.hasOption("dense");
            if (denseModel) {
                modelDims = Primitives.parseInt(cl.getOptionValue("dims"), DEFAULT_DENSE_DIMS);
            }
            disableHalfFloat = cl.hasOption("disable_halffloat");

            miniBatchSize =
                    Primitives.parseInt(cl.getOptionValue("mini_batch_size"), miniBatchSize);
            if (miniBatchSize <= 0) {
                throw new UDFArgumentException(
                    "mini_batch_size must be greater than 0: " + miniBatchSize);
            }

            mixConnectInfo = cl.getOptionValue("mix");
            mixSessionName = cl.getOptionValue("mix_session");
            mixThreshold = Primitives.parseInt(cl.getOptionValue("mix_threshold"), 3);
            if (mixThreshold > Byte.MAX_VALUE) {
                throw new UDFArgumentException(
                    "mix_threshold must be in range (0,127]: " + mixThreshold);
            }
            mixCancel = cl.hasOption("mix_cancel");
            ssl = cl.hasOption("ssl");
        }

        this.dense_model = denseModel;
        this.model_dims = modelDims;
        this.disable_halffloat = disableHalfFloat;
        this.is_mini_batch = miniBatchSize > 1;
        this.mini_batch_size = miniBatchSize;
        this.mixConnectInfo = mixConnectInfo;
        this.mixSessionName = mixSessionName;
        this.mixThreshold = mixThreshold;
        this.mixCancel = mixCancel;
        this.ssl = ssl;
        return cl;
    }

    @Nullable
    protected PredictionModel createModel() {
        if (enableNewModel) {
            return createNewModel(null);
        } else {
            return createOldModel(null);
        }
    }

    @Nonnull
    private final PredictionModel createOldModel(@Nullable String label) {
        PredictionModel model;
        final boolean useCovar = useCovariance();
        if (dense_model) {
            if (disable_halffloat == false && model_dims > DEFAULT_DENSE_DIMS) {
                logger.info("Build a space efficient dense model with " + model_dims
                        + " initial dimensions" + (useCovar ? " w/ covariances" : ""));
                model = new SpaceEfficientDenseModel(model_dims, useCovar);
            } else {
                logger.info("Build a dense model with initial with " + model_dims
                        + " initial dimensions" + (useCovar ? " w/ covariances" : ""));
                model = new DenseModel(model_dims, useCovar);
            }
        } else {
            int initModelSize = getInitialModelSize();
            logger.info(
                "Build a sparse model with initial with " + initModelSize + " initial dimensions");
            model = new SparseModel(initModelSize, useCovar);
        }
        if (mixConnectInfo != null) {
            model.configureClock();
            model = new SynchronizedModelWrapper(model);
            MixClient client = configureMixClient(mixConnectInfo, label, model);
            model.configureMix(client, mixCancel);
            this.mixClient = client;
        }
        assert (model != null);
        return model;
    }

    @Nonnull
    private final PredictionModel createNewModel(@Nullable String label) {
        PredictionModel model;
        final boolean useCovar = useCovariance();
        if (dense_model) {
            if (disable_halffloat == false && model_dims > DEFAULT_DENSE_DIMS) {
                logger.info("Build a space efficient dense model with " + model_dims
                        + " initial dimensions" + (useCovar ? " w/ covariances" : ""));
                model = new NewSpaceEfficientDenseModel(model_dims, useCovar);
            } else {
                logger.info("Build a dense model with initial with " + model_dims
                        + " initial dimensions" + (useCovar ? " w/ covariances" : ""));
                model = new NewDenseModel(model_dims, useCovar);
            }
        } else {
            int initModelSize = getInitialModelSize();
            logger.info(
                "Build a sparse model with initial with " + initModelSize + " initial dimensions");
            model = new NewSparseModel(initModelSize, useCovar);
        }
        if (mixConnectInfo != null) {
            model.configureClock();
            model = new SynchronizedModelWrapper(model);
            MixClient client = configureMixClient(mixConnectInfo, label, model);
            model.configureMix(client, mixCancel);
            this.mixClient = client;
        }
        assert (model != null);
        return model;
    }

    @Nonnull
    protected final Optimizer createOptimizer(@CheckForNull Map options) {
        Preconditions.checkNotNull(options);
        if (dense_model) {
            return DenseOptimizerFactory.create(model_dims < 0 ? DEFAULT_DENSE_DIMS : model_dims,
                options);
        } else {
            return SparseOptimizerFactory.create(model_dims < 0 ? DEFAULT_SPARSE_DIMS : model_dims,
                options);
        }
    }

    @Nonnull
    protected MixClient configureMixClient(@Nonnull String connectURIs, @Nullable String label,
            @Nonnull PredictionModel model) {
        String jobId = (mixSessionName == null) ? MixClient.DUMMY_JOB_ID : mixSessionName;
        if (label != null) {
            jobId = jobId + '-' + label;
        }
        MixEventName event = useCovariance() ? MixEventName.argminKLD : MixEventName.average;
        MixClient client = new MixClient(event, jobId, connectURIs, ssl, mixThreshold, model);
        logger.info("Successfully configured mix client: " + connectURIs);
        return client;
    }

    protected int getInitialModelSize() {
        return 16384;
    }

    @Nonnull
    protected ObjectInspector getFeatureOutputOI(@Nonnull PrimitiveObjectInspector featureInputOI)
            throws UDFArgumentException {
        if (dense_model) {
            // TODO validation
            return PrimitiveObjectInspectorFactory.javaIntObjectInspector; // see DenseModel
        }
        return ObjectInspectorUtils.getStandardObjectInspector(featureInputOI);
    }

    @Override
    public void close() throws HiveException {
        if (mixClient != null) {
            IOUtils.closeQuietly(mixClient);
            this.mixClient = null;
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy