All Downloads are FREE. Search and download functionalities are using the official Maven repository.

vendor.github.com.pion.transport.vnet.net.go Maven / Gradle / Ivy

There is a newer version: 2.9.1
Show newest version
package vnet

import (
	"encoding/binary"
	"errors"
	"fmt"
	"math/rand"
	"net"
	"strconv"
	"strings"
	"sync"
)

const (
	lo0String = "lo0String"
	udpString = "udp"
)

var (
	macAddrCounter                 uint64 = 0xBEEFED910200 //nolint:gochecknoglobals
	errNoInterface                        = errors.New("no interface is available")
	errNotFound                           = errors.New("not found")
	errUnexpectedNetwork                  = errors.New("unexpected network")
	errCantAssignRequestedAddr            = errors.New("can't assign requested address")
	errUnknownNetwork                     = errors.New("unknown network")
	errNoRouterLinked                     = errors.New("no router linked")
	errInvalidPortNumber                  = errors.New("invalid port number")
	errUnexpectedTypeSwitchFailure        = errors.New("unexpected type-switch failure")
	errBindFailerFor                      = errors.New("bind failed for")
	errEndPortLessThanStart               = errors.New("end port is less than the start")
	errPortSpaceExhausted                 = errors.New("port space exhausted")
	errVNetDisabled                       = errors.New("vnet is not enabled")
)

func newMACAddress() net.HardwareAddr {
	b := make([]byte, 8)
	binary.BigEndian.PutUint64(b, macAddrCounter)
	macAddrCounter++
	return b[2:]
}

type vNet struct {
	interfaces []*Interface // read-only
	staticIPs  []net.IP     // read-only
	router     *Router      // read-only
	udpConns   *udpConnMap  // read-only
	mutex      sync.RWMutex
}

func (v *vNet) _getInterfaces() ([]*Interface, error) {
	if len(v.interfaces) == 0 {
		return nil, errNoInterface
	}

	return v.interfaces, nil
}

func (v *vNet) getInterfaces() ([]*Interface, error) {
	v.mutex.RLock()
	defer v.mutex.RUnlock()

	return v._getInterfaces()
}

// caller must hold the mutex (read)
func (v *vNet) _getInterface(ifName string) (*Interface, error) {
	ifs, err := v._getInterfaces()
	if err != nil {
		return nil, err
	}
	for _, ifc := range ifs {
		if ifc.Name == ifName {
			return ifc, nil
		}
	}

	return nil, fmt.Errorf("interface %s %w", ifName, errNotFound)
}

func (v *vNet) getInterface(ifName string) (*Interface, error) {
	v.mutex.RLock()
	defer v.mutex.RUnlock()

	return v._getInterface(ifName)
}

// caller must hold the mutex
func (v *vNet) getAllIPAddrs(ipv6 bool) []net.IP {
	ips := []net.IP{}

	for _, ifc := range v.interfaces {
		addrs, err := ifc.Addrs()
		if err != nil {
			continue
		}

		for _, addr := range addrs {
			var ip net.IP
			if ipNet, ok := addr.(*net.IPNet); ok {
				ip = ipNet.IP
			} else if ipAddr, ok := addr.(*net.IPAddr); ok {
				ip = ipAddr.IP
			} else {
				continue
			}

			if !ipv6 {
				if ip.To4() != nil {
					ips = append(ips, ip)
				}
			}
		}
	}

	return ips
}

func (v *vNet) setRouter(r *Router) error {
	v.mutex.Lock()
	defer v.mutex.Unlock()

	v.router = r
	return nil
}

func (v *vNet) onInboundChunk(c Chunk) {
	v.mutex.Lock()
	defer v.mutex.Unlock()

	if c.Network() == udpString {
		if conn, ok := v.udpConns.find(c.DestinationAddr()); ok {
			conn.onInboundChunk(c)
		}
	}
}

// caller must hold the mutex
func (v *vNet) _dialUDP(network string, locAddr, remAddr *net.UDPAddr) (UDPPacketConn, error) {
	// validate network
	if network != udpString && network != "udp4" {
		return nil, fmt.Errorf("%w: %s", errUnexpectedNetwork, network)
	}

	if locAddr == nil {
		locAddr = &net.UDPAddr{
			IP: net.IPv4zero,
		}
	} else if locAddr.IP == nil {
		locAddr.IP = net.IPv4zero
	}

	// validate address. do we have that address?
	if !v.hasIPAddr(locAddr.IP) {
		return nil, &net.OpError{
			Op:   "listen",
			Net:  network,
			Addr: locAddr,
			Err:  fmt.Errorf("bind: %w", errCantAssignRequestedAddr),
		}
	}

	if locAddr.Port == 0 {
		// choose randomly from the range between 5000 and 5999
		port, err := v.assignPort(locAddr.IP, 5000, 5999)
		if err != nil {
			return nil, &net.OpError{
				Op:   "listen",
				Net:  network,
				Addr: locAddr,
				Err:  err,
			}
		}
		locAddr.Port = port
	} else if _, ok := v.udpConns.find(locAddr); ok {
		return nil, &net.OpError{
			Op:   "listen",
			Net:  network,
			Addr: locAddr,
			Err:  fmt.Errorf("bind: %w", errAddressAlreadyInUse),
		}
	}

	conn, err := newUDPConn(locAddr, remAddr, v)
	if err != nil {
		return nil, err
	}

	err = v.udpConns.insert(conn)
	if err != nil {
		return nil, err
	}

	return conn, nil
}

func (v *vNet) listenPacket(network string, address string) (UDPPacketConn, error) {
	v.mutex.Lock()
	defer v.mutex.Unlock()

	locAddr, err := v.resolveUDPAddr(network, address)
	if err != nil {
		return nil, err
	}

	return v._dialUDP(network, locAddr, nil)
}

func (v *vNet) listenUDP(network string, locAddr *net.UDPAddr) (UDPPacketConn, error) {
	v.mutex.Lock()
	defer v.mutex.Unlock()

	return v._dialUDP(network, locAddr, nil)
}

func (v *vNet) dialUDP(network string, locAddr, remAddr *net.UDPAddr) (UDPPacketConn, error) {
	v.mutex.Lock()
	defer v.mutex.Unlock()

	return v._dialUDP(network, locAddr, remAddr)
}

func (v *vNet) dial(network string, address string) (UDPPacketConn, error) {
	v.mutex.Lock()
	defer v.mutex.Unlock()

	remAddr, err := v.resolveUDPAddr(network, address)
	if err != nil {
		return nil, err
	}

	// Determine source address
	srcIP := v.determineSourceIP(nil, remAddr.IP)

	locAddr := &net.UDPAddr{IP: srcIP, Port: 0}

	return v._dialUDP(network, locAddr, remAddr)
}

func (v *vNet) resolveUDPAddr(network, address string) (*net.UDPAddr, error) {
	if network != udpString && network != "udp4" {
		return nil, fmt.Errorf("%w %s", errUnknownNetwork, network)
	}

	host, sPort, err := net.SplitHostPort(address)
	if err != nil {
		return nil, err
	}

	// Check if host is a domain name
	ip := net.ParseIP(host)
	if ip == nil {
		host = strings.ToLower(host)
		if host == "localhost" {
			ip = net.IPv4(127, 0, 0, 1)
		} else {
			// host is a domain name. resolve IP address by the name
			if v.router == nil {
				return nil, errNoRouterLinked
			}

			ip, err = v.router.resolver.lookUp(host)
			if err != nil {
				return nil, err
			}
		}
	}

	port, err := strconv.Atoi(sPort)
	if err != nil {
		return nil, errInvalidPortNumber
	}

	udpAddr := &net.UDPAddr{
		IP:   ip,
		Port: port,
	}

	return udpAddr, nil
}

func (v *vNet) write(c Chunk) error {
	if c.Network() == udpString {
		if udp, ok := c.(*chunkUDP); ok {
			if c.getDestinationIP().IsLoopback() {
				if conn, ok := v.udpConns.find(udp.DestinationAddr()); ok {
					conn.onInboundChunk(udp)
				}
				return nil
			}
		} else {
			return errUnexpectedTypeSwitchFailure
		}
	}

	if v.router == nil {
		return errNoRouterLinked
	}

	v.router.push(c)
	return nil
}

func (v *vNet) onClosed(addr net.Addr) {
	if addr.Network() == udpString {
		//nolint:errcheck
		v.udpConns.delete(addr) // #nosec
	}
}

// This method determines the srcIP based on the dstIP when locIP
// is any IP address ("0.0.0.0" or "::"). If locIP is a non-any addr,
// this method simply returns locIP.
// caller must hold the mutex
func (v *vNet) determineSourceIP(locIP, dstIP net.IP) net.IP {
	if locIP != nil && !locIP.IsUnspecified() {
		return locIP
	}

	var srcIP net.IP

	if dstIP.IsLoopback() {
		srcIP = net.ParseIP("127.0.0.1")
	} else {
		ifc, err2 := v._getInterface("eth0")
		if err2 != nil {
			return nil
		}

		addrs, err2 := ifc.Addrs()
		if err2 != nil {
			return nil
		}

		if len(addrs) == 0 {
			return nil
		}

		var findIPv4 bool
		if locIP != nil {
			findIPv4 = (locIP.To4() != nil)
		} else {
			findIPv4 = (dstIP.To4() != nil)
		}

		for _, addr := range addrs {
			ip := addr.(*net.IPNet).IP
			if findIPv4 {
				if ip.To4() != nil {
					srcIP = ip
					break
				}
			} else {
				if ip.To4() == nil {
					srcIP = ip
					break
				}
			}
		}
	}

	return srcIP
}

// caller must hold the mutex
func (v *vNet) hasIPAddr(ip net.IP) bool { //nolint:gocognit
	for _, ifc := range v.interfaces {
		if addrs, err := ifc.Addrs(); err == nil {
			for _, addr := range addrs {
				var locIP net.IP
				if ipNet, ok := addr.(*net.IPNet); ok {
					locIP = ipNet.IP
				} else if ipAddr, ok := addr.(*net.IPAddr); ok {
					locIP = ipAddr.IP
				} else {
					continue
				}

				switch ip.String() {
				case "0.0.0.0":
					if locIP.To4() != nil {
						return true
					}
				case "::":
					if locIP.To4() == nil {
						return true
					}
				default:
					if locIP.Equal(ip) {
						return true
					}
				}
			}
		}
	}

	return false
}

// caller must hold the mutex
func (v *vNet) allocateLocalAddr(ip net.IP, port int) error {
	// gather local IP addresses to bind
	var ips []net.IP
	if ip.IsUnspecified() {
		ips = v.getAllIPAddrs(ip.To4() == nil)
	} else if v.hasIPAddr(ip) {
		ips = []net.IP{ip}
	}

	if len(ips) == 0 {
		return fmt.Errorf("%w %s", errBindFailerFor, ip.String())
	}

	// check if all these transport addresses are not in use
	for _, ip2 := range ips {
		addr := &net.UDPAddr{
			IP:   ip2,
			Port: port,
		}
		if _, ok := v.udpConns.find(addr); ok {
			return &net.OpError{
				Op:   "bind",
				Net:  udpString,
				Addr: addr,
				Err:  fmt.Errorf("bind: %w", errAddressAlreadyInUse),
			}
		}
	}

	return nil
}

// caller must hold the mutex
func (v *vNet) assignPort(ip net.IP, start, end int) (int, error) {
	// choose randomly from the range between start and end (inclusive)
	if end < start {
		return -1, errEndPortLessThanStart
	}

	space := end + 1 - start
	offset := rand.Intn(space) //nolint:gosec
	for i := 0; i < space; i++ {
		port := ((offset + i) % space) + start

		err := v.allocateLocalAddr(ip, port)
		if err == nil {
			return port, nil
		}
	}

	return -1, errPortSpaceExhausted
}

// NetConfig is a bag of configuration parameters passed to NewNet().
type NetConfig struct {
	// StaticIPs is an array of static IP addresses to be assigned for this Net.
	// If no static IP address is given, the router will automatically assign
	// an IP address.
	StaticIPs []string

	// StaticIP is deprecated. Use StaticIPs.
	StaticIP string
}

// Net represents a local network stack euivalent to a set of layers from NIC
// up to the transport (UDP / TCP) layer.
type Net struct {
	v   *vNet
	ifs []*Interface
}

// NewNet creates an instance of Net.
// If config is nil, the virtual network is disabled. (uses corresponding
// net.Xxxx() operations.
// By design, it always have lo0 and eth0 interfaces.
// The lo0 has the address 127.0.0.1 assigned by default.
// IP address for eth0 will be assigned when this Net is added to a router.
func NewNet(config *NetConfig) *Net {
	if config == nil {
		ifs := []*Interface{}
		if orgIfs, err := net.Interfaces(); err == nil {
			for _, orgIfc := range orgIfs {
				ifc := NewInterface(orgIfc)
				if addrs, err := orgIfc.Addrs(); err == nil {
					for _, addr := range addrs {
						ifc.AddAddr(addr)
					}
				}

				ifs = append(ifs, ifc)
			}
		}

		return &Net{ifs: ifs}
	}

	lo0 := NewInterface(net.Interface{
		Index:        1,
		MTU:          16384,
		Name:         lo0String,
		HardwareAddr: nil,
		Flags:        net.FlagUp | net.FlagLoopback | net.FlagMulticast,
	})
	lo0.AddAddr(&net.IPNet{
		IP:   net.ParseIP("127.0.0.1"),
		Mask: net.CIDRMask(8, 32),
	})

	eth0 := NewInterface(net.Interface{
		Index:        2,
		MTU:          1500,
		Name:         "eth0",
		HardwareAddr: newMACAddress(),
		Flags:        net.FlagUp | net.FlagMulticast,
	})

	var staticIPs []net.IP
	for _, ipStr := range config.StaticIPs {
		if ip := net.ParseIP(ipStr); ip != nil {
			staticIPs = append(staticIPs, ip)
		}
	}
	if len(config.StaticIP) > 0 {
		if ip := net.ParseIP(config.StaticIP); ip != nil {
			staticIPs = append(staticIPs, ip)
		}
	}

	v := &vNet{
		interfaces: []*Interface{lo0, eth0},
		staticIPs:  staticIPs,
		udpConns:   newUDPConnMap(),
	}

	return &Net{
		v: v,
	}
}

// Interfaces returns a list of the system's network interfaces.
func (n *Net) Interfaces() ([]*Interface, error) {
	if n.v == nil {
		return n.ifs, nil
	}

	return n.v.getInterfaces()
}

// InterfaceByName returns the interface specified by name.
func (n *Net) InterfaceByName(name string) (*Interface, error) {
	if n.v == nil {
		for _, ifc := range n.ifs {
			if ifc.Name == name {
				return ifc, nil
			}
		}

		return nil, fmt.Errorf("interface %s %w", name, errNotFound)
	}

	return n.v.getInterface(name)
}

// ListenPacket announces on the local network address.
func (n *Net) ListenPacket(network string, address string) (net.PacketConn, error) {
	if n.v == nil {
		return net.ListenPacket(network, address)
	}

	return n.v.listenPacket(network, address)
}

// ListenUDP acts like ListenPacket for UDP networks.
func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (UDPPacketConn, error) {
	if n.v == nil {
		return net.ListenUDP(network, locAddr)
	}

	return n.v.listenUDP(network, locAddr)
}

// Dial connects to the address on the named network.
func (n *Net) Dial(network, address string) (net.Conn, error) {
	if n.v == nil {
		return net.Dial(network, address)
	}

	return n.v.dial(network, address)
}

// CreateDialer creates an instance of vnet.Dialer
func (n *Net) CreateDialer(dialer *net.Dialer) Dialer {
	if n.v == nil {
		return &vDialer{
			dialer: dialer,
		}
	}

	return &vDialer{
		dialer: dialer,
		v:      n.v,
	}
}

// DialUDP acts like Dial for UDP networks.
func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (UDPPacketConn, error) {
	if n.v == nil {
		return net.DialUDP(network, laddr, raddr)
	}

	return n.v.dialUDP(network, laddr, raddr)
}

// ResolveUDPAddr returns an address of UDP end point.
func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
	if n.v == nil {
		return net.ResolveUDPAddr(network, address)
	}

	return n.v.resolveUDPAddr(network, address)
}

func (n *Net) getInterface(ifName string) (*Interface, error) {
	if n.v == nil {
		return nil, errVNetDisabled
	}

	return n.v.getInterface(ifName)
}

func (n *Net) setRouter(r *Router) error {
	if n.v == nil {
		return errVNetDisabled
	}

	return n.v.setRouter(r)
}

func (n *Net) onInboundChunk(c Chunk) {
	if n.v == nil {
		return
	}

	n.v.onInboundChunk(c)
}

func (n *Net) getStaticIPs() []net.IP {
	if n.v == nil {
		return nil
	}

	return n.v.staticIPs
}

// IsVirtual tests if the virtual network is enabled.
func (n *Net) IsVirtual() bool {
	return n.v != nil
}

// Dialer is identical to net.Dialer excepts that its methods
// (Dial, DialContext) are overridden to use virtual network.
// Use vnet.CreateDialer() to create an instance of this Dialer.
type Dialer interface {
	Dial(network, address string) (net.Conn, error)
}

type vDialer struct {
	dialer *net.Dialer
	v      *vNet
}

func (d *vDialer) Dial(network, address string) (net.Conn, error) {
	if d.v == nil {
		return d.dialer.Dial(network, address)
	}

	return d.v.dial(network, address)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy