All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.optimizer.Regularization Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.optimizer;

import hivemall.utils.lang.Primitives;

import java.util.Map;

import javax.annotation.Nonnull;

public abstract class Regularization {
    /** the default regularization term 0.0001 */
    private static final float DEFAULT_LAMBDA = 0.0001f;

    protected final float lambda;

    public Regularization(@Nonnull Map options) {
        this.lambda = Primitives.parseFloat(options.get("lambda"), DEFAULT_LAMBDA);
    }

    public void getHyperParameters(@Nonnull Map hyperParams) {
        hyperParams.put("lambda", lambda);
    }

    public float regularize(final float weight, final float gradient) {
        return gradient + lambda * getRegularizer(weight);
    }

    abstract float getRegularizer(float weight);

    public static final class PassThrough extends Regularization {

        public PassThrough(final Map options) {
            super(options);
        }

        @Override
        public float getRegularizer(float weight) {
            return 0.f;
        }

        @Override
        public float regularize(final float weight, final float gradient) {
            return gradient;
        }

        @Override
        public void getHyperParameters(@Nonnull Map hyperParams) {
            super.getHyperParameters(hyperParams);
            hyperParams.put("regularization", "no");
        }
    }

    public static final class L1 extends Regularization {

        public L1(Map options) {
            super(options);
        }

        @Override
        public float getRegularizer(final float weight) {
            return weight > 0.f ? 1.f : -1.f;
        }

        @Override
        public void getHyperParameters(@Nonnull Map hyperParams) {
            super.getHyperParameters(hyperParams);
            hyperParams.put("regularization", "L1");
        }
    }

    public static final class L2 extends Regularization {

        public L2(final Map options) {
            super(options);
        }

        @Override
        public float getRegularizer(float weight) {
            return weight;
        }

        @Override
        public void getHyperParameters(@Nonnull Map hyperParams) {
            super.getHyperParameters(hyperParams);
            hyperParams.put("regularization", "L2");
        }
    }

    public static final class ElasticNet extends Regularization {
        private static final float DEFAULT_L1_RATIO = 0.5f;

        @Nonnull
        private final L1 l1;
        @Nonnull
        private final L2 l2;

        private final float l1Ratio;

        public ElasticNet(@Nonnull Map options) {
            super(options);

            this.l1 = new L1(options);
            this.l2 = new L2(options);

            this.l1Ratio = Primitives.parseFloat(options.get("l1_ratio"), DEFAULT_L1_RATIO);
            if (l1Ratio < 0.f || l1Ratio > 1.f) {
                throw new IllegalArgumentException(
                    "L1 ratio should be in [0.0, 1.0], but got " + l1Ratio);
            }
        }

        @Override
        public float getRegularizer(final float weight) {
            return l1Ratio * l1.getRegularizer(weight)
                    + (1.f - l1Ratio) * l2.getRegularizer(weight);
        }

        @Override
        public void getHyperParameters(@Nonnull Map hyperParams) {
            super.getHyperParameters(hyperParams);
            hyperParams.put("regularization", "ElasticNet");
            hyperParams.put("l1_ratio", l1Ratio);
        }
    }

    @Nonnull
    public static Regularization get(@Nonnull final Map options)
            throws IllegalArgumentException {
        final String regName = options.get("regularization");
        if (regName == null) {
            return new PassThrough(options);
        }

        if ("no".equalsIgnoreCase(regName)) {
            return new PassThrough(options);
        } else if ("l1".equalsIgnoreCase(regName)) {
            return new L1(options);
        } else if ("l2".equalsIgnoreCase(regName)) {
            return new L2(options);
        } else if ("elasticnet".equalsIgnoreCase(regName)) {
            return new ElasticNet(options);
        } else if ("rda".equalsIgnoreCase(regName)) {
            // Return `PassThrough` because we need special handling for RDA.
            // See an implementation of `Optimizer#RDA`.
            return new PassThrough(options);
        } else {
            throw new IllegalArgumentException("Unsupported regularization name: " + regName);
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy