All Downloads are FREE. Search and download functionalities are using the official Maven repository.

.pgjdbc-ng.tools.udt-gen.0.8.source-code.UDTGenerator.kt Maven / Gradle / Ivy

There is a newer version: 0.8.9
Show newest version
package com.impossibl.postgres.tools

import com.impossibl.postgres.api.jdbc.PGAnyType
import com.impossibl.postgres.api.jdbc.PGConnection
import com.impossibl.postgres.api.jdbc.PGType
import com.impossibl.postgres.types.QualifiedName
import com.squareup.javapoet.*
import com.xenomachina.argparser.*
import java.io.File
import java.io.InputStream
import java.io.Reader
import java.sql.*
import java.util.*
import javax.lang.model.element.Modifier


class UDTGenerator(
   private val connection: PGConnection,
   private val targetPackage: String,
   typeNames: List
) {

  companion object {

    private class Arguments(parser: ArgParser) {

      val types by parser.positionalList("TYPES", "type names")
      val url by parser.storing("--url", help = "Connection URL").default(null as String?)
      val host by parser.storing("-H", "--host", help = "Server host name").default(null as String?)
      val port by parser.storing("-T", "--port", help = "Server port number") { toInt() }.default(null as Int?)
      val db by parser.storing("-D", "--database", help = "Database name").default(null as String?)
      val user by parser.storing("-U", "--user", help = "Database username").default(null as String?)
      val password by parser.storing("-P", "--password", help = "Database password").default(null as String?)
      val outDirectory by parser.storing("-o", "--out", help = "Output directory") { File(this) }
         .default(File("out"))
         .addValidator {
           if (!value.exists())
             value.mkdirs()
           else if (!value.isDirectory)
             throw InvalidArgumentException("Out must be a directory")
         }
      val targetPackage by parser.storing("-p", "--pkg", help = "Target Java package")

    }

    @JvmStatic
    fun main(args: Array) {

      try {

        ArgParser(args).parseInto(::Arguments).run {

          val url =
             if (url != null) {
               url!!
             } else {

               if (db == null) {
                 throw MissingValueException("missing DB or URL")
               }

               val url =
                  "jdbc:pgsql:${host?.let { host -> "//$host${port?.let { ":$it" } ?: ""}/" } ?: ""}${db ?: ""}"

               url
             }

          val props = Properties()
          user?.let { props.setProperty("user", it) }
          password?.let { props.setProperty("password", it) }

          UDTGenerator(url, props, targetPackage, types)
             .generate(outDirectory)
        }
      } catch (x: SystemExitException) {
        x.printAndExit("UDT Generator")
      }
    }

  }

  constructor(connectionUrl: String, connectionProperties: Properties? = null, targetPackage: String, types: List)
     : this(
     DriverManager.getConnection(connectionUrl, connectionProperties ?: Properties()),
     targetPackage,
     types
  )

  constructor(connection: Connection, targetPackage: String, types: List)
     : this(
     connection.unwrap(PGConnection::class.java)
        ?: throw IllegalArgumentException("Requires a compatible PGConnection"),
     targetPackage,
     types
  )

  private val typesInfo = getTypesInfo(connection, typeNames)

  fun generate(): List {

    return typesInfo.mapNotNull { (sqlTypeName, typeCategory) ->

      when (typeCategory) {

        TypeCategory.Composite -> generatePOJO(sqlTypeName)

        TypeCategory.Enum -> generateEnum(sqlTypeName)

      }?.let { spec ->

        JavaFile.builder(targetPackage, spec)
           .skipJavaLangImports(true)
           .build()

      }

    }

  }

  fun generate(outDirectory: File) {
    generate().forEach { it.writeTo(outDirectory) }
  }

  private fun generateEnum(sqlTypeName: QualifiedName): TypeSpec? {

    val enums = getTypeEnumerations(connection, sqlTypeName)
    if (enums.isEmpty()) {
      System.out.println("Type `$sqlTypeName` contains no attributes")
      return null
    }

    val enumName = ClassName.get(targetPackage, sqlTypeName.localName.javaTypeName())

    val enumBldr = TypeSpec.enumBuilder(enumName)
       .addModifiers(Modifier.PUBLIC)
       .addField(
          FieldSpec.builder(String::class.java, "label")
             .addModifiers(Modifier.PRIVATE)
             .build()
       )
       .addMethod(
          MethodSpec.constructorBuilder()
             .addParameter(String::class.java, "label")
             .addStatement("this.label = label")
             .build()
       )
       .addMethod(
          MethodSpec.methodBuilder("getLabel")
             .addModifiers(Modifier.PUBLIC)
             .returns(String::class.java)
             .addStatement("return label")
             .build()
       )
       .addMethod(
          MethodSpec.methodBuilder("valueOfLabel")
             .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
             .returns(enumName)
             .addParameter(String::class.java, "label")
             .addCode("for (\$T value : values()) {\$>\n", enumName)
             .addCode("if (value.label.equals(label)) return value;\$<\n")
             .addCode("}\n")
             .addCode("throw new \$T(\$S);", java.lang.IllegalArgumentException::class.java, "Invalid label")
             .build()
       )

    enums.forEach {
      enumBldr.addEnumConstant(it.toUpperCase(), TypeSpec.anonymousClassBuilder("\$S", it).build())
    }

    return enumBldr.build()
  }

  private fun generatePOJO(sqlTypeName: QualifiedName): TypeSpec? {

    val attributes = getTypeAttributes(connection, sqlTypeName)
    if (attributes.isEmpty()) {
      System.out.println("Type `$sqlTypeName` contains no attributes")
      return null
    }

    val className = ClassName.get(targetPackage, sqlTypeName.localName.javaTypeName())

    val classBldr = TypeSpec.classBuilder(className)
       .addModifiers(Modifier.PUBLIC)
       .addSuperinterface(SQLData::class.java)
       .addField(
          FieldSpec.builder(String::class.java, "TYPE_NAME")
             .addModifiers(Modifier.PRIVATE, Modifier.STATIC, Modifier.FINAL)
             .initializer("\$S", sqlTypeName.toString(false))
             .build()
       )
       .addMethod(
          MethodSpec.methodBuilder("getSQLTypeName")
             .addAnnotation(Override::class.java)
             .addModifiers(Modifier.PUBLIC)
             .returns(String::class.java)
             .addException(SQLException::class.java)
             .addStatement("return TYPE_NAME")
             .build()
       )

    val readSQLBldr = MethodSpec.methodBuilder("readSQL")
       .addAnnotation(Override::class.java)
       .addException(SQLException::class.java)
       .addModifiers(Modifier.PUBLIC)
       .addParameter(SQLInput::class.java, "in")
       .addParameter(String::class.java, "typeName")
       .addException(SQLException::class.java)

    val writeSQLBldr = MethodSpec.methodBuilder("writeSQL")
       .addAnnotation(Override::class.java)
       .addException(SQLException::class.java)
       .addModifiers(Modifier.PUBLIC)
       .addParameter(SQLOutput::class.java, "out")
       .addException(SQLException::class.java)


    for (attr in attributes) {

      val attrPropName = attr.name.javaPropertyName()
      val attrSqlType = connection.resolveType(attr.typeName.toString())
      val attrTypeName = resolveTypeName(attr.typeName, attrSqlType)

      classBldr.addField(
         FieldSpec.builder(attrTypeName.box(), attrPropName)
            .addModifiers(Modifier.PRIVATE)
            .build()
      )

      classBldr.addMethod(
         MethodSpec.methodBuilder("get${attrPropName.capitalize()}")
            .addModifiers(Modifier.PUBLIC)
            .returns(attrTypeName)
            .addStatement("return \$L", attrPropName)
            .build()
      )

      classBldr.addMethod(
         MethodSpec.methodBuilder("set${attrPropName.capitalize()}")
            .addModifiers(Modifier.PUBLIC)
            .addParameter(attrTypeName, attrPropName)
            .addStatement("this.\$1L = \$1L", attrPropName)
            .build()
      )

      readSQLBldr.addCode(
         when {
           attrTypeName is ArrayTypeName ->
             CodeBlock.of("this.\$L = in.readObject(\$T.class);\n", attrPropName, attrTypeName)

           attrSqlType.javaType.readerTypeName == "Object" ->
             CodeBlock.of("this.\$L = in.readObject(\$T.class);\n", attrPropName, attrTypeName)

           typesInfo[attr.typeName] == TypeCategory.Enum ->
             CodeBlock.of("this.\$L = \$T.valueOfLabel(in.readString());\n", attrPropName, attrTypeName)

           else ->
             CodeBlock.of("this.\$L = in.read\$L();\n", attrPropName, attrSqlType.javaType.readerTypeName)
         }
      )

      writeSQLBldr.addCode(
         when {
           attrTypeName is ArrayTypeName ->
             CodeBlock.of("out.writeObject(this.\$L, null);\n", attrPropName)

           typesInfo[attr.typeName] == TypeCategory.Composite ->
             CodeBlock.of("out.writeObject(this.\$L, \$T.\$L);\n", attrPropName, PGType::class.java, "RECORD")

           typesInfo[attr.typeName] == TypeCategory.Enum ->
             CodeBlock.of("out.writeString(this.\$L.getLabel());\n", attrPropName)

           attrSqlType.javaType.writerTypeName == "Object" ->
             CodeBlock.of("out.writeObject(this.\$L, null);\n")

           else ->
             CodeBlock.of("out.write\$L(this.\$L);\n", attrSqlType.javaType.readerTypeName, attrPropName)
         }
      )

    }

    classBldr.addMethod(readSQLBldr.build())
    classBldr.addMethod(writeSQLBldr.build())

    return classBldr.build()
  }


  private fun resolveTypeName(srcTypeName: QualifiedName, sqlType: PGAnyType): TypeName =
     if (typesInfo.contains(srcTypeName)) {
       ClassName.get(targetPackage, srcTypeName.localName.javaTypeName())
     }
     else if (java.sql.Array::class.java.isAssignableFrom(sqlType.javaType)) {
       val elemSqlTypeName = getArrayElementType(connection, srcTypeName)
          ?: throw IllegalStateException("Cannot determine array element type: $srcTypeName")
       val elemSqlType = connection.resolveType(elemSqlTypeName.toString(false))
       ArrayTypeName.of(resolveTypeName(elemSqlTypeName, elemSqlType))
     }
     else {
       if (sqlType.javaType.isPrimitive)
         TypeName.get(sqlType.javaType).box()
       else
         ClassName.get(sqlType.javaType)
     }

}


private val  Class.writerTypeName: String get() = this.readerTypeName

private val  Class.readerTypeName: String
  get() =
    when (this) {
      Array::class.java -> "Bytes"
      Integer::class.java -> "Int"
      Reader::class.java -> "CharacterStream"
      InputStream::class.java -> "BinaryStream"
      Struct::class.java -> "Object"
      else ->
        if (`package`.name == "java.lang" || `package`.name == "java.sql") {
          simpleName
        } else {
          "Object"
        }
    }


enum class TypeCategory {
  Composite,
  Enum,
}


private fun getTypesInfo(connection: Connection, typeNames: List): Map {
  connection.prepareStatement(
     """
        select
          nspname, typname, typcategory
        from pg_type t
         left join pg_namespace n on (n.oid = typnamespace)
        where t.oid = ?::text::regtype;
    """.trimMargin()
  ).use { stmt ->
    return typeNames.mapNotNull { typeName ->
      stmt.setString(1, typeName)
      stmt.executeQuery().use { rs ->
        if (rs.next()) {
          val schemaName = rs.getString(1)
          val localName = rs.getString(2)
          val category = when (rs.getString(3)) {
            "C" -> TypeCategory.Composite
            "E" -> TypeCategory.Enum
            else -> throw InvalidArgumentException("Type category not supported for $typeName")
          }
          QualifiedName(schemaName, localName) to category
        }
        else
          null
      }
    }.toMap()
  }
}

private fun getArrayElementType(connection: Connection, typeName: QualifiedName): QualifiedName? {
  connection.prepareStatement(
     """
       SELECT
        n.nspname, e.typname
       FROM pg_type a
        LEFT JOIN pg_type e ON (a.typelem = e.oid)
        LEFT JOIN pg_namespace n ON (e.typnamespace = n.oid)
       WHERE a.oid = ?::text::regtype
     """.trimIndent()
  ).use { stmt ->
    stmt.setString(1, typeName.toString(false))
    stmt.executeQuery().use { rs ->
      if (!rs.next()) return null
      return QualifiedName(rs.getString(1), rs.getString(2))
    }
  }
}

private fun getTypeEnumerations(connection: Connection, typeName: QualifiedName): List {

  connection.prepareStatement(
     "SELECT enumlabel FROM pg_enum WHERE enumtypid = ?::text::regtype ORDER BY enumsortorder"
  ).use { stmt ->
    stmt.setString(1, typeName.toString(false))
    stmt.executeQuery().use { rs ->
      val enums = mutableListOf()
      while (rs.next()) {
        enums.add(rs.getString(1))
      }
      return enums
    }
  }
}

private data class TypeAttribute(
   val name: String,
   val typeName: QualifiedName,
   val number: Int,
   val nullable: Boolean,
   val isStruct: Boolean
)

private fun getTypeAttributes(connection: Connection, typeName: QualifiedName): List {

  val sql =
     """
       SELECT
          attname as name, n.nspname as type_namespace, typname as type_name, attnum as number,
          not attnotnull as nullable, case when typtype = 'c' then true else false end as is_struct
       from pg_catalog.pg_attribute a
        LEFT JOIN pg_catalog.pg_type t ON (atttypid = t.oid)
        LEFT JOIN pg_namespace n ON (t.typnamespace = n.oid)
       where
        attrelid = ?::text::regclass and attnum > 0 and not attisdropped
     """.trimIndent()

  val attrs = mutableListOf()

  connection.prepareStatement(sql).use { stmt ->
    stmt.setString(1, typeName.toString())
    stmt.executeQuery().use { rs ->
      while (rs.next()) {

        val attr = TypeAttribute(
           rs.getString("name"),
           QualifiedName(rs.getString("type_namespace"), rs.getString("type_name")),
           rs.getInt("number"),
           rs.getBoolean("nullable"),
           rs.getBoolean("is_struct")
        )

        attrs.add(attr)
      }
    }
  }

  return attrs
}

private fun String.javaTypeName(): String {
  return this.split("""[-_.]""".toRegex()).joinToString("") { it.toLowerCase().capitalize() }
}

private fun String.javaPropertyName(): String {
  return javaTypeName().decapitalize()
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy