All Downloads are FREE. Search and download functionalities are using the official Maven repository.

vendor.github.com.pion.stun.client.go Maven / Gradle / Ivy

The newest version!
// SPDX-FileCopyrightText: 2023 The Pion community 
// SPDX-License-Identifier: MIT

package stun

import (
	"crypto/tls"
	"errors"
	"fmt"
	"io"
	"log"
	"net"
	"runtime"
	"strconv"
	"sync"
	"sync/atomic"
	"time"

	"github.com/pion/dtls/v2"
	"github.com/pion/transport/v2"
	"github.com/pion/transport/v2/stdnet"
)

// ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI
var ErrUnsupportedURI = fmt.Errorf("invalid schema or transport")

// Dial connects to the address on the named network and then
// initializes Client on that connection, returning error if any.
func Dial(network, address string) (*Client, error) {
	conn, err := net.Dial(network, address)
	if err != nil {
		return nil, err
	}
	return NewClient(conn)
}

// DialConfig is used to pass configuration to DialURI()
type DialConfig struct {
	DTLSConfig dtls.Config
	TLSConfig  tls.Config

	Net transport.Net
}

// DialURI connect to the STUN/TURN URI and then
// initializes Client on that connection, returning error if any.
func DialURI(uri *URI, cfg *DialConfig) (*Client, error) {
	var conn Connection
	var err error

	nw := cfg.Net
	if nw == nil {
		nw, err = stdnet.NewNet()
		if err != nil {
			return nil, fmt.Errorf("failed to create net: %w", err)
		}
	}

	addr := net.JoinHostPort(uri.Host, strconv.Itoa(uri.Port))

	switch {
	case uri.Scheme == SchemeTypeSTUN:
		if conn, err = nw.Dial("udp", addr); err != nil {
			return nil, fmt.Errorf("failed to listen: %w", err)
		}

	case uri.Scheme == SchemeTypeTURN:
		network := "udp" //nolint:goconst
		if uri.Proto == ProtoTypeTCP {
			network = "tcp" //nolint:goconst
		}

		if conn, err = nw.Dial(network, addr); err != nil {
			return nil, fmt.Errorf("failed to dial: %w", err)
		}

	case uri.Scheme == SchemeTypeTURNS && uri.Proto == ProtoTypeUDP:
		dtlsCfg := cfg.DTLSConfig // Copy
		dtlsCfg.ServerName = uri.Host

		udpConn, err := nw.Dial("udp", addr)
		if err != nil {
			return nil, fmt.Errorf("failed to dial: %w", err)
		}

		if conn, err = dtls.Client(udpConn, &dtlsCfg); err != nil {
			return nil, fmt.Errorf("failed to connect to '%s': %w", addr, err)
		}

	case (uri.Scheme == SchemeTypeTURNS || uri.Scheme == SchemeTypeSTUNS) && uri.Proto == ProtoTypeTCP:
		tlsCfg := cfg.TLSConfig //nolint:govet
		tlsCfg.ServerName = uri.Host

		tcpConn, err := nw.Dial("tcp", addr)
		if err != nil {
			return nil, fmt.Errorf("failed to dial: %w", err)
		}

		conn = tls.Client(tcpConn, &tlsCfg)

	default:
		return nil, ErrUnsupportedURI
	}

	return NewClient(conn)
}

// ErrNoConnection means that ClientOptions.Connection is nil.
var ErrNoConnection = errors.New("no connection provided")

// ClientOption sets some client option.
type ClientOption func(c *Client)

// WithHandler sets client handler which is called if Agent emits the Event
// with TransactionID that is not currently registered by Client.
// Useful for handling Data indications from TURN server.
func WithHandler(h Handler) ClientOption {
	return func(c *Client) {
		c.handler = h
	}
}

// WithRTO sets client RTO as defined in STUN RFC.
func WithRTO(rto time.Duration) ClientOption {
	return func(c *Client) {
		c.rto = int64(rto)
	}
}

// WithClock sets Clock of client, the source of current time.
// Also clock is passed to default collector if set.
func WithClock(clock Clock) ClientOption {
	return func(c *Client) {
		c.clock = clock
	}
}

// WithTimeoutRate sets RTO timer minimum resolution.
func WithTimeoutRate(d time.Duration) ClientOption {
	return func(c *Client) {
		c.rtoRate = d
	}
}

// WithAgent sets client STUN agent.
//
// Defaults to agent implementation in current package,
// see agent.go.
func WithAgent(a ClientAgent) ClientOption {
	return func(c *Client) {
		c.a = a
	}
}

// WithCollector rests client timeout collector, the implementation
// of ticker which calls function on each tick.
func WithCollector(coll Collector) ClientOption {
	return func(c *Client) {
		c.collector = coll
	}
}

// WithNoConnClose prevents client from closing underlying connection when
// the Close() method is called.
func WithNoConnClose() ClientOption {
	return func(c *Client) {
		c.closeConn = false
	}
}

// WithNoRetransmit disables retransmissions and sets RTO to
// defaultMaxAttempts * defaultRTO which will be effectively time out
// if not set.
//
// Useful for TCP connections where transport handles RTO.
func WithNoRetransmit(c *Client) {
	c.maxAttempts = 0
	if c.rto == 0 {
		c.rto = defaultMaxAttempts * int64(defaultRTO)
	}
}

const (
	defaultTimeoutRate = time.Millisecond * 5
	defaultRTO         = time.Millisecond * 300
	defaultMaxAttempts = 7
)

// NewClient initializes new Client from provided options,
// starting internal goroutines and using default options fields
// if necessary. Call Close method after using Client to close conn and
// release resources.
//
// The conn will be closed on Close call. Use WithNoConnClose option to
// prevent that.
//
// Note that user should handle the protocol multiplexing, client does not
// provide any API for it, so if you need to read application data, wrap the
// connection with your (de-)multiplexer and pass the wrapper as conn.
func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
	c := &Client{
		close:       make(chan struct{}),
		c:           conn,
		clock:       systemClock(),
		rto:         int64(defaultRTO),
		rtoRate:     defaultTimeoutRate,
		t:           make(map[transactionID]*clientTransaction, 100),
		maxAttempts: defaultMaxAttempts,
		closeConn:   true,
	}
	for _, o := range options {
		o(c)
	}
	if c.c == nil {
		return nil, ErrNoConnection
	}
	if c.a == nil {
		c.a = NewAgent(nil)
	}
	if err := c.a.SetHandler(c.handleAgentCallback); err != nil {
		return nil, err
	}
	if c.collector == nil {
		c.collector = &tickerCollector{
			close: make(chan struct{}),
			clock: c.clock,
		}
	}
	if err := c.collector.Start(c.rtoRate, func(t time.Time) {
		closedOrPanic(c.a.Collect(t))
	}); err != nil {
		return nil, err
	}
	c.wg.Add(1)
	go c.readUntilClosed()
	runtime.SetFinalizer(c, clientFinalizer)
	return c, nil
}

func clientFinalizer(c *Client) {
	if c == nil {
		return
	}
	err := c.Close()
	if errors.Is(err, ErrClientClosed) {
		return
	}
	if err == nil {
		log.Println("client: called finalizer on non-closed client") // nolint
		return
	}
	log.Println("client: called finalizer on non-closed client:", err) // nolint
}

// Connection wraps Reader, Writer and Closer interfaces.
type Connection interface {
	io.Reader
	io.Writer
	io.Closer
}

// ClientAgent is Agent implementation that is used by Client to
// process transactions.
type ClientAgent interface {
	Process(*Message) error
	Close() error
	Start(id [TransactionIDSize]byte, deadline time.Time) error
	Stop(id [TransactionIDSize]byte) error
	Collect(time.Time) error
	SetHandler(h Handler) error
}

// Client simulates "connection" to STUN server.
type Client struct {
	rto         int64 // time.Duration
	a           ClientAgent
	c           Connection
	close       chan struct{}
	rtoRate     time.Duration
	maxAttempts int32
	closed      bool
	closeConn   bool // should call c.Close() while closing
	wg          sync.WaitGroup
	clock       Clock
	handler     Handler
	collector   Collector
	t           map[transactionID]*clientTransaction

	// mux guards closed and t
	mux sync.RWMutex
}

// clientTransaction represents transaction in progress.
// If transaction is succeed or failed, f will be called
// provided by event.
// Concurrent access is invalid.
type clientTransaction struct {
	id      transactionID
	attempt int32
	calls   int32
	h       Handler
	start   time.Time
	rto     time.Duration
	raw     []byte
}

func (t *clientTransaction) handle(e Event) {
	if atomic.AddInt32(&t.calls, 1) == 1 {
		t.h(e)
	}
}

var clientTransactionPool = &sync.Pool{ //nolint:gochecknoglobals
	New: func() interface{} {
		return &clientTransaction{
			raw: make([]byte, 1500),
		}
	},
}

func acquireClientTransaction() *clientTransaction {
	return clientTransactionPool.Get().(*clientTransaction) //nolint:forcetypeassert
}

func putClientTransaction(t *clientTransaction) {
	t.raw = t.raw[:0]
	t.start = time.Time{}
	t.attempt = 0
	t.id = transactionID{}
	clientTransactionPool.Put(t)
}

func (t *clientTransaction) nextTimeout(now time.Time) time.Time {
	return now.Add(time.Duration(t.attempt+1) * t.rto)
}

// start registers transaction.
//
// Could return ErrClientClosed, ErrTransactionExists.
func (c *Client) start(t *clientTransaction) error {
	c.mux.Lock()
	defer c.mux.Unlock()
	if c.closed {
		return ErrClientClosed
	}
	_, exists := c.t[t.id]
	if exists {
		return ErrTransactionExists
	}
	c.t[t.id] = t
	return nil
}

// Clock abstracts the source of current time.
type Clock interface {
	Now() time.Time
}

type systemClockService struct{}

func (systemClockService) Now() time.Time { return time.Now() }

func systemClock() systemClockService {
	return systemClockService{}
}

// SetRTO sets current RTO value.
func (c *Client) SetRTO(rto time.Duration) {
	atomic.StoreInt64(&c.rto, int64(rto))
}

// StopErr occurs when Client fails to stop transaction while
// processing error.
//
//nolint:errname
type StopErr struct {
	Err   error // value returned by Stop()
	Cause error // error that caused Stop() call
}

func (e StopErr) Error() string {
	return fmt.Sprintf("error while stopping due to %s: %s", sprintErr(e.Cause), sprintErr(e.Err))
}

// CloseErr indicates client close failure.
//
//nolint:errname
type CloseErr struct {
	AgentErr      error
	ConnectionErr error
}

func sprintErr(err error) string {
	if err == nil {
		return "" //nolint:goconst
	}
	return err.Error()
}

func (c CloseErr) Error() string {
	return fmt.Sprintf("failed to close: %s (connection), %s (agent)", sprintErr(c.ConnectionErr), sprintErr(c.AgentErr))
}

func (c *Client) readUntilClosed() {
	defer c.wg.Done()
	m := new(Message)
	m.Raw = make([]byte, 1024)
	for {
		select {
		case <-c.close:
			return
		default:
		}
		_, err := m.ReadFrom(c.c)
		if err == nil {
			if pErr := c.a.Process(m); errors.Is(pErr, ErrAgentClosed) {
				return
			}
		}
	}
}

func closedOrPanic(err error) {
	if err == nil || errors.Is(err, ErrAgentClosed) {
		return
	}
	panic(err) //nolint
}

type tickerCollector struct {
	close chan struct{}
	wg    sync.WaitGroup
	clock Clock
}

// Collector calls function f with constant rate.
//
// The simple Collector is ticker which calls function on each tick.
type Collector interface {
	Start(rate time.Duration, f func(now time.Time)) error
	Close() error
}

func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error {
	t := time.NewTicker(rate)
	a.wg.Add(1)
	go func() {
		defer a.wg.Done()
		for {
			select {
			case <-a.close:
				t.Stop()
				return
			case <-t.C:
				f(a.clock.Now())
			}
		}
	}()
	return nil
}

func (a *tickerCollector) Close() error {
	close(a.close)
	a.wg.Wait()
	return nil
}

// ErrClientClosed indicates that client is closed.
var ErrClientClosed = errors.New("client is closed")

// Close stops internal connection and agent, returning CloseErr on error.
func (c *Client) Close() error {
	if err := c.checkInit(); err != nil {
		return err
	}
	c.mux.Lock()
	if c.closed {
		c.mux.Unlock()
		return ErrClientClosed
	}
	c.closed = true
	c.mux.Unlock()
	if closeErr := c.collector.Close(); closeErr != nil {
		return closeErr
	}
	var connErr error
	agentErr := c.a.Close()
	if c.closeConn {
		connErr = c.c.Close()
	}
	close(c.close)
	c.wg.Wait()
	if agentErr == nil && connErr == nil {
		return nil
	}
	return CloseErr{
		AgentErr:      agentErr,
		ConnectionErr: connErr,
	}
}

// Indicate sends indication m to server. Shorthand to Start call
// with zero deadline and callback.
func (c *Client) Indicate(m *Message) error {
	return c.Start(m, nil)
}

// callbackWaitHandler blocks on wait() call until callback is called.
type callbackWaitHandler struct {
	handler   Handler
	callback  func(event Event)
	cond      *sync.Cond
	processed bool
}

func (s *callbackWaitHandler) HandleEvent(e Event) {
	s.cond.L.Lock()
	if s.callback == nil {
		panic("s.callback is nil") //nolint
	}
	s.callback(e)
	s.processed = true
	s.cond.Broadcast()
	s.cond.L.Unlock()
}

func (s *callbackWaitHandler) wait() {
	s.cond.L.Lock()
	for !s.processed {
		s.cond.Wait()
	}
	s.processed = false
	s.callback = nil
	s.cond.L.Unlock()
}

func (s *callbackWaitHandler) setCallback(f func(event Event)) {
	if f == nil {
		panic("f is nil") //nolint
	}
	s.cond.L.Lock()
	s.callback = f
	if s.handler == nil {
		s.handler = s.HandleEvent
	}
	s.cond.L.Unlock()
}

var callbackWaitHandlerPool = sync.Pool{ //nolint:gochecknoglobals
	New: func() interface{} {
		return &callbackWaitHandler{
			cond: sync.NewCond(new(sync.Mutex)),
		}
	},
}

// ErrClientNotInitialized means that client connection or agent is nil.
var ErrClientNotInitialized = errors.New("client not initialized")

func (c *Client) checkInit() error {
	if c == nil || c.c == nil || c.a == nil || c.close == nil {
		return ErrClientNotInitialized
	}
	return nil
}

// Do is Start wrapper that waits until callback is called. If no callback
// provided, Indicate is called instead.
//
// Do has cpu overhead due to blocking, see BenchmarkClient_Do.
// Use Start method for less overhead.
func (c *Client) Do(m *Message, f func(Event)) error {
	if err := c.checkInit(); err != nil {
		return err
	}
	if f == nil {
		return c.Indicate(m)
	}
	h := callbackWaitHandlerPool.Get().(*callbackWaitHandler) //nolint:forcetypeassert
	h.setCallback(f)
	defer func() {
		callbackWaitHandlerPool.Put(h)
	}()
	if err := c.Start(m, h.handler); err != nil {
		return err
	}
	h.wait()
	return nil
}

func (c *Client) delete(id transactionID) {
	c.mux.Lock()
	if c.t != nil {
		delete(c.t, id)
	}
	c.mux.Unlock()
}

type buffer struct {
	buf []byte
}

var bufferPool = &sync.Pool{ //nolint:gochecknoglobals
	New: func() interface{} {
		return &buffer{buf: make([]byte, 2048)}
	},
}

func (c *Client) handleAgentCallback(e Event) {
	c.mux.Lock()
	if c.closed {
		c.mux.Unlock()
		return
	}
	t, found := c.t[e.TransactionID]
	if found {
		delete(c.t, t.id)
	}
	c.mux.Unlock()
	if !found {
		if c.handler != nil && !errors.Is(e.Error, ErrTransactionStopped) {
			c.handler(e)
		}
		// Ignoring.
		return
	}
	if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil {
		// Transaction completed.
		t.handle(e)
		putClientTransaction(t)
		return
	}
	// Doing re-transmission.
	t.attempt++
	b := bufferPool.Get().(*buffer) //nolint:forcetypeassert
	b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)]
	defer bufferPool.Put(b)
	var (
		now     = c.clock.Now()
		timeOut = t.nextTimeout(now)
		id      = t.id
	)
	// Starting client transaction.
	if startErr := c.start(t); startErr != nil {
		c.delete(id)
		e.Error = startErr
		t.handle(e)
		putClientTransaction(t)
		return
	}
	// Starting agent transaction.
	if startErr := c.a.Start(id, timeOut); startErr != nil {
		c.delete(id)
		e.Error = startErr
		t.handle(e)
		putClientTransaction(t)
		return
	}
	// Writing message to connection again.
	_, writeErr := c.c.Write(b.buf)
	if writeErr != nil {
		c.delete(id)
		e.Error = writeErr
		// Stopping agent transaction instead of waiting until it's deadline.
		// This will call handleAgentCallback with "ErrTransactionStopped" error
		// which will be ignored.
		if stopErr := c.a.Stop(id); stopErr != nil {
			// Failed to stop agent transaction. Wrapping the error in StopError.
			e.Error = StopErr{
				Err:   stopErr,
				Cause: writeErr,
			}
		}
		t.handle(e)
		putClientTransaction(t)
		return
	}
}

// Start starts transaction (if h set) and writes message to server, handler
// is called asynchronously.
func (c *Client) Start(m *Message, h Handler) error {
	if err := c.checkInit(); err != nil {
		return err
	}
	c.mux.RLock()
	closed := c.closed
	c.mux.RUnlock()
	if closed {
		return ErrClientClosed
	}
	if h != nil {
		// Starting transaction only if h is set. Useful for indications.
		t := acquireClientTransaction()
		t.id = m.TransactionID
		t.start = c.clock.Now()
		t.h = h
		t.rto = time.Duration(atomic.LoadInt64(&c.rto))
		t.attempt = 0
		t.raw = append(t.raw[:0], m.Raw...)
		t.calls = 0
		d := t.nextTimeout(t.start)
		if err := c.start(t); err != nil {
			return err
		}
		if err := c.a.Start(m.TransactionID, d); err != nil {
			return err
		}
	}
	_, err := m.WriteTo(c.c)
	if err != nil && h != nil {
		c.delete(m.TransactionID)
		// Stopping transaction instead of waiting until deadline.
		if stopErr := c.a.Stop(m.TransactionID); stopErr != nil {
			return StopErr{
				Err:   stopErr,
				Cause: err,
			}
		}
	}
	return err
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy