com.signalfx.shaded.jetty.client.ProxyProtocolClientConnectionFactory Maven / Gradle / Ivy
//
// ========================================================================
// Copyright (c) 1995-2022 Mort Bay Consulting Pty Ltd and others.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package com.signalfx.shaded.jetty.client;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Executor;
import com.signalfx.shaded.jetty.io.AbstractConnection;
import com.signalfx.shaded.jetty.io.ClientConnectionFactory;
import com.signalfx.shaded.jetty.io.Connection;
import com.signalfx.shaded.jetty.io.EndPoint;
import com.signalfx.shaded.jetty.util.Callback;
import com.signalfx.shaded.jetty.util.Promise;
import com.signalfx.shaded.jetty.util.log.Log;
import com.signalfx.shaded.jetty.util.log.Logger;
/**
* ClientConnectionFactory for the
* PROXY protocol.
* Use the {@link V1} or {@link V2} versions of this class to specify what version of the
* PROXY protocol you want to use.
*/
public abstract class ProxyProtocolClientConnectionFactory implements ClientConnectionFactory
{
/**
* A ClientConnectionFactory for the PROXY protocol version 1.
*/
public static class V1 extends ProxyProtocolClientConnectionFactory
{
public V1(ClientConnectionFactory factory)
{
super(factory);
}
@Override
protected ProxyProtocolConnection newProxyProtocolConnection(EndPoint endPoint, Map context)
{
HttpDestination destination = (HttpDestination)context.get(HttpClientTransport.HTTP_DESTINATION_CONTEXT_KEY);
Executor executor = destination.getHttpClient().getExecutor();
Tag tag = (Tag)destination.getOrigin().getTag();
if (tag == null)
{
InetSocketAddress local = endPoint.getLocalAddress();
InetSocketAddress remote = endPoint.getRemoteAddress();
boolean ipv4 = local.getAddress() instanceof Inet4Address;
tag = new Tag(ipv4 ? "TCP4" : "TCP6", local.getAddress().getHostAddress(), local.getPort(), remote.getAddress().getHostAddress(), remote.getPort());
}
return new ProxyProtocolConnectionV1(endPoint, executor, getClientConnectionFactory(), context, tag);
}
/**
* PROXY protocol version 1 metadata holder to be used in conjunction
* with {@link com.signalfx.shaded.jetty.client.api.Request#tag(Object)}.
* Instances of this class are associated to a destination so that
* all connections of that destination will initiate the communication
* with the PROXY protocol version 1 bytes specified by this metadata.
*/
public static class Tag implements ClientConnectionFactory.Decorator
{
/**
* The PROXY V1 Tag typically used to "ping" the server.
*/
public static final Tag UNKNOWN = new Tag("UNKNOWN", null, 0, null, 0);
private final String family;
private final String srcIP;
private final int srcPort;
private final String dstIP;
private final int dstPort;
/**
* Creates a Tag whose metadata will be derived from the underlying EndPoint.
*/
public Tag()
{
this(null, 0);
}
/**
* Creates a Tag with the given source metadata.
* The destination metadata will be derived from the underlying EndPoint.
*
* @param srcIP the source IP address
* @param srcPort the source port
*/
public Tag(String srcIP, int srcPort)
{
this(null, srcIP, srcPort, null, 0);
}
/**
* Creates a Tag with the given metadata.
*
* @param family the protocol family
* @param srcIP the source IP address
* @param srcPort the source port
* @param dstIP the destination IP address
* @param dstPort the destination port
*/
public Tag(String family, String srcIP, int srcPort, String dstIP, int dstPort)
{
this.family = family;
this.srcIP = srcIP;
this.srcPort = srcPort;
this.dstIP = dstIP;
this.dstPort = dstPort;
}
public String getFamily()
{
return family;
}
public String getSourceAddress()
{
return srcIP;
}
public int getSourcePort()
{
return srcPort;
}
public String getDestinationAddress()
{
return dstIP;
}
public int getDestinationPort()
{
return dstPort;
}
@Override
public ClientConnectionFactory apply(ClientConnectionFactory factory)
{
return new V1(factory);
}
@Override
public boolean equals(Object obj)
{
if (this == obj)
return true;
if (obj == null || getClass() != obj.getClass())
return false;
Tag that = (Tag)obj;
return Objects.equals(family, that.family) &&
Objects.equals(srcIP, that.srcIP) &&
srcPort == that.srcPort &&
Objects.equals(dstIP, that.dstIP) &&
dstPort == that.dstPort;
}
@Override
public int hashCode()
{
return Objects.hash(family, srcIP, srcPort, dstIP, dstPort);
}
}
}
/**
* A ClientConnectionFactory for the PROXY protocol version 2.
*/
public static class V2 extends ProxyProtocolClientConnectionFactory
{
public V2(ClientConnectionFactory factory)
{
super(factory);
}
@Override
protected ProxyProtocolConnection newProxyProtocolConnection(EndPoint endPoint, Map context)
{
HttpDestination destination = (HttpDestination)context.get(HttpClientTransport.HTTP_DESTINATION_CONTEXT_KEY);
Executor executor = destination.getHttpClient().getExecutor();
Tag tag = (Tag)destination.getOrigin().getTag();
if (tag == null)
{
InetSocketAddress local = endPoint.getLocalAddress();
InetSocketAddress remote = endPoint.getRemoteAddress();
boolean ipv4 = local.getAddress() instanceof Inet4Address;
tag = new Tag(Tag.Command.PROXY, ipv4 ? Tag.Family.INET4 : Tag.Family.INET6, Tag.Protocol.STREAM, local.getAddress().getHostAddress(), local.getPort(), remote.getAddress().getHostAddress(), remote.getPort(), null);
}
return new ProxyProtocolConnectionV2(endPoint, executor, getClientConnectionFactory(), context, tag);
}
/**
* PROXY protocol version 2 metadata holder to be used in conjunction
* with {@link com.signalfx.shaded.jetty.client.api.Request#tag(Object)}.
* Instances of this class are associated to a destination so that
* all connections of that destination will initiate the communication
* with the PROXY protocol version 2 bytes specified by this metadata.
*/
public static class Tag implements ClientConnectionFactory.Decorator
{
/**
* The PROXY V2 Tag typically used to "ping" the server.
*/
public static final Tag LOCAL = new Tag(Command.LOCAL, Family.UNSPEC, Protocol.UNSPEC, null, 0, null, 0, null);
private Command command;
private Family family;
private Protocol protocol;
private String srcIP;
private int srcPort;
private String dstIP;
private int dstPort;
private List tlvs;
/**
* Creates a Tag whose metadata will be derived from the underlying EndPoint.
*/
public Tag()
{
this(null, 0);
}
/**
* Creates a Tag with the given source metadata.
* The destination metadata will be derived from the underlying EndPoint.
*
* @param srcIP the source IP address
* @param srcPort the source port
*/
public Tag(String srcIP, int srcPort)
{
this(Command.PROXY, null, Protocol.STREAM, srcIP, srcPort, null, 0, null);
}
/**
* Creates a Tag with the given source metadata and Type-Length-Value (TLV) objects.
* The destination metadata will be derived from the underlying EndPoint.
*
* @param srcIP the source IP address
* @param srcPort the source port
* @param tlvs the TLV objects
*/
public Tag(String srcIP, int srcPort, List tlvs)
{
this(Command.PROXY, null, Protocol.STREAM, srcIP, srcPort, null, 0, tlvs);
}
/**
* Creates a Tag with the given metadata.
*
* @param command the LOCAL or PROXY command
* @param family the protocol family
* @param protocol the protocol type
* @param srcIP the source IP address
* @param srcPort the source port
* @param dstIP the destination IP address
* @param dstPort the destination port
* @param tlvs the TLV objects
*/
public Tag(Command command, Family family, Protocol protocol, String srcIP, int srcPort, String dstIP, int dstPort, List tlvs)
{
this.command = command;
this.family = family;
this.protocol = protocol;
this.srcIP = srcIP;
this.srcPort = srcPort;
this.dstIP = dstIP;
this.dstPort = dstPort;
this.tlvs = tlvs;
}
public Command getCommand()
{
return command;
}
public Family getFamily()
{
return family;
}
public Protocol getProtocol()
{
return protocol;
}
public String getSourceAddress()
{
return srcIP;
}
public int getSourcePort()
{
return srcPort;
}
public String getDestinationAddress()
{
return dstIP;
}
public int getDestinationPort()
{
return dstPort;
}
public List getTLVs()
{
return tlvs;
}
@Override
public ClientConnectionFactory apply(ClientConnectionFactory factory)
{
return new V2(factory);
}
@Override
public boolean equals(Object obj)
{
if (this == obj)
return true;
if (obj == null || getClass() != obj.getClass())
return false;
Tag that = (Tag)obj;
return command == that.command &&
family == that.family &&
protocol == that.protocol &&
Objects.equals(srcIP, that.srcIP) &&
srcPort == that.srcPort &&
Objects.equals(dstIP, that.dstIP) &&
dstPort == that.dstPort &&
Objects.equals(tlvs, that.tlvs);
}
@Override
public int hashCode()
{
return Objects.hash(command, family, protocol, srcIP, srcPort, dstIP, dstPort, tlvs);
}
public enum Command
{
LOCAL, PROXY
}
public enum Family
{
UNSPEC, INET4, INET6, UNIX
}
public enum Protocol
{
UNSPEC, STREAM, DGRAM
}
public static class TLV
{
private final int type;
private final byte[] value;
public TLV(int type, byte[] value)
{
if (type < 0 || type > 255)
throw new IllegalArgumentException("Invalid type: " + type);
if (value != null && value.length > 65535)
throw new IllegalArgumentException("Invalid value length: " + value.length);
this.type = type;
this.value = Objects.requireNonNull(value);
}
public int getType()
{
return type;
}
public byte[] getValue()
{
return value;
}
@Override
public boolean equals(Object obj)
{
if (this == obj)
return true;
if (obj == null || getClass() != obj.getClass())
return false;
TLV that = (TLV)obj;
return type == that.type && Arrays.equals(value, that.value);
}
@Override
public int hashCode()
{
int result = Objects.hash(type);
result = 31 * result + Arrays.hashCode(value);
return result;
}
}
}
}
private final ClientConnectionFactory factory;
private ProxyProtocolClientConnectionFactory(ClientConnectionFactory factory)
{
this.factory = factory;
}
public ClientConnectionFactory getClientConnectionFactory()
{
return factory;
}
@Override
public Connection newConnection(EndPoint endPoint, Map context)
{
ProxyProtocolConnection connection = newProxyProtocolConnection(endPoint, context);
return customize(connection, context);
}
protected abstract ProxyProtocolConnection newProxyProtocolConnection(EndPoint endPoint, Map context);
protected abstract static class ProxyProtocolConnection extends AbstractConnection implements Callback
{
protected static final Logger LOG = Log.getLogger(ProxyProtocolConnection.class);
private final ClientConnectionFactory factory;
private final Map context;
private ProxyProtocolConnection(EndPoint endPoint, Executor executor, ClientConnectionFactory factory, Map context)
{
super(endPoint, executor);
this.factory = factory;
this.context = context;
}
@Override
public void onOpen()
{
super.onOpen();
writePROXYBytes(getEndPoint(), this);
}
protected abstract void writePROXYBytes(EndPoint endPoint, Callback callback);
@Override
public void succeeded()
{
try
{
EndPoint endPoint = getEndPoint();
Connection connection = factory.newConnection(endPoint, context);
if (LOG.isDebugEnabled())
LOG.debug("Written PROXY line, upgrading to {}", connection);
endPoint.upgrade(connection);
}
catch (Throwable x)
{
failed(x);
}
}
@Override
public void failed(Throwable x)
{
close();
Promise> promise = (Promise>)context.get(HttpClientTransport.HTTP_CONNECTION_PROMISE_CONTEXT_KEY);
promise.failed(x);
}
@Override
public InvocationType getInvocationType()
{
return InvocationType.NON_BLOCKING;
}
@Override
public void onFillable()
{
}
}
private static class ProxyProtocolConnectionV1 extends ProxyProtocolConnection
{
private final V1.Tag tag;
public ProxyProtocolConnectionV1(EndPoint endPoint, Executor executor, ClientConnectionFactory factory, Map context, V1.Tag tag)
{
super(endPoint, executor, factory, context);
this.tag = tag;
}
@Override
protected void writePROXYBytes(EndPoint endPoint, Callback callback)
{
try
{
InetSocketAddress localAddress = endPoint.getLocalAddress();
InetSocketAddress remoteAddress = endPoint.getRemoteAddress();
String family = tag.getFamily();
String srcIP = tag.getSourceAddress();
int srcPort = tag.getSourcePort();
String dstIP = tag.getDestinationAddress();
int dstPort = tag.getDestinationPort();
if (family == null)
family = localAddress.getAddress() instanceof Inet4Address ? "TCP4" : "TCP6";
family = family.toUpperCase(Locale.ENGLISH);
boolean unknown = family.equals("UNKNOWN");
StringBuilder builder = new StringBuilder(64);
builder.append("PROXY ").append(family);
if (!unknown)
{
if (srcIP == null)
srcIP = localAddress.getAddress().getHostAddress();
builder.append(" ").append(srcIP);
if (dstIP == null)
dstIP = remoteAddress.getAddress().getHostAddress();
builder.append(" ").append(dstIP);
if (srcPort <= 0)
srcPort = localAddress.getPort();
builder.append(" ").append(srcPort);
if (dstPort <= 0)
dstPort = remoteAddress.getPort();
builder.append(" ").append(dstPort);
}
builder.append("\r\n");
String line = builder.toString();
if (LOG.isDebugEnabled())
LOG.debug("Writing PROXY bytes: {}", line.trim());
ByteBuffer buffer = ByteBuffer.wrap(line.getBytes(StandardCharsets.US_ASCII));
endPoint.write(callback, buffer);
}
catch (Throwable x)
{
callback.failed(x);
}
}
}
private static class ProxyProtocolConnectionV2 extends ProxyProtocolConnection
{
private static final byte[] MAGIC = {0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A};
private final V2.Tag tag;
public ProxyProtocolConnectionV2(EndPoint endPoint, Executor executor, ClientConnectionFactory factory, Map context, V2.Tag tag)
{
super(endPoint, executor, factory, context);
this.tag = tag;
}
@Override
protected void writePROXYBytes(EndPoint endPoint, Callback callback)
{
try
{
int capacity = MAGIC.length;
capacity += 1; // version and command
capacity += 1; // family and protocol
capacity += 2; // length
capacity += 216; // max address length
List tlvs = tag.getTLVs();
int vectorsLength = tlvs == null ? 0 : tlvs.stream()
.mapToInt(tlv -> 1 + 2 + tlv.getValue().length)
.sum();
capacity += vectorsLength;
ByteBuffer buffer = ByteBuffer.allocateDirect(capacity);
buffer.put(MAGIC);
V2.Tag.Command command = tag.getCommand();
int versionAndCommand = (2 << 4) | (command.ordinal() & 0x0F);
buffer.put((byte)versionAndCommand);
V2.Tag.Family family = tag.getFamily();
String srcAddr = tag.getSourceAddress();
if (srcAddr == null)
srcAddr = endPoint.getLocalAddress().getAddress().getHostAddress();
int srcPort = tag.getSourcePort();
if (srcPort <= 0)
srcPort = endPoint.getLocalAddress().getPort();
if (family == null)
family = InetAddress.getByName(srcAddr) instanceof Inet4Address ? V2.Tag.Family.INET4 : V2.Tag.Family.INET6;
V2.Tag.Protocol protocol = tag.getProtocol();
if (protocol == null)
protocol = V2.Tag.Protocol.STREAM;
int familyAndProtocol = (family.ordinal() << 4) | protocol.ordinal();
buffer.put((byte)familyAndProtocol);
int length = 0;
switch (family)
{
case UNSPEC:
break;
case INET4:
length = 12;
break;
case INET6:
length = 36;
break;
case UNIX:
length = 216;
break;
default:
throw new IllegalStateException();
}
length += vectorsLength;
buffer.putShort((short)length);
String dstAddr = tag.getDestinationAddress();
if (dstAddr == null)
dstAddr = endPoint.getRemoteAddress().getAddress().getHostAddress();
int dstPort = tag.getDestinationPort();
if (dstPort <= 0)
dstPort = endPoint.getRemoteAddress().getPort();
switch (family)
{
case UNSPEC:
break;
case INET4:
case INET6:
buffer.put(InetAddress.getByName(srcAddr).getAddress());
buffer.put(InetAddress.getByName(dstAddr).getAddress());
buffer.putShort((short)srcPort);
buffer.putShort((short)dstPort);
break;
case UNIX:
int position = buffer.position();
buffer.put(srcAddr.getBytes(StandardCharsets.US_ASCII));
buffer.position(position + 108);
buffer.put(dstAddr.getBytes(StandardCharsets.US_ASCII));
break;
default:
throw new IllegalStateException();
}
if (tlvs != null)
{
for (V2.Tag.TLV tlv : tlvs)
{
buffer.put((byte)tlv.getType());
byte[] data = tlv.getValue();
buffer.putShort((short)data.length);
buffer.put(data);
}
}
buffer.flip();
endPoint.write(callback, buffer);
}
catch (Throwable x)
{
callback.failed(x);
}
}
}
}