All Downloads are FREE. Search and download functionalities are using the official Maven repository.

runtime.go.irt.transport_websocket_server.go Maven / Gradle / Ivy

The newest version!
package irt

import (
	"encoding/json"
	"fmt"
	"net/http"
	"time"

	"github.com/gorilla/websocket"
)

type processorRpcRequestMessage struct {
	connection *wsConnection
	message    *WebSocketRequestMessage
}

type processorRpcResponseMessage struct {
	connection *wsConnection
	message    *WebSocketResponseMessage
}

type wsConnection struct {
	conn     *websocket.Conn
	header   http.Header
	messages chan *processorRpcResponseMessage
	context  *ConnectionContext
}

type WebSocketServerTransport struct {
	dispatcher *Dispatcher
	marshaller Marshaller
	logger     Logger
	started    bool
	stopping   bool

	handlers *ConnectionHandlers

	upgrader    websocket.Upgrader
	connections map[*wsConnection]bool
	register    chan *wsConnection
	unregister  chan *wsConnection
	terminate   chan bool

	rpcRequests  chan *processorRpcRequestMessage
	rpcResponses chan *processorRpcResponseMessage
}

func NewWebSocketServerTransportEx(dispatcher *Dispatcher, marshaller Marshaller, logger Logger,
	handlers *ConnectionHandlers) *WebSocketServerTransport {
	transport := &WebSocketServerTransport{
		dispatcher: dispatcher,
		upgrader: websocket.Upgrader{
			ReadBufferSize:    DefaultReadBufferSize,
			WriteBufferSize:   DefaultWriteBufferSize,
			EnableCompression: true,
			CheckOrigin: func(r *http.Request) bool {
				return true
			},
		},
		marshaller:  marshaller,
		logger:      logger,
		connections: map[*wsConnection]bool{},
		register:    make(chan *wsConnection, DefaultRegistrationBuffer),
		unregister:  make(chan *wsConnection, DefaultRegistrationBuffer),

		handlers: handlers,

		rpcRequests:  make(chan *processorRpcRequestMessage, DefaultProcessingThreads),
		rpcResponses: make(chan *processorRpcResponseMessage, DefaultProcessingThreads),
	}

	return transport
}

func NewWebSocketServerTransport(dispatcher *Dispatcher, handlers *ConnectionHandlers) *WebSocketServerTransport {
	return NewWebSocketServerTransportEx(dispatcher, NewJSONMarshaller(false), NewConsoleLogger(LogTrace), handlers)
}

func (t *WebSocketServerTransport) runConnectionReader(c *wsConnection) {
	t.logger.Logf(LogTrace, "WebSocketReader - Started")
	defer func() {
		t.unregister <- c
		c.conn.Close()
		t.logger.Logf(LogTrace, "WebSocketReader - Finished")
	}()
	c.conn.SetReadLimit(DefaultMaxPackageSize)
	c.conn.SetReadDeadline(time.Now().Add(DefaultPongWaitTime))
	c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(DefaultPongWaitTime)); return nil })
	for {
		msgtype, message, err := c.conn.ReadMessage()
		if err != nil {
			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
				t.logger.Logf(LogError, "Unexpected connection close: %+v", err.Error())
			}
			break
		}

		if msgtype != websocket.TextMessage {
			t.logger.Logf(LogError, "Unsupported message type! %d", msgtype)
			continue
		}

		var msgType string
		switch msgtype {
		case websocket.TextMessage:
			msgType = "text"
			// case websocket.BinaryMessage:
			// 	msgType = "binary"
			// case websocket.PingMessage: msgType = "ping"
			// case websocket.PongMessage: msgType = "pong"
			// case websocket.CloseMessage: msgType = "close"
		default:
			msgType = "Unknown"
		}

		t.logger.Logf(LogTrace, "Incoming message:\n\nType:%s\nMessage:%s", msgType, string(message))

		msgBase := &WebSocketMessageBase{}
		err = t.marshaller.Unmarshal(message, msgBase)
		if err != nil || msgBase.Kind == "" {
			t.logger.Logf(LogError, "Can't unmarshal %s, body: %s", err.Error(), string(message))
		}

		switch msgBase.Kind {
		case MessageKindRPCRequest:
			msg := &WebSocketRequestMessage{}
			err = t.marshaller.Unmarshal(message, msg)
			if err != nil {
				t.logger.Logf(LogError, "Can't unmarshal %s, body: %s", err.Error(), string(message))
				continue
			}

			if msg.Headers != nil {
				authorization, ok := msg.Headers["Authorization"]
				if ok {
					t.logger.Logf(LogTrace, "Incoming message: Auth Update\n\n%s", authorization)
					c.context.System.Auth.UpdateFromValue(authorization)
					if t.handlers != nil && t.handlers.OnAuth != nil {
						if err := t.handlers.OnAuth(c.context); err != nil {
							c.conn.WriteMessage(websocket.CloseMessage, []byte(fmt.Sprintf("{\"error\": \"%s\"}", err.Error())))
							continue
						}
					}

					if msg.Service == "" || msg.Method == "" {
						// This was just auth updated
						c.messages <- &processorRpcResponseMessage{
							connection: c,
							message: &WebSocketResponseMessage{
								Kind: MessageKindRPCResponse,
								Ref: msg.ID,
							},
						}
						continue
					}
				}
			}

			t.rpcRequests <- &processorRpcRequestMessage{
				connection: c,
				message:    msg,
			}
		default:
			t.logger.Logf(LogError, "Unsupported request: %s", string(message))
		}
	}
}

func (t *WebSocketServerTransport) runConnectionWriter(c *wsConnection) {
	t.logger.Logf(LogTrace, "WebSocketWriter - Started")
	ticker := time.NewTicker(DefaultPingPeriod)
	defer func() {
		ticker.Stop()
		c.conn.Close()
		t.logger.Logf(LogTrace, "WebSocketWriter - Finished")
	}()
	for {
		select {
		case message, ok := <-c.messages:
			c.conn.SetWriteDeadline(time.Now().Add(DefaultWriteWaitTime))
			if !ok {
				// Channel is closed
				c.conn.WriteMessage(websocket.CloseMessage, []byte{})
				return
			}

			data, err := t.marshaller.Marshal(message.message)
			if err != nil {
				t.logger.Logf(LogError, "Can't marshal %s", err.Error())
				continue
			}

			t.logger.Logf(LogTrace, "Outgoing message:\n\nRef: %s\nData: %s", string(data))

			w, err := c.conn.NextWriter(websocket.TextMessage)
			if err != nil {
				return
			}
			w.Write(data)

			if err := w.Close(); err != nil {
				return
			}
		case <-ticker.C:
			t.logger.Logf(LogTrace, "WebSocket - Ping")
			c.conn.SetWriteDeadline(time.Now().Add(DefaultWriteWaitTime))
			if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
				return
			}
		}
	}
}

func (t *WebSocketServerTransport) runProcessor() {
	t.logger.Logf(LogTrace, "WebSocket Requests Processor - Started")
	for {
		select {
		case <-t.terminate:
			t.logger.Logf(LogTrace, "WebSocket Requests Processor - Finished")
			return

		case task := <-t.rpcRequests:
			{
				msg := &WebSocketResponseMessage{
					Kind: MessageKindRPCResponse,
					Ref:  task.message.ID,
				}
				result := &processorRpcResponseMessage{
					connection: task.connection,
					message:    msg,
				}

				data, err := t.dispatcher.Dispatch(task.connection.context, task.message.Service, task.message.Method, task.message.Data)
				if err != nil {
					msg.Kind = MessageKindRPCFailure
					errStr, _ := json.Marshal(err.Error())
					msg.Data = errStr
					t.rpcResponses <- result
					continue
				}

				msg.Data = data
				t.rpcResponses <- result
			}
		}
	}
}

func (t *WebSocketServerTransport) run() {

	t.terminate = make(chan bool)
	go func() {
		t.logger.Logf(LogDebug, "WebSocketTransport - Started")
		for {
			select {
			case <-t.terminate:
				t.logger.Logf(LogDebug, "Termination requested, closing connections...")
				for c := range t.connections {
					c.conn.Close()
				}
				t.stopping = false
				t.started = false
				t.logger.Logf(LogDebug, "WebSocketTransport - Finished")
				return
			case connection := <-t.register:
				t.connections[connection] = true
				go t.runConnectionReader(connection)
				go t.runConnectionWriter(connection)
				t.logger.Logf(LogDebug, "Connected (Connections %d)", len(t.connections))
			case connection := <-t.unregister:
				if _, ok := t.connections[connection]; ok {
					close(connection.messages)
					delete(t.connections, connection)
				}
				if t.handlers != nil && t.handlers.OnDisconnect != nil {
					t.handlers.OnDisconnect(connection.context)
				}

				t.logger.Logf(LogDebug, "Disconnected (Connections %d)", len(t.connections))
			case result := <-t.rpcResponses:
				select {
				case result.connection.messages <- result:
				default:
					close(result.connection.messages)
					_, ok := t.connections[result.connection]
					if ok {
						delete(t.connections, result.connection)
					}
				}
			}
		}
	}()

	for i := 0; i < DefaultProcessingThreads; i++ {
		go t.runProcessor()
	}
}

func (t *WebSocketServerTransport) Start() {
	if t.started {
		return
	}

	t.started = true
	t.run()
}

func (t *WebSocketServerTransport) Stop() {
	if !t.started || t.stopping {
		return
	}

	t.stopping = true
	close(t.terminate)
}

func (t *WebSocketServerTransport) ServeWS(w http.ResponseWriter, r *http.Request) {
	if !t.started {
		t.logger.Logf(LogError, "WebSocket Transport wasn't started and can't server requests.")
		return
	}

	// if r.URL.Scheme != "ws" && r.URL.Scheme != "wss" {
	// 	t.logger.Logf(LogError, "WebSocket Transport expects the scheme to be ws:// or wss://, got %s. rejecting.", r.URL.Scheme)
	// 	return
	// }

	conn, err := t.upgrader.Upgrade(w, r, nil)
	if err != nil {
		t.logger.Logf(LogWarning, "failed to upgrade connection: %s", err.Error())
		return
	}

	wsConnection := &wsConnection{
		conn:     conn,
		messages: make(chan *processorRpcResponseMessage, DefaultConnectionBuffer),
		context: &ConnectionContext{
			System: &SystemContext{
				Auth:       &Authorization{},
				RemoteAddr: r.RemoteAddr,
			},
			User: nil,
		},
	}

	err = wsConnection.context.System.Update(r)
	if err != nil {
		t.logger.Logf(LogWarning, "failed to properly update system context: %s", err.Error())
	}

	if t.handlers != nil && t.handlers.OnConnect != nil {
		if err := t.handlers.OnConnect(wsConnection.context, r); err != nil {
			wsConnection.conn.WriteMessage(websocket.CloseMessage, []byte(fmt.Sprintf("{\"error\": \"%s\"}", err.Error())))
			return
		}
	}

	t.register <- wsConnection
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy