All Downloads are FREE. Search and download functionalities are using the official Maven repository.

vendor.github.com.pion.dtls.v2.handshaker.go Maven / Gradle / Ivy

There is a newer version: 2.9.1
Show newest version
package dtls

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io"
	"sync"
	"time"

	"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
	"github.com/pion/dtls/v2/pkg/protocol/alert"
	"github.com/pion/dtls/v2/pkg/protocol/handshake"
	"github.com/pion/logging"
)

// [RFC6347 Section-4.2.4]
//                      +-----------+
//                +---> | PREPARING | <--------------------+
//                |     +-----------+                      |
//                |           |                            |
//                |           | Buffer next flight         |
//                |           |                            |
//                |          \|/                           |
//                |     +-----------+                      |
//                |     |  SENDING  |<------------------+  | Send
//                |     +-----------+                   |  | HelloRequest
//        Receive |           |                         |  |
//           next |           | Send flight             |  | or
//         flight |  +--------+                         |  |
//                |  |        | Set retransmit timer    |  | Receive
//                |  |       \|/                        |  | HelloRequest
//                |  |  +-----------+                   |  | Send
//                +--)--|  WAITING  |-------------------+  | ClientHello
//                |  |  +-----------+   Timer expires   |  |
//                |  |         |                        |  |
//                |  |         +------------------------+  |
//        Receive |  | Send           Read retransmit      |
//           last |  | last                                |
//         flight |  | flight                              |
//                |  |                                     |
//               \|/\|/                                    |
//            +-----------+                                |
//            | FINISHED  | -------------------------------+
//            +-----------+
//                 |  /|\
//                 |   |
//                 +---+
//              Read retransmit
//           Retransmit last flight

type handshakeState uint8

const (
	handshakeErrored handshakeState = iota
	handshakePreparing
	handshakeSending
	handshakeWaiting
	handshakeFinished
)

func (s handshakeState) String() string {
	switch s {
	case handshakeErrored:
		return "Errored"
	case handshakePreparing:
		return "Preparing"
	case handshakeSending:
		return "Sending"
	case handshakeWaiting:
		return "Waiting"
	case handshakeFinished:
		return "Finished"
	default:
		return "Unknown"
	}
}

type handshakeFSM struct {
	currentFlight flightVal
	flights       []*packet
	retransmit    bool
	state         *State
	cache         *handshakeCache
	cfg           *handshakeConfig
	closed        chan struct{}
}

type handshakeConfig struct {
	localPSKCallback            PSKCallback
	localPSKIdentityHint        []byte
	localCipherSuites           []CipherSuite             // Available CipherSuites
	localSignatureSchemes       []signaturehash.Algorithm // Available signature schemes
	extendedMasterSecret        ExtendedMasterSecretType  // Policy for the Extended Master Support extension
	localSRTPProtectionProfiles []SRTPProtectionProfile   // Available SRTPProtectionProfiles, if empty no SRTP support
	serverName                  string
	supportedProtocols          []string
	clientAuth                  ClientAuthType // If we are a client should we request a client certificate
	localCertificates           []tls.Certificate
	nameToCertificate           map[string]*tls.Certificate
	insecureSkipVerify          bool
	verifyPeerCertificate       func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
	sessionStore                SessionStore
	rootCAs                     *x509.CertPool
	clientCAs                   *x509.CertPool
	retransmitInterval          time.Duration
	customCipherSuites          func() []CipherSuite

	onFlightState func(flightVal, handshakeState)
	log           logging.LeveledLogger
	keyLogWriter  io.Writer

	initialEpoch uint16

	mu sync.Mutex
}

type flightConn interface {
	notify(ctx context.Context, level alert.Level, desc alert.Description) error
	writePackets(context.Context, []*packet) error
	recvHandshake() <-chan chan struct{}
	setLocalEpoch(epoch uint16)
	handleQueuedPackets(context.Context) error
	sessionKey() []byte
}

func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) {
	if c.keyLogWriter == nil {
		return
	}
	c.mu.Lock()
	defer c.mu.Unlock()
	_, err := c.keyLogWriter.Write([]byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)))
	if err != nil {
		c.log.Debugf("failed to write key log file: %s", err)
	}
}

func srvCliStr(isClient bool) string {
	if isClient {
		return "client"
	}
	return "server"
}

func newHandshakeFSM(
	s *State, cache *handshakeCache, cfg *handshakeConfig,
	initialFlight flightVal,
) *handshakeFSM {
	return &handshakeFSM{
		currentFlight: initialFlight,
		state:         s,
		cache:         cache,
		cfg:           cfg,
		closed:        make(chan struct{}),
	}
}

func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState handshakeState) error {
	state := initialState
	defer func() {
		close(s.closed)
	}()
	for {
		s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String())
		if s.cfg.onFlightState != nil {
			s.cfg.onFlightState(s.currentFlight, state)
		}
		var err error
		switch state {
		case handshakePreparing:
			state, err = s.prepare(ctx, c)
		case handshakeSending:
			state, err = s.send(ctx, c)
		case handshakeWaiting:
			state, err = s.wait(ctx, c)
		case handshakeFinished:
			state, err = s.finish(ctx, c)
		default:
			return errInvalidFSMTransition
		}
		if err != nil {
			return err
		}
	}
}

func (s *handshakeFSM) Done() <-chan struct{} {
	return s.closed
}

func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeState, error) {
	s.flights = nil
	// Prepare flights
	var (
		a    *alert.Alert
		err  error
		pkts []*packet
	)
	gen, retransmit, errFlight := s.currentFlight.getFlightGenerator()
	if errFlight != nil {
		err = errFlight
		a = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}
	} else {
		pkts, a, err = gen(c, s.state, s.cache, s.cfg)
		s.retransmit = retransmit
	}
	if a != nil {
		if alertErr := c.notify(ctx, a.Level, a.Description); alertErr != nil {
			if err != nil {
				err = alertErr
			}
		}
	}
	if err != nil {
		return handshakeErrored, err
	}

	s.flights = pkts
	epoch := s.cfg.initialEpoch
	nextEpoch := epoch
	for _, p := range s.flights {
		p.record.Header.Epoch += epoch
		if p.record.Header.Epoch > nextEpoch {
			nextEpoch = p.record.Header.Epoch
		}
		if h, ok := p.record.Content.(*handshake.Handshake); ok {
			h.Header.MessageSequence = uint16(s.state.handshakeSendSequence)
			s.state.handshakeSendSequence++
		}
	}
	if epoch != nextEpoch {
		s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch)
		c.setLocalEpoch(nextEpoch)
	}
	return handshakeSending, nil
}

func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) {
	// Send flights
	if err := c.writePackets(ctx, s.flights); err != nil {
		return handshakeErrored, err
	}

	if s.currentFlight.isLastSendFlight() {
		return handshakeFinished, nil
	}
	return handshakeWaiting, nil
}

func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, error) { //nolint:gocognit
	parse, errFlight := s.currentFlight.getFlightParser()
	if errFlight != nil {
		if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
			if errFlight != nil {
				return handshakeErrored, alertErr
			}
		}
		return handshakeErrored, errFlight
	}

	retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
	for {
		select {
		case done := <-c.recvHandshake():
			nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
			close(done)
			if alert != nil {
				if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
					if err != nil {
						err = alertErr
					}
				}
			}
			if err != nil {
				return handshakeErrored, err
			}
			if nextFlight == 0 {
				break
			}
			s.cfg.log.Tracef("[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String())
			if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight {
				return handshakeFinished, nil
			}
			s.currentFlight = nextFlight
			return handshakePreparing, nil

		case <-retransmitTimer.C:
			if !s.retransmit {
				return handshakeWaiting, nil
			}
			return handshakeSending, nil
		case <-ctx.Done():
			return handshakeErrored, ctx.Err()
		}
	}
}

func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) {
	parse, errFlight := s.currentFlight.getFlightParser()
	if errFlight != nil {
		if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
			if errFlight != nil {
				return handshakeErrored, alertErr
			}
		}
		return handshakeErrored, errFlight
	}

	retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
	select {
	case done := <-c.recvHandshake():
		nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
		close(done)
		if alert != nil {
			if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
				if err != nil {
					err = alertErr
				}
			}
		}
		if err != nil {
			return handshakeErrored, err
		}
		if nextFlight == 0 {
			break
		}
		if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight {
			return handshakeFinished, nil
		}
		<-retransmitTimer.C
		// Retransmit last flight
		return handshakeSending, nil

	case <-ctx.Done():
		return handshakeErrored, ctx.Err()
	}
	return handshakeFinished, nil
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy