org.apache.lens.ml.algo.spark.BaseSparkAlgo Maven / Gradle / Ivy
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.lens.ml.algo.spark;
import java.lang.reflect.Field;
import java.util.*;
import org.apache.lens.api.LensConf;
import org.apache.lens.api.LensException;
import org.apache.lens.ml.algo.api.AlgoParam;
import org.apache.lens.ml.algo.api.Algorithm;
import org.apache.lens.ml.algo.api.MLAlgo;
import org.apache.lens.ml.algo.api.MLModel;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
/**
* The Class BaseSparkAlgo.
*/
public abstract class BaseSparkAlgo implements MLAlgo {
/** The Constant LOG. */
public static final Log LOG = LogFactory.getLog(BaseSparkAlgo.class);
/** The name. */
private final String name;
/** The description. */
private final String description;
/** The spark context. */
protected JavaSparkContext sparkContext;
/** The params. */
protected Map params;
/** The conf. */
protected transient LensConf conf;
/** The training fraction. */
@AlgoParam(name = "trainingFraction", help = "% of dataset to be used for training", defaultValue = "0")
protected double trainingFraction;
/** The use training fraction. */
private boolean useTrainingFraction;
/** The label. */
@AlgoParam(name = "label", help = "Name of column which is used as a training label for supervised learning")
protected String label;
/** The partition filter. */
@AlgoParam(name = "partition", help = "Partition filter used to create create HCatInputFormats")
protected String partitionFilter;
/** The features. */
@AlgoParam(name = "feature", help = "Column name(s) which are to be used as sample features")
protected List features;
/**
* Instantiates a new base spark algo.
*
* @param name the name
* @param description the description
*/
public BaseSparkAlgo(String name, String description) {
this.name = name;
this.description = description;
}
public void setSparkContext(JavaSparkContext sparkContext) {
this.sparkContext = sparkContext;
}
@Override
public LensConf getConf() {
return conf;
}
/*
* (non-Javadoc)
*
* @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
*/
@Override
public void configure(LensConf configuration) {
this.conf = configuration;
}
/*
* (non-Javadoc)
*
* @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
* java.lang.String, java.lang.String[])
*/
@Override
public MLModel> train(LensConf conf, String db, String table, String modelId, String... params)
throws LensException {
parseParams(params);
TableTrainingSpec.TableTrainingSpecBuilder builder = TableTrainingSpec.newBuilder().hiveConf(toHiveConf(conf))
.database(db).table(table).partitionFilter(partitionFilter).featureColumns(features).labelColumn(label);
if (useTrainingFraction) {
builder.trainingFraction(trainingFraction);
}
TableTrainingSpec spec = builder.build();
LOG.info("Training " + " with " + features.size() + " features");
spec.createRDDs(sparkContext);
RDD trainingRDD = spec.getTrainingRDD();
BaseSparkClassificationModel> model = trainInternal(modelId, trainingRDD);
model.setTable(table);
model.setParams(Arrays.asList(params));
model.setLabelColumn(label);
model.setFeatureColumns(features);
return model;
}
/**
* To hive conf.
*
* @param conf the conf
* @return the hive conf
*/
protected HiveConf toHiveConf(LensConf conf) {
HiveConf hiveConf = new HiveConf();
for (String key : conf.getProperties().keySet()) {
hiveConf.set(key, conf.getProperties().get(key));
}
return hiveConf;
}
/**
* Parses the params.
*
* @param args the args
*/
public void parseParams(String[] args) {
if (args.length % 2 != 0) {
throw new IllegalArgumentException("Invalid number of params " + args.length);
}
params = new LinkedHashMap();
for (int i = 0; i < args.length; i += 2) {
if ("f".equalsIgnoreCase(args[i]) || "feature".equalsIgnoreCase(args[i])) {
if (features == null) {
features = new ArrayList();
}
features.add(args[i + 1]);
} else if ("l".equalsIgnoreCase(args[i]) || "label".equalsIgnoreCase(args[i])) {
label = args[i + 1];
} else {
params.put(args[i].replaceAll("\\-+", ""), args[i + 1]);
}
}
if (params.containsKey("trainingFraction")) {
// Get training Fraction
String trainingFractionStr = params.get("trainingFraction");
try {
trainingFraction = Double.parseDouble(trainingFractionStr);
useTrainingFraction = true;
} catch (NumberFormatException nfe) {
throw new IllegalArgumentException("Invalid training fraction", nfe);
}
}
if (params.containsKey("partition") || params.containsKey("p")) {
partitionFilter = params.containsKey("partition") ? params.get("partition") : params.get("p");
}
parseAlgoParams(params);
}
/**
* Gets the param value.
*
* @param param the param
* @param defaultVal the default val
* @return the param value
*/
public double getParamValue(String param, double defaultVal) {
if (params.containsKey(param)) {
try {
return Double.parseDouble(params.get(param));
} catch (NumberFormatException nfe) {
LOG.warn("Couldn't parse param value: " + param + " as double.");
}
}
return defaultVal;
}
/**
* Gets the param value.
*
* @param param the param
* @param defaultVal the default val
* @return the param value
*/
public int getParamValue(String param, int defaultVal) {
if (params.containsKey(param)) {
try {
return Integer.parseInt(params.get(param));
} catch (NumberFormatException nfe) {
LOG.warn("Couldn't parse param value: " + param + " as integer.");
}
}
return defaultVal;
}
public String getName() {
return name;
}
public String getDescription() {
return description;
}
public Map getArgUsage() {
Map usage = new LinkedHashMap();
Class> clz = this.getClass();
// Put class name and description as well as part of the usage
Algorithm algorithm = clz.getAnnotation(Algorithm.class);
if (algorithm != null) {
usage.put("Algorithm Name", algorithm.name());
usage.put("Algorithm Description", algorithm.description());
}
// Get all algo params including base algo params
while (clz != null) {
for (Field field : clz.getDeclaredFields()) {
AlgoParam param = field.getAnnotation(AlgoParam.class);
if (param != null) {
usage.put("[param] " + param.name(), param.help() + " Default Value = " + param.defaultValue());
}
}
if (clz.equals(BaseSparkAlgo.class)) {
break;
}
clz = clz.getSuperclass();
}
return usage;
}
/**
* Parses the algo params.
*
* @param params the params
*/
public abstract void parseAlgoParams(Map params);
/**
* Train internal.
*
* @param modelId the model id
* @param trainingRDD the training rdd
* @return the base spark classification model
* @throws LensException the lens exception
*/
protected abstract BaseSparkClassificationModel trainInternal(String modelId, RDD trainingRDD)
throws LensException;
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy