All Downloads are FREE. Search and download functionalities are using the official Maven repository.

za.co.absa.cobrix.spark.cobol.schema.CobolSchema.scala Maven / Gradle / Ivy

/*
 * Copyright 2018-2019 ABSA Group Limited
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package za.co.absa.cobrix.spark.cobol.schema

import java.time.ZonedDateTime
import java.time.format.DateTimeFormatter
import org.apache.spark.sql.types._
import org.slf4j.LoggerFactory
import za.co.absa.cobrix.cobol.parser.Copybook
import za.co.absa.cobrix.cobol.parser.ast._
import za.co.absa.cobrix.cobol.parser.ast.datatype.{AlphaNumeric, Decimal, Integral}
import za.co.absa.cobrix.cobol.parser.common.{Constants, ReservedWords}
import za.co.absa.cobrix.spark.cobol.schema.SchemaRetentionPolicy.SchemaRetentionPolicy

import scala.collection.mutable.ArrayBuffer

/**
  * This class provides a view on a COBOL schema from the perspective of Spark. When provided with a parsed copybook the class
  * provides the corresponding Spark schema and also other properties for the Spark data source.
  *
  * @param copybook                A parsed copybook.
  * @param policy                  Specifies a policy to transform the input schema. The default policy is to keep the schema exactly as it is in the copybook.
  * @param generateRecordId        If true, a record id field will be prepended to to the begginning of the schema.
  * @param generateSegIdFieldsCnt  A number of segment ID levels to generate
  * @param segmentIdProvidedPrefix A prefix for each segment id levels to make segment ids globally unique (by default the current timestamp will be used)
  */
class CobolSchema(val copybook: Copybook,
                  policy: SchemaRetentionPolicy,
                  generateRecordId: Boolean,
                  generateSegIdFieldsCnt: Int = 0,
                  segmentIdProvidedPrefix: String = "") extends Serializable {

  private val logger = LoggerFactory.getLogger(this.getClass)

  val segmentIdPrefix: String = if (segmentIdProvidedPrefix.isEmpty) getDefaultSegmentIdPrefix else segmentIdProvidedPrefix

  def getCobolSchema: Copybook = copybook

  @throws(classOf[IllegalStateException])
  private[this] lazy val sparkSchema = {
    logger.info("Layout positions:\n" + copybook.generateRecordLayoutPositions())
    val records = for (record <- copybook.ast.children) yield {
      parseGroup(record.asInstanceOf[Group])
    }
    val expandRecords = if (policy == SchemaRetentionPolicy.CollapseRoot) {
      // Expand root group fields
      records.toArray.flatMap(group => group.dataType.asInstanceOf[StructType].fields)
    } else {
      records.toArray
    }

    val recordsWithSegmentFields = if (generateSegIdFieldsCnt > 0) {
      val newFields = for (level <- Range(0, generateSegIdFieldsCnt))
        yield StructField(s"${Constants.segmentIdField}$level", StringType, nullable = true)
      newFields.toArray ++ expandRecords
    } else {
      expandRecords
    }

    val recordsWithRecordId = if (generateRecordId) {
      StructField(Constants.fileIdField, IntegerType, nullable = false) +:
        StructField(Constants.recordIdField, LongType, nullable = false) +: recordsWithSegmentFields
    } else {
      recordsWithSegmentFields
    }
    StructType(recordsWithRecordId)
  }

  @throws(classOf[IllegalStateException])
  private[this] lazy val sparkFlatSchema = {
    logger.info("Layout positions:\n" + copybook.generateRecordLayoutPositions())
    val arraySchema = copybook.ast.children.toArray
    val records = arraySchema.flatMap(record => {
      parseGroupFlat(record.asInstanceOf[Group], s"${record.name}_")
    })
    StructType(records)
  }

  def getSparkSchema: StructType = {
    sparkSchema
  }

  def getSparkFlatSchema: StructType = {
    sparkFlatSchema
  }

  lazy val getRecordSize: Int = copybook.getRecordSize

  def isRecordFixedSize: Boolean = copybook.isRecordFixedSize

  @throws(classOf[IllegalStateException])
  private def parseGroup(group: Group): StructField = {
    val fields = for (field <- group.children if !field.isFiller) yield {
      field match {
        case group: Group =>
          parseGroup(group)
        case s: Primitive =>
          val dataType: DataType = s.dataType match {
            case d: Decimal =>
              val computation = d.compact.getOrElse(-1)
              if (computation == 1) {
                FloatType
              } else if (computation == 2) {
                DoubleType
              } else {
                DecimalType(d.getEffectivePrecision, d.getEffectiveScale)
              }
            case _: AlphaNumeric => StringType
            case dt: Integral =>
              if (dt.precision > Constants.maxLongPrecision) {
                DecimalType(precision = dt.precision, scale = 0)
              } else if (dt.precision > Constants.maxIntegerPrecision) {
                LongType
              }
              else {
                IntegerType
              }
            case _ => throw new IllegalStateException("Unknown AST object")
          }
          if (s.isArray) {
            StructField(s.name, ArrayType(dataType), nullable = true)
          } else {
            StructField(s.name, dataType, nullable = true)
          }
      }

    }
    if (group.isArray) {
      StructField(group.name, ArrayType(StructType(fields.toArray)), nullable = true)
    } else {
      StructField(group.name, StructType(fields.toArray), nullable = true)
    }

  }

  @throws(classOf[IllegalStateException])
  private def parseGroupFlat(group: Group, structPath: String = ""): ArrayBuffer[StructField] = {
    val fields = new ArrayBuffer[StructField]()
    for (field <- group.children if !field.isFiller) {
      field match {
        case group: Group =>
          if (group.isArray) {
            for (i <- Range(1, group.arrayMaxSize + 1)) {
              val path = s"$structPath${group.name}_${i}_"
              fields ++= parseGroupFlat(group, path)
            }
          } else {
            val path = s"$structPath${group.name}_"
            fields ++= parseGroupFlat(group, path)
          }
        case s: Primitive =>
          val dataType: DataType = s.dataType match {
            case d: Decimal =>
              DecimalType(d.getEffectivePrecision, d.getEffectiveScale)
            case _: AlphaNumeric => StringType
            case dt: Integral =>
              if (dt.precision > Constants.maxIntegerPrecision) {
                LongType
              }
              else {
                IntegerType
              }
            case _ => throw new IllegalStateException("Unknown AST object")
          }
          val path = s"$structPath" //${group.name}_"
          if (s.isArray) {
            for (i <- Range(1, s.arrayMaxSize + 1)) {
              fields += StructField(s"$path{s.name}_$i", ArrayType(dataType), nullable = true)
            }
          } else {
            fields += StructField(s"$path${s.name}", dataType, nullable = true)
          }
      }
    }

    fields
  }

  private def getDefaultSegmentIdPrefix: String = {
    val timestampFormat = DateTimeFormatter.ofPattern("yyyyMMddHHmmss")
    val now = ZonedDateTime.now()
    timestampFormat.format(now)
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy