Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.optimizer;
import hivemall.model.IWeightValue;
import hivemall.optimizer.Optimizer.OptimizerBase;
import it.unimi.dsi.fastutil.objects.Object2ObjectMap;
import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap;
import java.util.Map;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
public final class SparseOptimizerFactory {
private static final Log LOG = LogFactory.getLog(SparseOptimizerFactory.class);
@Nonnull
public static Optimizer create(@Nonnull final int ndims,
@Nonnull final Map options) {
final String optimizerName = options.get("optimizer");
if (optimizerName == null) {
throw new IllegalArgumentException("`optimizer` not defined");
}
final String name = optimizerName.toLowerCase();
if ("rda".equalsIgnoreCase(options.get("regularization"))
&& "adagrad".equals(name) == false) {
throw new IllegalArgumentException(
"`-regularization rda` is only supported for AdaGrad but `-optimizer "
+ optimizerName + "`. Please specify `-regularization l1` and so on.");
}
final OptimizerBase optimizerImpl;
if ("sgd".equals(name)) {
optimizerImpl = new Optimizer.SGD(options);
} else if ("momentum".equals(name)) {
optimizerImpl = new Momentum(ndims, options);
} else if ("nesterov".equals(name)) {
options.put("nesterov", "");
optimizerImpl = new Momentum(ndims, options);
} else if ("adagrad".equals(name)) {
// If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`.
if ("rda".equalsIgnoreCase(options.get("regularization"))) {
AdaGrad adagrad = new AdaGrad(ndims, options);
optimizerImpl = new AdagradRDA(ndims, adagrad, options);
} else {
optimizerImpl = new AdaGrad(ndims, options);
}
} else if ("rmsprop".equals(name)) {
optimizerImpl = new RMSprop(ndims, options);
} else if ("rmspropgraves".equals(name) || "rmsprop_graves".equals(name)) {
optimizerImpl = new RMSpropGraves(ndims, options);
} else if ("adadelta".equals(name)) {
optimizerImpl = new AdaDelta(ndims, options);
} else if ("adam".equals(name)) {
optimizerImpl = new Adam(ndims, options);
} else if ("nadam".equals(name)) {
optimizerImpl = new Nadam(ndims, options);
} else if ("eve".equals(name)) {
optimizerImpl = new Eve(ndims, options);
} else if ("adam_hd".equals(name) || "adamhd".equals(name)) {
optimizerImpl = new AdamHD(ndims, options);
} else {
throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName);
}
if (LOG.isInfoEnabled()) {
LOG.info(
"Configured " + optimizerImpl.getOptimizerName() + " as the optimizer: " + options);
LOG.info("ETA estimator: " + optimizerImpl._eta);
}
return optimizerImpl;
}
@NotThreadSafe
static final class Momentum extends Optimizer.Momentum {
@Nonnull
private final Object2ObjectMap