All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.xbill.DNS.DohResolver Maven / Gradle / Ivy

There is a newer version: 3.6.2_1
Show newest version
// SPDX-License-Identifier: BSD-3-Clause
package org.xbill.DNS;

import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.HttpURLConnection;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.net.URL;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import lombok.SneakyThrows;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.xbill.DNS.AsyncSemaphore.Permit;
import org.xbill.DNS.utils.base64;

/**
 * Proof-of-concept DNS over HTTP (DoH) resolver.
 * This class is not suitable for high load scenarios because of the shortcomings of Java's built-in
 * HTTP clients. For more control, implement your own {@link Resolver} using e.g. OkHttp.
 *
 * 

On Java 8, it uses HTTP/1.1, which is against the recommendation of RFC 8484 to use HTTP/2 and * thus slower. On Java 11 or newer, HTTP/2 is always used, but the built-in HttpClient has its own * issues with connection handling. * *

As of 2020-09-13, the following limits of public resolvers for HTTP/2 were observed: *

  • https://cloudflare-dns.com/dns-query: max streams=250, idle timeout=400s *
  • https://dns.google/dns-query: max streams=100, idle timeout=240s * * @since 3.0 */ @Slf4j public final class DohResolver implements Resolver { private static final boolean USE_HTTP_CLIENT; private static final Map httpClients = Collections.synchronizedMap(new WeakHashMap<>()); private final SSLSocketFactory sslSocketFactory; private static Object defaultHttpRequestBuilder; private static Method publisherOfByteArrayMethod; private static Method requestBuilderTimeoutMethod; private static Method requestBuilderCopyMethod; private static Method requestBuilderUriMethod; private static Method requestBuilderBuildMethod; private static Method requestBuilderPostMethod; private static Method httpClientNewBuilderMethod; private static Method httpClientBuilderTimeoutMethod; private static Method httpClientBuilderExecutorMethod; private static Method httpClientBuilderBuildMethod; private static Method httpClientSendAsyncMethod; private static Method byteArrayBodyPublisherMethod; private static Method httpResponseBodyMethod; private static Method httpResponseStatusCodeMethod; private boolean usePost = false; private Duration timeout = Duration.ofSeconds(5); private String uriTemplate; private final Duration idleConnectionTimeout; private OPTRecord queryOPT = new OPTRecord(0, 0, 0); private TSIG tsig; private Executor defaultExecutor = ForkJoinPool.commonPool(); /** * Maximum concurrent HTTP/2 streams or HTTP/1.1 connections. * *

    rfc7540#section-6.5.2 recommends a minimum of 100 streams for HTTP/2. */ private final AsyncSemaphore maxConcurrentRequests; private final AtomicLong lastRequest = new AtomicLong(0); private final AsyncSemaphore initialRequestLock = new AsyncSemaphore(1); private static final String APPLICATION_DNS_MESSAGE = "application/dns-message"; static { boolean initSuccess = false; if (!System.getProperty("java.version").startsWith("1.")) { try { Class httpClientBuilderClass = Class.forName("java.net.http.HttpClient$Builder"); Class httpClientClass = Class.forName("java.net.http.HttpClient"); Class httpVersionEnum = Class.forName("java.net.http.HttpClient$Version"); Class httpRequestBuilderClass = Class.forName("java.net.http.HttpRequest$Builder"); Class httpRequestClass = Class.forName("java.net.http.HttpRequest"); Class bodyPublishersClass = Class.forName("java.net.http.HttpRequest$BodyPublishers"); Class bodyPublisherClass = Class.forName("java.net.http.HttpRequest$BodyPublisher"); Class httpResponseClass = Class.forName("java.net.http.HttpResponse"); Class bodyHandlersClass = Class.forName("java.net.http.HttpResponse$BodyHandlers"); Class bodyHandlerClass = Class.forName("java.net.http.HttpResponse$BodyHandler"); // HttpClient.Builder httpClientBuilderTimeoutMethod = httpClientBuilderClass.getDeclaredMethod("connectTimeout", Duration.class); httpClientBuilderExecutorMethod = httpClientBuilderClass.getDeclaredMethod("executor", Executor.class); httpClientBuilderBuildMethod = httpClientBuilderClass.getDeclaredMethod("build"); // HttpClient httpClientNewBuilderMethod = httpClientClass.getDeclaredMethod("newBuilder"); httpClientSendAsyncMethod = httpClientClass.getDeclaredMethod("sendAsync", httpRequestClass, bodyHandlerClass); // HttpRequestBuilder Method requestBuilderHeaderMethod = httpRequestBuilderClass.getDeclaredMethod("header", String.class, String.class); Method requestBuilderVersionMethod = httpRequestBuilderClass.getDeclaredMethod("version", httpVersionEnum); requestBuilderTimeoutMethod = httpRequestBuilderClass.getDeclaredMethod("timeout", Duration.class); requestBuilderUriMethod = httpRequestBuilderClass.getDeclaredMethod("uri", URI.class); requestBuilderCopyMethod = httpRequestBuilderClass.getDeclaredMethod("copy"); requestBuilderBuildMethod = httpRequestBuilderClass.getDeclaredMethod("build"); requestBuilderPostMethod = httpRequestBuilderClass.getDeclaredMethod("POST", bodyPublisherClass); // HttpRequest Method requestBuilderNewBuilderMethod = httpRequestClass.getDeclaredMethod("newBuilder"); // BodyPublishers publisherOfByteArrayMethod = bodyPublishersClass.getDeclaredMethod("ofByteArray", byte[].class); // BodyPublisher byteArrayBodyPublisherMethod = bodyHandlersClass.getDeclaredMethod("ofByteArray"); // HttpResponse httpResponseBodyMethod = httpResponseClass.getDeclaredMethod("body"); httpResponseStatusCodeMethod = httpResponseClass.getDeclaredMethod("statusCode"); // defaultHttpRequestBuilder = HttpRequest.newBuilder(); // defaultHttpRequestBuilder.version(HttpClient.Version.HTTP_2); // defaultHttpRequestBuilder.header("Content-Type", "application/dns-message"); // defaultHttpRequestBuilder.header("Accept", "application/dns-message"); defaultHttpRequestBuilder = requestBuilderNewBuilderMethod.invoke(null); @SuppressWarnings({"unchecked", "rawtypes"}) Enum http2Version = Enum.valueOf((Class) httpVersionEnum, "HTTP_2"); requestBuilderVersionMethod.invoke(defaultHttpRequestBuilder, http2Version); requestBuilderHeaderMethod.invoke( defaultHttpRequestBuilder, "Content-Type", APPLICATION_DNS_MESSAGE); requestBuilderHeaderMethod.invoke( defaultHttpRequestBuilder, "Accept", APPLICATION_DNS_MESSAGE); initSuccess = true; } catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { // fallback to Java 8 log.warn("Java >= 11 detected, but HttpRequest not available"); } } USE_HTTP_CLIENT = initSuccess; } /** * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s). * * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query} */ public DohResolver(String uriTemplate) { this(uriTemplate, 100, Duration.ofMinutes(2)); } /** * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s). * * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query} * @param maxConcurrentRequests Maximum concurrent HTTP/2 streams for Java 11+ or HTTP/1.1 * connections for Java 8. On Java 8 this cannot exceed the system property {@code * http.maxConnections}. * @param idleConnectionTimeout Max. idle time for HTTP/2 connections until a request is * serialized. Applies to Java 11+ only. * @since 3.3 */ public DohResolver( String uriTemplate, int maxConcurrentRequests, Duration idleConnectionTimeout) { this.uriTemplate = uriTemplate; this.idleConnectionTimeout = idleConnectionTimeout; if (maxConcurrentRequests <= 0) { throw new IllegalArgumentException("maxConcurrentRequests must be > 0"); } if (!USE_HTTP_CLIENT) { try { int javaMaxConn = Integer.parseInt(System.getProperty("http.maxConnections", "5")); if (maxConcurrentRequests > javaMaxConn) { maxConcurrentRequests = javaMaxConn; } } catch (NumberFormatException nfe) { // well, use what we got } } this.maxConcurrentRequests = new AsyncSemaphore(maxConcurrentRequests); try { sslSocketFactory = SSLContext.getDefault().getSocketFactory(); } catch (NoSuchAlgorithmException e) { throw new RuntimeException(e); } } @SneakyThrows private Object getHttpClient(Executor executor) { return httpClients.computeIfAbsent( executor, key -> { try { // return HttpClient.newBuilder() // .connectTimeout(timeout). // .executor(executor) // .build(); Object httpClientBuilder = httpClientNewBuilderMethod.invoke(null); httpClientBuilderTimeoutMethod.invoke(httpClientBuilder, timeout); httpClientBuilderExecutorMethod.invoke(httpClientBuilder, key); return httpClientBuilderBuildMethod.invoke(httpClientBuilder); } catch (IllegalAccessException | InvocationTargetException e) { log.warn("Could not create a HttpClient with for Executor {}", key, e); return null; } }); } /** Not implemented. Specify the port in {@link #setUriTemplate(String)} if required. */ @Override public void setPort(int port) { // Not implemented, port is part of the URI } /** Not implemented. */ @Override public void setTCP(boolean flag) { // Not implemented, HTTP is always TCP } /** Not implemented. */ @Override public void setIgnoreTruncation(boolean flag) { // Not implemented, protocol uses TCP and doesn't have truncation } /** * Sets the EDNS information on outgoing messages. * * @param version The EDNS version to use. 0 indicates EDNS0 and -1 indicates no EDNS. * @param payloadSize ignored * @param flags EDNS extended flags to be set in the OPT record. * @param options EDNS options to be set in the OPT record */ @Override public void setEDNS(int version, int payloadSize, int flags, List options) { switch (version) { case -1: queryOPT = null; break; case 0: queryOPT = new OPTRecord(0, 0, version, flags, options); break; default: throw new IllegalArgumentException("invalid EDNS version - must be 0 or -1 to disable"); } } @Override public void setTSIGKey(TSIG key) { this.tsig = key; } @Override public void setTimeout(Duration timeout) { this.timeout = timeout; httpClients.clear(); } @Override public Duration getTimeout() { return timeout; } @Override public CompletionStage sendAsync(Message query) { return sendAsync(query, defaultExecutor); } @Override public CompletionStage sendAsync(Message query, Executor executor) { if (USE_HTTP_CLIENT) { return sendAsync11(query, executor); } return sendAsync8(query, executor); } private CompletionStage sendAsync8(final Message query, Executor executor) { byte[] queryBytes = prepareQuery(query).toWire(); String url = getUrl(queryBytes); long startTime = System.nanoTime(); return maxConcurrentRequests .acquire(timeout) .handleAsync( (permit, ex) -> { if (ex != null) { return this.timeoutFailedFuture(query, ex); } else { try { SendAndGetMessageBytesResponse result = sendAndGetMessageBytes(url, queryBytes, startTime); Message response; if (result.rc == Rcode.NOERROR) { response = new Message(result.responseBytes); verifyTSIG(query, response, result.responseBytes, tsig); } else { response = new Message(0); response.getHeader().setRcode(result.rc); } response.setResolver(this); return CompletableFuture.completedFuture(response); } catch (SocketTimeoutException e) { return this.timeoutFailedFuture(query, e); } catch (IOException e) { return this.failedFuture(e); } finally { permit.release(); } } }, executor) .thenCompose(Function.identity()); } @Value private static class SendAndGetMessageBytesResponse { int rc; byte[] responseBytes; } private SendAndGetMessageBytesResponse sendAndGetMessageBytes( String url, byte[] queryBytes, long startTime) throws IOException { HttpURLConnection conn = (HttpURLConnection) new URL(url).openConnection(); if (conn instanceof HttpsURLConnection) { ((HttpsURLConnection) conn).setSSLSocketFactory(sslSocketFactory); } Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS); conn.setConnectTimeout((int) remainingTimeout.toMillis()); conn.setReadTimeout((int) remainingTimeout.toMillis()); conn.setRequestMethod(usePost ? "POST" : "GET"); conn.setRequestProperty("Content-Type", APPLICATION_DNS_MESSAGE); conn.setRequestProperty("Accept", APPLICATION_DNS_MESSAGE); if (usePost) { conn.setDoOutput(true); conn.getOutputStream().write(queryBytes); } int rc = conn.getResponseCode(); if (rc < 200 || rc >= 300) { discardStream(conn.getInputStream()); discardStream(conn.getErrorStream()); return new SendAndGetMessageBytesResponse(Rcode.SERVFAIL, null); } try (InputStream is = conn.getInputStream()) { int length = conn.getContentLength(); if (length > -1) { byte[] responseBytes = new byte[conn.getContentLength()]; int r; int offset = 0; while ((r = is.read(responseBytes, offset, responseBytes.length - offset)) > 0) { offset += r; remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS); if (remainingTimeout.isNegative()) { throw new SocketTimeoutException(); } } if (offset < responseBytes.length) { throw new EOFException("Could not read expected content length"); } return new SendAndGetMessageBytesResponse(Rcode.NOERROR, responseBytes); } else { try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) { byte[] buffer = new byte[4096]; int r; while ((r = is.read(buffer, 0, buffer.length)) > 0) { remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS); if (remainingTimeout.isNegative()) { throw new SocketTimeoutException(); } bos.write(buffer, 0, r); } return new SendAndGetMessageBytesResponse(Rcode.NOERROR, bos.toByteArray()); } } } catch (IOException ioe) { discardStream(conn.getErrorStream()); throw ioe; } } private void discardStream(InputStream es) throws IOException { if (es != null) { try (InputStream in = es) { byte[] buf = new byte[4096]; while (in.read(buf) > 0) { // discard } } catch (IOException ioe) { // ignore } } } private CompletionStage sendAsync11(final Message query, Executor executor) { long startTime = System.nanoTime(); byte[] queryBytes = prepareQuery(query).toWire(); String url = getUrl(queryBytes); // var requestBuilder = defaultHttpRequestBuilder.copy(); // requestBuilder.uri(URI.create(url)); Object requestBuilder; try { requestBuilder = requestBuilderCopyMethod.invoke(defaultHttpRequestBuilder); requestBuilderUriMethod.invoke(requestBuilder, URI.create(url)); if (usePost) { // requestBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(queryBytes)); requestBuilderPostMethod.invoke( requestBuilder, publisherOfByteArrayMethod.invoke(null, queryBytes)); } } catch (IllegalAccessException | InvocationTargetException e) { return failedFuture(e); } // check if this request needs to be done synchronously because of HttpClient's stupidity to // not use the connection pool for HTTP/2 until one connection is successfully established, // which could lead to hundreds of connections (and threads with the default executor) Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS); return initialRequestLock .acquire(remainingTimeout) .handle( (initialRequestPermit, initialRequestEx) -> { if (initialRequestEx != null) { return this.timeoutFailedFuture(query, initialRequestEx); } else { return sendAsync11WithInitialRequestPermit( query, executor, startTime, requestBuilder, initialRequestPermit); } }) .thenCompose(Function.identity()); } private CompletionStage sendAsync11WithInitialRequestPermit( Message query, Executor executor, long startTime, Object requestBuilder, Permit initialRequestPermit) { long lastRequestTime = lastRequest.get(); boolean isInitialRequest = (lastRequestTime < System.nanoTime() - idleConnectionTimeout.toNanos()); if (!isInitialRequest) { initialRequestPermit.release(); } // check if we already exceeded the query timeout while checking the initial connection Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS); if (remainingTimeout.isNegative()) { if (isInitialRequest) { initialRequestPermit.release(); } return timeoutFailedFuture(query, null); } // Lock a HTTP/2 stream. Another stupidity of HttpClient to not simply queue the // request, but fail with an IOException which also CLOSES the connection... *facepalm* return maxConcurrentRequests .acquire(remainingTimeout) .handle( (maxConcurrentRequestPermit, maxConcurrentRequestEx) -> { if (maxConcurrentRequestEx != null) { if (isInitialRequest) { initialRequestPermit.release(); } return this.timeoutFailedFuture(query, maxConcurrentRequestEx); } else { return sendAsync11WithConcurrentRequestPermit( query, executor, startTime, requestBuilder, initialRequestPermit, isInitialRequest, maxConcurrentRequestPermit); } }) .thenCompose(Function.identity()); } private CompletionStage sendAsync11WithConcurrentRequestPermit( Message query, Executor executor, long startTime, Object requestBuilder, Permit initialRequestPermit, boolean isInitialRequest, Permit maxConcurrentRequestPermit) { // check if the stream lock acquisition took too long Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS); if (remainingTimeout.isNegative()) { if (isInitialRequest) { initialRequestPermit.release(); } maxConcurrentRequestPermit.release(); return timeoutFailedFuture(query, null); } // var httpRequest = requestBuilder.timeout(remainingTimeout).build(); // var bodyHandler = HttpResponse.BodyHandlers.ofByteArray(); // return getHttpClient(executor).sendAsync(httpRequest, bodyHandler) try { Object httpClient = getHttpClient(executor); requestBuilderTimeoutMethod.invoke(requestBuilder, remainingTimeout); Object httpRequest = requestBuilderBuildMethod.invoke(requestBuilder); Object bodyHandler = byteArrayBodyPublisherMethod.invoke(null); CompletableFuture f = ((CompletableFuture) httpClientSendAsyncMethod.invoke(httpClient, httpRequest, bodyHandler)) .whenComplete( (result, ex) -> { if (ex == null) { lastRequest.set(startTime); } maxConcurrentRequestPermit.release(); if (isInitialRequest) { initialRequestPermit.release(); } }) .handleAsync( (response, ex) -> { if (ex != null) { if (ex.getCause().getClass().getSimpleName().equals("HttpTimeoutException")) { return this.timeoutFailedFuture(query, ex.getCause()); } else { return this.failedFuture(ex); } } else { try { Message responseMessage; // int rc = response.statusCode(); int rc = (int) httpResponseStatusCodeMethod.invoke(response); if (rc >= 200 && rc < 300) { // byte[] responseBytes = response.body(); byte[] responseBytes = (byte[]) httpResponseBodyMethod.invoke(response); responseMessage = new Message(responseBytes); verifyTSIG(query, responseMessage, responseBytes, tsig); } else { responseMessage = new Message(); responseMessage.getHeader().setRcode(Rcode.SERVFAIL); } responseMessage.setResolver(this); return CompletableFuture.completedFuture(responseMessage); } catch (IOException | IllegalAccessException | InvocationTargetException e) { return this.failedFuture(e); } } }, executor) .thenCompose(Function.identity()); return TimeoutCompletableFuture.compatTimeout( f, remainingTimeout.toMillis(), TimeUnit.MILLISECONDS); } catch (IllegalAccessException | InvocationTargetException e) { return failedFuture(e); } } private CompletableFuture failedFuture(Throwable e) { CompletableFuture f = new CompletableFuture<>(); f.completeExceptionally(e); return f; } private CompletableFuture timeoutFailedFuture(Message query, Throwable inner) { return failedFuture( new IOException( "Query " + query.getHeader().getID() + " for " + query.getQuestion().getName() + "/" + Type.string(query.getQuestion().getType()) + " timed out", inner)); } private String getUrl(byte[] queryBytes) { String url = uriTemplate; if (!usePost) { url += "?dns=" + base64.toString(queryBytes, true); } return url; } private Message prepareQuery(Message query) { Message preparedQuery = query.clone(); preparedQuery.getHeader().setID(0); if (queryOPT != null && preparedQuery.getOPT() == null) { preparedQuery.addRecord(queryOPT, Section.ADDITIONAL); } if (tsig != null) { tsig.apply(preparedQuery, null); } return preparedQuery; } private void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) { if (tsig == null) { return; } int error = tsig.verify(response, b, query.getTSIG()); log.debug( "TSIG verify for query {}, {}/{}: {}", query.getHeader().getID(), query.getQuestion().getName(), Type.string(query.getQuestion().getType()), Rcode.TSIGstring(error)); } /** Returns {@code true} if the HTTP method POST to resolve, {@code false} if GET is used. */ public boolean isUsePost() { return usePost; } /** * Sets the HTTP method to use for resolving. * * @param usePost {@code true} to use POST, {@code false} to use GET (the default). */ public void setUsePost(boolean usePost) { this.usePost = usePost; } /** Gets the current URI used for resolving. */ public String getUriTemplate() { return uriTemplate; } /** Sets the URI to use for resolving, e.g. {@code https://dns.google/dns-query} */ public void setUriTemplate(String uriTemplate) { this.uriTemplate = uriTemplate; } /** * Gets the default {@link Executor} for request handling, defaults to {@link * ForkJoinPool#commonPool()}. * * @since 3.3 * @deprecated not applicable if {@link #sendAsync(Message, Executor)} is used. */ @Deprecated public Executor getExecutor() { return defaultExecutor; } /** * Sets the default {@link Executor} for request handling. * * @param executor The new {@link Executor}, can be {@code null} (which is equivalent to {@link * ForkJoinPool#commonPool()}). * @since 3.3 * @deprecated Use {@link #sendAsync(Message, Executor)}. */ @Deprecated public void setExecutor(Executor executor) { this.defaultExecutor = executor == null ? ForkJoinPool.commonPool() : executor; httpClients.clear(); } @Override public String toString() { return "DohResolver {" + (usePost ? "POST " : "GET ") + uriTemplate + "}"; } }





  • © 2015 - 2024 Weber Informatics LLC | Privacy Policy