All Downloads are FREE. Search and download functionalities are using the official Maven repository.

vendor.github.com.pion.datachannel.datachannel.go Maven / Gradle / Ivy

The newest version!
// Package datachannel implements WebRTC Data Channels
package datachannel

import (
	"errors"
	"fmt"
	"io"
	"sync"
	"sync/atomic"
	"time"

	"github.com/pion/logging"
	"github.com/pion/sctp"
)

const receiveMTU = 8192

// Reader is an extended io.Reader
// that also returns if the message is text.
type Reader interface {
	ReadDataChannel([]byte) (int, bool, error)
}

// ReadDeadliner extends an io.Reader to expose setting a read deadline.
type ReadDeadliner interface {
	SetReadDeadline(time.Time) error
}

// Writer is an extended io.Writer
// that also allows indicating if a message is text.
type Writer interface {
	WriteDataChannel([]byte, bool) (int, error)
}

// ReadWriteCloser is an extended io.ReadWriteCloser
// that also implements our Reader and Writer.
type ReadWriteCloser interface {
	io.Reader
	io.Writer
	Reader
	Writer
	io.Closer
}

// DataChannel represents a data channel
type DataChannel struct {
	Config

	// stats
	messagesSent     uint32
	messagesReceived uint32
	bytesSent        uint64
	bytesReceived    uint64

	mu                      sync.Mutex
	onOpenCompleteHandler   func()
	openCompleteHandlerOnce sync.Once

	stream *sctp.Stream
	log    logging.LeveledLogger
}

// Config is used to configure the data channel.
type Config struct {
	ChannelType          ChannelType
	Negotiated           bool
	Priority             uint16
	ReliabilityParameter uint32
	Label                string
	Protocol             string
	LoggerFactory        logging.LoggerFactory
}

func newDataChannel(stream *sctp.Stream, config *Config) (*DataChannel, error) {
	return &DataChannel{
		Config: *config,
		stream: stream,
		log:    config.LoggerFactory.NewLogger("datachannel"),
	}, nil
}

// Dial opens a data channels over SCTP
func Dial(a *sctp.Association, id uint16, config *Config) (*DataChannel, error) {
	stream, err := a.OpenStream(id, sctp.PayloadTypeWebRTCBinary)
	if err != nil {
		return nil, err
	}

	dc, err := Client(stream, config)
	if err != nil {
		return nil, err
	}

	return dc, nil
}

// Client opens a data channel over an SCTP stream
func Client(stream *sctp.Stream, config *Config) (*DataChannel, error) {
	msg := &channelOpen{
		ChannelType:          config.ChannelType,
		Priority:             config.Priority,
		ReliabilityParameter: config.ReliabilityParameter,

		Label:    []byte(config.Label),
		Protocol: []byte(config.Protocol),
	}

	if !config.Negotiated {
		rawMsg, err := msg.Marshal()
		if err != nil {
			return nil, fmt.Errorf("failed to marshal ChannelOpen %w", err)
		}

		if _, err = stream.WriteSCTP(rawMsg, sctp.PayloadTypeWebRTCDCEP); err != nil {
			return nil, fmt.Errorf("failed to send ChannelOpen %w", err)
		}
	}
	return newDataChannel(stream, config)
}

// Accept is used to accept incoming data channels over SCTP
func Accept(a *sctp.Association, config *Config, existingChannels ...*DataChannel) (*DataChannel, error) {
	stream, err := a.AcceptStream()
	if err != nil {
		return nil, err
	}
	for _, ch := range existingChannels {
		if ch.StreamIdentifier() == stream.StreamIdentifier() {
			ch.stream.SetDefaultPayloadType(sctp.PayloadTypeWebRTCBinary)
			return ch, nil
		}
	}

	stream.SetDefaultPayloadType(sctp.PayloadTypeWebRTCBinary)

	dc, err := Server(stream, config)
	if err != nil {
		return nil, err
	}

	return dc, nil
}

// Server accepts a data channel over an SCTP stream
func Server(stream *sctp.Stream, config *Config) (*DataChannel, error) {
	buffer := make([]byte, receiveMTU)
	n, ppi, err := stream.ReadSCTP(buffer)
	if err != nil {
		return nil, err
	}

	if ppi != sctp.PayloadTypeWebRTCDCEP {
		return nil, fmt.Errorf("%w %s", ErrInvalidPayloadProtocolIdentifier, ppi)
	}

	openMsg, err := parseExpectDataChannelOpen(buffer[:n])
	if err != nil {
		return nil, fmt.Errorf("failed to parse DataChannelOpen packet %w", err)
	}

	config.ChannelType = openMsg.ChannelType
	config.Priority = openMsg.Priority
	config.ReliabilityParameter = openMsg.ReliabilityParameter
	config.Label = string(openMsg.Label)
	config.Protocol = string(openMsg.Protocol)

	dataChannel, err := newDataChannel(stream, config)
	if err != nil {
		return nil, err
	}

	err = dataChannel.writeDataChannelAck()
	if err != nil {
		return nil, err
	}

	err = dataChannel.commitReliabilityParams()
	if err != nil {
		return nil, err
	}
	return dataChannel, nil
}

// Read reads a packet of len(p) bytes as binary data
func (c *DataChannel) Read(p []byte) (int, error) {
	n, _, err := c.ReadDataChannel(p)
	return n, err
}

// ReadDataChannel reads a packet of len(p) bytes
func (c *DataChannel) ReadDataChannel(p []byte) (int, bool, error) {
	for {
		n, ppi, err := c.stream.ReadSCTP(p)
		if errors.Is(err, io.EOF) {
			// When the peer sees that an incoming stream was
			// reset, it also resets its corresponding outgoing stream.
			if closeErr := c.stream.Close(); closeErr != nil {
				return 0, false, closeErr
			}
		}
		if err != nil {
			return 0, false, err
		}

		if ppi == sctp.PayloadTypeWebRTCDCEP {
			if err = c.handleDCEP(p[:n]); err != nil {
				c.log.Errorf("Failed to handle DCEP: %s", err.Error())
			}
			continue
		} else if ppi == sctp.PayloadTypeWebRTCBinaryEmpty || ppi == sctp.PayloadTypeWebRTCStringEmpty {
			n = 0
		}

		atomic.AddUint32(&c.messagesReceived, 1)
		atomic.AddUint64(&c.bytesReceived, uint64(n))

		isString := ppi == sctp.PayloadTypeWebRTCString || ppi == sctp.PayloadTypeWebRTCStringEmpty
		return n, isString, err
	}
}

// SetReadDeadline sets a deadline for reads to return
func (c *DataChannel) SetReadDeadline(t time.Time) error {
	return c.stream.SetReadDeadline(t)
}

// MessagesSent returns the number of messages sent
func (c *DataChannel) MessagesSent() uint32 {
	return atomic.LoadUint32(&c.messagesSent)
}

// MessagesReceived returns the number of messages received
func (c *DataChannel) MessagesReceived() uint32 {
	return atomic.LoadUint32(&c.messagesReceived)
}

// OnOpen sets an event handler which is invoked when
// a DATA_CHANNEL_ACK message is received.
// The handler is called only on thefor the channel opened
// https://datatracker.ietf.org/doc/html/draft-ietf-rtcweb-data-protocol-09#section-5.2
func (c *DataChannel) OnOpen(f func()) {
	c.mu.Lock()
	c.openCompleteHandlerOnce = sync.Once{}
	c.onOpenCompleteHandler = f
	c.mu.Unlock()
}

func (c *DataChannel) onOpenComplete() {
	c.mu.Lock()
	hdlr := c.onOpenCompleteHandler
	c.mu.Unlock()

	if hdlr != nil {
		go c.openCompleteHandlerOnce.Do(func() {
			hdlr()
		})
	}
}

// BytesSent returns the number of bytes sent
func (c *DataChannel) BytesSent() uint64 {
	return atomic.LoadUint64(&c.bytesSent)
}

// BytesReceived returns the number of bytes received
func (c *DataChannel) BytesReceived() uint64 {
	return atomic.LoadUint64(&c.bytesReceived)
}

// StreamIdentifier returns the Stream identifier associated to the stream.
func (c *DataChannel) StreamIdentifier() uint16 {
	return c.stream.StreamIdentifier()
}

func (c *DataChannel) handleDCEP(data []byte) error {
	msg, err := parse(data)
	if err != nil {
		return fmt.Errorf("failed to parse DataChannel packet %w", err)
	}

	switch msg := msg.(type) {
	case *channelAck:
		c.log.Debug("Received DATA_CHANNEL_ACK")
		if err = c.commitReliabilityParams(); err != nil {
			return err
		}
		c.onOpenComplete()
	default:
		return fmt.Errorf("%w %v", ErrInvalidMessageType, msg)
	}

	return nil
}

// Write writes len(p) bytes from p as binary data
func (c *DataChannel) Write(p []byte) (n int, err error) {
	return c.WriteDataChannel(p, false)
}

// WriteDataChannel writes len(p) bytes from p
func (c *DataChannel) WriteDataChannel(p []byte, isString bool) (n int, err error) {
	// https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-12#section-6.6
	// SCTP does not support the sending of empty user messages.  Therefore,
	// if an empty message has to be sent, the appropriate PPID (WebRTC
	// String Empty or WebRTC Binary Empty) is used and the SCTP user
	// message of one zero byte is sent.  When receiving an SCTP user
	// message with one of these PPIDs, the receiver MUST ignore the SCTP
	// user message and process it as an empty message.
	var ppi sctp.PayloadProtocolIdentifier
	switch {
	case !isString && len(p) > 0:
		ppi = sctp.PayloadTypeWebRTCBinary
	case !isString && len(p) == 0:
		ppi = sctp.PayloadTypeWebRTCBinaryEmpty
	case isString && len(p) > 0:
		ppi = sctp.PayloadTypeWebRTCString
	case isString && len(p) == 0:
		ppi = sctp.PayloadTypeWebRTCStringEmpty
	}

	atomic.AddUint32(&c.messagesSent, 1)
	atomic.AddUint64(&c.bytesSent, uint64(len(p)))

	if len(p) == 0 {
		_, err := c.stream.WriteSCTP([]byte{0}, ppi)
		return 0, err
	}
	return c.stream.WriteSCTP(p, ppi)
}

func (c *DataChannel) writeDataChannelAck() error {
	ack := channelAck{}
	ackMsg, err := ack.Marshal()
	if err != nil {
		return fmt.Errorf("failed to marshal ChannelOpen ACK: %w", err)
	}

	if _, err = c.stream.WriteSCTP(ackMsg, sctp.PayloadTypeWebRTCDCEP); err != nil {
		return fmt.Errorf("failed to send ChannelOpen ACK: %w", err)
	}

	return err
}

// Close closes the DataChannel and the underlying SCTP stream.
func (c *DataChannel) Close() error {
	// https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7
	// Closing of a data channel MUST be signaled by resetting the
	// corresponding outgoing streams [RFC6525].  This means that if one
	// side decides to close the data channel, it resets the corresponding
	// outgoing stream.  When the peer sees that an incoming stream was
	// reset, it also resets its corresponding outgoing stream.  Once this
	// is completed, the data channel is closed.  Resetting a stream sets
	// the Stream Sequence Numbers (SSNs) of the stream back to 'zero' with
	// a corresponding notification to the application layer that the reset
	// has been performed.  Streams are available for reuse after a reset
	// has been performed.
	return c.stream.Close()
}

// BufferedAmount returns the number of bytes of data currently queued to be
// sent over this stream.
func (c *DataChannel) BufferedAmount() uint64 {
	return c.stream.BufferedAmount()
}

// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
// data that is considered "low." Defaults to 0.
func (c *DataChannel) BufferedAmountLowThreshold() uint64 {
	return c.stream.BufferedAmountLowThreshold()
}

// SetBufferedAmountLowThreshold is used to update the threshold.
// See BufferedAmountLowThreshold().
func (c *DataChannel) SetBufferedAmountLowThreshold(th uint64) {
	c.stream.SetBufferedAmountLowThreshold(th)
}

// OnBufferedAmountLow sets the callback handler which would be called when the
// number of bytes of outgoing data buffered is lower than the threshold.
func (c *DataChannel) OnBufferedAmountLow(f func()) {
	c.stream.OnBufferedAmountLow(f)
}

func (c *DataChannel) commitReliabilityParams() error {
	switch c.Config.ChannelType {
	case ChannelTypeReliable:
		c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter)
	case ChannelTypeReliableUnordered:
		c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter)
	case ChannelTypePartialReliableRexmit:
		c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter)
	case ChannelTypePartialReliableRexmitUnordered:
		c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter)
	case ChannelTypePartialReliableTimed:
		c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter)
	case ChannelTypePartialReliableTimedUnordered:
		c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter)
	default:
		return fmt.Errorf("%w %v", ErrInvalidChannelType, c.Config.ChannelType)
	}
	return nil
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy