All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.optimizer.EtaEstimator Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.optimizer;

import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.StringUtils;

import java.util.Map;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;

public abstract class EtaEstimator {

    public static final float DEFAULT_ETA0 = 0.1f;
    public static final float DEFAULT_ETA = 0.3f;
    public static final double DEFAULT_POWER_T = 0.1d;

    protected final float eta0;

    public EtaEstimator(float eta0) {
        this.eta0 = eta0;
    }

    @Nonnull
    public abstract String typeName();

    public float eta0() {
        return eta0;
    }

    public abstract float eta(long t);

    public void update(@Nonnegative float multiplier) {}

    public void getHyperParameters(@Nonnull Map hyperParams) {
        hyperParams.put("eta", typeName());
        hyperParams.put("eta0", eta0());
    }

    public static final class FixedEtaEstimator extends EtaEstimator {

        public FixedEtaEstimator(float eta) {
            super(eta);
        }

        @Nonnull
        public String typeName() {
            return "Fixed";
        }

        @Override
        public float eta(long t) {
            return eta0;
        }

        @Override
        public String toString() {
            return "FixedEtaEstimator [ eta0 = " + eta0 + " ]";
        }

    }

    public static final class SimpleEtaEstimator extends EtaEstimator {

        private final float finalEta;
        private final double total_steps;

        public SimpleEtaEstimator(float eta0, long total_steps) {
            super(eta0);
            this.finalEta = (float) (eta0 / 2.d);
            this.total_steps = total_steps;
        }

        @Nonnull
        public String typeName() {
            return "Simple";
        }

        @Override
        public float eta(final long t) {
            if (t > total_steps) {
                return finalEta;
            }
            return (float) (eta0 / (1.d + (t / total_steps)));
        }

        @Override
        public String toString() {
            return "SimpleEtaEstimator [ eta0 = " + eta0 + ", totalSteps = " + total_steps
                    + ", finalEta = " + finalEta + " ]";
        }

        public void getHyperParameters(@Nonnull Map hyperParams) {
            super.getHyperParameters(hyperParams);
            hyperParams.put("total_steps", total_steps);
        }

    }

    public static final class InvscalingEtaEstimator extends EtaEstimator {

        private final double power_t;

        public InvscalingEtaEstimator(float eta0, double power_t) {
            super(eta0);
            this.power_t = power_t;
        }

        @Nonnull
        public String typeName() {
            return "Invscaling";
        }

        @Override
        public float eta(final long t) {
            return (float) (eta0 / Math.pow(t, power_t));
        }

        @Override
        public String toString() {
            return "InvscalingEtaEstimator [ eta0 = " + eta0 + ", power_t = " + power_t + " ]";
        }

        public void getHyperParameters(@Nonnull Map hyperParams) {
            super.getHyperParameters(hyperParams);
            hyperParams.put("power_t", power_t);
        }
    }

    /**
     * bold driver: Gemulla et al., Large-scale matrix factorization with distributed stochastic
     * gradient descent, KDD 2011.
     */
    public static final class AdjustingEtaEstimator extends EtaEstimator {

        private float eta;

        public AdjustingEtaEstimator(float eta) {
            super(eta);
            this.eta = eta;
        }

        @Nonnull
        public String typeName() {
            return "boldDriver";
        }

        @Override
        public float eta(long t) {
            return eta;
        }

        @Override
        public void update(@Nonnegative float multiplier) {
            float newEta = eta * multiplier;
            if (!NumberUtils.isFinite(newEta)) {
                // avoid NaN or INFINITY
                return;
            }
            this.eta = Math.min(eta0, newEta); // never be larger than eta0
        }

        @Override
        public String toString() {
            return "AdjustingEtaEstimator [ eta0 = " + eta0 + ", eta = " + eta + " ]";
        }

    }

    @Nonnull
    public static EtaEstimator get(@Nullable CommandLine cl) throws UDFArgumentException {
        return get(cl, DEFAULT_ETA0);
    }

    @Nonnull
    public static EtaEstimator get(@Nullable CommandLine cl, float defaultEta0)
            throws UDFArgumentException {
        if (cl == null) {
            return new InvscalingEtaEstimator(defaultEta0, DEFAULT_POWER_T);
        }

        if (cl.hasOption("boldDriver")) {
            float eta = Primitives.parseFloat(cl.getOptionValue("eta"), DEFAULT_ETA);
            return new AdjustingEtaEstimator(eta);
        }

        String etaValue = cl.getOptionValue("eta");
        if (etaValue != null) {
            float eta = Float.parseFloat(etaValue);
            return new FixedEtaEstimator(eta);
        }

        float eta0 = Primitives.parseFloat(cl.getOptionValue("eta0"), defaultEta0);
        if (cl.hasOption("t")) {
            long t = Long.parseLong(cl.getOptionValue("t"));
            return new SimpleEtaEstimator(eta0, t);
        }

        double power_t = Primitives.parseDouble(cl.getOptionValue("power_t"), DEFAULT_POWER_T);
        return new InvscalingEtaEstimator(eta0, power_t);
    }

    @Nonnull
    public static EtaEstimator get(@Nonnull final Map options)
            throws IllegalArgumentException {
        final float eta0 = Primitives.parseFloat(options.get("eta0"), DEFAULT_ETA0);
        final double power_t = Primitives.parseDouble(options.get("power_t"), DEFAULT_POWER_T);

        final String etaScheme = options.get("eta");
        if (etaScheme == null) {
            return new InvscalingEtaEstimator(eta0, power_t);
        }

        if ("fixed".equalsIgnoreCase(etaScheme)) {
            return new FixedEtaEstimator(eta0);
        } else if ("simple".equalsIgnoreCase(etaScheme)) {
            final long t;
            if (options.containsKey("total_steps")) {
                t = Long.parseLong(options.get("total_steps"));
            } else {
                throw new IllegalArgumentException(
                    "-total_steps MUST be provided when `-eta simple` is specified");
            }
            return new SimpleEtaEstimator(eta0, t);
        } else if ("inv".equalsIgnoreCase(etaScheme) || "inverse".equalsIgnoreCase(etaScheme)
                || "invscaling".equalsIgnoreCase(etaScheme)) {
            return new InvscalingEtaEstimator(eta0, power_t);
        } else {
            if (StringUtils.isNumber(etaScheme)) {
                float eta = Float.parseFloat(etaScheme);
                return new FixedEtaEstimator(eta);
            }
            throw new IllegalArgumentException("Unsupported ETA name: " + etaScheme);
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy