All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lens.ml.impl.TableTestingSpec Maven / Gradle / Ivy

There is a newer version: 2.7.1
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.lens.ml.impl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.ql.metadata.Hive;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.metadata.Table;

import lombok.Getter;

/**
 * Table specification for running test on a table.
 */
public class TableTestingSpec {

  /** The Constant LOG. */
  public static final Log LOG = LogFactory.getLog(TableTestingSpec.class);

  /** The db. */
  private String db;

  /** The table containing input data. */
  private String inputTable;

  // TODO use partition condition
  /** The partition filter. */
  private String partitionFilter;

  /** The feature columns. */
  private List featureColumns;

  /** The label column. */
  private String labelColumn;

  /** The output column. */
  private String outputColumn;

  /** The output table. */
  private String outputTable;

  /** The conf. */
  private transient HiveConf conf;

  /** The algorithm. */
  private String algorithm;

  /** The model id. */
  private String modelID;

  @Getter
  private boolean outputTableExists;

  @Getter
  private String testID;

  private HashMap columnNameToFieldSchema;

  /**
   * The Class TableTestingSpecBuilder.
   */
  public static class TableTestingSpecBuilder {

    /** The spec. */
    private final TableTestingSpec spec;

    /**
     * Instantiates a new table testing spec builder.
     */
    public TableTestingSpecBuilder() {
      spec = new TableTestingSpec();
    }

    /**
     * Database.
     *
     * @param database the database
     * @return the table testing spec builder
     */
    public TableTestingSpecBuilder database(String database) {
      spec.db = database;
      return this;
    }

    /**
     * Set the input table
     *
     * @param table the table
     * @return the table testing spec builder
     */
    public TableTestingSpecBuilder inputTable(String table) {
      spec.inputTable = table;
      return this;
    }

    /**
     * Partition filter for input table
     *
     * @param partFilter the part filter
     * @return the table testing spec builder
     */
    public TableTestingSpecBuilder partitionFilter(String partFilter) {
      spec.partitionFilter = partFilter;
      return this;
    }

    /**
     * Feature columns.
     *
     * @param featureColumns the feature columns
     * @return the table testing spec builder
     */
    public TableTestingSpecBuilder featureColumns(List featureColumns) {
      spec.featureColumns = featureColumns;
      return this;
    }

    /**
     * Labe column.
     *
     * @param labelColumn the label column
     * @return the table testing spec builder
     */
    public TableTestingSpecBuilder lableColumn(String labelColumn) {
      spec.labelColumn = labelColumn;
      return this;
    }

    /**
     * Output column.
     *
     * @param outputColumn the output column
     * @return the table testing spec builder
     */
    public TableTestingSpecBuilder outputColumn(String outputColumn) {
      spec.outputColumn = outputColumn;
      return this;
    }

    /**
     * Output table.
     *
     * @param table the table
     * @return the table testing spec builder
     */
    public TableTestingSpecBuilder outputTable(String table) {
      spec.outputTable = table;
      return this;
    }

    /**
     * Hive conf.
     *
     * @param conf the conf
     * @return the table testing spec builder
     */
    public TableTestingSpecBuilder hiveConf(HiveConf conf) {
      spec.conf = conf;
      return this;
    }

    /**
     * Algorithm.
     *
     * @param algorithm the algorithm
     * @return the table testing spec builder
     */
    public TableTestingSpecBuilder algorithm(String algorithm) {
      spec.algorithm = algorithm;
      return this;
    }

    /**
     * Model id.
     *
     * @param modelID the model id
     * @return the table testing spec builder
     */
    public TableTestingSpecBuilder modelID(String modelID) {
      spec.modelID = modelID;
      return this;
    }

    /**
     * Builds the.
     *
     * @return the table testing spec
     */
    public TableTestingSpec build() {
      return spec;
    }

    /**
     * Set the unique test id
     *
     * @param testID
     * @return
     */
    public TableTestingSpecBuilder testID(String testID) {
      spec.testID = testID;
      return this;
    }
  }

  /**
   * New builder.
   *
   * @return the table testing spec builder
   */
  public static TableTestingSpecBuilder newBuilder() {
    return new TableTestingSpecBuilder();
  }

  /**
   * Validate.
   *
   * @return true, if successful
   */
  public boolean validate() {
    List columns;
    try {
      Hive metastoreClient = Hive.get(conf);
      Table tbl = (db == null) ? metastoreClient.getTable(inputTable) : metastoreClient.getTable(db, inputTable);
      columns = tbl.getAllCols();
      columnNameToFieldSchema = new HashMap();

      for (FieldSchema fieldSchema : columns) {
        columnNameToFieldSchema.put(fieldSchema.getName(), fieldSchema);
      }

      // Check if output table exists
      Table outTbl = metastoreClient.getTable(db == null ? "default" : db, outputTable, false);
      outputTableExists = (outTbl != null);
    } catch (HiveException exc) {
      LOG.error("Error getting table info " + toString(), exc);
      return false;
    }

    // Check if labeled column and feature columns are contained in the table
    List testTableColumns = new ArrayList(columns.size());
    for (FieldSchema column : columns) {
      testTableColumns.add(column.getName());
    }

    if (!testTableColumns.containsAll(featureColumns)) {
      LOG.info("Invalid feature columns: " + featureColumns + ". Actual columns in table:" + testTableColumns);
      return false;
    }

    if (!testTableColumns.contains(labelColumn)) {
      LOG.info("Invalid label column: " + labelColumn + ". Actual columns in table:" + testTableColumns);
      return false;
    }

    if (StringUtils.isBlank(outputColumn)) {
      LOG.info("Output column is required");
      return false;
    }

    if (StringUtils.isBlank(outputTable)) {
      LOG.info("Output table is required");
      return false;
    }
    return true;
  }

  public String getTestQuery() {
    if (!validate()) {
      return null;
    }

    // We always insert a dynamic partition
    StringBuilder q = new StringBuilder("INSERT OVERWRITE TABLE " + outputTable + " PARTITION (part_testid='" + testID
      + "')  SELECT ");
    String featureCols = StringUtils.join(featureColumns, ",");
    q.append(featureCols).append(",").append(labelColumn).append(", ").append("predict(").append("'").append(algorithm)
      .append("', ").append("'").append(modelID).append("', ").append(featureCols).append(") ").append(outputColumn)
      .append(" FROM ").append(inputTable);

    return q.toString();
  }

  public String getCreateOutputTableQuery() {
    StringBuilder createTableQuery = new StringBuilder("CREATE TABLE IF NOT EXISTS ").append(outputTable).append("(");
    // Output table contains feature columns, label column, output column
    List outputTableColumns = new ArrayList();
    for (String featureCol : featureColumns) {
      outputTableColumns.add(featureCol + " " + columnNameToFieldSchema.get(featureCol).getType());
    }

    outputTableColumns.add(labelColumn + " " + columnNameToFieldSchema.get(labelColumn).getType());
    outputTableColumns.add(outputColumn + " string");

    createTableQuery.append(StringUtils.join(outputTableColumns, ", "));

    // Append partition column
    createTableQuery.append(") PARTITIONED BY (part_testid string)");

    return createTableQuery.toString();
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy