All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lens.ml.algo.spark.TableTrainingSpec Maven / Gradle / Ivy

There is a newer version: 2.7.1
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.lens.ml.algo.spark;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import org.apache.lens.api.LensException;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hive.hcatalog.data.HCatRecord;
import org.apache.hive.hcatalog.data.schema.HCatFieldSchema;
import org.apache.hive.hcatalog.data.schema.HCatSchema;
import org.apache.hive.hcatalog.mapreduce.HCatInputFormat;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;

import com.google.common.base.Preconditions;
import lombok.Getter;
import lombok.ToString;

/**
 * The Class TableTrainingSpec.
 */
@ToString
public class TableTrainingSpec implements Serializable {

  /** The Constant LOG. */
  public static final Log LOG = LogFactory.getLog(TableTrainingSpec.class);

  /** The training rdd. */
  @Getter
  private transient RDD trainingRDD;

  /** The testing rdd. */
  @Getter
  private transient RDD testingRDD;

  /** The database. */
  @Getter
  private String database;

  /** The table. */
  @Getter
  private String table;

  /** The partition filter. */
  @Getter
  private String partitionFilter;

  /** The feature columns. */
  @Getter
  private List featureColumns;

  /** The label column. */
  @Getter
  private String labelColumn;

  /** The conf. */
  @Getter
  private transient HiveConf conf;

  // By default all samples are considered for training
  /** The split training. */
  private boolean splitTraining;

  /** The training fraction. */
  private double trainingFraction = 1.0;

  /** The label pos. */
  int labelPos;

  /** The feature positions. */
  int[] featurePositions;

  /** The num features. */
  int numFeatures;

  /** The labeled rdd. */
  transient JavaRDD labeledRDD;

  /**
   * New builder.
   *
   * @return the table training spec builder
   */
  public static TableTrainingSpecBuilder newBuilder() {
    return new TableTrainingSpecBuilder();
  }

  /**
   * The Class TableTrainingSpecBuilder.
   */
  public static class TableTrainingSpecBuilder {

    /** The spec. */
    final TableTrainingSpec spec;

    /**
     * Instantiates a new table training spec builder.
     */
    public TableTrainingSpecBuilder() {
      spec = new TableTrainingSpec();
    }

    /**
     * Hive conf.
     *
     * @param conf the conf
     * @return the table training spec builder
     */
    public TableTrainingSpecBuilder hiveConf(HiveConf conf) {
      spec.conf = conf;
      return this;
    }

    /**
     * Database.
     *
     * @param db the db
     * @return the table training spec builder
     */
    public TableTrainingSpecBuilder database(String db) {
      spec.database = db;
      return this;
    }

    /**
     * Table.
     *
     * @param table the table
     * @return the table training spec builder
     */
    public TableTrainingSpecBuilder table(String table) {
      spec.table = table;
      return this;
    }

    /**
     * Partition filter.
     *
     * @param partFilter the part filter
     * @return the table training spec builder
     */
    public TableTrainingSpecBuilder partitionFilter(String partFilter) {
      spec.partitionFilter = partFilter;
      return this;
    }

    /**
     * Label column.
     *
     * @param labelColumn the label column
     * @return the table training spec builder
     */
    public TableTrainingSpecBuilder labelColumn(String labelColumn) {
      spec.labelColumn = labelColumn;
      return this;
    }

    /**
     * Feature columns.
     *
     * @param featureColumns the feature columns
     * @return the table training spec builder
     */
    public TableTrainingSpecBuilder featureColumns(List featureColumns) {
      spec.featureColumns = featureColumns;
      return this;
    }

    /**
     * Builds the.
     *
     * @return the table training spec
     */
    public TableTrainingSpec build() {
      return spec;
    }

    /**
     * Training fraction.
     *
     * @param trainingFraction the training fraction
     * @return the table training spec builder
     */
    public TableTrainingSpecBuilder trainingFraction(double trainingFraction) {
      Preconditions.checkArgument(trainingFraction >= 0 && trainingFraction <= 1.0,
        "Training fraction shoule be between 0 and 1");
      spec.trainingFraction = trainingFraction;
      spec.splitTraining = true;
      return this;
    }
  }

  /**
   * The Class DataSample.
   */
  public static class DataSample implements Serializable {

    /** The labeled point. */
    private final LabeledPoint labeledPoint;

    /** The sample. */
    private final double sample;

    /**
     * Instantiates a new data sample.
     *
     * @param labeledPoint the labeled point
     */
    public DataSample(LabeledPoint labeledPoint) {
      sample = Math.random();
      this.labeledPoint = labeledPoint;
    }
  }

  /**
   * The Class TrainingFilter.
   */
  public static class TrainingFilter implements Function {

    /** The training fraction. */
    private double trainingFraction;

    /**
     * Instantiates a new training filter.
     *
     * @param fraction the fraction
     */
    public TrainingFilter(double fraction) {
      trainingFraction = fraction;
    }

    /*
     * (non-Javadoc)
     *
     * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
     */
    @Override
    public Boolean call(DataSample v1) throws Exception {
      return v1.sample <= trainingFraction;
    }
  }

  /**
   * The Class TestingFilter.
   */
  public static class TestingFilter implements Function {

    /** The training fraction. */
    private double trainingFraction;

    /**
     * Instantiates a new testing filter.
     *
     * @param fraction the fraction
     */
    public TestingFilter(double fraction) {
      trainingFraction = fraction;
    }

    /*
     * (non-Javadoc)
     *
     * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
     */
    @Override
    public Boolean call(DataSample v1) throws Exception {
      return v1.sample > trainingFraction;
    }
  }

  /**
   * The Class GetLabeledPoint.
   */
  public static class GetLabeledPoint implements Function {

    /*
     * (non-Javadoc)
     *
     * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
     */
    @Override
    public LabeledPoint call(DataSample v1) throws Exception {
      return v1.labeledPoint;
    }
  }

  /**
   * Validate.
   *
   * @return true, if successful
   */
  boolean validate() {
    List columns;
    try {
      HCatInputFormat.setInput(conf, database == null ? "default" : database, table, partitionFilter);
      HCatSchema tableSchema = HCatInputFormat.getTableSchema(conf);
      columns = tableSchema.getFields();
    } catch (IOException exc) {
      LOG.error("Error getting table info " + toString(), exc);
      return false;
    }

    LOG.info(table + " columns " + columns.toString());

    boolean valid = false;
    if (columns != null && !columns.isEmpty()) {
      // Check labeled column
      List columnNames = new ArrayList();
      for (HCatFieldSchema col : columns) {
        columnNames.add(col.getName());
      }

      // Need at least one feature column and one label column
      valid = columnNames.contains(labelColumn) && columnNames.size() > 1;

      if (valid) {
        labelPos = columnNames.indexOf(labelColumn);

        // Check feature columns
        if (featureColumns == null || featureColumns.isEmpty()) {
          // feature columns are not provided, so all columns except label column are feature columns
          featurePositions = new int[columnNames.size() - 1];
          int p = 0;
          for (int i = 0; i < columnNames.size(); i++) {
            if (i == labelPos) {
              continue;
            }
            featurePositions[p++] = i;
          }

          columnNames.remove(labelPos);
          featureColumns = columnNames;
        } else {
          // Feature columns were provided, verify all feature columns are present in the table
          valid = columnNames.containsAll(featureColumns);
          if (valid) {
            // Get feature positions
            featurePositions = new int[featureColumns.size()];
            for (int i = 0; i < featureColumns.size(); i++) {
              featurePositions[i] = columnNames.indexOf(featureColumns.get(i));
            }
          }
        }
        numFeatures = featureColumns.size();
      }
    }

    return valid;
  }

  /**
   * Creates the rd ds.
   *
   * @param sparkContext the spark context
   * @throws LensException the lens exception
   */
  public void createRDDs(JavaSparkContext sparkContext) throws LensException {
    // Validate the spec
    if (!validate()) {
      throw new LensException("Table spec not valid: " + toString());
    }

    LOG.info("Creating RDDs with spec " + toString());

    // Get the RDD for table
    JavaPairRDD tableRDD;
    try {
      tableRDD = HiveTableRDD.createHiveTableRDD(sparkContext, conf, database, table, partitionFilter);
    } catch (IOException e) {
      throw new LensException(e);
    }

    // Map into trainable RDD
    // TODO: Figure out a way to use custom value mappers
    FeatureValueMapper[] valueMappers = new FeatureValueMapper[numFeatures];
    final DoubleValueMapper doubleMapper = new DoubleValueMapper();
    for (int i = 0; i < numFeatures; i++) {
      valueMappers[i] = doubleMapper;
    }

    ColumnFeatureFunction trainPrepFunction = new ColumnFeatureFunction(featurePositions, valueMappers, labelPos,
      numFeatures, 0);
    labeledRDD = tableRDD.map(trainPrepFunction);

    if (splitTraining) {
      // We have to split the RDD between a training RDD and a testing RDD
      LOG.info("Splitting RDD for table " + database + "." + table + " with split fraction " + trainingFraction);
      JavaRDD sampledRDD = labeledRDD.map(new Function() {
        @Override
        public DataSample call(LabeledPoint v1) throws Exception {
          return new DataSample(v1);
        }
      });

      trainingRDD = sampledRDD.filter(new TrainingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
      testingRDD = sampledRDD.filter(new TestingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
    } else {
      LOG.info("Using same RDD for train and test");
      trainingRDD = labeledRDD.rdd();
      testingRDD = trainingRDD;
    }
    LOG.info("Generated RDDs");
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy