All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.cxf.jaxrs.provider.ServerProviderFactory Maven / Gradle / Ivy

There is a newer version: 4.1.0
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.cxf.jaxrs.provider;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import jakarta.ws.rs.BeanParam;
import jakarta.ws.rs.Priorities;
import jakarta.ws.rs.RuntimeType;
import jakarta.ws.rs.container.ContainerRequestFilter;
import jakarta.ws.rs.container.ContainerResponseFilter;
import jakarta.ws.rs.container.DynamicFeature;
import jakarta.ws.rs.container.PreMatching;
import jakarta.ws.rs.core.Application;
import jakarta.ws.rs.core.Configurable;
import jakarta.ws.rs.core.Configuration;
import jakarta.ws.rs.core.Feature;
import jakarta.ws.rs.core.FeatureContext;
import jakarta.ws.rs.ext.ExceptionMapper;
import jakarta.ws.rs.ext.ReaderInterceptor;
import jakarta.ws.rs.ext.WriterInterceptor;
import org.apache.cxf.Bus;
import org.apache.cxf.BusFactory;
import org.apache.cxf.common.util.ClassHelper;
import org.apache.cxf.endpoint.Endpoint;
import org.apache.cxf.helpers.CastUtils;
import org.apache.cxf.jaxrs.impl.ConfigurableImpl;
import org.apache.cxf.jaxrs.impl.FeatureContextImpl;
import org.apache.cxf.jaxrs.impl.RequestPreprocessor;
import org.apache.cxf.jaxrs.impl.ResourceInfoImpl;
import org.apache.cxf.jaxrs.impl.WebApplicationExceptionMapper;
import org.apache.cxf.jaxrs.lifecycle.ResourceProvider;
import org.apache.cxf.jaxrs.model.AbstractResourceInfo;
import org.apache.cxf.jaxrs.model.ApplicationInfo;
import org.apache.cxf.jaxrs.model.BeanParamInfo;
import org.apache.cxf.jaxrs.model.ClassResourceInfo;
import org.apache.cxf.jaxrs.model.FilterProviderInfo;
import org.apache.cxf.jaxrs.model.OperationResourceInfo;
import org.apache.cxf.jaxrs.model.ProviderInfo;
import org.apache.cxf.jaxrs.nio.NioMessageBodyWriter;
import org.apache.cxf.jaxrs.utils.AnnotationUtils;
import org.apache.cxf.jaxrs.utils.InjectionUtils;
import org.apache.cxf.jaxrs.utils.JAXRSUtils;
import org.apache.cxf.message.Message;
import org.apache.cxf.message.MessageUtils;

public final class ServerProviderFactory extends ProviderFactory {
    private static final String WADL_PROVIDER_NAME = "org.apache.cxf.jaxrs.model.wadl.WadlGenerator";
    private static final String MAKE_DEFAULT_WAE_LEAST_SPECIFIC = "default.wae.mapper.least.specific";
    private List>> exceptionMappers =
        new ArrayList<>(1);

    private List> preMatchContainerRequestFilters =
        new ArrayList<>(1);
    private Map> postMatchContainerRequestFilters =
        new NameKeyMap<>(true);
    private Map> containerResponseFilters =
        new NameKeyMap<>(false);
    private RequestPreprocessor requestPreprocessor;
    private ApplicationInfo application;
    private Set dynamicFeatures = new LinkedHashSet<>();

    private Map, BeanParamInfo> beanParams = new ConcurrentHashMap<>();
    private ProviderInfo wadlGenerator;

    private ServerProviderFactory(Bus bus) {
        super(bus);
        wadlGenerator = createWadlGenerator(bus);
    }

    private static ProviderInfo createWadlGenerator(Bus bus) {
        Object provider = createProvider(WADL_PROVIDER_NAME, bus);
        if (provider == null) {
            return null;
        }
        return new ProviderInfo((ContainerRequestFilter)provider, bus, true);
    }

    public static ServerProviderFactory getInstance() {
        return createInstance(null);
    }

    public static ServerProviderFactory createInstance(Bus bus) {
        if (bus == null) {
            bus = BusFactory.getThreadDefaultBus();
        }
        ServerProviderFactory factory = new ServerProviderFactory(bus);
        ProviderFactory.initFactory(factory);
        factory.setProviders(false, false,
                             new WebApplicationExceptionMapper(),
                             new NioMessageBodyWriter());
        factory.setBusProviders();
        return factory;
    }

    public static ServerProviderFactory getInstance(Message m) {
        Endpoint e = m.getExchange().getEndpoint();
        return (ServerProviderFactory)e.get(SERVER_FACTORY_NAME);
    }

    public List> getPreMatchContainerRequestFilters() {
        return getContainerRequestFilters(preMatchContainerRequestFilters, true);
    }

    public List> getPostMatchContainerRequestFilters(Set names) {
        return getBoundFilters(postMatchContainerRequestFilters, names);

    }

    private List> getContainerRequestFilters(
        List> filters, boolean syncNeeded) {

        if (wadlGenerator == null) {
            return filters;
        }
        if (filters.isEmpty()) {
            return Collections.singletonList(wadlGenerator);
        } else if (!syncNeeded) {
            filters.add(0, wadlGenerator);
            return filters;
        } else {
            synchronized (filters) {
                if (filters.get(0) != wadlGenerator) {
                    filters.add(0, wadlGenerator);
                }
            }
            return filters;
        }
    }

    public List> getContainerResponseFilters(Set names) {
        return getBoundFilters(containerResponseFilters, names);
    }

    public void addBeanParamInfo(BeanParamInfo bpi) {
        beanParams.put(bpi.getResourceClass(), bpi);
        for (Method m : bpi.getResourceClass().getMethods()) {
            if (m.getAnnotation(BeanParam.class) != null) {
                BeanParamInfo methodBpi = new BeanParamInfo(m.getParameterTypes()[0], getBus());
                addBeanParamInfo(methodBpi);
            }
        }
        for (Field f : bpi.getResourceClass().getDeclaredFields()) {
            if (f.getAnnotation(BeanParam.class) != null) {
                BeanParamInfo fieldBpi = new BeanParamInfo(f.getType(), getBus());
                addBeanParamInfo(fieldBpi);
            }
        }
    }

    public BeanParamInfo getBeanParamInfo(Class beanClass) {
        return beanParams.get(beanClass);
    }

    @SuppressWarnings("unchecked")
    public  ExceptionMapper createExceptionMapper(Class exceptionType,
                                                                          Message m) {
        
        boolean makeDefaultWaeLeastSpecific =
            MessageUtils.getContextualBoolean(m, MAKE_DEFAULT_WAE_LEAST_SPECIFIC, true);
        
        return (ExceptionMapper)exceptionMappers.stream()
                .filter(em -> handleMapper(em, exceptionType, m, ExceptionMapper.class, Throwable.class, true))
                .sorted(new ExceptionProviderInfoComparator(exceptionType,
                                                            makeDefaultWaeLeastSpecific))
                .map(ProviderInfo::getProvider)
                .findFirst()
                .orElse(null);
        
    }


    @SuppressWarnings("unchecked")
    @Override
    protected void setProviders(boolean custom, boolean busGlobal, Object... providers) {
        List allProviders = new LinkedList<>();
        for (Object p : providers) {
            if (p instanceof Feature) {
                FeatureContext featureContext = createServerFeatureContext();
                Feature feature = (Feature)p; 
                injectApplicationIntoFeature(feature);
                feature.configure(featureContext);
                Configuration cfg = featureContext.getConfiguration();

                for (Object featureProvider : cfg.getInstances()) {
                    Map, Integer> contracts = cfg.getContracts(featureProvider.getClass());
                    if (contracts != null && !contracts.isEmpty()) {
                        Class providerCls = ClassHelper.getRealClass(getBus(), featureProvider);
                        
                        allProviders.add(new FilterProviderInfo(featureProvider.getClass(),
                                                                        providerCls,
                                                                        featureProvider,
                                                                        getBus(),
                                                                        getFilterNameBindings(getBus(), 
                                                                                              featureProvider),
                                                                        false,
                                                                        contracts));
                    } else {
                        allProviders.add(featureProvider);
                    }
                }
            } else {
                allProviders.add(p);
            }
        }


        List> postMatchRequestFilters =
            new LinkedList<>();
        List> postMatchResponseFilters =
            new LinkedList<>();

        List> theProviders =
            prepareProviders(custom, busGlobal, allProviders.toArray(), application);
        super.setCommonProviders(theProviders, RuntimeType.SERVER);
        for (ProviderInfo provider : theProviders) {
            Class providerCls = ClassHelper.getRealClass(getBus(), provider.getProvider());

            // Check if provider is constrained to server
            if (!constrainedTo(providerCls, RuntimeType.SERVER)) {
                continue;
            }

            if (filterContractSupported(provider, providerCls, ContainerRequestFilter.class)) {
                addContainerRequestFilter(postMatchRequestFilters,
                                          (ProviderInfo)provider);
            }

            if (filterContractSupported(provider, providerCls, ContainerResponseFilter.class)) {
                postMatchResponseFilters.add((ProviderInfo)provider);
            }

            if (DynamicFeature.class.isAssignableFrom(providerCls)) {
                //TODO: review the possibility of DynamicFeatures needing to have Contexts injected
                Object feature = provider.getProvider();
                dynamicFeatures.add((DynamicFeature)feature);
            }

            if (filterContractSupported(provider, providerCls, ExceptionMapper.class)) {    
                addProviderToList(exceptionMappers, provider);
            }

        }

        Collections.sort(preMatchContainerRequestFilters,
            new BindingPriorityComparator(ContainerRequestFilter.class, true));
        mapInterceptorFilters(postMatchContainerRequestFilters, postMatchRequestFilters,
                              ContainerRequestFilter.class, true);
        mapInterceptorFilters(containerResponseFilters, postMatchResponseFilters,
                              ContainerResponseFilter.class, false);

        injectContextProxies(exceptionMappers,
            postMatchContainerRequestFilters.values(), preMatchContainerRequestFilters,
            containerResponseFilters.values());
    }

    protected void injectApplicationIntoFeature(Feature feature) {
        if (application != null) {
            AbstractResourceInfo info = new AbstractResourceInfo(feature.getClass(),
                                                                 ClassHelper.getRealClass(feature),
                                                                 true,
                                                                 true,
                                                                 getBus()) {
                @Override
                public boolean isSingleton() {
                    return false;
                }
            };
            Method contextMethod = info.getContextMethods().get(Application.class);
            if (contextMethod != null) {
                InjectionUtils.injectThroughMethod(feature, contextMethod, application.getProvider());
                return;
            }
            for (Field contextField : info.getContextFields()) {
                if (Application.class == contextField.getType()) {
                    InjectionUtils.injectContextField(info, contextField, feature, application.getProvider());
                    break;
                }
            }
            
        }
        
    }

    @Override
    protected void injectContextProxiesIntoProvider(ProviderInfo pi) {
        injectContextProxiesIntoProvider(pi, application == null ? null : application.getProvider());
    }

    @Override
    protected void injectContextValues(ProviderInfo pi, Message m) {
        if (m != null) {
            InjectionUtils.injectContexts(pi.getProvider(), pi, m);
            if (application != null && application.contextsAvailable()) {
                InjectionUtils.injectContexts(application.getProvider(), application, m);
            }
        }
    }

    private void addContainerRequestFilter(
        List> postMatchFilters,
        ProviderInfo p) {
        ContainerRequestFilter filter = p.getProvider();
        if (isWadlGenerator(filter.getClass())) {
            wadlGenerator = p;
        } else {
            if (isPrematching(filter.getClass())) {
                addProviderToList(preMatchContainerRequestFilters, p);
            } else {
                postMatchFilters.add(p);
            }
        }

    }

    private static boolean isWadlGenerator(Class filterCls) {
        if (filterCls == null || filterCls == Object.class) {
            return false;
        }
        if (WADL_PROVIDER_NAME.equals(filterCls.getName())) {
            return true;
        }
        return isWadlGenerator(filterCls.getSuperclass());
    }

    public RequestPreprocessor getRequestPreprocessor() {
        return requestPreprocessor;
    }

    public void setApplicationProvider(ApplicationInfo app) {
        application = app;
    }

    public ApplicationInfo getApplicationProvider() {
        return application;
    }

    public void setRequestPreprocessor(RequestPreprocessor rp) {
        this.requestPreprocessor = rp;
    }

    public void clearExceptionMapperProxies() {
        clearProxies(exceptionMappers);
    }

    @Override
    public void clearProviders() {
        super.clearProviders();
        exceptionMappers.clear();
        preMatchContainerRequestFilters.clear();
        postMatchContainerRequestFilters.clear();
        containerResponseFilters.clear();
    }

    @Override
    public void clearThreadLocalProxies() {
        if (application != null) {
            application.clearThreadLocalProxies();
        }
        super.clearThreadLocalProxies();
    }

    public void applyDynamicFeatures(List list) {
        if (!dynamicFeatures.isEmpty()) {
            for (ClassResourceInfo cri : list) {
                doApplyDynamicFeatures(cri);
            }
        }
    }

    public Configuration getConfiguration(Message m) {
        return new ServerConfigurationImpl();
    }

    private void doApplyDynamicFeatures(ClassResourceInfo cri) {
        Set oris = cri.getMethodDispatcher().getOperationResourceInfos();
        for (OperationResourceInfo ori : oris) {
            String nameBinding = DEFAULT_FILTER_NAME_BINDING
                + ori.getClassResourceInfo().getServiceClass().getName()
                + "."
                + ori.getMethodToInvoke().toString();
            for (DynamicFeature feature : dynamicFeatures) {
                FeatureContext featureContext = createServerFeatureContext();
                feature.configure(new ResourceInfoImpl(ori), featureContext);
                Configuration cfg = featureContext.getConfiguration();
                for (Object provider : cfg.getInstances()) {
                    Map, Integer> contracts = cfg.getContracts(provider.getClass());
                    if (contracts != null && !contracts.isEmpty()) {
                        Class providerCls = ClassHelper.getRealClass(getBus(), provider);
                        registerUserProvider(new FilterProviderInfo(provider.getClass(),
                            providerCls,
                            provider,
                            getBus(),
                            Collections.singleton(nameBinding),
                            true,
                            contracts));
                        ori.addNameBindings(Collections.singletonList(nameBinding));
                    }
                }
            }
        }
        Collection subs = cri.getSubResources();
        for (ClassResourceInfo sub : subs) {
            if (sub != cri) {
                doApplyDynamicFeatures(sub);
            }
        }
    }

    private FeatureContext createServerFeatureContext() {
        final FeatureContextImpl featureContext = new FeatureContextImpl();
        final ServerConfigurableFactory factory = getBus().getExtension(ServerConfigurableFactory.class);
        final Configurable configImpl = (factory == null) 
            ? new ServerFeatureContextConfigurable(featureContext) 
                : factory.create(featureContext);
        featureContext.setConfigurable(configImpl);

        if (application != null) {
            Map appProps = application.getProvider().getProperties();
            for (Map.Entry entry : appProps.entrySet()) {
                configImpl.property(entry.getKey(), entry.getValue());
            }
        }
        return featureContext;
    }

    protected static boolean isPrematching(Class filterCls) {
        return AnnotationUtils.getClassAnnotation(filterCls, PreMatching.class) != null;
    }

    private static class ServerFeatureContextConfigurable extends ConfigurableImpl {
        protected ServerFeatureContextConfigurable(FeatureContext mc) {
            super(mc, RuntimeType.SERVER);
        }
    }

    public static void clearThreadLocalProxies(Message message) {
        clearThreadLocalProxies(ServerProviderFactory.getInstance(message), message);
    }
    public static void clearThreadLocalProxies(ServerProviderFactory factory, Message message) {
        factory.clearThreadLocalProxies();
        ClassResourceInfo cri =
            (ClassResourceInfo)message.getExchange().get(JAXRSUtils.ROOT_RESOURCE_CLASS);
        if (cri != null) {
            cri.clearThreadLocalProxies();
        }
    }
    public static void releaseRequestState(Message message) {
        releaseRequestState(ServerProviderFactory.getInstance(message), message);
    }
    public static void releaseRequestState(ServerProviderFactory factory, Message message) {
        Object rootInstance = message.getExchange().remove(JAXRSUtils.ROOT_INSTANCE);
        if (rootInstance != null) {
            Object rootProvider = message.getExchange().remove(JAXRSUtils.ROOT_PROVIDER);
            if (rootProvider != null) {
                try {
                    ((ResourceProvider)rootProvider).releaseInstance(message, rootInstance);
                } catch (Throwable tex) {
                    // ignore
                }
            }
        }

        clearThreadLocalProxies(factory, message);
    }


    private class ServerConfigurationImpl implements Configuration {
        ServerConfigurationImpl() {

        }

        @Override
        public Set> getClasses() {
            return application != null ? application.getProvider().getClasses()
                : Collections.>emptySet();
        }

        @Override
        public Set getInstances() {
            return application != null ? application.getProvider().getSingletons()
                : Collections.emptySet();
        }

        @Override
        public boolean isEnabled(Feature f) {
            return dynamicFeatures.contains((Object)f);
        }

        @Override
        public boolean isEnabled(Class featureCls) {
            for (DynamicFeature f : dynamicFeatures) {
                if (featureCls.isAssignableFrom(f.getClass())) {
                    return true;
                }
            }
            return false;
        }

        @Override
        public boolean isRegistered(Object o) {
            return isRegistered(preMatchContainerRequestFilters, o)
                || isRegistered(postMatchContainerRequestFilters.values(), o)
                || isRegistered(containerResponseFilters.values(), o)
                || isRegistered(readerInterceptors.values(), o)
                || isRegistered(writerInterceptors.values(), o);
        }

        @Override
        public boolean isRegistered(Class cls) {
            return isRegistered(preMatchContainerRequestFilters, cls)
                || isRegistered(postMatchContainerRequestFilters.values(), cls)
                || isRegistered(containerResponseFilters.values(), cls)
                || isRegistered(readerInterceptors.values(), cls)
                || isRegistered(writerInterceptors.values(), cls);
        }

        @Override
        public Map, Integer> getContracts(Class cls) {
            Map, Integer> map = new HashMap<>();
            if (isRegistered(cls)) {
                if (ContainerRequestFilter.class.isAssignableFrom(cls)) {
                    boolean isPreMatch = cls.getAnnotation(PreMatching.class) != null;
                    map.put(ContainerRequestFilter.class,
                            getPriority(isPreMatch ? preMatchContainerRequestFilters
                                : postMatchContainerRequestFilters.values(), cls, ContainerRequestFilter.class));
                }
                if (ContainerResponseFilter.class.isAssignableFrom(cls)) {
                    map.put(ContainerResponseFilter.class,
                            getPriority(containerResponseFilters.values(), cls, ContainerResponseFilter.class));
                }
                if (WriterInterceptor.class.isAssignableFrom(cls)) {
                    map.put(WriterInterceptor.class,
                            getPriority(writerInterceptors.values(), cls, WriterInterceptor.class));
                }
                if (ReaderInterceptor.class.isAssignableFrom(cls)) {
                    map.put(ReaderInterceptor.class,
                            getPriority(readerInterceptors.values(), cls, ReaderInterceptor.class));
                }
            }
            return map;
        }

        @Override
        public Map getProperties() {
            return application != null ? application.getProperties()
                : Collections.emptyMap();
        }

        @Override
        public Object getProperty(String name) {
            return getProperties().get(name);
        }

        @Override
        public Collection getPropertyNames() {
            return getProperties().keySet();
        }

        @Override
        public RuntimeType getRuntimeType() {
            return RuntimeType.SERVER;
        }

        private boolean isRegistered(Collection list, Object o) {
            Collection> list2 = CastUtils.cast(list);
            for (ProviderInfo pi : list2) {
                if (pi.getProvider() == o) {
                    return true;
                }
            }
            return false;
        }
        private boolean isRegistered(Collection list, Class cls) {
            Collection> list2 = CastUtils.cast(list);
            for (ProviderInfo p : list2) {
                Class pClass = ClassHelper.getRealClass(p.getBus(), p.getProvider());
                if (cls.isAssignableFrom(pClass)) {
                    return true;
                }
            }
            return false;
        }
        private Integer getPriority(Collection list, Class cls, Class filterClass) {
            Collection> list2 = CastUtils.cast(list);
            for (ProviderInfo p : list2) {
                if (p instanceof FilterProviderInfo) {
                    Class pClass = ClassHelper.getRealClass(p.getBus(), p.getProvider());
                    if (cls.isAssignableFrom(pClass)) {
                        return ((FilterProviderInfo)p).getPriority(filterClass);
                    }
                }
            }
            return Priorities.USER;
        }
    }
    public static class ExceptionProviderInfoComparator extends ProviderInfoClassComparator {
        private boolean makeDefaultWaeLeastSpecific;
        public ExceptionProviderInfoComparator(Class expectedCls, boolean makeDefaultWaeLeastSpecific) {
            super(expectedCls);
            this.makeDefaultWaeLeastSpecific = makeDefaultWaeLeastSpecific;
        }
        public int compare(ProviderInfo p1, ProviderInfo p2) {
            if (makeDefaultWaeLeastSpecific) {
                if (p1.getProvider() instanceof WebApplicationExceptionMapper
                    && !p1.isCustom()) {
                    return 1;
                } else if (p2.getProvider() instanceof WebApplicationExceptionMapper
                    && !p2.isCustom()) {
                    return -1;
                }
            }
            int result = super.compare(p1, p2);
            if (result == 0) {
                result = comparePriorityStatus(p1.getProvider().getClass(), p2.getProvider().getClass());
            }
            return result;
        }
    }
}