All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.dimajix.spark.sql.sources.fixedwidth.FixedWidthFormat.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
/*
 * Adapted for fixed width format 2018 Kaya Kupferschmidt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.dimajix.spark.sql.sources.fixedwidth

import java.text.DecimalFormat
import java.text.DecimalFormatSymbols
import java.util.Locale

import com.univocity.parsers.fixed.FixedWidthWriter
import org.apache.commons.lang3.time.FastDateFormat
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.CodecStreams
import org.apache.spark.sql.execution.datasources.OutputWriter
import org.apache.spark.sql.execution.datasources.OutputWriterFactory
import org.apache.spark.sql.execution.datasources.TextBasedFileFormat
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.DateType
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.FloatType
import org.apache.spark.sql.types.NumericType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.TimestampType


class FixedWidthFormat extends TextBasedFileFormat with DataSourceRegister{
    override def shortName = "fixedwidth"

    override def toString = "FixedWidth"

    override def inferSchema(sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = ???

    override def prepareWrite(sparkSession: SparkSession, job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = {
        FixedWidthUtils.verifySchema(dataSchema)
        val conf = job.getConfiguration
        val fixedWidthOptions = new FixedWidthOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
        fixedWidthOptions.compressionCodec.foreach { codec =>
            CompressionCodecs.setCodecConfiguration(conf, codec)
        }

        new OutputWriterFactory {
            override def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = {
                new FixedWidthOutputWriter(path, dataSchema, context, fixedWidthOptions)
            }

            override def getFileExtension(context: TaskAttemptContext): String = {
                ".dat" + CodecStreams.getCompressionExtension(context)
            }
        }
    }
}


private[spark] class FixedWidthOutputWriter(
    file: String,
    schema: StructType,
    context: TaskAttemptContext,
    options: FixedWidthOptions) extends OutputWriter with Logging {

    private val writer = CodecStreams.createOutputStreamWriter(context, new Path(file))
    private val writerSettings = options.asWriterSettings(schema)
    writerSettings.setHeaders(schema.fieldNames: _*)
    private val gen = new FixedWidthWriter(writer, writerSettings)
    private var printHeader = options.headerFlag

    private val decimalSymbols = DecimalFormatSymbols.getInstance(Locale.ROOT)

    // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`.
    // When the value is null, this converter should not be called.
    private type ValueConverter = (InternalRow, Int) => String

    // `ValueConverter`s for all values in the fields of the schema
    private val valueConverters: Array[ValueConverter] =
        schema.map(makeConverter).toArray

    private def padNumber(number:String, fieldSize:Int) : String = {
        val (sign,abs) = if (number.head == '-')
            ("-", number.tail)
        else if (options.numbersPositiveSign)
            ("+", number)
        else
            ("", number)

        val padding = fieldSize - sign.length - abs.length
        if (options.numbersLeadingZeros && padding > 0)
            sign + ("0" * padding) + abs
        else
            sign + abs
    }

    private def fieldSize(field:StructField) : Int = {
        field.metadata.getLong("size").toInt
    }
    private def floatFormat(field:StructField) : DecimalFormat = {
        val width = fieldSize(field)
        if (field.metadata.contains("format"))
            new DecimalFormat(field.metadata.getString("format"), decimalSymbols)
        else
            new DecimalFormat("#." + "#" * width, decimalSymbols)
    }
    private def decimalFormat(field:StructField) : DecimalFormat = {
        if (field.metadata.contains("format")) {
            new DecimalFormat(field.metadata.getString("format"), decimalSymbols)
        }
        else {
            val dt = field.dataType.asInstanceOf[DecimalType]
            if (dt.scale > 0)
                new DecimalFormat("#" * (dt.precision - dt.scale) + "." + "0" * dt.scale, decimalSymbols)
            else
                new DecimalFormat("#" * dt.precision, decimalSymbols)
        }
    }

    private def makeConverter(field: StructField): ValueConverter = field.dataType match {
        case DateType => {
            val format = if (field.metadata.contains("format"))
                FastDateFormat.getInstance(field.metadata.getString("format"), Locale.US)
            else
                options.dateFormat
            (row: InternalRow, ordinal: Int) =>
                format.format(DateTimeUtils.toJavaDate(row.getInt(ordinal)))
        }

        case TimestampType => {
            val format = if (field.metadata.contains("format"))
                FastDateFormat.getInstance(field.metadata.getString("format"), options.timeZone, Locale.US)
            else
                options.timestampFormat
            (row: InternalRow, ordinal: Int) =>
                format.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)))
        }

        case FloatType => {
            val format = floatFormat(field)
            val width = fieldSize(field)

            (row: InternalRow, ordinal: Int) =>
                val v = row.getFloat(ordinal)
                val str = format.format(v)
                padNumber(str, width)
        }

        case DoubleType => {
            val format = floatFormat(field)
            val width = fieldSize(field)

            (row: InternalRow, ordinal: Int) =>
                val v = row.getDouble(ordinal)
                val str = format.format(v)
                padNumber(str, width)
        }

        case dt: DecimalType => {
            val format = decimalFormat(field)
            val width = fieldSize(field)

            (row: InternalRow, ordinal: Int) =>
                val v = row.getDecimal(ordinal, dt.precision, dt.scale).toJavaBigDecimal
                val str = format.format(v)
                padNumber(str, width)
        }

        case dt: NumericType => {
            val fieldSize = field.metadata.getLong("size").toInt

            (row: InternalRow, ordinal: Int) =>
                val str = row.get(ordinal, dt).toString
                padNumber(str, fieldSize)
        }

        case dt: DataType => {
            (row: InternalRow, ordinal: Int) =>
                row.get(ordinal, dt).toString
        }
    }

    private def convertRow(row: InternalRow): Seq[String] = {
        var i = 0
        val values = new Array[String](row.numFields)
        while (i < row.numFields) {
            if (!row.isNullAt(i)) {
                values(i) = valueConverters(i).apply(row, i)
            } else {
                values(i) = options.nullValue
            }
            i += 1
        }
        values
    }

    /**
      * Writes a single InternalRow using Univocity.
      */
    def write(row: InternalRow): Unit = {
        if (printHeader) {
            gen.writeHeaders()
        }
        gen.writeRow(convertRow(row): _*)
        printHeader = false
    }

    def close(): Unit = gen.close()

    def flush(): Unit = gen.flush()

    /*override*/ def path(): String = this.file
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy