All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.mariadb.r2dbc.util.SslConfig Maven / Gradle / Ivy

// SPDX-License-Identifier: Apache-2.0
// Copyright (c) 2020-2024 MariaDB Corporation Ab

package org.mariadb.r2dbc.util;

import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.r2dbc.spi.R2dbcTransientResourceException;
import java.io.*;
import java.util.List;
import java.util.function.UnaryOperator;
import javax.net.ssl.SSLException;
import org.mariadb.r2dbc.SslMode;

public class SslConfig {

  public static final SslConfig DISABLE_INSTANCE = new SslConfig(SslMode.DISABLE);

  private final SslMode sslMode;
  private String serverSslCert;
  private String clientSslCert;
  private String clientSslKey;
  private CharSequence clientSslPassword;
  private List tlsProtocol;
  private SslContextBuilder sslContextBuilder;
  private UnaryOperator sslContextBuilderCustomizer;

  private boolean sslTunnelDisableHostVerification;

  public SslConfig(
      SslMode sslMode,
      String serverSslCert,
      String clientSslCert,
      String clientSslKey,
      CharSequence clientSslPassword,
      List tlsProtocol,
      boolean sslTunnelDisableHostVerification,
      UnaryOperator sslContextBuilderCustomizer)
      throws R2dbcTransientResourceException {
    this.sslMode = sslMode;
    this.serverSslCert = serverSslCert;
    this.clientSslCert = clientSslCert;
    this.tlsProtocol = tlsProtocol;
    this.clientSslKey = clientSslKey;
    this.clientSslPassword = clientSslPassword;
    this.sslTunnelDisableHostVerification = sslTunnelDisableHostVerification;
    this.sslContextBuilderCustomizer = sslContextBuilderCustomizer;
    if (sslMode != SslMode.DISABLE) {
      this.sslContextBuilder = getSslContextBuilder();
    }
  }

  public SslConfig(SslMode sslMode) {
    this.sslMode = sslMode;
  }

  public SslMode getSslMode() {
    return sslMode;
  }

  public boolean sslTunnelDisableHostVerification() {
    return this.sslTunnelDisableHostVerification;
  }

  private SslContextBuilder getSslContextBuilder() throws R2dbcTransientResourceException {
    final SslContextBuilder sslCtxBuilder = SslContextBuilder.forClient();

    if (sslMode == SslMode.TRUST || sslMode == SslMode.TUNNEL) {
      sslCtxBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
    } else {

      if (serverSslCert != null) {
        InputStream inStream = null;
        try {
          inStream = loadCert(serverSslCert);
          sslCtxBuilder.trustManager(inStream);
        } catch (FileNotFoundException fileNotFoundEx) {
          throw new R2dbcTransientResourceException(
              "Failed to find serverSslCert file. serverSslCert=" + serverSslCert,
              "08000",
              fileNotFoundEx);
        } finally {
          if (inStream != null) {
            try {
              inStream.close();
            } catch (IOException e) {
            }
          }
        }
      } else {
        throw new R2dbcTransientResourceException(
            "Server certificate needed (option `serverSslCert`) for ssl mode " + sslMode, "08000");
      }
    }
    if (clientSslCert != null && clientSslKey != null) {
      InputStream certificatesStream = null;
      try {
        certificatesStream = loadCert(clientSslCert);
      } catch (FileNotFoundException fileNotFoundEx) {
        if (certificatesStream != null) {
          try {
            certificatesStream.close();
          } catch (IOException e) {
          }
        }
        throw new R2dbcTransientResourceException(
            "Failed to find clientSslCert file. clientSslCert=" + clientSslCert,
            "08000",
            fileNotFoundEx);
      }

      InputStream privateKeyStream = null;
      try {
        privateKeyStream = loadCert(clientSslKey);
        sslCtxBuilder.keyManager(
            certificatesStream,
            privateKeyStream,
            clientSslPassword == null ? null : clientSslPassword.toString());
      } catch (FileNotFoundException fileNotFoundEx) {
        throw new R2dbcTransientResourceException(
            "Failed to find clientSslKey file. clientSslKey=" + clientSslKey,
            "08000",
            fileNotFoundEx);
      } finally {
        if (privateKeyStream != null) {
          try {
            privateKeyStream.close();
          } catch (IOException e) {
          }
        }
      }
    }

    if (tlsProtocol != null) {
      sslCtxBuilder.protocols(tlsProtocol.toArray(new String[tlsProtocol.size()]));
    }

    if (sslContextBuilderCustomizer == null) {
      return sslCtxBuilder;
    }
    return sslContextBuilderCustomizer.apply(sslCtxBuilder);
  }

  public SslContext getSslContext() throws R2dbcTransientResourceException, SSLException {
    return sslContextBuilder.build();
  }

  private InputStream loadCert(String path) throws FileNotFoundException {
    InputStream inStream;
    // generate a keyStore from the provided cert
    if (path.startsWith("-----BEGIN CERTIFICATE-----")) {
      inStream = new ByteArrayInputStream(path.getBytes());
    } else if (path.startsWith("classpath:")) {
      // Load it from a classpath relative file
      String classpathFile = path.substring("classpath:".length());
      inStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(classpathFile);
    } else {
      inStream = new FileInputStream(path);
    }
    return inStream;
  }

  @Override
  public String toString() {
    StringBuilder sb = new StringBuilder();
    boolean first = true;
    if (sslMode != SslMode.DISABLE) {
      sb.append("sslMode=").append(sslMode.value);
      first = false;
    }
    if (serverSslCert != null) {
      if (!first) sb.append("&");
      sb.append("serverSslCert=").append(serverSslCert);
      first = false;
    }
    if (clientSslCert != null) {
      if (!first) sb.append("&");
      sb.append("clientSslCert=").append(clientSslCert);
      first = false;
    }
    if (tlsProtocol != null) {
      if (!first) sb.append("&");
      sb.append("tlsProtocol=").append(String.join(",", tlsProtocol));
      first = false;
    }
    if (clientSslKey != null) {
      if (!first) sb.append("&");
      sb.append("clientSslKey=").append(clientSslKey);
      first = false;
    }
    if (clientSslPassword != null) {
      if (!first) sb.append("&");
      sb.append("clientSslPassword=***");
      first = false;
    }

    if (sslTunnelDisableHostVerification) {
      if (!first) sb.append("&");
      sb.append("sslTunnelDisableHostVerification=true");
      first = false;
    }

    return sb.toString();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy