All Downloads are FREE. Search and download functionalities are using the official Maven repository.

jvmTest.okhttp3.internal.connection.RetryConnectionTest.kt Maven / Gradle / Ivy

There is a newer version: 5.0.0-alpha.14
Show newest version
/*
 * Copyright (C) 2015 Square, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package okhttp3.internal.connection

import java.io.IOException
import java.security.cert.CertificateException
import javax.net.ssl.SSLHandshakeException
import javax.net.ssl.SSLSocket
import okhttp3.ConnectionSpec
import okhttp3.OkHttpClientTestRule
import okhttp3.TestValueFactory
import okhttp3.TlsVersion
import okhttp3.tls.internal.TlsUtil.localhost
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.RegisterExtension

class RetryConnectionTest {
  private val factory = TestValueFactory()
  private val handshakeCertificates = localhost()
  private val retryableException = SSLHandshakeException("Simulated handshake exception")

  @RegisterExtension
  val clientTestRule = OkHttpClientTestRule()

  private var client = clientTestRule.newClient()

  @AfterEach internal fun tearDown() {
    factory.close()
  }

  @Test fun nonRetryableIOException() {
    val exception = IOException("Non-handshake exception")
    assertThat(retryTlsHandshake(exception)).isFalse()
  }

  @Test fun nonRetryableSSLHandshakeException() {
    val exception = SSLHandshakeException("Certificate handshake exception").apply {
      initCause(CertificateException())
    }
    assertThat(retryTlsHandshake(exception)).isFalse()
  }

  @Test fun retryableSSLHandshakeException() {
    assertThat(retryTlsHandshake(retryableException)).isTrue()
  }

  @Test fun someFallbacksSupported() {
    val sslV3 = ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS)
      .tlsVersions(TlsVersion.SSL_3_0)
      .build()
    val routePlanner = factory.newRoutePlanner(client)
    val route = factory.newRoute()
    val connectionSpecs = listOf(ConnectionSpec.MODERN_TLS, ConnectionSpec.COMPATIBLE_TLS, sslV3)
    val enabledSocketTlsVersions = arrayOf(
      TlsVersion.TLS_1_2,
      TlsVersion.TLS_1_1,
      TlsVersion.TLS_1_0
    )
    var socket = createSocketWithEnabledProtocols(*enabledSocketTlsVersions)

    // MODERN_TLS is used here.
    val attempt0 = routePlanner.planConnectToRoute(route)
      .planWithCurrentOrInitialConnectionSpec(connectionSpecs, socket)
    assertThat(attempt0.isTlsFallback).isFalse()
    connectionSpecs[attempt0.connectionSpecIndex].apply(socket, attempt0.isTlsFallback)
    assertEnabledProtocols(socket, TlsVersion.TLS_1_2)
    val attempt1 = attempt0.nextConnectionSpec(connectionSpecs, socket)
    assertThat(attempt1).isNotNull()
    assertThat(attempt1!!.isTlsFallback).isTrue()
    socket.close()

    // COMPATIBLE_TLS is used here.
    socket = createSocketWithEnabledProtocols(*enabledSocketTlsVersions)
    connectionSpecs[attempt1.connectionSpecIndex].apply(socket, attempt1.isTlsFallback)
    assertEnabledProtocols(socket, TlsVersion.TLS_1_2, TlsVersion.TLS_1_1, TlsVersion.TLS_1_0)
    val attempt2 = attempt1.nextConnectionSpec(connectionSpecs, socket)
    assertThat(attempt2).isNull()
    socket.close()

    // sslV3 is not used because SSLv3 is not enabled on the socket.
  }

  private fun createSocketWithEnabledProtocols(vararg tlsVersions: TlsVersion): SSLSocket {
    return (handshakeCertificates.sslSocketFactory().createSocket() as SSLSocket).apply {
      enabledProtocols = javaNames(*tlsVersions)
    }
  }

  private fun assertEnabledProtocols(socket: SSLSocket, vararg required: TlsVersion) {
    assertThat(socket.enabledProtocols)
      .containsExactlyInAnyOrder(*javaNames(*required))
  }

  private fun javaNames(vararg tlsVersions: TlsVersion) =
    tlsVersions.map { it.javaName }.toTypedArray()
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy