All Downloads are FREE. Search and download functionalities are using the official Maven repository.

vendor.github.com.pion.dtls.v2.conn.go Maven / Gradle / Ivy

There is a newer version: 2.9.1
Show newest version
package dtls

import (
	"context"
	"errors"
	"fmt"
	"io"
	"net"
	"sync"
	"sync/atomic"
	"time"

	"github.com/pion/dtls/v2/internal/closer"
	"github.com/pion/dtls/v2/pkg/crypto/elliptic"
	"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
	"github.com/pion/dtls/v2/pkg/protocol"
	"github.com/pion/dtls/v2/pkg/protocol/alert"
	"github.com/pion/dtls/v2/pkg/protocol/handshake"
	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
	"github.com/pion/logging"
	"github.com/pion/transport/connctx"
	"github.com/pion/transport/deadline"
	"github.com/pion/transport/replaydetector"
)

const (
	initialTickerInterval = time.Second
	cookieLength          = 20
	sessionLength         = 32
	defaultNamedCurve     = elliptic.X25519
	inboundBufferSize     = 8192
	// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
	defaultReplayProtectionWindow = 64
)

func invalidKeyingLabels() map[string]bool {
	return map[string]bool{
		"client finished": true,
		"server finished": true,
		"master secret":   true,
		"key expansion":   true,
	}
}

// Conn represents a DTLS connection
type Conn struct {
	lock           sync.RWMutex     // Internal lock (must not be public)
	nextConn       connctx.ConnCtx  // Embedded Conn, typically a udpconn we read/write from
	fragmentBuffer *fragmentBuffer  // out-of-order and missing fragment handling
	handshakeCache *handshakeCache  // caching of handshake messages for verifyData generation
	decrypted      chan interface{} // Decrypted Application Data or error, pull by calling `Read`

	state State // Internal state

	maximumTransmissionUnit int

	handshakeCompletedSuccessfully atomic.Value

	encryptedPackets [][]byte

	connectionClosedByUser bool
	closeLock              sync.Mutex
	closed                 *closer.Closer
	handshakeLoopsFinished sync.WaitGroup

	readDeadline  *deadline.Deadline
	writeDeadline *deadline.Deadline

	log logging.LeveledLogger

	reading               chan struct{}
	handshakeRecv         chan chan struct{}
	cancelHandshaker      func()
	cancelHandshakeReader func()

	fsm *handshakeFSM

	replayProtectionWindow uint
}

func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
	err := validateConfig(config)
	if err != nil {
		return nil, err
	}

	if nextConn == nil {
		return nil, errNilNextConn
	}

	cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.PSK == nil || len(config.Certificates) > 0, config.PSK != nil)
	if err != nil {
		return nil, err
	}

	signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
	if err != nil {
		return nil, err
	}

	workerInterval := initialTickerInterval
	if config.FlightInterval != 0 {
		workerInterval = config.FlightInterval
	}

	loggerFactory := config.LoggerFactory
	if loggerFactory == nil {
		loggerFactory = logging.NewDefaultLoggerFactory()
	}

	logger := loggerFactory.NewLogger("dtls")

	mtu := config.MTU
	if mtu <= 0 {
		mtu = defaultMTU
	}

	replayProtectionWindow := config.ReplayProtectionWindow
	if replayProtectionWindow <= 0 {
		replayProtectionWindow = defaultReplayProtectionWindow
	}

	c := &Conn{
		nextConn:                connctx.New(nextConn),
		fragmentBuffer:          newFragmentBuffer(),
		handshakeCache:          newHandshakeCache(),
		maximumTransmissionUnit: mtu,

		decrypted: make(chan interface{}, 1),
		log:       logger,

		readDeadline:  deadline.New(),
		writeDeadline: deadline.New(),

		reading:          make(chan struct{}, 1),
		handshakeRecv:    make(chan chan struct{}),
		closed:           closer.NewCloser(),
		cancelHandshaker: func() {},

		replayProtectionWindow: uint(replayProtectionWindow),

		state: State{
			isClient: isClient,
		},
	}

	c.setRemoteEpoch(0)
	c.setLocalEpoch(0)

	serverName := config.ServerName
	// Do not allow the use of an IP address literal as an SNI value.
	// See RFC 6066, Section 3.
	if net.ParseIP(serverName) != nil {
		serverName = ""
	}

	hsCfg := &handshakeConfig{
		localPSKCallback:            config.PSK,
		localPSKIdentityHint:        config.PSKIdentityHint,
		localCipherSuites:           cipherSuites,
		localSignatureSchemes:       signatureSchemes,
		extendedMasterSecret:        config.ExtendedMasterSecret,
		localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
		serverName:                  serverName,
		supportedProtocols:          config.SupportedProtocols,
		clientAuth:                  config.ClientAuth,
		localCertificates:           config.Certificates,
		insecureSkipVerify:          config.InsecureSkipVerify,
		verifyPeerCertificate:       config.VerifyPeerCertificate,
		rootCAs:                     config.RootCAs,
		clientCAs:                   config.ClientCAs,
		customCipherSuites:          config.CustomCipherSuites,
		retransmitInterval:          workerInterval,
		log:                         logger,
		initialEpoch:                0,
		keyLogWriter:                config.KeyLogWriter,
		sessionStore:                config.SessionStore,
	}

	// rfc5246#section-7.4.3
	// In addition, the hash and signature algorithms MUST be compatible
	// with the key in the server's end-entity certificate.
	if !isClient {
		cert, err := hsCfg.getCertificate("")
		if err != nil && !errors.Is(err, errNoCertificates) {
			return nil, err
		}
		hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites)
	}

	var initialFlight flightVal
	var initialFSMState handshakeState

	if initialState != nil {
		if c.state.isClient {
			initialFlight = flight5
		} else {
			initialFlight = flight6
		}
		initialFSMState = handshakeFinished

		c.state = *initialState
	} else {
		if c.state.isClient {
			initialFlight = flight1
		} else {
			initialFlight = flight0
		}
		initialFSMState = handshakePreparing
	}
	// Do handshake
	if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
		return nil, err
	}

	c.log.Trace("Handshake Completed")

	return c, nil
}

// Dial connects to the given network address and establishes a DTLS connection on top.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use DialWithContext() instead.
func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
	ctx, cancel := config.connectContextMaker()
	defer cancel()

	return DialWithContext(ctx, network, raddr, config)
}

// Client establishes a DTLS connection over an existing connection.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ClientWithContext() instead.
func Client(conn net.Conn, config *Config) (*Conn, error) {
	ctx, cancel := config.connectContextMaker()
	defer cancel()

	return ClientWithContext(ctx, conn, config)
}

// Server listens for incoming DTLS connections.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ServerWithContext() instead.
func Server(conn net.Conn, config *Config) (*Conn, error) {
	ctx, cancel := config.connectContextMaker()
	defer cancel()

	return ServerWithContext(ctx, conn, config)
}

// DialWithContext connects to the given network address and establishes a DTLS connection on top.
func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
	pConn, err := net.DialUDP(network, nil, raddr)
	if err != nil {
		return nil, err
	}
	return ClientWithContext(ctx, pConn, config)
}

// ClientWithContext establishes a DTLS connection over an existing connection.
func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
	switch {
	case config == nil:
		return nil, errNoConfigProvided
	case config.PSK != nil && config.PSKIdentityHint == nil:
		return nil, errPSKAndIdentityMustBeSetForClient
	}

	return createConn(ctx, conn, config, true, nil)
}

// ServerWithContext listens for incoming DTLS connections.
func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
	if config == nil {
		return nil, errNoConfigProvided
	}

	return createConn(ctx, conn, config, false, nil)
}

// Read reads data from the connection.
func (c *Conn) Read(p []byte) (n int, err error) {
	if !c.isHandshakeCompletedSuccessfully() {
		return 0, errHandshakeInProgress
	}

	select {
	case <-c.readDeadline.Done():
		return 0, errDeadlineExceeded
	default:
	}

	for {
		select {
		case <-c.readDeadline.Done():
			return 0, errDeadlineExceeded
		case out, ok := <-c.decrypted:
			if !ok {
				return 0, io.EOF
			}
			switch val := out.(type) {
			case ([]byte):
				if len(p) < len(val) {
					return 0, errBufferTooSmall
				}
				copy(p, val)
				return len(val), nil
			case (error):
				return 0, val
			}
		}
	}
}

// Write writes len(p) bytes from p to the DTLS connection
func (c *Conn) Write(p []byte) (int, error) {
	if c.isConnectionClosed() {
		return 0, ErrConnClosed
	}

	select {
	case <-c.writeDeadline.Done():
		return 0, errDeadlineExceeded
	default:
	}

	if !c.isHandshakeCompletedSuccessfully() {
		return 0, errHandshakeInProgress
	}

	return len(p), c.writePackets(c.writeDeadline, []*packet{
		{
			record: &recordlayer.RecordLayer{
				Header: recordlayer.Header{
					Epoch:   c.state.getLocalEpoch(),
					Version: protocol.Version1_2,
				},
				Content: &protocol.ApplicationData{
					Data: p,
				},
			},
			shouldEncrypt: true,
		},
	})
}

// Close closes the connection.
func (c *Conn) Close() error {
	err := c.close(true) //nolint:contextcheck
	c.handshakeLoopsFinished.Wait()
	return err
}

// ConnectionState returns basic DTLS details about the connection.
// Note that this replaced the `Export` function of v1.
func (c *Conn) ConnectionState() State {
	c.lock.RLock()
	defer c.lock.RUnlock()
	return *c.state.clone()
}

// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
	c.lock.RLock()
	defer c.lock.RUnlock()

	if c.state.srtpProtectionProfile == 0 {
		return 0, false
	}

	return c.state.srtpProtectionProfile, true
}

func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
	c.lock.Lock()
	defer c.lock.Unlock()

	var rawPackets [][]byte

	for _, p := range pkts {
		if h, ok := p.record.Content.(*handshake.Handshake); ok {
			handshakeRaw, err := p.record.Marshal()
			if err != nil {
				return err
			}

			c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
				srvCliStr(c.state.isClient), h.Header.Type.String(),
				p.record.Header.Epoch, h.Header.MessageSequence)
			c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)

			rawHandshakePackets, err := c.processHandshakePacket(p, h)
			if err != nil {
				return err
			}
			rawPackets = append(rawPackets, rawHandshakePackets...)
		} else {
			rawPacket, err := c.processPacket(p)
			if err != nil {
				return err
			}
			rawPackets = append(rawPackets, rawPacket)
		}
	}
	if len(rawPackets) == 0 {
		return nil
	}
	compactedRawPackets := c.compactRawPackets(rawPackets)

	for _, compactedRawPackets := range compactedRawPackets {
		if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil {
			return netError(err)
		}
	}

	return nil
}

func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
	combinedRawPackets := make([][]byte, 0)
	currentCombinedRawPacket := make([]byte, 0)

	for _, rawPacket := range rawPackets {
		if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
			combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
			currentCombinedRawPacket = []byte{}
		}
		currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...)
	}

	combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)

	return combinedRawPackets
}

func (c *Conn) processPacket(p *packet) ([]byte, error) {
	epoch := p.record.Header.Epoch
	for len(c.state.localSequenceNumber) <= int(epoch) {
		c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
	}
	seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
	if seq > recordlayer.MaxSequenceNumber {
		// RFC 6347 Section 4.1.0
		// The implementation must either abandon an association or rehandshake
		// prior to allowing the sequence number to wrap.
		return nil, errSequenceNumberOverflow
	}
	p.record.Header.SequenceNumber = seq

	rawPacket, err := p.record.Marshal()
	if err != nil {
		return nil, err
	}

	if p.shouldEncrypt {
		var err error
		rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
		if err != nil {
			return nil, err
		}
	}

	return rawPacket, nil
}

func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) {
	rawPackets := make([][]byte, 0)

	handshakeFragments, err := c.fragmentHandshake(h)
	if err != nil {
		return nil, err
	}
	epoch := p.record.Header.Epoch
	for len(c.state.localSequenceNumber) <= int(epoch) {
		c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
	}

	for _, handshakeFragment := range handshakeFragments {
		seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
		if seq > recordlayer.MaxSequenceNumber {
			return nil, errSequenceNumberOverflow
		}

		recordlayerHeader := &recordlayer.Header{
			Version:        p.record.Header.Version,
			ContentType:    p.record.Header.ContentType,
			ContentLen:     uint16(len(handshakeFragment)),
			Epoch:          p.record.Header.Epoch,
			SequenceNumber: seq,
		}

		rawPacket, err := recordlayerHeader.Marshal()
		if err != nil {
			return nil, err
		}

		p.record.Header = *recordlayerHeader

		rawPacket = append(rawPacket, handshakeFragment...)
		if p.shouldEncrypt {
			var err error
			rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
			if err != nil {
				return nil, err
			}
		}

		rawPackets = append(rawPackets, rawPacket)
	}

	return rawPackets, nil
}

func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) {
	content, err := h.Message.Marshal()
	if err != nil {
		return nil, err
	}

	fragmentedHandshakes := make([][]byte, 0)

	contentFragments := splitBytes(content, c.maximumTransmissionUnit)
	if len(contentFragments) == 0 {
		contentFragments = [][]byte{
			{},
		}
	}

	offset := 0
	for _, contentFragment := range contentFragments {
		contentFragmentLen := len(contentFragment)

		headerFragment := &handshake.Header{
			Type:            h.Header.Type,
			Length:          h.Header.Length,
			MessageSequence: h.Header.MessageSequence,
			FragmentOffset:  uint32(offset),
			FragmentLength:  uint32(contentFragmentLen),
		}

		offset += contentFragmentLen

		fragmentedHandshake, err := headerFragment.Marshal()
		if err != nil {
			return nil, err
		}

		fragmentedHandshake = append(fragmentedHandshake, contentFragment...)
		fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
	}

	return fragmentedHandshakes, nil
}

var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
	New: func() interface{} {
		b := make([]byte, inboundBufferSize)
		return &b
	},
}

func (c *Conn) readAndBuffer(ctx context.Context) error {
	bufptr, ok := poolReadBuffer.Get().(*[]byte)
	if !ok {
		return errFailedToAccessPoolReadBuffer
	}
	defer poolReadBuffer.Put(bufptr)

	b := *bufptr
	i, err := c.nextConn.ReadContext(ctx, b)
	if err != nil {
		return netError(err)
	}

	pkts, err := recordlayer.UnpackDatagram(b[:i])
	if err != nil {
		return err
	}

	var hasHandshake bool
	for _, p := range pkts {
		hs, alert, err := c.handleIncomingPacket(ctx, p, true)
		if alert != nil {
			if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
				if err == nil {
					err = alertErr
				}
			}
		}
		if hs {
			hasHandshake = true
		}

		var e *alertError
		if errors.As(err, &e) {
			if e.IsFatalOrCloseNotify() {
				return e
			}
		} else if err != nil {
			return e
		}
	}
	if hasHandshake {
		done := make(chan struct{})
		select {
		case c.handshakeRecv <- done:
			// If the other party may retransmit the flight,
			// we should respond even if it not a new message.
			<-done
		case <-c.fsm.Done():
		}
	}
	return nil
}

func (c *Conn) handleQueuedPackets(ctx context.Context) error {
	pkts := c.encryptedPackets
	c.encryptedPackets = nil

	for _, p := range pkts {
		_, alert, err := c.handleIncomingPacket(ctx, p, false) // don't re-enqueue
		if alert != nil {
			if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
				if err == nil {
					err = alertErr
				}
			}
		}
		var e *alertError
		if errors.As(err, &e) {
			if e.IsFatalOrCloseNotify() {
				return e
			}
		} else if err != nil {
			return e
		}
	}
	return nil
}

func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
	h := &recordlayer.Header{}
	if err := h.Unmarshal(buf); err != nil {
		// Decode error must be silently discarded
		// [RFC6347 Section-4.1.2.7]
		c.log.Debugf("discarded broken packet: %v", err)
		return false, nil, nil
	}

	// Validate epoch
	remoteEpoch := c.state.getRemoteEpoch()
	if h.Epoch > remoteEpoch {
		if h.Epoch > remoteEpoch+1 {
			c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
				h.Epoch, h.SequenceNumber,
			)
			return false, nil, nil
		}
		if enqueue {
			c.log.Debug("received packet of next epoch, queuing packet")
			c.encryptedPackets = append(c.encryptedPackets, buf)
		}
		return false, nil, nil
	}

	// Anti-replay protection
	for len(c.state.replayDetector) <= int(h.Epoch) {
		c.state.replayDetector = append(c.state.replayDetector,
			replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber),
		)
	}
	markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber)
	if !ok {
		c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
			h.Epoch, h.SequenceNumber,
		)
		return false, nil, nil
	}

	// Decrypt
	if h.Epoch != 0 {
		if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
			if enqueue {
				c.encryptedPackets = append(c.encryptedPackets, buf)
				c.log.Debug("handshake not finished, queuing packet")
			}
			return false, nil, nil
		}

		var err error
		buf, err = c.state.cipherSuite.Decrypt(buf)
		if err != nil {
			c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
			return false, nil, nil
		}
	}

	isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
	if err != nil {
		// Decode error must be silently discarded
		// [RFC6347 Section-4.1.2.7]
		c.log.Debugf("defragment failed: %s", err)
		return false, nil, nil
	} else if isHandshake {
		markPacketAsValid()
		for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
			header := &handshake.Header{}
			if err := header.Unmarshal(out); err != nil {
				c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
				continue
			}
			c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient)
		}

		return true, nil, nil
	}

	r := &recordlayer.RecordLayer{}
	if err := r.Unmarshal(buf); err != nil {
		return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
	}

	switch content := r.Content.(type) {
	case *alert.Alert:
		c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
		var a *alert.Alert
		if content.Description == alert.CloseNotify {
			// Respond with a close_notify [RFC5246 Section 7.2.1]
			a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
		}
		markPacketAsValid()
		return false, a, &alertError{content}
	case *protocol.ChangeCipherSpec:
		if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
			if enqueue {
				c.encryptedPackets = append(c.encryptedPackets, buf)
				c.log.Debugf("CipherSuite not initialized, queuing packet")
			}
			return false, nil, nil
		}

		newRemoteEpoch := h.Epoch + 1
		c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)

		if c.state.getRemoteEpoch()+1 == newRemoteEpoch {
			c.setRemoteEpoch(newRemoteEpoch)
			markPacketAsValid()
		}
	case *protocol.ApplicationData:
		if h.Epoch == 0 {
			return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
		}

		markPacketAsValid()

		select {
		case c.decrypted <- content.Data:
		case <-c.closed.Done():
		case <-ctx.Done():
		}

	default:
		return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
	}
	return false, nil, nil
}

func (c *Conn) recvHandshake() <-chan chan struct{} {
	return c.handshakeRecv
}

func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
	if level == alert.Fatal && len(c.state.SessionID) > 0 {
		// According to the RFC, we need to delete the stored session.
		// https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
		if ss := c.fsm.cfg.sessionStore; ss != nil {
			c.log.Tracef("clean invalid session: %s", c.state.SessionID)
			if err := ss.Del(c.sessionKey()); err != nil {
				return err
			}
		}
	}
	return c.writePackets(ctx, []*packet{
		{
			record: &recordlayer.RecordLayer{
				Header: recordlayer.Header{
					Epoch:   c.state.getLocalEpoch(),
					Version: protocol.Version1_2,
				},
				Content: &alert.Alert{
					Level:       level,
					Description: desc,
				},
			},
			shouldEncrypt: c.isHandshakeCompletedSuccessfully(),
		},
	})
}

func (c *Conn) setHandshakeCompletedSuccessfully() {
	c.handshakeCompletedSuccessfully.Store(struct{ bool }{true})
}

func (c *Conn) isHandshakeCompletedSuccessfully() bool {
	boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool })
	return boolean.bool
}

func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit
	c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight)

	done := make(chan struct{})
	ctxRead, cancelRead := context.WithCancel(context.Background())
	c.cancelHandshakeReader = cancelRead
	cfg.onFlightState = func(f flightVal, s handshakeState) {
		if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
			c.setHandshakeCompletedSuccessfully()
			close(done)
		}
	}

	ctxHs, cancel := context.WithCancel(context.Background())
	c.cancelHandshaker = cancel

	firstErr := make(chan error, 1)

	c.handshakeLoopsFinished.Add(2)

	// Handshake routine should be live until close.
	// The other party may request retransmission of the last flight to cope with packet drop.
	go func() {
		defer c.handshakeLoopsFinished.Done()
		err := c.fsm.Run(ctxHs, c, initialState)
		if !errors.Is(err, context.Canceled) {
			select {
			case firstErr <- err:
			default:
			}
		}
	}()
	go func() {
		defer func() {
			// Escaping read loop.
			// It's safe to close decrypted channnel now.
			close(c.decrypted)

			// Force stop handshaker when the underlying connection is closed.
			cancel()
		}()
		defer c.handshakeLoopsFinished.Done()
		for {
			if err := c.readAndBuffer(ctxRead); err != nil {
				var e *alertError
				if errors.As(err, &e) {
					if !e.IsFatalOrCloseNotify() {
						if c.isHandshakeCompletedSuccessfully() {
							// Pass the error to Read()
							select {
							case c.decrypted <- err:
							case <-c.closed.Done():
							case <-ctxRead.Done():
							}
						}
						continue // non-fatal alert must not stop read loop
					}
				} else {
					switch {
					case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF):
					default:
						if c.isHandshakeCompletedSuccessfully() {
							// Keep read loop and pass the read error to Read()
							select {
							case c.decrypted <- err:
							case <-c.closed.Done():
							case <-ctxRead.Done():
							}
							continue // non-fatal alert must not stop read loop
						}
					}
				}

				select {
				case firstErr <- err:
				default:
				}

				if e != nil {
					if e.IsFatalOrCloseNotify() {
						_ = c.close(false) //nolint:contextcheck
					}
				}
				return
			}
		}
	}()

	select {
	case err := <-firstErr:
		cancelRead()
		cancel()
		return c.translateHandshakeCtxError(err)
	case <-ctx.Done():
		cancelRead()
		cancel()
		return c.translateHandshakeCtxError(ctx.Err())
	case <-done:
		return nil
	}
}

func (c *Conn) translateHandshakeCtxError(err error) error {
	if err == nil {
		return nil
	}
	if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
		return nil
	}
	return &HandshakeError{Err: err}
}

func (c *Conn) close(byUser bool) error {
	c.cancelHandshaker()
	c.cancelHandshakeReader()

	if c.isHandshakeCompletedSuccessfully() && byUser {
		// Discard error from notify() to return non-error on the first user call of Close()
		// even if the underlying connection is already closed.
		_ = c.notify(context.Background(), alert.Warning, alert.CloseNotify)
	}

	c.closeLock.Lock()
	// Don't return ErrConnClosed at the first time of the call from user.
	closedByUser := c.connectionClosedByUser
	if byUser {
		c.connectionClosedByUser = true
	}
	c.closed.Close()
	c.closeLock.Unlock()

	if closedByUser {
		return ErrConnClosed
	}

	return c.nextConn.Close()
}

func (c *Conn) isConnectionClosed() bool {
	select {
	case <-c.closed.Done():
		return true
	default:
		return false
	}
}

func (c *Conn) setLocalEpoch(epoch uint16) {
	c.state.localEpoch.Store(epoch)
}

func (c *Conn) setRemoteEpoch(epoch uint16) {
	c.state.remoteEpoch.Store(epoch)
}

// LocalAddr implements net.Conn.LocalAddr
func (c *Conn) LocalAddr() net.Addr {
	return c.nextConn.LocalAddr()
}

// RemoteAddr implements net.Conn.RemoteAddr
func (c *Conn) RemoteAddr() net.Addr {
	return c.nextConn.RemoteAddr()
}

func (c *Conn) sessionKey() []byte {
	if c.state.isClient {
		// As ServerName can be like 0.example.com, it's better to add
		// delimiter character which is not allowed to be in
		// neither address or domain name.
		return []byte(c.nextConn.RemoteAddr().String() + "_" + c.fsm.cfg.serverName)
	}
	return c.state.SessionID
}

// SetDeadline implements net.Conn.SetDeadline
func (c *Conn) SetDeadline(t time.Time) error {
	c.readDeadline.Set(t)
	return c.SetWriteDeadline(t)
}

// SetReadDeadline implements net.Conn.SetReadDeadline
func (c *Conn) SetReadDeadline(t time.Time) error {
	c.readDeadline.Set(t)
	// Read deadline is fully managed by this layer.
	// Don't set read deadline to underlying connection.
	return nil
}

// SetWriteDeadline implements net.Conn.SetWriteDeadline
func (c *Conn) SetWriteDeadline(t time.Time) error {
	c.writeDeadline.Set(t)
	// Write deadline is also fully managed by this layer.
	return nil
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy