All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.fakeservlet.FakeServletRequest Maven / Gradle / Ivy

package org.fakeservlet;

import javax.servlet.AsyncContext;
import javax.servlet.DispatcherType;
import javax.servlet.ReadListener;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpUpgradeHandler;
import javax.servlet.http.Part;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URL;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.time.Instant;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Vector;
import java.util.stream.Collectors;

/**
 * DANGER! Unfinished class! Implement methods as you go!
 */
public class FakeServletRequest implements HttpServletRequest {
    private final String scheme;
    private final int port;
    private final String host;
    private String method;
    private final String contextPath;
    private final String servletPath;
    private String pathInfo;

    private final Map> headers = new LinkedHashMap<>();
    private final Map> parameters = new LinkedHashMap<>();
    private HttpSession httpSession;
    private final Map attributes = new HashMap<>();
    private byte[] requestBody;
    private Principal userPrincipal;
    private final String protocol = "HTTP/1.1";
    private List userRoles = new ArrayList<>();

    public FakeServletRequest(String method, URL contextRoot, String servletPath, String pathInfo) {
        setMethod(method);
        this.contextPath = contextRoot.getPath();
        this.servletPath = servletPath;
        this.pathInfo = pathInfo;
        this.scheme = contextRoot.getProtocol();
        this.port = contextRoot.getPort() != -1 ? contextRoot.getPort() : defaultPort(scheme);
        this.host = contextRoot.getHost();
    }

    private int defaultPort(String scheme) {
        return scheme.equals("https") ? 443 : 80;
    }

    @Override
    public String getAuthType() {
        throw unimplemented();
    }

    private List cookies = new ArrayList<>();

    public void addCookie(String key, String value) {
        cookies.add(new Cookie(key, URLEncoder.encode(value, StandardCharsets.UTF_8)));
    }

    @Override
    public Cookie[] getCookies() {
        return cookies.isEmpty() ? null : cookies.toArray(new Cookie[0]);
    }

    public void setCookies(List cookies) {
        this.cookies = cookies;
    }

    @Override
    public long getDateHeader(String name) {
        String s = getHeader(name);
        return s != null ? DateTimeFormatter.RFC_1123_DATE_TIME.parse(s, Instant::from).toEpochMilli() : -1;
    }

    @Override
    public String getHeader(String s) {
        return Optional.ofNullable(headers.get(s)).map(l -> l.get(0)).orElse(null);
    }

    @Override
    public Enumeration getHeaders(String s) {
        return Optional.ofNullable(headers.get(s)).map(Collections::enumeration).orElse(null);
    }

    @Override
    public Enumeration getHeaderNames() {
        return Collections.enumeration(headers.keySet());
    }

    @Override
    public int getIntHeader(String name) {
        return Optional.ofNullable(getHeader(name)).map(Integer::parseInt).orElse(-1);
    }

    @Override
    public String getMethod() {
        return method;
    }

    @Override
    public String getPathInfo() {
        return pathInfo;
    }

    public void setPathInfo(String pathInfo) {
        this.pathInfo = pathInfo;
    }

    @Override
    public String getPathTranslated() {
        throw unimplemented();
    }

    private AssertionError unimplemented() {
        return new AssertionError("called unexpected method");
    }

    @Override
    public String getContextPath() {
        return contextPath;
    }

    @Override
    public String getQueryString() {
        if (parameters.isEmpty()) {
            return null;
        } else {
            return parameters.entrySet().stream()
                    .map(entry -> entry.getValue().stream().map(v -> urlEncode(entry.getKey()) + "=" + urlEncode(v)).collect(Collectors.joining("&")))
                    .collect(Collectors.joining("&"));
        }
    }

    private String urlEncode(String value) {
        return URLEncoder.encode(value, StandardCharsets.US_ASCII);
    }

    @Override
    public String getRemoteUser() {
        return getUserPrincipal() != null ? getUserPrincipal().getName() : null;
    }

    @Override
    public boolean isUserInRole(String s) {
        return userRoles.contains(s);
    }

    @Override
    public Principal getUserPrincipal() {
        return userPrincipal;
    }

    @Override
    public String getRequestedSessionId() {
        throw unimplemented();
    }

    @Override
    public String getRequestURI() {
        return getContextPath() + getServletPath() + (getPathInfo() != null ? getPathInfo() : "");
    }

    @Override
    public StringBuffer getRequestURL() {
        final StringBuffer url = getAuthority();
        url.append(getRequestURI());
        return url;
    }

    public StringBuffer getAuthority() {
        StringBuffer url = new StringBuffer(128);
        url.append(getScheme()).append("://").append(getRemoteHost());
        if (port > 0 && port != defaultPort(scheme)) {
            url.append(':').append(port);
        }
        return url;
    }

    @Override
    public String getServletPath() {
        return servletPath;
    }

    @Override
    public HttpSession getSession(boolean create) {
        if (httpSession == null && create) {
            httpSession = new FakeHttpSession();
        }
        return httpSession;
    }

    @Override
    public HttpSession getSession() {
        return getSession(true);
    }

    @Override
    public String changeSessionId() {
        throw unimplemented();
    }

    @Override
    public boolean isRequestedSessionIdValid() {
        throw unimplemented();
    }

    @Override
    public boolean isRequestedSessionIdFromCookie() {
        throw unimplemented();
    }

    @Override
    public boolean isRequestedSessionIdFromURL() {
        throw unimplemented();
    }

    @Override
    public boolean isRequestedSessionIdFromUrl() {
        throw unimplemented();
    }

    @Override
    public boolean authenticate(HttpServletResponse response) {
        response.setStatus(401);
        return false;
    }

    @Override
    public void login(String s, String s1) {

    }

    @Override
    public void logout() {

    }

    @Override
    public Collection getParts() {
        throw unimplemented();
    }

    @Override
    public Part getPart(String s) {
        throw unimplemented();
    }

    @Override
    public  T upgrade(Class aClass) {
        throw unimplemented();
    }

    @Override
    public Object getAttribute(String name) {
        return attributes.get(name);
    }

    @Override
    public Enumeration getAttributeNames() {
        return new Vector<>(attributes.keySet()).elements();
    }

    @Override
    public String getCharacterEncoding() {
        // TODO
        throw unimplemented();
    }

    @Override
    public void setCharacterEncoding(String s) {
        // TODO

    }

    @Override
    public int getContentLength() {
        // TODO
        throw unimplemented();
    }

    @Override
    public long getContentLengthLong() {
        throw unimplemented();
    }

    @Override
    public String getContentType() {
        // TODO
        throw unimplemented();
    }

    @Override
    public ServletInputStream getInputStream() {
        if (requestBody != null) {
            ByteArrayInputStream inputStream = new ByteArrayInputStream(requestBody);
            return new ServletInputStream() {
                @Override
                public boolean isFinished() {
                    return inputStream.available() == 0;
                }

                @Override
                public boolean isReady() {
                    return true;
                }

                @Override
                public void setReadListener(ReadListener readListener) {
                    // TODO
                    throw unimplemented();
                }

                @Override
                public int read() {
                    return inputStream.read();
                }

                @Override
                public int read(byte[] b, int off, int len) {
                    return inputStream.read(b, off, len);
                }
            };
        } else {
            throw new AssertionError("Call setRequestBody first");
        }
    }

    @Override
    public String getParameter(String s) {
        return parameters.containsKey(s) ? parameters.get(s).get(0) : null;
    }

    @Override
    public Enumeration getParameterNames() {
        return new Vector<>(parameters.keySet()).elements();
    }

    @Override
    public String[] getParameterValues(String s) {
        return parameters.containsKey(s) ? parameters.get(s).toArray(new String[0]) : null;
    }

    @Override
    public Map getParameterMap() {
        HashMap map = new HashMap<>();
        parameters.forEach((k,v) -> map.put(k, getParameterValues(k)));
        return map;
    }

    @Override
    public String getProtocol() {
        return protocol;
    }

    @Override
    public String getScheme() {
        return scheme;
    }

    @Override
    public String getServerName() {
        return host;
    }

    @Override
    public int getServerPort() {
        return port;
    }

    @Override
    public BufferedReader getReader() {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public String getRemoteAddr() {
        return "127.0.0.1";
    }

    @Override
    public String getRemoteHost() {
        return getServerName();
    }

    @Override
    public void setAttribute(String name, Object o) {
        this.attributes.put(name, o);
    }

    @Override
    public void removeAttribute(String name) {
        this.attributes.remove(name);
    }

    @Override
    public Locale getLocale() {
        // TODO
        throw unimplemented();
    }

    @Override
    public Enumeration getLocales() {
        // TODO
        throw unimplemented();
    }

    @Override
    public boolean isSecure() {
        // TODO
        throw unimplemented();
    }

    @Override
    public RequestDispatcher getRequestDispatcher(String s) {
        throw unimplemented();
    }

    @Override
    public String getRealPath(String s) {
        throw unimplemented();
    }

    @Override
    public int getRemotePort() {
        throw unimplemented();
    }

    @Override
    public String getLocalName() {
        throw unimplemented();
    }

    @Override
    public String getLocalAddr() {
        // TODO
        throw unimplemented();
    }

    @Override
    public int getLocalPort() {
        throw unimplemented();
    }

    @Override
    public ServletContext getServletContext() {
        throw unimplemented();
    }

    @Override
    public AsyncContext startAsync() throws IllegalStateException {
        throw new IllegalStateException("async context not supported");
    }

    @Override
    public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException {
        return startAsync();
    }

    @Override
    public boolean isAsyncStarted() {
        return false;
    }

    @Override
    public boolean isAsyncSupported() {
        return false;
    }

    @Override
    public AsyncContext getAsyncContext() {
        throw new IllegalStateException("async context not supported");
    }

    @Override
    public DispatcherType getDispatcherType() {
        throw unimplemented();
    }

    public void addParameter(String key, String value) {
        parameters.computeIfAbsent(key, n -> new ArrayList<>()).add(value);
    }

    public void setSession(HttpSession httpSession) {
        this.httpSession = httpSession;
    }

    public void setMethod(String method) {
        this.method = method;
    }

    public void addHeader(String name, String value) {
        headers.computeIfAbsent(name, n -> new ArrayList<>()).add(value);
    }

    public void setRequestBody(byte[] requestBody) {
        this.requestBody = requestBody;
    }

    public void setRequestBody(String requestBody) {
        setRequestBody(requestBody.getBytes());
    }

    public void setUserPrincipal(Principal userPrincipal) {
        this.userPrincipal = userPrincipal;
    }

    public FakeServletResponse service(HttpServlet servlet) throws ServletException, IOException {
        FakeServletResponse response = new FakeServletResponse(this);
        servlet.service(this, response);
        return response;
    }

    public void setUser(String username, List userRoles) {
        setUserPrincipal(() -> username);
        this.userRoles = userRoles;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy