All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.otavia.sql.DriverManager.scala Maven / Gradle / Ivy

/*
 * Copyright 2022 Yan Kun 
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package cc.otavia.sql

import cc.otavia.common.Report
import cc.otavia.sql.spi.ADBCServiceProvider

import java.util
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
import java.util.{ServiceConfigurationError, ServiceLoader}
import scala.collection.mutable
import scala.jdk.CollectionConverters.*
import scala.language.unsafeNulls

object DriverManager {

    private val DEFAULT_DRIVER =
        Map("mysql" -> "cc.otavia.mysql.MySQLDriver", "postgres" -> "cc.otavia.postgres.PostgresDriver")

    private val drivers: ConcurrentHashMap[String, DriverFactory] = new ConcurrentHashMap[String, DriverFactory]()

    def defaultDriver(url: String): DriverFactory =
        if (url.trim.toLowerCase.startsWith("jdbc:mysql:")) getDriverFactory(DEFAULT_DRIVER("mysql"))
        else if (url.trim.toLowerCase.startsWith("jdbc:postgresql:"))
            getDriverFactory(DEFAULT_DRIVER("postgres"))
        else throw new IllegalArgumentException(s"Schema not supported yet: ${url}")

    def getDriverFactory(driverName: String): DriverFactory = {
        if (drivers.isEmpty) findServiceProviders()
        if (drivers.containsKey(driverName))
            drivers.get(driverName)
        else
            throw new IllegalArgumentException(
              s"Can't find ADBC driver for name ${driverName}, current supported drivers are ${drivers.keys().asScala.mkString("[", ",", "]")}"
            )
    }

    private def findServiceProviders(): Unit = {
        val classLoaderOfLoggerFactory = getClass.getClassLoader
        val serviceLoader              = getServiceLoader(classLoaderOfLoggerFactory)
        val providerList               = new mutable.ArrayBuffer[ADBCServiceProvider]()
        val iterator                   = serviceLoader.iterator()
        while (iterator.hasNext) safelyInstantiate(providerList, iterator)
        for (provider <- providerList) {
            val factory = provider.getDriverFactory
            drivers.put(factory.driverClassName, factory)
        }
        if (drivers.isEmpty) Report.report("No ADBC driver found!")
    }

    private def getServiceLoader(classLoaderOfLoggerFactory: ClassLoader): ServiceLoader[ADBCServiceProvider] = {
        ServiceLoader.load(classOf[ADBCServiceProvider], classLoaderOfLoggerFactory)
    }

    private def safelyInstantiate(
        providerList: mutable.ArrayBuffer[ADBCServiceProvider],
        iterator: util.Iterator[ADBCServiceProvider]
    ): Unit = try {
        val provider = iterator.next()
        providerList.addOne(provider)
    } catch {
        case e: ServiceConfigurationError =>
            Report.report(s"An ADBC service provider failed to instantiate:\n${e.getMessage}", "ADBC")
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy