okhttp3.dnsoverhttps.DnsOverHttps.kt Maven / Gradle / Ivy
/*
* Copyright (C) 2018 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.dnsoverhttps
import java.io.IOException
import java.net.HttpURLConnection
import java.net.InetAddress
import java.net.UnknownHostException
import java.util.ArrayList
import java.util.concurrent.CountDownLatch
import okhttp3.CacheControl
import okhttp3.Call
import okhttp3.Callback
import okhttp3.Dns
import okhttp3.HttpUrl
import okhttp3.MediaType
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.OkHttpClient
import okhttp3.Protocol
import okhttp3.Request
import okhttp3.RequestBody.Companion.toRequestBody
import okhttp3.Response
import okhttp3.internal.platform.Platform
import okhttp3.internal.publicsuffix.PublicSuffixDatabase
/**
* [DNS over HTTPS implementation][doh_spec].
*
* > A DNS API client encodes a single DNS query into an HTTP request
* > using either the HTTP GET or POST method and the other requirements
* > of this section. The DNS API server defines the URI used by the
* > request through the use of a URI Template.
*
* ### Warning: This is a non-final API.
*
* As of OkHttp 3.14, this feature is an unstable preview: the API is subject to change, and the
* implementation is incomplete. We expect that OkHttp 4.6 or 4.7 will finalize this API. Until
* then, expect API and behavior changes when you update your OkHttp dependency.**
*
* [doh_spec]: https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-13
*/
class DnsOverHttps internal constructor(
@get:JvmName("client") val client: OkHttpClient,
@get:JvmName("url") val url: HttpUrl,
@get:JvmName("includeIPv6") val includeIPv6: Boolean,
@get:JvmName("post") val post: Boolean,
@get:JvmName("resolvePrivateAddresses") val resolvePrivateAddresses: Boolean,
@get:JvmName("resolvePublicAddresses") val resolvePublicAddresses: Boolean
) : Dns {
@Throws(UnknownHostException::class)
override fun lookup(hostname: String): List {
if (!resolvePrivateAddresses || !resolvePublicAddresses) {
val privateHost = isPrivateHost(hostname)
if (privateHost && !resolvePrivateAddresses) {
throw UnknownHostException("private hosts not resolved")
}
if (!privateHost && !resolvePublicAddresses) {
throw UnknownHostException("public hosts not resolved")
}
}
return lookupHttps(hostname)
}
@Throws(UnknownHostException::class)
private fun lookupHttps(hostname: String): List {
val networkRequests = ArrayList(2)
val failures = ArrayList(2)
val results = ArrayList(5)
buildRequest(hostname, networkRequests, results, failures, DnsRecordCodec.TYPE_A)
if (includeIPv6) {
buildRequest(hostname, networkRequests, results, failures, DnsRecordCodec.TYPE_AAAA)
}
executeRequests(hostname, networkRequests, results, failures)
return if (results.isNotEmpty()) {
results
} else {
throwBestFailure(hostname, failures)
}
}
private fun buildRequest(
hostname: String,
networkRequests: MutableList,
results: MutableList,
failures: MutableList,
type: Int
) {
val request = buildRequest(hostname, type)
val response = getCacheOnlyResponse(request)
response?.let { processResponse(it, hostname, results, failures) } ?: networkRequests.add(
client.newCall(request))
}
private fun executeRequests(
hostname: String,
networkRequests: List,
responses: MutableList,
failures: MutableList
) {
val latch = CountDownLatch(networkRequests.size)
for (call in networkRequests) {
call.enqueue(object : Callback {
override fun onFailure(call: Call, e: IOException) {
synchronized(failures) {
failures.add(e)
}
latch.countDown()
}
override fun onResponse(call: Call, response: Response) {
processResponse(response, hostname, responses, failures)
latch.countDown()
}
})
}
try {
latch.await()
} catch (e: InterruptedException) {
failures.add(e)
}
}
private fun processResponse(
response: Response,
hostname: String,
results: MutableList,
failures: MutableList
) {
try {
val addresses = readResponse(hostname, response)
synchronized(results) {
results.addAll(addresses)
}
} catch (e: Exception) {
synchronized(failures) {
failures.add(e)
}
}
}
@Throws(UnknownHostException::class)
private fun throwBestFailure(hostname: String, failures: List): List {
if (failures.isEmpty()) {
throw UnknownHostException(hostname)
}
val failure = failures[0]
if (failure is UnknownHostException) {
throw failure
}
val unknownHostException = UnknownHostException(hostname)
unknownHostException.initCause(failure)
for (i in 1 until failures.size) {
unknownHostException.addSuppressed(failures[i])
}
throw unknownHostException
}
private fun getCacheOnlyResponse(request: Request): Response? {
if (!post && client.cache != null) {
try {
// Use the cache without hitting the network first
// 504 code indicates that the Cache is stale
val preferCache = CacheControl.Builder()
.onlyIfCached()
.build()
val cacheRequest = request.newBuilder().cacheControl(preferCache).build()
val cacheResponse = client.newCall(cacheRequest).execute()
if (cacheResponse.code != HttpURLConnection.HTTP_GATEWAY_TIMEOUT) {
return cacheResponse
}
} catch (ioe: IOException) {
// Failures are ignored as we can fallback to the network
// and hopefully repopulate the cache.
}
}
return null
}
@Throws(Exception::class)
private fun readResponse(hostname: String, response: Response): List {
if (response.cacheResponse == null && response.protocol !== Protocol.HTTP_2) {
Platform.get().log("Incorrect protocol: ${response.protocol}", Platform.WARN)
}
response.use {
if (!response.isSuccessful) {
throw IOException("response: " + response.code + " " + response.message)
}
val body = response.body
if (body!!.contentLength() > MAX_RESPONSE_SIZE) {
throw IOException(
"response size exceeds limit ($MAX_RESPONSE_SIZE bytes): ${body.contentLength()} bytes"
)
}
val responseBytes = body.source().readByteString()
return DnsRecordCodec.decodeAnswers(hostname, responseBytes)
}
}
private fun buildRequest(hostname: String, type: Int): Request =
Request.Builder().header("Accept", DNS_MESSAGE.toString()).apply {
val query = DnsRecordCodec.encodeQuery(hostname, type)
if (post) {
url(url).post(query.toRequestBody(DNS_MESSAGE))
} else {
val encoded = query.base64Url().replace("=", "")
val requestUrl = url.newBuilder().addQueryParameter("dns", encoded).build()
url(requestUrl)
}
}.build()
class Builder {
internal var client: OkHttpClient? = null
internal var url: HttpUrl? = null
internal var includeIPv6 = true
internal var post = false
internal var systemDns = Dns.SYSTEM
internal var bootstrapDnsHosts: List? = null
internal var resolvePrivateAddresses = false
internal var resolvePublicAddresses = true
fun build(): DnsOverHttps {
val client = this.client ?: throw NullPointerException("client not set")
return DnsOverHttps(
client.newBuilder().dns(buildBootstrapClient(this)).build(),
checkNotNull(url) { "url not set" },
includeIPv6,
post,
resolvePrivateAddresses,
resolvePublicAddresses
)
}
fun client(client: OkHttpClient) = apply {
this.client = client
}
fun url(url: HttpUrl) = apply {
this.url = url
}
fun includeIPv6(includeIPv6: Boolean) = apply {
this.includeIPv6 = includeIPv6
}
fun post(post: Boolean) = apply {
this.post = post
}
fun resolvePrivateAddresses(resolvePrivateAddresses: Boolean) = apply {
this.resolvePrivateAddresses = resolvePrivateAddresses
}
fun resolvePublicAddresses(resolvePublicAddresses: Boolean) = apply {
this.resolvePublicAddresses = resolvePublicAddresses
}
fun bootstrapDnsHosts(bootstrapDnsHosts: List?) = apply {
this.bootstrapDnsHosts = bootstrapDnsHosts
}
fun bootstrapDnsHosts(vararg bootstrapDnsHosts: InetAddress): Builder =
bootstrapDnsHosts(bootstrapDnsHosts.toList())
fun systemDns(systemDns: Dns) = apply {
this.systemDns = systemDns
}
}
companion object {
val DNS_MESSAGE: MediaType = "application/dns-message".toMediaType()
const val MAX_RESPONSE_SIZE = 64 * 1024
private fun buildBootstrapClient(builder: Builder): Dns {
val hosts = builder.bootstrapDnsHosts
return if (hosts != null) {
BootstrapDns(builder.url!!.host, hosts)
} else {
builder.systemDns
}
}
internal fun isPrivateHost(host: String): Boolean {
return PublicSuffixDatabase.get().getEffectiveTldPlusOne(host) == null
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy