All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.cxf.jaxrs.provider.ServerProviderFactory Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.cxf.jaxrs.provider;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.ws.rs.BeanParam;
import javax.ws.rs.Priorities;
import javax.ws.rs.RuntimeType;
import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.container.ContainerResponseFilter;
import javax.ws.rs.container.DynamicFeature;
import javax.ws.rs.container.PreMatching;
import javax.ws.rs.core.Configuration;
import javax.ws.rs.core.Feature;
import javax.ws.rs.core.FeatureContext;
import javax.ws.rs.ext.ExceptionMapper;
import javax.ws.rs.ext.ReaderInterceptor;
import javax.ws.rs.ext.WriterInterceptor;

import org.apache.cxf.Bus;
import org.apache.cxf.BusFactory;
import org.apache.cxf.common.util.ClassHelper;
import org.apache.cxf.endpoint.Endpoint;
import org.apache.cxf.helpers.CastUtils;
import org.apache.cxf.jaxrs.impl.ConfigurableImpl;
import org.apache.cxf.jaxrs.impl.RequestPreprocessor;
import org.apache.cxf.jaxrs.impl.ResourceInfoImpl;
import org.apache.cxf.jaxrs.impl.WebApplicationExceptionMapper;
import org.apache.cxf.jaxrs.lifecycle.ResourceProvider;
import org.apache.cxf.jaxrs.model.ApplicationInfo;
import org.apache.cxf.jaxrs.model.BeanParamInfo;
import org.apache.cxf.jaxrs.model.ClassResourceInfo;
import org.apache.cxf.jaxrs.model.FilterProviderInfo;
import org.apache.cxf.jaxrs.model.OperationResourceInfo;
import org.apache.cxf.jaxrs.model.ProviderInfo;
import org.apache.cxf.jaxrs.utils.AnnotationUtils;
import org.apache.cxf.jaxrs.utils.InjectionUtils;
import org.apache.cxf.jaxrs.utils.JAXRSUtils;
import org.apache.cxf.message.Message;
import org.apache.cxf.message.MessageUtils;

public final class ServerProviderFactory extends ProviderFactory {
    private static final Set> SERVER_FILTER_INTERCEPTOR_CLASSES = 
        new HashSet>(Arrays.>asList(ContainerRequestFilter.class,
                                                      ContainerResponseFilter.class,
                                                      ReaderInterceptor.class,
                                                      WriterInterceptor.class));
    
    private static final String WADL_PROVIDER_NAME = "org.apache.cxf.jaxrs.model.wadl.WadlGenerator";
    private static final String MAKE_DEFAULT_WAE_LEAST_SPECIFIC = "default.wae.mapper.least.specific";
    private List>> exceptionMappers = 
        new ArrayList>>(1);
    
    private List> preMatchContainerRequestFilters = 
        new ArrayList>(1);
    private Map> postMatchContainerRequestFilters = 
        new NameKeyMap>(true);
    private Map> containerResponseFilters = 
        new NameKeyMap>(false);
    private RequestPreprocessor requestPreprocessor;
    private ApplicationInfo application;
    private Set dynamicFeatures = new LinkedHashSet();
    
    private Map, BeanParamInfo> beanParams = new HashMap, BeanParamInfo>();
    private ProviderInfo wadlGenerator;
        
    private ServerProviderFactory(Bus bus) {
        super(bus);
        wadlGenerator = createWadlGenerator(bus);
    }
    
    private static ProviderInfo createWadlGenerator(Bus bus) {
        Object provider = createProvider(WADL_PROVIDER_NAME);
        if (provider == null) {
            return null;
        } else {
            return new ProviderInfo((ContainerRequestFilter)provider, bus, true);
        }
    }
    
    public static ServerProviderFactory getInstance() {
        return createInstance(null);
    }
    
    public static ServerProviderFactory createInstance(Bus bus) {
        if (bus == null) {
            bus = BusFactory.getThreadDefaultBus();
        }
        ServerProviderFactory factory = new ServerProviderFactory(bus);
        ProviderFactory.initFactory(factory);
        factory.setProviders(false, false, new WebApplicationExceptionMapper());
        factory.setBusProviders();
        return factory;
    }
    
    public static ServerProviderFactory getInstance(Message m) {
        Endpoint e = m.getExchange().getEndpoint();
        return (ServerProviderFactory)e.get(SERVER_FACTORY_NAME);
    }
    
    public List> getPreMatchContainerRequestFilters() {
        return getContainerRequestFilters(preMatchContainerRequestFilters, true);
    }
    
    public List> getPostMatchContainerRequestFilters(Set names) {
        return getBoundFilters(postMatchContainerRequestFilters, names);
        
    }
    
    private List> getContainerRequestFilters(
        List> filters, boolean syncNeeded) {
        
        if (wadlGenerator == null) { 
            return filters;
        }
        if (filters.size() == 0) {
            return Collections.singletonList(wadlGenerator);
        } else if (!syncNeeded) {
            filters.add(0, wadlGenerator);
            return filters;
        } else {
            synchronized (filters) {
                if (filters.get(0) != wadlGenerator) {
                    filters.add(0, wadlGenerator);
                }
            }
            return filters;
        }
    }
    
    public List> getContainerResponseFilters(Set names) {
        return getBoundFilters(containerResponseFilters, names);
    }
    
    public void addBeanParamInfo(BeanParamInfo bpi) {
        beanParams.put(bpi.getResourceClass(), bpi);
        for (Method m : bpi.getResourceClass().getMethods()) {
            if (m.getAnnotation(BeanParam.class) != null) {
                BeanParamInfo methodBpi = new BeanParamInfo(m.getParameterTypes()[0], getBus());
                addBeanParamInfo(methodBpi);
            }
        }
        for (Field f : bpi.getResourceClass().getDeclaredFields()) {
            if (f.getAnnotation(BeanParam.class) != null) {
                BeanParamInfo fieldBpi = new BeanParamInfo(f.getType(), getBus());
                addBeanParamInfo(fieldBpi);
            }
        }
    }
    
    public BeanParamInfo getBeanParamInfo(Class beanClass) {
        return beanParams.get(beanClass);
    }

    @SuppressWarnings("unchecked")
    public  ExceptionMapper createExceptionMapper(Class exceptionType,
                                                                          Message m) {
        List>> candidates = new LinkedList>>();
        for (ProviderInfo> em : exceptionMappers) {
            if (handleMapper(em, exceptionType, m, ExceptionMapper.class, true)) {
                candidates.add(em);
            }
        }
        if (candidates.size() == 0) {
            return null;
        }
        boolean makeDefaultWaeLeastSpecific = 
            MessageUtils.getContextualBoolean(m, MAKE_DEFAULT_WAE_LEAST_SPECIFIC, false);
        Collections.sort(candidates, new ExceptionProviderInfoComparator(exceptionType,
                                                                         makeDefaultWaeLeastSpecific));
        return (ExceptionMapper) candidates.get(0).getProvider();
    }
    
    
    @SuppressWarnings("unchecked")
    @Override
    protected void setProviders(boolean custom, boolean busGlobal, Object... providers) {
        List> postMatchRequestFilters = 
            new LinkedList>();
        List> postMatchResponseFilters = 
            new LinkedList>();
        
        List> theProviders = 
            prepareProviders(custom, busGlobal, (Object[])providers, application);
        super.setCommonProviders(theProviders);
        for (ProviderInfo provider : theProviders) {
            Class providerCls = ClassHelper.getRealClass(getBus(), provider.getProvider());
            
            if (filterContractSupported(provider, providerCls, ContainerRequestFilter.class)) {
                addContainerRequestFilter(postMatchRequestFilters, 
                                          (ProviderInfo)provider);
            }
            
            if (filterContractSupported(provider, providerCls, ContainerResponseFilter.class)) {
                postMatchResponseFilters.add((ProviderInfo)provider); 
            }
            
            if (DynamicFeature.class.isAssignableFrom(providerCls)) {
                //TODO: review the possibility of DynamicFeatures needing to have Contexts injected
                Object feature = provider.getProvider();
                dynamicFeatures.add((DynamicFeature)feature);
            }
            
            
            if (ExceptionMapper.class.isAssignableFrom(providerCls)) {
                addProviderToList(exceptionMappers, provider); 
            }
            
        }
        
        Collections.sort(preMatchContainerRequestFilters, 
            new BindingPriorityComparator(ContainerRequestFilter.class, true));
        mapInterceptorFilters(postMatchContainerRequestFilters, postMatchRequestFilters,
                              ContainerRequestFilter.class, true);
        mapInterceptorFilters(containerResponseFilters, postMatchResponseFilters,
                              ContainerResponseFilter.class, false);
        
        injectContextProxies(exceptionMappers,
            postMatchContainerRequestFilters.values(), preMatchContainerRequestFilters,
            containerResponseFilters.values());
    }
    
    @Override
    protected void injectContextProxiesIntoProvider(ProviderInfo pi) {
        injectContextProxiesIntoProvider(pi, application == null ? null : application.getProvider());
    }
    
    @Override
    protected void injectContextValues(ProviderInfo pi, Message m) {
        if (m != null) {
            InjectionUtils.injectContexts(pi.getProvider(), pi, m);
            if (application != null && application.contextsAvailable()) {
                InjectionUtils.injectContexts(application.getProvider(), application, m);
            }
        }
    }
    
    private void addContainerRequestFilter(
        List> postMatchFilters,
        ProviderInfo p) {
        ContainerRequestFilter filter = p.getProvider();
        if (isWadlGenerator(filter.getClass())) {
            wadlGenerator = p; 
        } else {
            if (isPrematching(filter.getClass())) {
                addProviderToList(preMatchContainerRequestFilters, p);
            } else {
                postMatchFilters.add(p);
            }
        }
        
    }
    
    private static boolean isWadlGenerator(Class filterCls) {
        if (filterCls == null || filterCls == Object.class) {
            return false;
        }
        if (WADL_PROVIDER_NAME.equals(filterCls.getName())) {
            return true;
        } else {
            return isWadlGenerator(filterCls.getSuperclass());
        }
    }
    
    public RequestPreprocessor getRequestPreprocessor() {
        return requestPreprocessor;
    }
    
    public void setApplicationProvider(ApplicationInfo app) {
        application = app;
    }
    
    public ApplicationInfo getApplicationProvider() {
        return application;
    }
    
    public void setRequestPreprocessor(RequestPreprocessor rp) {
        this.requestPreprocessor = rp;
    }
    
    public void clearExceptionMapperProxies() {
        clearProxies(exceptionMappers);
    }
    
    @Override
    public void clearProviders() {
        super.clearProviders();
        exceptionMappers.clear();
        preMatchContainerRequestFilters.clear();
        postMatchContainerRequestFilters.clear();
        containerResponseFilters.clear();
    }
    
    @Override
    public void clearThreadLocalProxies() {
        if (application != null) {
            application.clearThreadLocalProxies();
        }
        super.clearThreadLocalProxies();
    }
    
    public void applyDynamicFeatures(List list) {
        if (dynamicFeatures.size() > 0) {
            for (ClassResourceInfo cri : list) {
                doApplyDynamicFeatures(cri);
            }
        }
    }
    
    public Configuration getConfiguration(Message m) {
        return new ServerConfigurationImpl();
    }
    
    private void doApplyDynamicFeatures(ClassResourceInfo cri) {
        Set oris = cri.getMethodDispatcher().getOperationResourceInfos();
        for (OperationResourceInfo ori : oris) {
            for (DynamicFeature feature : dynamicFeatures) {
                FeatureContext featureContext = new MethodFeatureContextImpl(ori);
                feature.configure(new ResourceInfoImpl(ori), featureContext);
            }
        }
        Collection subs = cri.getSubResources();
        for (ClassResourceInfo sub : subs) {
            if (sub != cri) {
                doApplyDynamicFeatures(sub);    
            }
        }
    }
    
    protected static boolean isPrematching(Class filterCls) {
        return AnnotationUtils.getClassAnnotation(filterCls, PreMatching.class) != null;
    }
    
    
    
    private class MethodFeatureContextImpl implements FeatureContext {
        private MethodFeatureContextConfigurable configImpl;    
        private OperationResourceInfo ori;
        private String nameBinding;
        
        MethodFeatureContextImpl(OperationResourceInfo ori) {
            this.ori = ori;
            configImpl = new MethodFeatureContextConfigurable(this);
            if (application != null) {
                Map appProps = application.getProvider().getProperties();
                for (Map.Entry entry : appProps.entrySet()) {
                    configImpl.property(entry.getKey(), entry.getValue());
                }
            }
            nameBinding = DEFAULT_FILTER_NAME_BINDING 
                + ori.getClassResourceInfo().getServiceClass().getName()
                + "."
                + ori.getMethodToInvoke().toString();
        }
        

        @Override
        public Configuration getConfiguration() {
            return configImpl.getConfiguration();
        }
        
        @Override
        public FeatureContext property(String name, Object value) {
            return configImpl.property(name, value);
        }

        @Override
        public FeatureContext register(Class cls) {
            return configImpl.register(cls);
        }

        @Override
        public FeatureContext register(Object object) {
            return configImpl.register(object);
        }

        @Override
        public FeatureContext register(Class cls, int index) {
            return configImpl.register(cls, index);
        }

        @Override
        public FeatureContext register(Class cls, Class... contracts) {
            return configImpl.register(cls, contracts);
        }

        @Override
        public FeatureContext register(Class cls, Map, Integer> map) {
            return configImpl.register(cls, map);
        }

        @Override
        public FeatureContext register(Object object, int index) {
            return configImpl.register(object, index);
        }

        @Override
        public FeatureContext register(Object object, Class... contracts) {
            return configImpl.register(object, contracts);
        }

        @Override
        public FeatureContext register(Object object, Map, Integer> map) {
            return configImpl.register(object, map);
        }
        
        FeatureContext doRegister(Object provider, Map, Integer> contracts) {
        
            Map, Integer> actualContracts = new HashMap, Integer>();
            
            for (Class contract : contracts.keySet()) {
                if (SERVER_FILTER_INTERCEPTOR_CLASSES.contains(contract)
                    && contract.isAssignableFrom(provider.getClass())) {
                    actualContracts.put(contract, contracts.get(contract));
                }
            }
            if (!actualContracts.isEmpty()) {
                registerUserProvider(new FilterProviderInfo(provider, 
                    getBus(),
                    nameBinding,
                    true,
                    actualContracts));
                ori.addNameBindings(Collections.singletonList(nameBinding));
            }
            return this;
        }
        
    }
    
    private static class MethodFeatureContextConfigurable extends ConfigurableImpl {
        protected MethodFeatureContextConfigurable(MethodFeatureContextImpl mc) {
            super(mc, RuntimeType.SERVER, SERVER_FILTER_INTERCEPTOR_CLASSES.toArray(new Class[]{}));
        }
        @Override
        public FeatureContext register(Object provider, Map, Integer> contracts) {
            super.register(provider, contracts);
            return ((MethodFeatureContextImpl)super.getConfigurable())
                .doRegister(provider, contracts);
        }
        
    }
    
    public static void clearThreadLocalProxies(Message message) {
        clearThreadLocalProxies(ServerProviderFactory.getInstance(message), message);
    }
    public static void clearThreadLocalProxies(ServerProviderFactory factory, Message message) {
        factory.clearThreadLocalProxies();
        ClassResourceInfo cri =
            (ClassResourceInfo)message.getExchange().get(JAXRSUtils.ROOT_RESOURCE_CLASS);
        if (cri != null) {
            cri.clearThreadLocalProxies();
        }    
    }
    public static void releaseRequestState(Message message) {
        releaseRequestState(ServerProviderFactory.getInstance(message), message);
    }
    public static void releaseRequestState(ServerProviderFactory factory, Message message) {
        Object rootInstance = message.getExchange().remove(JAXRSUtils.ROOT_INSTANCE);
        if (rootInstance != null) {
            Object rootProvider = message.getExchange().remove(JAXRSUtils.ROOT_PROVIDER);
            if (rootProvider != null) {
                try {
                    ((ResourceProvider)rootProvider).releaseInstance(message, rootInstance);
                } catch (Throwable tex) {
                    // ignore
                }
            }
        }
        
        clearThreadLocalProxies(factory, message);
    }
    
    
    private class ServerConfigurationImpl implements Configuration {
        ServerConfigurationImpl() {
            
        }
        
        @Override
        public Set> getClasses() {
            return application != null ? application.getProvider().getClasses() 
                : Collections.>emptySet();
        }

        @Override
        public Set getInstances() {
            return application != null ? application.getProvider().getSingletons() 
                : Collections.emptySet();
        }

        @Override
        public boolean isEnabled(Feature f) {
            return dynamicFeatures.contains(f);
        }

        @Override
        public boolean isEnabled(Class featureCls) {
            for (DynamicFeature f : dynamicFeatures) {
                if (featureCls.isAssignableFrom(f.getClass())) { 
                    return true;
                }
            }
            return false;
        }

        @Override
        public boolean isRegistered(Object o) {
            return isRegistered(preMatchContainerRequestFilters, o)
                || isRegistered(postMatchContainerRequestFilters.values(), o)
                || isRegistered(containerResponseFilters.values(), o)
                || isRegistered(readerInterceptors.values(), o)
                || isRegistered(writerInterceptors.values(), o);
        }

        @Override
        public boolean isRegistered(Class cls) {
            return isRegistered(preMatchContainerRequestFilters, cls)
                || isRegistered(postMatchContainerRequestFilters.values(), cls)
                || isRegistered(containerResponseFilters.values(), cls)
                || isRegistered(readerInterceptors.values(), cls)
                || isRegistered(writerInterceptors.values(), cls);
        }

        @Override
        public Map, Integer> getContracts(Class cls) {
            Map, Integer> map = new HashMap, Integer>();
            if (isRegistered(cls)) {
                if (ContainerRequestFilter.class.isAssignableFrom(cls)) {
                    boolean isPreMatch = cls.getAnnotation(PreMatching.class) != null;
                    map.put(ContainerRequestFilter.class, 
                            getPriority(isPreMatch ? preMatchContainerRequestFilters
                                : postMatchContainerRequestFilters.values(), cls, ContainerRequestFilter.class));    
                }
                if (ContainerResponseFilter.class.isAssignableFrom(cls)) {
                    map.put(ContainerResponseFilter.class, 
                            getPriority(containerResponseFilters.values(), cls, ContainerResponseFilter.class));    
                }
                if (WriterInterceptor.class.isAssignableFrom(cls)) {
                    map.put(WriterInterceptor.class, 
                            getPriority(writerInterceptors.values(), cls, WriterInterceptor.class));    
                }
                if (ReaderInterceptor.class.isAssignableFrom(cls)) {
                    map.put(ReaderInterceptor.class, 
                            getPriority(readerInterceptors.values(), cls, ReaderInterceptor.class));    
                }
            }
            return map;
        }
        
        @Override
        public Map getProperties() {
            return application != null ? application.getProperties() 
                : Collections.emptyMap();
        }

        @Override
        public Object getProperty(String name) {
            return getProperties().get(name);
        }

        @Override
        public Collection getPropertyNames() {
            return getProperties().keySet();
        }

        @Override
        public RuntimeType getRuntimeType() {
            return RuntimeType.SERVER;
        }
        
        private boolean isRegistered(Collection list, Object o) {
            Collection> list2 = CastUtils.cast(list);
            for (ProviderInfo pi : list2) {
                if (pi.getProvider() == o) {
                    return true;
                }
            }
            return false;
        }
        private boolean isRegistered(Collection list, Class cls) {
            Collection> list2 = CastUtils.cast(list);
            for (ProviderInfo pi : list2) {
                if (cls.isAssignableFrom(pi.getProvider().getClass())) {
                    return true;
                }
            }
            return false;
        }
        private Integer getPriority(Collection list, Class cls, Class filterClass) {
            Collection> list2 = CastUtils.cast(list);
            for (ProviderInfo pi : list2) {
                if (pi instanceof FilterProviderInfo && pi.getProvider().getClass().isAssignableFrom(cls)) {
                    return ((FilterProviderInfo)pi).getPriority(filterClass);
                }
            }
            return Priorities.USER;
        }
    }
    public static class ExceptionProviderInfoComparator extends ProviderInfoClassComparator {
        private boolean makeDefaultWaeLeastSpecific;
        public ExceptionProviderInfoComparator(Class expectedCls, boolean makeDefaultWaeLeastSpecific) {
            super(expectedCls);
            this.makeDefaultWaeLeastSpecific = makeDefaultWaeLeastSpecific;
        }
        public int compare(ProviderInfo p1, ProviderInfo p2) {
            if (makeDefaultWaeLeastSpecific) {
                if (p1.getProvider() instanceof WebApplicationExceptionMapper
                    && !p1.isCustom()) {
                    return 1;
                } else if (p2.getProvider() instanceof WebApplicationExceptionMapper
                    && !p2.isCustom()) {
                    return -1;
                } 
            }
            return super.compare(p1, p2);
        }
    }
    
}