com.twitter.finagle.socks.SocksConnectHandler.scala Maven / Gradle / Ivy
The newest version!
package com.twitter.finagle.socks
import java.net.{Inet4Address, Inet6Address, InetSocketAddress, SocketAddress}
import java.util.concurrent.atomic.AtomicReference
import org.jboss.netty.buffer.{ChannelBuffer, ChannelBuffers}
import org.jboss.netty.channel._
import com.twitter.finagle.{ChannelClosedException, ConnectionFailedException, InconsistentStateException}
import com.twitter.io.Charsets
sealed abstract class AuthenticationSetting(val typeByte: Byte)
object Unauthenticated extends AuthenticationSetting(0x00)
case class UsernamePassAuthenticationSetting(username: String, password: String)
extends AuthenticationSetting(0x02)
object SocksConnectHandler {
// Throwables used as `cause` fields for ConnectionFailedExceptions.
private[socks] val InvalidInit = new Throwable("unexpected SOCKS version or authentication " +
"level specified in connect response from proxy")
private[socks] val InvalidResponse = new Throwable("unexpected SOCKS version or response " +
"status specified in connect response from proxy")
// Socks Version constants
private val Version1: Byte = 0x01
private val Version5: Byte = 0x05
// Socks IP Address Constants
private val IpV4Indicator: Byte = 0x01
private val IpV6Indicator: Byte = 0x04
private val HostnameIndicator: Byte = 0x03
// Socks communication constants
private val Connect: Byte = 0x01
private val Reserved: Byte = 0x00
private val SuccessResponse: Byte = 0x00
}
/**
* Handle connections through a SOCKS proxy.
*
* See http://www.ietf.org/rfc/rfc1928.txt
*
* Only username and password authentication is implemented;
* See https://tools.ietf.org/rfc/rfc1929.txt
*
* We assume the proxy is provided by ssh -D.
*/
class SocksConnectHandler(
proxyAddr: SocketAddress,
addr: InetSocketAddress,
authenticationSettings: Seq[AuthenticationSetting] = Seq(Unauthenticated))
extends SimpleChannelHandler {
import SocksConnectHandler._
object State extends Enumeration {
val Start, Connected, Requested, Authenticating = Value
}
import State._
private[this] var state = Start
private[this] val buf = ChannelBuffers.dynamicBuffer()
private[this] val bytes = new Array[Byte](4)
private[this] val connectFuture = new AtomicReference[ChannelFuture](null)
private[this] val authenticationMap =
authenticationSettings.map { setting => setting.typeByte -> setting }.toMap
private[this] val supportedTypes = authenticationMap.keys.toArray.sorted
// following Netty's ReplayingDecoderBuffer, we throw this when we run out of bytes
object ReplayError extends scala.Error
private[this] def fail(c: Channel, t: Throwable) {
Option(connectFuture.get) foreach { _.setFailure(t) }
Channels.close(c)
}
private[this] def write(ctx: ChannelHandlerContext, msg: Any) {
Channels.write(ctx, Channels.future(ctx.getChannel), msg, null)
}
private[this] def writeInit(ctx: ChannelHandlerContext) {
val buf = ChannelBuffers.dynamicBuffer(1024)
buf.writeByte(Version5)
buf.writeByte(supportedTypes.size.toByte)
buf.writeBytes(supportedTypes)
write(ctx, buf)
}
private[this] def readInit(): Option[AuthenticationSetting] = {
checkReadableBytes(2)
buf.readBytes(bytes, 0, 2)
if (bytes(0) == Version5)
authenticationMap.get(bytes(1))
else
None
}
private[this] def writeRequest(ctx: ChannelHandlerContext) {
val buf = ChannelBuffers.dynamicBuffer(1024)
buf.writeBytes(Array[Byte](Version5, Connect, Reserved))
addr.getAddress match {
case v4Addr: Inet4Address =>
buf.writeByte(IpV4Indicator)
buf.writeBytes(v4Addr.getAddress)
case v6Addr: Inet6Address =>
buf.writeByte(IpV6Indicator)
buf.writeBytes(v6Addr.getAddress)
case _ => // unresolved host
buf.writeByte(HostnameIndicator)
val hostnameBytes = addr.getHostName.getBytes(Charsets.UsAscii)
buf.writeByte(hostnameBytes.size)
buf.writeBytes(hostnameBytes)
}
buf.writeShort(addr.getPort)
write(ctx, buf)
}
private[this]
def writeUserNameAndPass(ctx: ChannelHandlerContext, username: String, pass: String) {
val buf = ChannelBuffers.buffer(1024)
buf.writeByte(Version1)
// RFC does not specify an encoding. Assume UTF8
val usernameBytes = username.getBytes(Charsets.Utf8)
buf.writeByte(usernameBytes.size.toByte)
buf.writeBytes(usernameBytes)
val passBytes = pass.getBytes(Charsets.Utf8)
buf.writeByte(passBytes.size.toByte)
buf.writeBytes(passBytes)
write(ctx, buf)
}
private[this] def readAuthenticated() = {
checkReadableBytes(2)
buf.readBytes(bytes, 0, 2)
bytes(0) == Version1 && bytes(1) == SuccessResponse
}
private[this] def readResponse(): Boolean = {
checkReadableBytes(4)
buf.readBytes(bytes, 0, 4)
if (bytes(0) == Version5 &&
bytes(1) == SuccessResponse &&
bytes(2) == Reserved) {
bytes(3) match {
case IpV4Indicator =>
discardBytes(4)
case HostnameIndicator =>
checkReadableBytes(1)
discardBytes(buf.readUnsignedByte())
case IpV6Indicator =>
discardBytes(16)
}
discardBytes(2)
true
} else {
false
}
}
private[this] def discardBytes(numBytes: Int) {
checkReadableBytes(numBytes)
buf.readBytes(numBytes)
}
private[this] def checkReadableBytes(numBytes: Int) {
if (buf.readableBytes < numBytes)
throw ReplayError
}
override def connectRequested(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
e match {
case de: DownstreamChannelStateEvent =>
if (!connectFuture.compareAndSet(null, e.getFuture)) {
fail(ctx.getChannel, new InconsistentStateException(addr))
return
}
// proxy cancellation
val wrappedConnectFuture = Channels.future(de.getChannel, true)
de.getFuture.addListener(new ChannelFutureListener {
def operationComplete(f: ChannelFuture) {
if (f.isCancelled)
wrappedConnectFuture.cancel()
}
})
// Proxy failures here so that if the connect fails, it is
// propagated to the listener, not just on the channel.
wrappedConnectFuture.addListener(new ChannelFutureListener {
def operationComplete(f: ChannelFuture) {
if (f.isSuccess || f.isCancelled)
return
fail(f.getChannel, f.getCause)
}
})
val wrappedEvent = new DownstreamChannelStateEvent(
de.getChannel, wrappedConnectFuture,
de.getState, proxyAddr)
super.connectRequested(ctx, wrappedEvent)
case _ =>
fail(ctx.getChannel, new InconsistentStateException(addr))
}
}
// we delay propagating connection upstream until we've completed the proxy connection.
override def channelConnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
if (connectFuture.get eq null) {
fail(ctx.getChannel, new InconsistentStateException(addr))
return
}
// proxy cancellations again.
connectFuture.get.addListener(new ChannelFutureListener {
def operationComplete(f: ChannelFuture) {
if (f.isSuccess)
SocksConnectHandler.super.channelConnected(ctx, e)
else if (f.isCancelled)
fail(ctx.getChannel, new ChannelClosedException(addr))
}
})
state = Connected
writeInit(ctx)
}
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
if (connectFuture.get eq null) {
fail(ctx.getChannel, new InconsistentStateException(addr))
return
}
buf.writeBytes(e.getMessage.asInstanceOf[ChannelBuffer])
buf.markReaderIndex()
try {
state match {
case Connected =>
readInit() match {
case Some(Unauthenticated) =>
state = Requested
writeRequest(ctx)
case Some(UsernamePassAuthenticationSetting(username,pass)) =>
state = Authenticating
writeUserNameAndPass(ctx, username, pass)
case None =>
fail(e.getChannel, new ConnectionFailedException(InvalidInit, addr))
}
case Authenticating =>
if (readAuthenticated()) {
state = Requested
writeRequest(ctx)
} else {
fail(e.getChannel, new ConnectionFailedException(InvalidResponse, addr))
}
case Requested =>
if (readResponse()) {
ctx.getPipeline.remove(this)
connectFuture.get.setSuccess()
} else {
fail(e.getChannel, new ConnectionFailedException(InvalidResponse, addr))
}
}
} catch {
case ReplayError => buf.resetReaderIndex()
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy