All Downloads are FREE. Search and download functionalities are using the official Maven repository.

templates.data-delivery-data-records.spark.schema.base.vm Maven / Gradle / Ivy

The newest version!
package ${basePackage};

#foreach ($import in $record.recordBaseImports)
import ${import};
#end

#foreach ($import in $record.baseImports)
import ${import};
#end

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.spark.sql.Column;
import scala.collection.JavaConverters;
import scala.collection.Seq;

import com.boozallen.aiops.data.delivery.spark.SparkSchema;
import com.boozallen.aiops.data.delivery.spark.SparkDatasetUtils;

import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.lit;

/**
 * Base implementation of the Spark schema for ${record.capitalizedName}.
 *
 * GENERATED CODE - DO NOT MODIFY (add your customizations in ${record.capitalizedName}).
 *
 * Generated from: ${templateName} 
 */
public abstract class ${record.capitalizedName}SchemaBase extends SparkSchema {

#set($columnVars = {})
#foreach ($field in $record.fields)
    #set ($columnVars[$field.name] = "${field.upperSnakecaseName}_COLUMN")
    public static final String ${columnVars[$field.name]} = "${field.sparkAttributes.columnName}";
#end

    public ${record.capitalizedName}SchemaBase() {
        super();

  #foreach ($field in $record.fields)
      #if ($field.sparkAttributes.isDecimalType())
        add(${columnVars[$field.name]}, new ${field.shortType}(${field.sparkAttributes.defaultDecimalPrecision}, ${field.sparkAttributes.decimalScale}), ${field.sparkAttributes.isNullable()}, "${field.description}");
      #else
        add(${columnVars[$field.name]}, ${field.shortType}, ${field.sparkAttributes.isNullable()}, "${field.description}");
      #end
  #end
    }

    /**
     * Casts the given dataset to this schema.
     * 
     * @param dataset
     *            the dataset to cast
     * @return the dataset cast to this schema
     */
    public Dataset cast(Dataset dataset) {
    #foreach ($field in $record.fields)
        DataType ${field.name}Type = getDataType(${columnVars[$field.name]});
    #end

        return dataset
        #foreach ($field in $record.fields)
        		.withColumn(${columnVars[$field.name]}, col(${columnVars[$field.name]}).cast(${field.name}Type))#if (!$foreach.hasNext);#end
        #end
    }


    /***
     * Gets a list of field names that are of type Date (org.apache.spark.sql.types.DataTypes.DateType)
     * @return List of field names
     */
    public List getDateFields() {
        List dateFields = new ArrayList<>();
        List headerFields = Arrays.asList(getStructType().fieldNames());
        for(String headerField: headerFields) {
            DataType dataType = getDataType(headerField);
            if(dataType.equals(DataTypes.DateType)) {
                dateFields.add(headerField);
            }
        }

        return dateFields;
    }

    /***
     * Appends spark validation logic to an unvalidated spark DataFrame (org.apache.spark.sql.Dataset)
     * @return Dataset with appended validation logic
     */
    public Dataset validateDataFrame(Dataset data) {
        Dataset dataWithValidations = data
            #foreach ($field in $record.fields)
                #if (${field.isRequired()})
                    .withColumn(${columnVars[$field.name]} + "_IS_NOT_NULL", col(${columnVars[$field.name]}).isNotNull())
                #end
                #if (${field.getValidation().getMinValue()})
                    .withColumn(${columnVars[$field.name]} + "_GREATER_THAN_MIN", col(${columnVars[$field.name]}).gt(lit(${field.getValidation().getMinValue()})).or(col(${columnVars[$field.name]}).equalTo(lit(${field.getValidation().getMinValue()}))))
                #end
                #if (${field.getValidation().getMaxValue()})
                    .withColumn(${columnVars[$field.name]} + "_LESS_THAN_MAX", col(${columnVars[$field.name]}).lt(lit(${field.getValidation().getMaxValue()})).or(col(${columnVars[$field.name]}).equalTo(lit(${field.getValidation().getMaxValue()}))))
                #end
                #if (${field.getValidation().getScale()})
                    .withColumn(${columnVars[$field.name]} + "_MATCHES_SCALE", col(${columnVars[$field.name]}).rlike(("^[0-9]*(?:\\.[0-9]{0,${field.getValidation().getScale()}})?$")))
                #end
                #if (${field.getValidation().getMinLength()})
                    .withColumn(${columnVars[$field.name]} + "_GREATER_THAN_MAX_LENGTH", col(${columnVars[$field.name]}).rlike(("^.{${field.getValidation().getMinLength()},}")))
                #end
                #if (${field.getValidation().getMaxLength()})
                    .withColumn(${columnVars[$field.name]} + "_LESS_THAN_MAX_LENGTH", col(${columnVars[$field.name]}).rlike(("^.{${field.getValidation().getMaxLength()},}")).equalTo(lit(false)))
                #end
                #foreach ($format in $field.getValidation().getFormats())
                    #if ($foreach.first)
                        .withColumn(${columnVars[$field.name]} + "_MATCHES_FORMAT", col(${columnVars[$field.name]}).rlike(("$format.replace("\","\\")"))
                    #else
                        .or(col(${columnVars[$field.name]}).rlike(("$format.replace("\","\\")")))
                    #end
                    #if ($foreach.last)
                        )
                    #end
                #end
            #end ;

        Column filterSchema = null;
        List validationColumns = new ArrayList<>();
        Collections.addAll(validationColumns, dataWithValidations.columns());
        validationColumns.removeAll(Arrays.asList(data.columns()));
        for (String columnName : validationColumns) {
            if (filterSchema == null) {
                filterSchema = col(columnName).equalTo(lit(true));
            } else {
                filterSchema = filterSchema.and(col(columnName).equalTo(lit(true)));
            }
        }
        // Isolate the valid data
        Dataset validData = data;
        if (filterSchema != null) {
            validData = dataWithValidations.filter(filterSchema);
        }

        // Remove validation columns from valid data
        Seq columnsToDrop =
                JavaConverters.collectionAsScalaIterableConverter(validationColumns).asScala().toSeq();
        validData = validData.drop(columnsToDrop);

        return validData;
    }

    /**
     * Returns a given record as a Spark dataset row.
     *
     * @return the record as a Spark dataset row
     */
    public static Row asRow(${record.capitalizedName} record) {
        return RowFactory.create(
        #foreach ($field in $record.fields)
            #if ($field.type.dictionaryType.isComplex())
            record.get${field.capitalizedName}() != null ? record.get${field.capitalizedName}().getValue() : null#if ($foreach.hasNext),#end
            #else
            record.get${field.capitalizedName}()#if ($foreach.hasNext),#end
            #end
        #end
        );
    }


    /**
     * Returns a given Spark dataset row as a record.
     *
     * @return the row as a record
     */
    public static ${record.capitalizedName} mapRow(Row row) {
    ${record.capitalizedName} record = new ${record.capitalizedName}();
    #foreach ($field in $record.fields)
        ${field.type.dictionaryType.genericShortType} ${field.name}Value = (${field.type.dictionaryType.genericShortType}) SparkDatasetUtils.getRowValue(row, "${field.fieldName}");
        #if ($field.type.dictionaryType.isComplex())
        record.set${field.capitalizedName}(new ${field.genericType}(${field.name}Value));
        #else
        record.set${field.capitalizedName}(${field.name}Value);
        #end
    #end
    return record;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy