
org.kantega.reststop.servlets.ReststopInitializer Maven / Gradle / Ivy
/*
* Copyright 2018 Kantega AS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.kantega.reststop.servlets;
import org.kantega.reststop.api.PluginExport;
import org.kantega.reststop.classloaderutils.Artifact;
import org.kantega.reststop.classloaderutils.PluginClassLoader;
import org.kantega.reststop.classloaderutils.PluginInfo;
import org.kantega.reststop.core.ClassLoaderFactory;
import org.kantega.reststop.core.DefaultReststopPluginManager;
import org.kantega.reststop.servlet.api.FilterPhase;
import org.kantega.reststop.servlet.api.ServletBuilder;
import org.kantega.reststop.servlet.api.ServletDeployer;
import org.w3c.dom.Document;
import org.xml.sax.SAXException;
import javax.servlet.*;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.*;
import java.util.stream.Collectors;
import static java.util.Arrays.asList;
/**
*
*/
public class ReststopInitializer implements ServletContainerInitializer{
@Override
public void onStartup(Set> classes, ServletContext servletContext) throws ServletException {
PluginDelegatingFilter pluginDelegatingFilter = new PluginDelegatingFilter();
DefaultServletBuilder servletBuilder = new DefaultServletBuilder(servletContext, pluginDelegatingFilter);
Map staticServices = new HashMap<>();
staticServices.put(ServletContext.class, servletContext);
staticServices.put(ServletBuilder.class, servletBuilder);
staticServices.put(ServletDeployer.class, pluginDelegatingFilter);
DefaultReststopPluginManager manager = new DefaultReststopPluginManager(getClass().getClassLoader(), findGlobalConfigFile(servletContext), staticServices);
servletContext.setAttribute("reststopPluginManager", manager);
FilterRegistration.Dynamic registration = servletContext.addFilter(PluginDelegatingFilter.class.getName(), pluginDelegatingFilter);
registration.setAsyncSupported(true);
registration.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/*");
servletContext.addListener(new ShutdownListener(manager));
deployPlugins(manager, servletContext);
}
private void deployPlugins(DefaultReststopPluginManager manager, ServletContext servletContext) throws ServletException {
List plugins = new ArrayList<>();
plugins.addAll(getExternalPlugins(servletContext));
plugins.addAll(getWarBundledPlugins(servletContext));
manager.deploy(plugins, new DefaultClassLoaderFactory());
}
private File findGlobalConfigFile(ServletContext servletContext) throws ServletException {
String configDirectory = requiredInitParam(servletContext, "pluginConfigurationDirectory");
String applicationName = requiredInitParam(servletContext, "applicationName");
File globalConfigurationFile = new File(configDirectory, applicationName +".conf");
if(!globalConfigurationFile.exists()) {
throw new ServletException("Configuration file does not exist: " + globalConfigurationFile.getAbsolutePath());
}
return globalConfigurationFile;
}
private String requiredInitParam(ServletContext servletContext, String paramName) throws ServletException {
String value = initParam(servletContext, paramName);
if(value == null) {
throw new ServletException("You web application is missing a required servlet context-param '" + paramName + "'");
}
return value;
}
private String initParam(ServletContext servletContext, String paramName) throws ServletException {
String value = servletContext.getInitParameter(paramName);
if (value == null) value = System.getProperty(paramName);
return value;
}
private static class ShutdownListener implements ServletContextListener {
private final DefaultReststopPluginManager manager;
public ShutdownListener(DefaultReststopPluginManager manager) {
this.manager = manager;
}
@Override
public void contextInitialized(ServletContextEvent sce) {
}
@Override
public void contextDestroyed(ServletContextEvent sce) {
manager.stop();
}
}
private List getWarBundledPlugins(ServletContext servletContext) {
String pluginsPath = servletContext.getRealPath("/WEB-INF/reststop/plugins.xml");
String repositoryPath = servletContext.getRealPath("/WEB-INF/reststop/repository/");
if(pluginsPath != null && repositoryPath != null) {
File pluginsFile = new File(pluginsPath);
File repoDir = new File(repositoryPath);
if(pluginsFile.exists() && repoDir.exists()) {
try {
Document pluginsXml = DocumentBuilderFactory.newInstance().newDocumentBuilder().parse(pluginsFile);
if(pluginsXml != null) {
return getPluginInfos(repoDir, pluginsXml);
}
} catch (SAXException | IOException | ParserConfigurationException e) {
throw new RuntimeException(e);
}
}
}
return Collections.emptyList();
}
private List getPluginInfos(File repoDir, Document pluginsXml) {
List infos = PluginInfo.parse(pluginsXml);
resolve(infos, repoDir);
return infos;
}
private List getExternalPlugins(ServletContext servletContext) throws ServletException {
Document pluginsXml = (Document) servletContext.getAttribute("pluginsXml");
String repoPath = initParam(servletContext, "repositoryPath");
File repoDir = null;
if(repoPath != null) {
repoDir = new File(repoPath);
if(!repoDir.exists()) {
throw new ServletException("repositoryPath does not exist: " + repoDir);
}
if(!repoDir.isDirectory()) {
throw new ServletException("repositoryPath is not a directory: " + repoDir);
}
}
if(pluginsXml == null) {
String path = initParam(servletContext, "plugins.xml");
if(path != null) {
try {
pluginsXml = DocumentBuilderFactory.newInstance().newDocumentBuilder().parse(new File(path));
servletContext.setAttribute("pluginsXml", pluginsXml);
} catch (SAXException | IOException | ParserConfigurationException e) {
throw new RuntimeException(e);
}
}
}
if(pluginsXml != null) {
return getPluginInfos(repoDir, pluginsXml);
}
return Collections.emptyList();
}
private void resolve(List infos, File repoDir) {
for (PluginInfo info: infos) {
if(info.getFile() == null) {
File pluginJar = getPluginFile(repoDir, info);
info.setFile(pluginJar);
}
for (Artifact artifact : info.getClassPath("runtime")) {
if(artifact.getFile() == null) {
artifact.setFile(getPluginFile(repoDir, artifact));
}
}
}
}
private File getPluginFile(File repoDir, Artifact artifact) {
if (repoDir != null) {
return new File(repoDir,
artifact.getGroupId().replace('.', '/') + "/"
+ artifact.getArtifactId() + "/"
+ artifact.getVersion() + "/"
+ artifact.getArtifactId() + "-" + artifact.getVersion() + ".jar");
} else {
return artifact.getFile();
}
}
public static class DefaultServletBuilder implements ServletBuilder {
private final ServletContext servletContext;
private PluginDelegatingFilter pluginDelegatingFilter;
public DefaultServletBuilder(ServletContext servletContext, PluginDelegatingFilter pluginDelegatingFilter) {
this.servletContext = servletContext;
this.pluginDelegatingFilter = pluginDelegatingFilter;
}
@Override
public Filter filter(Filter filter, FilterPhase phase, String path, String... additionalPaths) {
if(filter == null ) {
throw new IllegalArgumentException("Filter cannot be null");
}
if(path == null) {
throw new IllegalArgumentException("Paths for filter " + filter + " cannot be null");
}
if(additionalPaths == null) {
throw new IllegalArgumentException("Additional paths for filter " + filter + " cannot be null");
}
List mappings = new ArrayList<>(Collections.singletonList(path));
mappings.addAll(asList(additionalPaths));
return new MappingWrappedFilter(filter, mappings.toArray(new String[mappings.size()]) , phase);
}
@Override
public Filter resourceServlet(URL url, String path, String... additionalPaths) {
return servlet(new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
String mediaType = servletContext.getMimeType(path);
if(mediaType == null) {
mediaType = "text/html";
}
if(mediaType.equals("text/html")) {
resp.setCharacterEncoding("utf-8");
}
resp.setContentType(mediaType);
OutputStream output = resp.getOutputStream();
try (InputStream input = url.openStream()){
byte[] buffer = new byte[1024];
int n;
while (-1 != (n = input.read(buffer))) {
output.write(buffer, 0, n);
}
}
}
}, path, additionalPaths);
}
@Override
public Filter servlet(HttpServlet servlet, String path, String... additionalPaths) {
if(servlet == null ) {
throw new IllegalArgumentException("Servlet parameter cannot be null");
}
if(path == null) {
throw new IllegalArgumentException("Path for servlet " +servlet + " cannot be null");
}
if(additionalPaths == null) {
throw new IllegalArgumentException("Additional paths for servlet " +servlet + " cannot be null");
}
return filter(new ServletWrapperFilter(servlet), FilterPhase.USER, path, additionalPaths);
}
@Override
public RedirectBuilder redirectFrom(String fromPath, String... additionalFromPaths) {
return location -> servlet(new HttpServlet() {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.sendRedirect(location);
}
}, fromPath, additionalFromPaths);
}
@Override
public ServletConfig servletConfig(String name, Properties properties) {
return new PropertiesWebConfig(name, properties, servletContext);
}
@Override
public FilterConfig filterConfig(String name, Properties properties) {
return new PropertiesWebConfig(name, properties, servletContext);
}
@Override
public FilterChain newFilterChain(FilterChain filterChain) {
PluginFilterChain orig = (PluginFilterChain) filterChain;
return pluginDelegatingFilter.buildFilterChain(orig.getRequest(), orig.getFilterChain());
}
private static class PropertiesWebConfig implements ServletConfig, FilterConfig {
private final String name;
private final Properties properties;
private final ServletContext servletContext;
public PropertiesWebConfig(String name, Properties properties, ServletContext servletContext) {
this.name = name;
this.properties = properties;
this.servletContext = servletContext;
}
@Override
public String getFilterName() {
return name;
}
@Override
public String getServletName() {
return name;
}
@Override
public ServletContext getServletContext() {
return servletContext;
}
@Override
public String getInitParameter(String name) {
return properties.getProperty(name);
}
@Override
public Enumeration getInitParameterNames() {
return Collections.enumeration(properties.stringPropertyNames());
}
}
private static class ServletWrapperFilter implements Filter {
private final HttpServlet servlet;
public ServletWrapperFilter(final HttpServlet servlet) {
this.servlet = servlet;
}
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) servletRequest;
HttpServletResponse resp = (HttpServletResponse) servletResponse;
servlet.service(new HttpServletRequestWrapper(req) {
@Override
public String getServletPath() {
return getMappedServletPath();
}
@Override
public String getPathInfo() {
String requestURI = getRequestURI();
return requestURI.substring(super.getContextPath().length() + getMappedServletPath().length());
}
String getMappedServletPath(){
String servletPath = (String) req.getAttribute(MappingWrappedFilter.MATCHED_MAPPING);
while(servletPath.endsWith("*") || servletPath.endsWith("/")) {
servletPath = servletPath.substring(0, servletPath.length()-1);
}
return servletPath;
}
}, resp);
}
@Override
public void destroy() {
}
}
}
static class MappingWrappedFilter implements Filter {
static final String MATCHED_MAPPING = "MATCHED_MAPPING";
private final Filter filter;
private final String[] mappings;
private final FilterPhase phase;
public MappingWrappedFilter(Filter filter, String[] mappings, FilterPhase phase) {
this.filter = filter;
this.mappings = mappings;
this.phase = phase;
}
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) servletRequest;
if(mappingMatchesRequest(req)) {
filter.doFilter(servletRequest, servletResponse, filterChain);
} else {
filterChain.doFilter(servletRequest, servletResponse);
}
}
private boolean mappingMatchesRequest(HttpServletRequest req) {
String contextRelative = req.getRequestURI().substring(req.getContextPath().length());
for (String mapping : mappings) {
if(mapping.equals(contextRelative) || mapping.endsWith("*") && contextRelative.regionMatches(0, mapping, 0, mapping.length()-1)){
req.setAttribute(MATCHED_MAPPING, mapping);
return true;
}
}
return false;
}
@Override
public void destroy() {
}
}
public static class PluginDelegatingFilter implements Filter, ServletDeployer {
private volatile Collection> filters = Collections.emptyList();
private final Comparator> comparator =
Comparator.comparing(e -> (e.getExport() instanceof MappingWrappedFilter) ? ((MappingWrappedFilter)e.getExport()).phase.ordinal() : FilterPhase.USER.ordinal());
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
servletResponse.setCharacterEncoding("utf-8");
buildFilterChain((HttpServletRequest) servletRequest, filterChain).doFilter(servletRequest, servletResponse);
}
protected FilterChain buildFilterChain(HttpServletRequest request, FilterChain filterChain) {
Iterator> matchingFilters = this.filters.stream()
.filter(e -> isMatching(request, e))
.iterator();
return new PluginFilterChain(request, matchingFilters, filterChain);
}
private boolean isMatching(HttpServletRequest request, PluginExport filterExport) {
if(filterExport.getExport() instanceof MappingWrappedFilter) {
return ((MappingWrappedFilter)filterExport.getExport()).mappingMatchesRequest(request);
} else {
return true;
}
}
@Override
public void destroy() {
}
@Override
public void deploy(Collection> filters) {
this.filters = filters.stream()
.sorted(comparator)
.collect(Collectors.toList());
}
}
private static class PluginFilterChain implements FilterChain {
private final FilterChain filterChain;
private final HttpServletRequest request;
private final Iterator> filters;
public PluginFilterChain(HttpServletRequest request, Iterator> filters, FilterChain filterChain) {
this.request = request;
this.filters = filters;
this.filterChain = filterChain;
}
public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException {
if(filters.hasNext()) {
PluginExport filterExport = filters.next();
Filter filter = filterExport.getExport() instanceof MappingWrappedFilter ? ((MappingWrappedFilter)filterExport.getExport()).filter : filterExport.getExport();
ClassLoader loader = Thread.currentThread().getContextClassLoader();
try {
Thread.currentThread().setContextClassLoader(filterExport.getClassLoader());
filter.doFilter(request, response, this);
} finally {
Thread.currentThread().setContextClassLoader(loader);
}
} else {
filterChain.doFilter(request, response);
}
}
private FilterChain getFilterChain() {
return filterChain;
}
public HttpServletRequest getRequest() {
return request;
}
}
private class DefaultClassLoaderFactory implements ClassLoaderFactory {
@Override
public PluginClassLoader createPluginClassLoader(PluginInfo pluginInfo, ClassLoader parentClassLoader, List allPlugins) {
try {
PluginClassLoader loader = new PluginClassLoader(pluginInfo, parentClassLoader);
loader.addURL(pluginInfo.getFile().toURI().toURL());
for (Artifact artifact : pluginInfo.getClassPath("runtime")) {
if(allPlugins.stream().noneMatch(p -> p.getPluginId().equals(artifact.getPluginId()))) {
loader.addURL(artifact.getFile().toURI().toURL());
}
}
return loader;
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy