All Downloads are FREE. Search and download functionalities are using the official Maven repository.

common.turbotunnel.redialpacketconn.go Maven / Gradle / Ivy

The newest version!
package turbotunnel

import (
	"context"
	"errors"
	"net"
	"sync"
	"sync/atomic"
	"time"
)

// RedialPacketConn implements a long-lived net.PacketConn atop a sequence of
// other, transient net.PacketConns. RedialPacketConn creates a new
// net.PacketConn by calling a provided dialContext function. Whenever the
// net.PacketConn experiences a ReadFrom or WriteTo error, RedialPacketConn
// calls the dialContext function again and starts sending and receiving packets
// on the new net.PacketConn. RedialPacketConn's own ReadFrom and WriteTo
// methods return an error only when the dialContext function returns an error.
//
// RedialPacketConn uses static local and remote addresses that are independent
// of those of any dialed net.PacketConn.
type RedialPacketConn struct {
	localAddr   net.Addr
	remoteAddr  net.Addr
	dialContext func(context.Context) (net.PacketConn, error)
	recvQueue   chan []byte
	sendQueue   chan []byte
	closed      chan struct{}
	closeOnce   sync.Once
	// The first dial error, which causes the clientPacketConn to be
	// closed and is returned from future read/write operations. Compare to
	// the rerr and werr in io.Pipe.
	err atomic.Value
}

// NewRedialPacketConn makes a new RedialPacketConn, with the given static local
// and remote addresses, and dialContext function.
func NewRedialPacketConn(
	localAddr, remoteAddr net.Addr,
	dialContext func(context.Context) (net.PacketConn, error),
) *RedialPacketConn {
	c := &RedialPacketConn{
		localAddr:   localAddr,
		remoteAddr:  remoteAddr,
		dialContext: dialContext,
		recvQueue:   make(chan []byte, queueSize),
		sendQueue:   make(chan []byte, queueSize),
		closed:      make(chan struct{}),
		err:         atomic.Value{},
	}
	go c.dialLoop()
	return c
}

// dialLoop repeatedly calls c.dialContext and passes the resulting
// net.PacketConn to c.exchange. It returns only when c is closed or dialContext
// returns an error.
func (c *RedialPacketConn) dialLoop() {
	ctx, cancel := context.WithCancel(context.Background())
	for {
		select {
		case <-c.closed:
			cancel()
			return
		default:
		}
		conn, err := c.dialContext(ctx)
		if err != nil {
			c.closeWithError(err)
			cancel()
			return
		}
		c.exchange(conn)
		conn.Close()
	}
}

// exchange calls ReadFrom on the given net.PacketConn and places the resulting
// packets in the receive queue, and takes packets from the send queue and calls
// WriteTo on them, making the current net.PacketConn active.
func (c *RedialPacketConn) exchange(conn net.PacketConn) {
	readErrCh := make(chan error)
	writeErrCh := make(chan error)

	go func() {
		defer close(readErrCh)
		for {
			select {
			case <-c.closed:
				return
			case <-writeErrCh:
				return
			default:
			}

			var buf [1500]byte
			n, _, err := conn.ReadFrom(buf[:])
			if err != nil {
				readErrCh <- err
				return
			}
			p := make([]byte, n)
			copy(p, buf[:])
			select {
			case c.recvQueue <- p:
			default: // OK to drop packets.
			}
		}
	}()

	go func() {
		defer close(writeErrCh)
		for {
			select {
			case <-c.closed:
				return
			case <-readErrCh:
				return
			case p := <-c.sendQueue:
				_, err := conn.WriteTo(p, c.remoteAddr)
				if err != nil {
					writeErrCh <- err
					return
				}
			}
		}
	}()

	select {
	case <-readErrCh:
	case <-writeErrCh:
	}
}

// ReadFrom reads a packet from the currently active net.PacketConn. The
// packet's original remote address is replaced with the RedialPacketConn's own
// remote address.
func (c *RedialPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
	select {
	case <-c.closed:
		return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
	default:
	}
	select {
	case <-c.closed:
		return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
	case buf := <-c.recvQueue:
		return copy(p, buf), c.remoteAddr, nil
	}
}

// WriteTo writes a packet to the currently active net.PacketConn. The addr
// argument is ignored and instead replaced with the RedialPacketConn's own
// remote address.
func (c *RedialPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
	// addr is ignored.
	select {
	case <-c.closed:
		return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
	default:
	}
	buf := make([]byte, len(p))
	copy(buf, p)
	select {
	case c.sendQueue <- buf:
		return len(buf), nil
	default:
		// Drop the outgoing packet if the send queue is full.
		return len(buf), nil
	}
}

// closeWithError unblocks pending operations and makes future operations fail
// with the given error. If err is nil, it becomes errClosedPacketConn.
func (c *RedialPacketConn) closeWithError(err error) error {
	var once bool
	c.closeOnce.Do(func() {
		// Store the error to be returned by future read/write
		// operations.
		if err == nil {
			err = errors.New("operation on closed connection")
		}
		c.err.Store(err)
		close(c.closed)
		once = true
	})
	if !once {
		return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
	}
	return nil
}

// Close unblocks pending operations and makes future operations fail with a
// "closed connection" error.
func (c *RedialPacketConn) Close() error {
	return c.closeWithError(nil)
}

// LocalAddr returns the localAddr value that was passed to NewRedialPacketConn.
func (c *RedialPacketConn) LocalAddr() net.Addr { return c.localAddr }

func (c *RedialPacketConn) SetDeadline(t time.Time) error      { return errNotImplemented }
func (c *RedialPacketConn) SetReadDeadline(t time.Time) error  { return errNotImplemented }
func (c *RedialPacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy