All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.wavesplatform.network.HandshakeHandler.scala Maven / Gradle / Ivy

The newest version!
package com.wavesplatform.network

import com.wavesplatform.network.Handshake.InvalidHandshakeException
import com.wavesplatform.utils.ScorexLogging
import io.netty.buffer.ByteBuf
import io.netty.channel.*
import io.netty.channel.ChannelHandler.Sharable
import io.netty.channel.group.ChannelGroup
import io.netty.handler.codec.ReplayingDecoder
import io.netty.util.AttributeKey
import io.netty.util.concurrent.ScheduledFuture

import java.net.InetSocketAddress
import java.util
import java.util.concurrent.{ConcurrentMap, TimeUnit}
import scala.concurrent.duration.FiniteDuration

class HandshakeDecoder(peerDatabase: PeerDatabase) extends ReplayingDecoder[Void] with ScorexLogging {
  override def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: util.List[AnyRef]): Unit =
    try {
      out.add(Handshake.decode(in))
      ctx.pipeline().remove(this)
    } catch {
      case e: InvalidHandshakeException => block(ctx, e)
    }

  protected def block(ctx: ChannelHandlerContext, e: Throwable): Unit = {
    peerDatabase.blacklistAndClose(ctx.channel(), e.getMessage)
  }
}

case object HandshakeTimeoutExpired

class HandshakeTimeoutHandler(handshakeTimeout: FiniteDuration) extends ChannelInboundHandlerAdapter with ScorexLogging {
  private var timeout: Option[ScheduledFuture[?]] = None

  private def cancelTimeout(): Unit = timeout.foreach(_.cancel(true))

  override def channelRegistered(ctx: ChannelHandlerContext): Unit = {
    log.trace(s"${id(ctx)} Scheduling handshake timeout in $handshakeTimeout")
    timeout = Some(
      ctx
        .channel()
        .eventLoop()
        .schedule(
          { () =>
            log.trace(s"${id(ctx)} Firing handshake timeout expired")
            ctx.fireChannelRead(HandshakeTimeoutExpired)
          },
          handshakeTimeout.toMillis,
          TimeUnit.MILLISECONDS
        )
    )

    super.channelRegistered(ctx)
  }

  override def channelUnregistered(ctx: ChannelHandlerContext): Unit = {
    cancelTimeout()
    super.channelUnregistered(ctx)
  }

  override def channelRead(ctx: ChannelHandlerContext, msg: AnyRef): Unit = msg match {
    case hs: Handshake =>
      cancelTimeout()
      super.channelRead(ctx, hs)
    case other =>
      super.channelRead(ctx, other)
  }
}

abstract class HandshakeHandler(
    localHandshake: Handshake,
    establishedConnections: ConcurrentMap[Channel, PeerInfo],
    peerConnections: ConcurrentMap[PeerKey, Channel],
    peerDatabase: PeerDatabase,
    allChannels: ChannelGroup
) extends ChannelInboundHandlerAdapter
    with ScorexLogging {

  import HandshakeHandler.*

  protected def suspendAndClose(msg: => String, verifiedRemoteAddress: Option[InetSocketAddress], ctx: ChannelHandlerContext): Unit = {
    log.debug(s"${id(ctx)} $msg")
    verifiedRemoteAddress.foreach(peerDatabase.suspend)
    ctx.close()
  }

  override def channelRead(ctx: ChannelHandlerContext, msg: AnyRef): Unit = msg match {
    case HandshakeTimeoutExpired =>
      log.trace(s"Timeout expired while waiting for handshake: ${id(ctx.channel())}")
      ctx.channel().remoteAddress() match {
        case isa: InetSocketAddress => peerDatabase.suspend(isa)
        case _                      =>
      }
      ctx.close()
    case remoteHandshake: Handshake =>
      val verifiedDeclaredAddress = remoteHandshake.declaredAddress.filter(_ == ctx.channel().remoteAddress())

      if (localHandshake.applicationName != remoteHandshake.applicationName)
        suspendAndClose(
          s"Remote application name ${remoteHandshake.applicationName} does not match local ${localHandshake.applicationName}",
          verifiedDeclaredAddress,
          ctx
        )
      else if (!versionIsSupported(remoteHandshake.applicationVersion))
        suspendAndClose(s"Remote application version ${remoteHandshake.applicationVersion} is not supported", verifiedDeclaredAddress, ctx)
      else {
        verifiedDeclaredAddress.foreach { vda =>
          ctx.channel().attr(NodeDeclaredAddressAttributeKey).set(vda)
          peerDatabase.touch(vda)
        }

        PeerKey(ctx, remoteHandshake.nodeNonce) match {
          case None =>
            log.warn(s"Can't get PeerKey from ${id(ctx)}")
            ctx.close()

          case Some(key) =>
            val previousPeer = peerConnections.putIfAbsent(key, ctx.channel())
            if (previousPeer == null) {
              log.info(s"${id(ctx)} Accepted handshake $remoteHandshake")
              removeHandshakeHandlers(ctx, this)
              establishedConnections.put(ctx.channel(), peerInfo(remoteHandshake, ctx.channel()))

              ctx.channel().attr(NodeNameAttributeKey).set(remoteHandshake.nodeName)
              ctx.channel().attr(NodeVersionAttributeKey).set(remoteHandshake.applicationVersion)

              Option(ctx.channel().attr(ConnectionStartAttributeKey).get()).foreach { start =>
                log.trace(s"Time taken to accept handshake = ${System.currentTimeMillis() - start} ms")
              }
              ctx.channel().closeFuture().addListener { (f: ChannelFuture) =>
                peerConnections.remove(key, f.channel())
                establishedConnections.remove(f.channel())
                log.trace(s"${id(f.channel())} was closed")
              }

              connectionNegotiated(ctx)
              ctx.fireChannelRead(msg)
            } else {
              suspendAndClose(
                s"${id(ctx)} Already connected to peer with nonce ${remoteHandshake.nodeNonce} on channel ${id(previousPeer)}",
                verifiedDeclaredAddress,
                ctx
              )
            }
        }
      }
    case _ => super.channelRead(ctx, msg)
  }

  protected def connectionNegotiated(ctx: ChannelHandlerContext): Unit = {
    ctx.channel().closeFuture().addListener((_: ChannelFuture) => allChannels.remove(ctx.channel()))
    allChannels.add(ctx.channel())
  }

  protected def sendLocalHandshake(ctx: ChannelHandlerContext): Unit = {
    ctx.writeAndFlush(localHandshake.encode(ctx.alloc().buffer()))
  }
}

object HandshakeHandler {

  val NodeNameAttributeKey: AttributeKey[String]                       = AttributeKey.newInstance[String]("name")
  val NodeVersionAttributeKey: AttributeKey[(Int, Int, Int)]           = AttributeKey.newInstance[(Int, Int, Int)]("version")
  val NodeDeclaredAddressAttributeKey: AttributeKey[InetSocketAddress] = AttributeKey.newInstance[InetSocketAddress]("declaredAddress")

  private val ConnectionStartAttributeKey = AttributeKey.newInstance[Long]("connectionStart")

  def versionIsSupported(remoteVersion: (Int, Int, Int)): Boolean =
    (remoteVersion._1 == 0 && remoteVersion._2 >= 13) || (remoteVersion._1 == 1 && remoteVersion._2 >= 0)

  def removeHandshakeHandlers(ctx: ChannelHandlerContext, thisHandler: ChannelHandler): Unit = {
    ctx.pipeline().remove(classOf[HandshakeTimeoutHandler])
    ctx.pipeline().remove(thisHandler)
  }

  def peerInfo(remoteHandshake: Handshake, channel: Channel): PeerInfo =
    PeerInfo(
      channel.remoteAddress(),
      remoteHandshake.declaredAddress,
      remoteHandshake.applicationName,
      remoteHandshake.applicationVersion,
      remoteHandshake.nodeName,
      remoteHandshake.nodeNonce
    )

  @Sharable
  class Server(
      handshake: Handshake,
      establishedConnections: ConcurrentMap[Channel, PeerInfo],
      peerConnections: ConcurrentMap[PeerKey, Channel],
      peerDatabase: PeerDatabase,
      allChannels: ChannelGroup
  ) extends HandshakeHandler(handshake, establishedConnections, peerConnections, peerDatabase, allChannels) {
    override protected def connectionNegotiated(ctx: ChannelHandlerContext): Unit = {
      sendLocalHandshake(ctx)
      super.connectionNegotiated(ctx)
    }
  }

  @Sharable
  class Client(
      handshake: Handshake,
      establishedConnections: ConcurrentMap[Channel, PeerInfo],
      peerConnections: ConcurrentMap[PeerKey, Channel],
      peerDatabase: PeerDatabase,
      allChannels: ChannelGroup
  ) extends HandshakeHandler(handshake, establishedConnections, peerConnections, peerDatabase, allChannels) {
    override def channelActive(ctx: ChannelHandlerContext): Unit = {
      sendLocalHandshake(ctx)
      ctx.channel().attr(ConnectionStartAttributeKey).set(System.currentTimeMillis())
      super.channelActive(ctx)
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy