All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.brutusin.rpc.websocket.WebsocketEndpoint Maven / Gradle / Ivy

/*
 * Copyright 2016 Ignacio del Valle Alles [email protected].
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.brutusin.rpc.websocket;

import java.io.IOException;
import java.lang.reflect.Type;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import javax.websocket.CloseReason;
import javax.websocket.Endpoint;
import javax.websocket.Session;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import org.brutusin.rpc.RpcResponse;
import org.brutusin.rpc.RpcRequest;
import org.brutusin.rpc.exception.ServiceNotFoundException;
import org.brutusin.json.spi.JsonCodec;
import org.brutusin.json.spi.JsonSchema;
import org.brutusin.rpc.RpcSpringContext;
import org.brutusin.rpc.RpcUtils;
import org.brutusin.rpc.exception.InvalidRequestException;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.security.core.context.SecurityContext;

/**
 *
 * @author Ignacio del Valle Alles [email protected]
 */
public class WebsocketEndpoint extends Endpoint {
    
    private final Map contextMap = Collections.synchronizedMap(new HashMap());
    private final Map wrapperMap = Collections.synchronizedMap(new HashMap());

    /**
     *
     * @param session
     * @param config
     */
    @Override
    public void onOpen(Session session, EndpointConfig config) {
        final WebsocketContext websocketContext = contextMap.get(session.getRequestParameterMap().get("requestId").get(0));
        if (!allowAccess(session, websocketContext)) {
            try {
                session.close(new CloseReason(CloseReason.CloseCodes.CANNOT_ACCEPT, "Authentication required"));
            } catch (IOException ex) {
                throw new RuntimeException(ex);
            }
            return;
        }
        final SessionImpl sessionImpl = new SessionImpl(session, websocketContext);
        sessionImpl.init();
        wrapperMap.put(session.getId(), sessionImpl);
        
        session.addMessageHandler(new MessageHandler.Whole() {
            public void onMessage(String message) {
                WebsocketActionSupportImpl.setInstance(new WebsocketActionSupportImpl(sessionImpl));
                try {
                    String response = process(message, sessionImpl);
                    if (response != null) {
                        sessionImpl.sendToPeerRaw(response);
                    }
                } finally {
                    WebsocketActionSupportImpl.clear();
                }
            }
        });
    }
    
    public Map getContextMap() {
        return contextMap;
    }
    
    @Override
    public void onClose(Session session, CloseReason closeReason) {
        contextMap.remove(session.getRequestParameterMap().get("requestId").get(0));
        final SessionImpl sessionImpl = wrapperMap.remove(session.getId());
        if (sessionImpl != null) {
            try {
                WebsocketActionSupportImpl.setInstance(new WebsocketActionSupportImpl(sessionImpl));
                for (Topic topic : sessionImpl.getCtx().getSpringContext().getTopics().values()) {
                    try {
                        topic.unsubscribe();
                    } catch (InvalidSubscriptionException ise) {
                        // Ignored already unsubscribed
                    }
                }
            } finally {
                WebsocketActionSupportImpl.clear();
                sessionImpl.close();
            }
        }
    }
    
    @Override
    public void onError(Session session, Throwable thr) {
        thr.printStackTrace();
    }
    
    protected boolean allowAccess(Session session, WebsocketContext websocketContext) {
        final RpcSpringContext rpcCtx = websocketContext.getSpringContext();
        if (rpcCtx.getParent() != null) {
            try {
                if (rpcCtx.getParent().getBean("springSecurityFilterChain") != null) { // Security active
                    final SecurityContext sc = (SecurityContext) websocketContext.getSecurityContext();
                    if (sc.getAuthentication() == null) {
                        return false;
                    } else {
                        return sc.getAuthentication().isAuthenticated();
                    }
                }
            } catch (NoSuchBeanDefinitionException ex) {
                return true;
            }
        }
        return true;
    }

    /**
     *
     * @param message
     * @return
     */
    private String process(String message, SessionImpl sessionImpl) {
        RpcRequest req = null;
        Object result = null;
        Throwable throwable = null;
        try {
            req = JsonCodec.getInstance().parse(message, RpcRequest.class);
            result = execute(req, sessionImpl.getCtx().getSpringContext());
        } catch (Throwable th) {
            throwable = th;
        }
        if (req != null && req.getId() == null) {
            return null;
        }
        RpcResponse resp = new RpcResponse();
        if (req != null) {
            resp.setId(req.getId());
        }
        resp.setError(RpcResponse.Error.from(throwable));
        resp.setResult(result);
        return JsonCodec.getInstance().transform(resp);
    }

    /**
     *
     * @param request
     * @return
     */
    private Object execute(RpcRequest request, RpcSpringContext rpcCtx) throws Exception {
        if (!"2.0".equals(request.getJsonrpc())) {
            throw new InvalidRequestException("Only jsonrpc 2.0 supported");
        }
        String serviceId = request.getMethod();
        Map services = rpcCtx.getWebSocketServices();
        if (serviceId == null || !services.containsKey(serviceId)) {
            throw new ServiceNotFoundException();
        }
        WebsocketAction service = services.get(serviceId);
        Object input;
        if (request.getParams() == null) {
            input = null;
        } else {
            Type inputType = service.getInputType();
            JsonSchema inputSchema = JsonCodec.getInstance().getSchema(inputType);
            inputSchema.validate(request.getParams());
            if (inputType.equals(Object.class)) {
                input = request.getParams();
            } else {
                input = JsonCodec.getInstance().load(request.getParams(), RpcUtils.getClass(inputType));
            }
        }
        return service.execute(input);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy