org.apache.lens.ml.impl.TableTestingSpec Maven / Gradle / Ivy
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.lens.ml.impl;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.ql.metadata.Hive;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.metadata.Table;
import lombok.Getter;
/**
* Table specification for running test on a table.
*/
public class TableTestingSpec {
/** The Constant LOG. */
public static final Log LOG = LogFactory.getLog(TableTestingSpec.class);
/** The db. */
private String db;
/** The table containing input data. */
private String inputTable;
// TODO use partition condition
/** The partition filter. */
private String partitionFilter;
/** The feature columns. */
private List featureColumns;
/** The label column. */
private String labelColumn;
/** The output column. */
private String outputColumn;
/** The output table. */
private String outputTable;
/** The conf. */
private transient HiveConf conf;
/** The algorithm. */
private String algorithm;
/** The model id. */
private String modelID;
@Getter
private boolean outputTableExists;
@Getter
private String testID;
private HashMap columnNameToFieldSchema;
/**
* The Class TableTestingSpecBuilder.
*/
public static class TableTestingSpecBuilder {
/** The spec. */
private final TableTestingSpec spec;
/**
* Instantiates a new table testing spec builder.
*/
public TableTestingSpecBuilder() {
spec = new TableTestingSpec();
}
/**
* Database.
*
* @param database the database
* @return the table testing spec builder
*/
public TableTestingSpecBuilder database(String database) {
spec.db = database;
return this;
}
/**
* Set the input table
*
* @param table the table
* @return the table testing spec builder
*/
public TableTestingSpecBuilder inputTable(String table) {
spec.inputTable = table;
return this;
}
/**
* Partition filter for input table
*
* @param partFilter the part filter
* @return the table testing spec builder
*/
public TableTestingSpecBuilder partitionFilter(String partFilter) {
spec.partitionFilter = partFilter;
return this;
}
/**
* Feature columns.
*
* @param featureColumns the feature columns
* @return the table testing spec builder
*/
public TableTestingSpecBuilder featureColumns(List featureColumns) {
spec.featureColumns = featureColumns;
return this;
}
/**
* Labe column.
*
* @param labelColumn the label column
* @return the table testing spec builder
*/
public TableTestingSpecBuilder lableColumn(String labelColumn) {
spec.labelColumn = labelColumn;
return this;
}
/**
* Output column.
*
* @param outputColumn the output column
* @return the table testing spec builder
*/
public TableTestingSpecBuilder outputColumn(String outputColumn) {
spec.outputColumn = outputColumn;
return this;
}
/**
* Output table.
*
* @param table the table
* @return the table testing spec builder
*/
public TableTestingSpecBuilder outputTable(String table) {
spec.outputTable = table;
return this;
}
/**
* Hive conf.
*
* @param conf the conf
* @return the table testing spec builder
*/
public TableTestingSpecBuilder hiveConf(HiveConf conf) {
spec.conf = conf;
return this;
}
/**
* Algorithm.
*
* @param algorithm the algorithm
* @return the table testing spec builder
*/
public TableTestingSpecBuilder algorithm(String algorithm) {
spec.algorithm = algorithm;
return this;
}
/**
* Model id.
*
* @param modelID the model id
* @return the table testing spec builder
*/
public TableTestingSpecBuilder modelID(String modelID) {
spec.modelID = modelID;
return this;
}
/**
* Builds the.
*
* @return the table testing spec
*/
public TableTestingSpec build() {
return spec;
}
/**
* Set the unique test id
*
* @param testID
* @return
*/
public TableTestingSpecBuilder testID(String testID) {
spec.testID = testID;
return this;
}
}
/**
* New builder.
*
* @return the table testing spec builder
*/
public static TableTestingSpecBuilder newBuilder() {
return new TableTestingSpecBuilder();
}
/**
* Validate.
*
* @return true, if successful
*/
public boolean validate() {
List columns;
try {
Hive metastoreClient = Hive.get(conf);
Table tbl = (db == null) ? metastoreClient.getTable(inputTable) : metastoreClient.getTable(db, inputTable);
columns = tbl.getAllCols();
columnNameToFieldSchema = new HashMap();
for (FieldSchema fieldSchema : columns) {
columnNameToFieldSchema.put(fieldSchema.getName(), fieldSchema);
}
// Check if output table exists
Table outTbl = metastoreClient.getTable(db == null ? "default" : db, outputTable, false);
outputTableExists = (outTbl != null);
} catch (HiveException exc) {
LOG.error("Error getting table info " + toString(), exc);
return false;
}
// Check if labeled column and feature columns are contained in the table
List testTableColumns = new ArrayList(columns.size());
for (FieldSchema column : columns) {
testTableColumns.add(column.getName());
}
if (!testTableColumns.containsAll(featureColumns)) {
LOG.info("Invalid feature columns: " + featureColumns + ". Actual columns in table:" + testTableColumns);
return false;
}
if (!testTableColumns.contains(labelColumn)) {
LOG.info("Invalid label column: " + labelColumn + ". Actual columns in table:" + testTableColumns);
return false;
}
if (StringUtils.isBlank(outputColumn)) {
LOG.info("Output column is required");
return false;
}
if (StringUtils.isBlank(outputTable)) {
LOG.info("Output table is required");
return false;
}
return true;
}
public String getTestQuery() {
if (!validate()) {
return null;
}
// We always insert a dynamic partition
StringBuilder q = new StringBuilder("INSERT OVERWRITE TABLE " + outputTable + " PARTITION (part_testid='" + testID
+ "') SELECT ");
String featureCols = StringUtils.join(featureColumns, ",");
q.append(featureCols).append(",").append(labelColumn).append(", ").append("predict(").append("'").append(algorithm)
.append("', ").append("'").append(modelID).append("', ").append(featureCols).append(") ").append(outputColumn)
.append(" FROM ").append(inputTable);
return q.toString();
}
public String getCreateOutputTableQuery() {
StringBuilder createTableQuery = new StringBuilder("CREATE TABLE IF NOT EXISTS ").append(outputTable).append("(");
// Output table contains feature columns, label column, output column
List outputTableColumns = new ArrayList();
for (String featureCol : featureColumns) {
outputTableColumns.add(featureCol + " " + columnNameToFieldSchema.get(featureCol).getType());
}
outputTableColumns.add(labelColumn + " " + columnNameToFieldSchema.get(labelColumn).getType());
outputTableColumns.add(outputColumn + " string");
createTableQuery.append(StringUtils.join(outputTableColumns, ", "));
// Append partition column
createTableQuery.append(") PARTITIONED BY (part_testid string)");
return createTableQuery.toString();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy