common.websocketconn.websocketconn_test.go Maven / Gradle / Ivy
The newest version!
package websocketconn
import (
"bytes"
"fmt"
"io"
"net"
"net/http"
"net/url"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
)
// Returns a (server, client) pair of websocketconn.Conns.
func connPair() (*Conn, *Conn, error) {
// Will be assigned inside server.Handler.
var serverConn *Conn
// Start up a web server to receive the request.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, nil, err
}
defer ln.Close()
errCh := make(chan error)
server := http.Server{
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
upgrader := websocket.Upgrader{
CheckOrigin: func(*http.Request) bool { return true },
}
ws, err := upgrader.Upgrade(rw, req, nil)
if err != nil {
errCh <- err
return
}
serverConn = New(ws)
close(errCh)
}),
}
defer server.Close()
go func() {
err := server.Serve(ln)
if err != nil && err != http.ErrServerClosed {
errCh <- err
}
}()
// Make a request to the web server.
urlStr := (&url.URL{Scheme: "ws", Host: ln.Addr().String()}).String()
ws, _, err := (&websocket.Dialer{}).Dial(urlStr, nil)
if err != nil {
return nil, nil, err
}
clientConn := New(ws)
// The server is finished when errCh is written to or closed.
err = <-errCh
if err != nil {
return nil, nil, err
}
return serverConn, clientConn, nil
}
// Test that you can write in chunks and read the result concatenated.
func TestWrite(t *testing.T) {
tests := [][][]byte{
{},
{[]byte("foo")},
{[]byte("foo"), []byte("bar")},
{{}, []byte("foo"), {}, {}, []byte("bar")},
}
for _, test := range tests {
s, c, err := connPair()
if err != nil {
t.Fatal(err)
}
// This is a little awkward because we need to read to and write
// from both ends of the Conn, and we need to do it in separate
// goroutines because otherwise a Write may block waiting for
// someone to Read it. Here we set up a loop in a separate
// goroutine, reading from the Conn s and writing to the dataCh
// and errCh channels, whose ultimate effect in the select loop
// below is like
// data, err := io.ReadAll(s)
dataCh := make(chan []byte)
errCh := make(chan error)
go func() {
for {
var buf [1024]byte
n, err := s.Read(buf[:])
if err != nil {
errCh <- err
return
}
p := make([]byte, n)
copy(p, buf[:])
dataCh <- p
}
}()
// Write the data to the client side of the Conn, one chunk at a
// time.
for i, chunk := range test {
n, err := c.Write(chunk)
if err != nil || n != len(chunk) {
t.Fatalf("%+q Write chunk %d: got (%d, %v), expected (%d, %v)",
test, i, n, err, len(chunk), nil)
}
}
// We cannot immediately c.Close here, because that closes the
// connection right away, without waiting for buffered data to
// be sent.
// Pull data and err from the server goroutine above.
var data []byte
err = nil
loop:
for {
select {
case p := <-dataCh:
data = append(data, p...)
case err = <-errCh:
break loop
case <-time.After(100 * time.Millisecond):
break loop
}
}
s.Close()
c.Close()
// Now data and err contain the result of reading everything
// from s.
expected := bytes.Join(test, []byte{})
if err != nil || !bytes.Equal(data, expected) {
t.Fatalf("%+q ReadAll: got (%+q, %v), expected (%+q, %v)",
test, data, err, expected, nil)
}
}
}
// Test that multiple goroutines may call Read on a Conn simultaneously. Run
// this with
//
// go test -race
func TestConcurrentRead(t *testing.T) {
s, c, err := connPair()
if err != nil {
t.Fatal(err)
}
defer s.Close()
// Set up multiple threads reading from the same conn.
errCh := make(chan error, 2)
var wg sync.WaitGroup
wg.Add(2)
for i := 0; i < 2; i++ {
go func() {
defer wg.Done()
_, err := io.Copy(io.Discard, s)
if err != nil {
errCh <- err
}
}()
}
// Write a bunch of data to the other end.
for i := 0; i < 2000; i++ {
_, err := fmt.Fprintf(c, "%d", i)
if err != nil {
c.Close()
t.Fatalf("Write: %v", err)
}
}
c.Close()
wg.Wait()
close(errCh)
err = <-errCh
if err != nil {
t.Fatalf("Read: %v", err)
}
}
// Test that multiple goroutines may call Write on a Conn simultaneously. Run
// this with
//
// go test -race
func TestConcurrentWrite(t *testing.T) {
s, c, err := connPair()
if err != nil {
t.Fatal(err)
}
// Set up multiple threads writing to the same conn.
errCh := make(chan error, 3)
var wg sync.WaitGroup
wg.Add(2)
for i := 0; i < 2; i++ {
go func() {
defer wg.Done()
for j := 0; j < 1000; j++ {
_, err := fmt.Fprintf(s, "%d", j)
if err != nil {
errCh <- err
break
}
}
}()
}
go func() {
wg.Wait()
err := s.Close()
if err != nil {
errCh <- err
}
close(errCh)
}()
// Read from the other end.
_, err = io.Copy(io.Discard, c)
c.Close()
if err != nil {
t.Fatalf("Read: %v", err)
}
err = <-errCh
if err != nil {
t.Fatalf("Write: %v", err)
}
}
// Test that Read and Write methods return errors after Close.
func TestClose(t *testing.T) {
s, c, err := connPair()
if err != nil {
t.Fatal(err)
}
defer c.Close()
err = s.Close()
if err != nil {
t.Fatal(err)
}
var buf [10]byte
n, err := s.Read(buf[:])
if n != 0 || err == nil {
t.Fatalf("Read after Close returned (%v, %v), expected (%v, non-nil)", n, err, 0)
}
_, err = s.Write([]byte{1, 2, 3})
// Here we break the abstraction a little and look for a specific error,
// io.ErrClosedPipe. This is because we know the Conn uses an io.Pipe
// internally.
if err != io.ErrClosedPipe {
t.Fatalf("Write after Close returned %v, expected %v", err, io.ErrClosedPipe)
}
}
// Benchmark creating a server websocket.Conn (without the websocketconn.Conn
// wrapper) for different read/write buffer sizes.
func BenchmarkUpgradeBufferSize(b *testing.B) {
// Buffer size of 0 would mean the default of 4096:
// https://github.com/gorilla/websocket/blob/v1.5.0/conn.go#L37
// But a size of zero also has the effect of causing reuse of the HTTP
// server's buffers. So we test 4096 separately from 0.
// https://github.com/gorilla/websocket/blob/v1.5.0/server.go#L32
for _, bufSize := range []int{0, 128, 1024, 2048, 4096, 8192} {
upgrader := websocket.Upgrader{
CheckOrigin: func(*http.Request) bool { return true },
ReadBufferSize: bufSize,
WriteBufferSize: bufSize,
}
b.Run(fmt.Sprintf("%d", bufSize), func(b *testing.B) {
// Start up a web server to receive the request.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
b.Fatal(err)
}
defer ln.Close()
wsCh := make(chan *websocket.Conn)
errCh := make(chan error)
server := http.Server{
Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
ws, err := upgrader.Upgrade(rw, req, nil)
if err != nil {
errCh <- err
return
}
wsCh <- ws
}),
}
defer server.Close()
go func() {
err := server.Serve(ln)
if err != nil && err != http.ErrServerClosed {
errCh <- err
}
}()
// Make a request to the web server.
dialer := &websocket.Dialer{
ReadBufferSize: bufSize,
WriteBufferSize: bufSize,
}
urlStr := (&url.URL{Scheme: "ws", Host: ln.Addr().String()}).String()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ws, _, err := dialer.Dial(urlStr, nil)
if err != nil {
b.Fatal(err)
}
ws.Close()
select {
case <-wsCh:
case err := <-errCh:
b.Fatal(err)
}
}
b.StopTimer()
})
}
}
// Benchmark read/write in the client←server and server←client directions, with
// messages of different sizes. Run with -benchmem to see memory allocations.
func BenchmarkReadWrite(b *testing.B) {
trial := func(b *testing.B, readConn, writeConn *Conn, msgSize int) {
go func() {
io.Copy(io.Discard, readConn)
}()
data := make([]byte, msgSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
n, err := writeConn.Write(data[:])
b.SetBytes(int64(n))
if err != nil {
b.Fatal(err)
}
}
}
for _, msgSize := range []int{150, 3000} {
s, c, err := connPair()
if err != nil {
b.Fatal(err)
}
b.Run(fmt.Sprintf("c←s %d", msgSize), func(b *testing.B) {
trial(b, c, s, msgSize)
})
b.Run(fmt.Sprintf("s←c %d", msgSize), func(b *testing.B) {
trial(b, s, c, msgSize)
})
err = s.Close()
if err != nil {
b.Fatal(err)
}
err = c.Close()
if err != nil {
b.Fatal(err)
}
}
}