All Downloads are FREE. Search and download functionalities are using the official Maven repository.

za.co.absa.cobrix.spark.cobol.schema.CobolSchema.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 ABSA Group Limited
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package za.co.absa.cobrix.spark.cobol.schema

import org.apache.spark.sql.types._
import za.co.absa.cobrix.cobol.internal.Logging
import za.co.absa.cobrix.cobol.parser.Copybook
import za.co.absa.cobrix.cobol.parser.ast._
import za.co.absa.cobrix.cobol.parser.ast.datatype.{AlphaNumeric, COMP1, COMP2, Decimal, Integral}
import za.co.absa.cobrix.cobol.parser.common.Constants
import za.co.absa.cobrix.cobol.parser.encoding.RAW
import za.co.absa.cobrix.cobol.parser.policies.MetadataPolicy
import za.co.absa.cobrix.cobol.reader.policies.SchemaRetentionPolicy
import za.co.absa.cobrix.cobol.reader.policies.SchemaRetentionPolicy.SchemaRetentionPolicy
import za.co.absa.cobrix.cobol.reader.schema.{CobolSchema => CobolReaderSchema}
import za.co.absa.cobrix.spark.cobol.parameters.CobolParametersParser.getReaderProperties
import za.co.absa.cobrix.spark.cobol.parameters.MetadataFields.{MAX_ELEMENTS, MAX_LENGTH, MIN_ELEMENTS}
import za.co.absa.cobrix.spark.cobol.parameters.{CobolParametersParser, Parameters}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer


/**
  * This class provides a view on a COBOL schema from the perspective of Spark. When provided with a parsed copybook the class
  * provides the corresponding Spark schema and also other properties for the Spark data source.
  *
  * @param copybook                A parsed copybook.
  * @param schemaRetentionPolicy   pecifies a policy to transform the input schema. The default policy is to keep the schema exactly as it is in the copybook.
  * @param strictIntegralPrecision If true, Cobrix will not generate short/integer/long Spark data types, and always use decimal(n) with the exact precision that matches the copybook.
  * @param generateRecordId        If true, a record id field will be prepended to the beginning of the schema.
  * @param generateRecordBytes     If true, a record bytes field will be appended to the beginning of the schema.
  * @param inputFileNameField      If non-empty, a source file name will be prepended to the beginning of the schema.
  * @param generateSegIdFieldsCnt  A number of segment ID levels to generate
  * @param segmentIdProvidedPrefix A prefix for each segment id levels to make segment ids globally unique (by default the current timestamp will be used)
  * @param metadataPolicy          Specifies a policy to generate metadata fields.
  */
class CobolSchema(copybook: Copybook,
                  schemaRetentionPolicy: SchemaRetentionPolicy,
                  strictIntegralPrecision: Boolean = false,
                  inputFileNameField: String = "",
                  generateRecordId: Boolean = false,
                  generateRecordBytes: Boolean = false,
                  generateSegIdFieldsCnt: Int = 0,
                  segmentIdProvidedPrefix: String = "",
                  metadataPolicy: MetadataPolicy = MetadataPolicy.Basic)
  extends CobolReaderSchema(
    copybook, schemaRetentionPolicy, strictIntegralPrecision, inputFileNameField, generateRecordId, generateRecordBytes,
    generateSegIdFieldsCnt, segmentIdProvidedPrefix
    ) with Logging with Serializable {

  @throws(classOf[IllegalStateException])
  private[this] lazy val sparkSchema = createSparkSchema()

  @throws(classOf[IllegalStateException])
  private[this] lazy val sparkFlatSchema = {
    val arraySchema = copybook.ast.children.toArray
    val records = arraySchema.flatMap(record => {
      parseGroupFlat(record.asInstanceOf[Group], s"${record.name}_")
    })
    StructType(records)
  }

  def getSparkSchema: StructType = {
    sparkSchema
  }

  def getSparkFlatSchema: StructType = {
    sparkFlatSchema
  }

  @throws(classOf[IllegalStateException])
  private def createSparkSchema(): StructType = {
    val records = for (record <- copybook.getRootRecords) yield {
      val group = record.asInstanceOf[Group]
      val redefines = copybook.getAllSegmentRedefines
      parseGroup(group, redefines)
    }
    val expandRecords = if (schemaRetentionPolicy == SchemaRetentionPolicy.CollapseRoot || copybook.isFlatCopybook) {
      // Expand root group fields
      records.toArray.flatMap(group => group.dataType.asInstanceOf[StructType].fields)
    } else {
      records.toArray
    }

    val recordsWithSegmentFields = if (generateSegIdFieldsCnt > 0) {
      val newFields = for (level <- Range(0, generateSegIdFieldsCnt))
        yield {
          val maxPrefixLength = getMaximumSegmentIdLength(segmentIdProvidedPrefix)
          val segFieldMetadata = new MetadataBuilder()
          segFieldMetadata.putLong(MAX_LENGTH, maxPrefixLength.toLong)

          StructField(s"${Constants.segmentIdField}$level", StringType, nullable = true, metadata = segFieldMetadata.build())
        }

      newFields.toArray ++ expandRecords
    } else {
      expandRecords
    }

    val recordsWithFileName = if (inputFileNameField.nonEmpty) {
      StructField(inputFileNameField, StringType, nullable = true) +: recordsWithSegmentFields
    } else {
      recordsWithSegmentFields
    }

    val recordsWithRecordBytes = if (generateRecordBytes) {
      StructField(Constants.recordBytes, BinaryType, nullable = false) +: recordsWithFileName
    } else {
      recordsWithFileName
    }

    val recordsWithRecordId = if (generateRecordId) {
      StructField(Constants.fileIdField, IntegerType, nullable = false) +:
        StructField(Constants.recordIdField, LongType, nullable = false) +:
        StructField(Constants.recordByteLength, IntegerType, nullable = false) +: recordsWithRecordBytes
    } else {
      recordsWithRecordBytes
    }

    StructType(recordsWithRecordId)
  }

  private [cobrix] def getMaximumSegmentIdLength(segmentIdProvidedPrefix: String): Int = {
    val DATETIME_PREFIX_LENGTH = 15
    val SEGMENT_ID_MAX_GENERATED_LENGTH = 50

    val prefixLength = if (segmentIdProvidedPrefix.isEmpty) DATETIME_PREFIX_LENGTH else segmentIdProvidedPrefix.length

    prefixLength + SEGMENT_ID_MAX_GENERATED_LENGTH
  }

  @throws(classOf[IllegalStateException])
  private def parseGroup(group: Group, segmentRedefines: List[Group]): StructField = {
    val fields = group.children.flatMap(field => {
      if (field.isFiller) {
        // Skipping fillers
        Nil
      } else {
        field match {
          case group: Group =>
            if (group.parentSegment.isEmpty) {
              parseGroup(group, segmentRedefines) :: Nil
            } else {
              // Skipping child segments on this level
              Nil
            }
          case p: Primitive =>
            parsePrimitive(p) :: Nil
        }
      }
    })

    val fieldsWithChildrenSegments = fields ++ getChildSegments(group, segmentRedefines)
    val metadata = new MetadataBuilder()

    if (metadataPolicy == MetadataPolicy.Extended)
      addExtendedMetadata(metadata, group)

    if (group.isArray) {
      if (metadataPolicy != MetadataPolicy.NoMetadata)
        addArrayMetadata(metadata, group)
      StructField(group.name, ArrayType(StructType(fieldsWithChildrenSegments.toArray)), nullable = true, metadata.build())
    } else {
      StructField(group.name, StructType(fieldsWithChildrenSegments.toArray), nullable = true, metadata.build())
    }
  }

  @throws(classOf[IllegalStateException])
  private def parsePrimitive(p: Primitive): StructField = {
    val metadata = new MetadataBuilder()
    val dataType: DataType = p.dataType match {
      case d: Decimal      =>
        d.compact match {
          case Some(COMP1()) => FloatType
          case Some(COMP2()) => DoubleType
          case _             => DecimalType(d.getEffectivePrecision, d.getEffectiveScale)
        }
      case a: AlphaNumeric =>
        if (metadataPolicy != MetadataPolicy.NoMetadata)
          addAlphaNumericMetadata(metadata, a)
        a.enc match {
          case Some(RAW) => BinaryType
          case _         => StringType
        }
      case dt: Integral  if strictIntegralPrecision  =>
        DecimalType(precision = dt.precision, scale = 0)
      case dt: Integral  =>
        if (dt.precision > Constants.maxLongPrecision) {
          DecimalType(precision = dt.precision, scale = 0)
        } else if (dt.precision > Constants.maxIntegerPrecision) {
          LongType
        }
        else {
          IntegerType
        }
      case _               => throw new IllegalStateException("Unknown AST object")
    }

    if (metadataPolicy == MetadataPolicy.Extended)
      addExtendedMetadata(metadata, p)

    if (p.isArray) {
      if (metadataPolicy != MetadataPolicy.NoMetadata)
        addArrayMetadata(metadata, p)
      StructField(p.name, ArrayType(dataType), nullable = true, metadata.build())
    } else {
      StructField(p.name, dataType, nullable = true, metadata.build())
    }
  }

  private def addArrayMetadata(metadataBuilder: MetadataBuilder, st: Statement): MetadataBuilder = {
    metadataBuilder.putLong(MIN_ELEMENTS, st.arrayMinSize)
    metadataBuilder.putLong(MAX_ELEMENTS, st.arrayMaxSize)
  }

  private def addAlphaNumericMetadata(metadataBuilder: MetadataBuilder, a: AlphaNumeric): MetadataBuilder = {
    metadataBuilder.putLong(MAX_LENGTH, a.length)
  }

  private def addExtendedMetadata(metadataBuilder: MetadataBuilder, s: Statement): MetadataBuilder = {
    metadataBuilder.putLong("level", s.level)
    if (s.originalName.nonEmpty && s.originalName != s.name)
      metadataBuilder.putString("originalName", s.originalName)
    s.redefines.foreach(redefines => metadataBuilder.putString("redefines", redefines))
    s.dependingOn.foreach(dependingOn => metadataBuilder.putString("depending_on", dependingOn))
    metadataBuilder.putLong("offset", s.binaryProperties.offset)
    metadataBuilder.putLong("byte_size", s.binaryProperties.dataSize)

    s match {
      case p: Primitive => addExtendedPrimitiveMetadata(metadataBuilder, p)
      case g: Group     => addExtendedGroupMetadata(metadataBuilder, g)
    }

    metadataBuilder
  }

  private def addExtendedPrimitiveMetadata(metadataBuilder: MetadataBuilder, p: Primitive): MetadataBuilder = {
    metadataBuilder.putString("pic", p.dataType.originalPic.getOrElse(p.dataType.pic))
    p.dataType match {
      case a: Integral =>
        a.compact.foreach(usage => metadataBuilder.putString("usage", usage.toString))
        metadataBuilder.putLong("precision", a.precision)
        metadataBuilder.putBoolean("signed", a.signPosition.nonEmpty)
        metadataBuilder.putBoolean("sign_separate", a.isSignSeparate)
      case a: Decimal  =>
        a.compact.foreach(usage => metadataBuilder.putString("usage", usage.toString))
        metadataBuilder.putLong("precision", a.precision)
        metadataBuilder.putLong("scale", a.scale)
        if (a.scaleFactor != 0)
          metadataBuilder.putLong("scale_factor", a.scaleFactor)
        metadataBuilder.putBoolean("signed", a.signPosition.nonEmpty)
        metadataBuilder.putBoolean("sign_separate", a.isSignSeparate)
        metadataBuilder.putBoolean("implied_decimal", !a.explicitDecimal)
      case _           =>
    }
    metadataBuilder
  }

  private def addExtendedGroupMetadata(metadataBuilder: MetadataBuilder, g: Group): MetadataBuilder = {
    g.groupUsage.foreach(usage => metadataBuilder.putString("usage", usage.toString))
    metadataBuilder
  }

  private def getChildSegments(group: Group, segmentRedefines: List[Group]): ArrayBuffer[StructField] = {
    val childSegments = new mutable.ArrayBuffer[StructField]()

    segmentRedefines.foreach(segment => {
      segment.parentSegment.foreach(parent => {
        if (parent.name.equalsIgnoreCase(group.name)) {
          val child = parseGroup(segment, segmentRedefines)
          val fields = child.dataType.asInstanceOf[StructType].fields
          childSegments += StructField(segment.name, ArrayType(StructType(fields)), nullable = true)
        }
      })
    })
    childSegments
  }

  @throws(classOf[IllegalStateException])
  private def parseGroupFlat(group: Group, structPath: String = ""): ArrayBuffer[StructField] = {
    val fields = new ArrayBuffer[StructField]()
    for (field <- group.children if !field.isFiller) {
      field match {
        case group: Group =>
          if (group.isArray) {
            for (i <- Range(1, group.arrayMaxSize + 1)) {
              val path = s"$structPath${group.name}_${i}_"
              fields ++= parseGroupFlat(group, path)
            }
          } else {
            val path = s"$structPath${group.name}_"
            fields ++= parseGroupFlat(group, path)
          }
        case s: Primitive =>
          val dataType: DataType = s.dataType match {
            case d: Decimal      =>
              DecimalType(d.getEffectivePrecision, d.getEffectiveScale)
            case a: AlphaNumeric =>
              a.enc match {
                case Some(RAW) => BinaryType
                case _         => StringType
              }
            case dt: Integral    =>
              if (dt.precision > Constants.maxIntegerPrecision) {
                LongType
              }
              else {
                IntegerType
              }
            case _               => throw new IllegalStateException("Unknown AST object")
          }
          val path = s"$structPath" //${group.name}_"
          if (s.isArray) {
            for (i <- Range(1, s.arrayMaxSize + 1)) {
              fields += StructField(s"$path{s.name}_$i", ArrayType(dataType), nullable = true)
            }
          } else {
            fields += StructField(s"$path${s.name}", dataType, nullable = true)
          }
      }
    }

    fields
  }
}

object CobolSchema {
  def fromBaseReader(schema: CobolReaderSchema): CobolSchema = {
    new CobolSchema(
      schema.copybook,
      schema.policy,
      schema.strictIntegralPrecision,
      schema.inputFileNameField,
      schema.generateRecordId,
      schema.generateRecordBytes,
      schema.generateSegIdFieldsCnt,
      schema.segmentIdPrefix,
      schema.metadataPolicy
      )
  }

  def fromSparkOptions(copyBookContents: Seq[String], sparkReaderOptions: Map[String, String]): CobolSchema = {
    val lowercaseOptions = sparkReaderOptions.map { case (k, v) => (k.toLowerCase, v) }
    val cobolParameters = CobolParametersParser.parse(new Parameters(lowercaseOptions))
    val readerParameters = getReaderProperties(cobolParameters, None)

    CobolSchema.fromBaseReader(CobolReaderSchema.fromReaderParameters(copyBookContents, readerParameters))
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy