All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.grails.plugins.testing.GrailsMockHttpServletRequest.groovy Maven / Gradle / Ivy

There is a newer version: 7.0.0-M1
Show newest version
/*
 * Copyright 2008 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.grails.plugins.testing

import grails.converters.JSON
import grails.converters.XML

import javax.servlet.AsyncContext
import javax.servlet.AsyncEvent
import javax.servlet.AsyncListener
import javax.servlet.DispatcherType
import javax.servlet.ServletContext
import javax.servlet.ServletInputStream
import javax.servlet.ServletRequest
import javax.servlet.ServletResponse
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
import javax.servlet.http.Part
import grails.web.mime.MimeType
import org.grails.web.util.GrailsApplicationAttributes
import org.grails.web.servlet.mvc.GrailsWebRequest
import org.grails.web.util.WebUtils
import org.springframework.http.HttpHeaders
import org.springframework.http.HttpMethod
import org.springframework.mock.web.MockHttpServletRequest
import org.springframework.util.Assert
import org.springframework.util.LinkedMultiValueMap
import org.springframework.util.MultiValueMap
import org.springframework.web.multipart.MultipartFile
import org.springframework.web.multipart.MultipartHttpServletRequest

/**
 * A custom mock HTTP servlet request that provides the extra properties
 * and methods normally injected by the "servlets" plugin.
 */
class GrailsMockHttpServletRequest extends MockHttpServletRequest implements MultipartHttpServletRequest{

    boolean invalidToken
    MultiValueMap multipartFiles = new LinkedMultiValueMap()

    private Map multipartContentTypes = Collections.emptyMap()
    private Map multipartHeaders = Collections.emptyMap()
    private HttpHeaders httpHeaders = new HttpHeaders()

    HttpHeaders requestHeaders
    HttpMethod requestMethod = HttpMethod.GET

    private cachedJson
    private cachedXml
    DispatcherType dispatcherType
    AsyncContext asyncContext
    private ServletInputStream cachedInputStream


    public GrailsMockHttpServletRequest() {
        super();
        method = 'GET'
    }

    public GrailsMockHttpServletRequest(ServletContext servletContext) {
        super(servletContext);
        method = 'GET'
    }

    /**
     * Sets the request format to use
     * @param format The request format
     */
    void setFormat(String format) {
        setAttribute(GrailsApplicationAttributes.CONTENT_FORMAT, format)
    }
    
    @Override
    void setContentType(String newContentType) {
        super.setContentType(newContentType)
        def webRequest = getAttribute(GrailsApplicationAttributes.WEB_REQUEST)
        def mimeType = MimeType.configuredMimeTypes?.find { mt ->
            mt?.name == newContentType
        } 
        
        if(!mimeType) {
            mimeType = new MimeType(newContentType)
        }
        setAttribute(GrailsApplicationAttributes.REQUEST_FORMATS, [mimeType] as MimeType[])
    }

    /**
     * Sets the body of the request to be a json packet
     *
     * @param sourceJson The source json
     */
    void setJson(Object sourceJson) {
        setContentType('application/json; charset=UTF-8')
        setFormat('json')
        if (sourceJson instanceof String) {
            setContent(sourceJson.getBytes("UTF-8"))
        }
        else if (sourceJson instanceof JSON) {
            setContent(sourceJson.toString().getBytes("UTF-8"))
        }
        else {
            setContent(new JSON(sourceJson).toString().getBytes("UTF-8"))
        }
        getAttribute("org.codehaus.groovy.grails.WEB_REQUEST")?.informParameterCreationListeners()
    }

    /**
     * Sets the body of the request to be an XML packet
     *
     * @param sourceXml
     */
    void setXml(Object sourceXml) {
        setContentType("text/xml; charset=UTF-8")
        setFormat("xml")

        if (sourceXml instanceof String) {
            setContent(sourceXml.getBytes("UTF-8"))
        }
        else {
            XML xml
            if(sourceXml instanceof XML) {
                xml = (XML)sourceXml
            } else {
                xml = new XML(sourceXml)
            }
            setContent(xml.toString().getBytes("UTF-8"))
        }

        getAttribute("org.codehaus.groovy.grails.WEB_REQUEST")?.informParameterCreationListeners()
    }

    void setXML(Object sourceXml) {
        setXml(sourceXml)
    }

    void setJSON(Object sourceJson) {
        setJson(sourceJson)
    }

    /**
     * Implementation of the dynamic "forwardURI" property.
     */
    String getForwardURI() {
        def result = getAttribute(WebUtils.FORWARD_REQUEST_URI_ATTRIBUTE)
        if (!result) result = requestURI
        return result
    }

    /**
     * Sets the "forwardURI" property for the request.
     */
    void setForwardURI(String uri) {
        setAttribute(WebUtils.FORWARD_REQUEST_URI_ATTRIBUTE, uri)
    }

    /**
     * Indicates whether this is an AJAX request or not (as far as
     * Grails is concerned). Returns true if it is an
     * AJAX request, otherwise false.
     */
    boolean isXhr() {
        return getHeader("X-Requested-With") == "XMLHttpRequest"
    }

    /**
     * Makes this request an AJAX request as Grails understands it.
     * This cannot be undone, so if you need a non-AJAX request you
     * will have to create a new instance.
     */
    void makeAjaxRequest() {
        addHeader("X-Requested-With", "XMLHttpRequest")
    }

    /**
     * Map-like access to request attributes, e.g. request["count"].
     */
    def getAt(String key) {
        return getAttribute(key)
    }

    /**
     * Map-like setting of request attributes, e.g. request["count"] = 10.
     */
    void putAt(String key, Object val) {
        setAttribute(key, val)
    }

    /**
     * Property access for request attributes.
     */
    def getProperty(String name) {
        def mp = getClass().metaClass.getMetaProperty(name)
        return mp ? mp.getProperty(this) : getAttribute(name)
    }

    /**
     * Property setting of request attributes.
     */
    void setProperty(String name, value) {
        def mp = getClass().metaClass.getMetaProperty(name)
        if (mp) {
            mp.setProperty(this, value)
        }
        else {
            setAttribute(name, value)
        }
    }

    boolean isGet() { method == "GET" }
    boolean isPost() { method == "POST" }

    /**
    * Parses the request content as XML using XmlSlurper and returns
    * the GPath result object. Throws an exception if there is no
    * content or the content is not valid XML.
    */
    def getXML() {
        if (!cachedXml) {
            cachedXml = GrailsMockHttpServletRequest.classLoader.loadClass("grails.converters.XML").parse(this)
        }
        return cachedXml
    }

    /**
     * Parses the request content as JSON using the JSON converter.
     * Throws an exception if there is no content or the content is
     * not valid JSON.
     */
    def getJSON() {
        if (!cachedJson) {
            cachedJson = GrailsMockHttpServletRequest.classLoader.loadClass("grails.converters.JSON").parse(this)
        }
        return cachedJson
    }

    /**
     * Adds a "find()" method to the request that searches the request's
     * attributes. Returns the first attribute for which the closure
     * returns true, just like the normal Groovy find() method.
     */
    def find(Closure c) {
        def result = [:]
        for (String name in attributeNames) {
            def match = false
            switch (c.parameterTypes.length) {
                case 0:
                    match = c.call()
                    break

                case 1:
                    match = c.call(key:name, value:getAttribute(name))
                    break

                default:
                    match =  c.call(name, getAttribute(name))
            }
            if (match) {
                result[name] = getAttribute(name)
                break
            }
        }

        return result
    }

    /**
     * Like the {@link #find(Closure)} method, this searches the request
     * attributes. Returns all the attributes that match the closure
     * conditions.
     */
    def findAll(Closure c) {
        def results = [:]
        for (String name in attributeNames) {
            def match = false
            switch (c.parameterTypes.length) {
                case 0:
                    match = c.call()
                    break

                case 1:
                    match = c.call(key:name, value:getAttribute(name))
                    break

                default:
                    match =  c.call(name, getAttribute(name))
            }
            if (match) { results[name] = getAttribute(name) }
        }
        results
    }

    /**
     * Iterates over the request attributes.
     */
    def each(Closure c) {
        for (String name in attributeNames) {
            switch (c.parameterTypes.length) {
                case 0:
                    c.call()
                    break

                case 1:
                    c.call(key:name, value:getAttribute(name))
                    break

                default:
                    c.call(name, getAttribute(name))
            }
        }
    }

    /**
     * {@inheritDoc }
     */
    Iterator getFileNames() {
        return multipartFiles.keySet().iterator()
    }

    /**
     * {@inheritDoc }
     */
    MultipartFile getFile(String name) {
        return multipartFiles.getFirst(name)
    }

    /**
     * {@inheritDoc }
     */
    List getFiles(String name) {
        return multipartFiles.get(name)
    }

    /**
     * {@inheritDoc }
     */
    Map getFileMap() {
        return multipartFiles.toSingleValueMap()
    }

    /**
     * {@inheritDoc }
     */
    MultiValueMap getMultiFileMap() {
        return multipartFiles
    }

    /**
     * Add a file to this request. The parameter name from the multipart
     * form is taken from the {@link MultipartFile#getName()}.
     * @param file multipart file to be added
     */
    void addFile(MultipartFile file) {
        setMethod("POST")
        setContentType("multipart/form-data")

        Assert.notNull(file, "MultipartFile must not be null")
        multipartFiles.add(file.getName(), file)
    }

    /**
     * Add a file for the given location and bytes
     *
     * @param location The location
     * @param contents The bytes
     */
    void addFile(String location, byte[] contents) {
        setMethod("POST")
        setContentType("multipart/form-data")

        multipartFiles.add(location, new GrailsMockMultipartFile(location, contents))
    }

    @Override
    void clearAttributes() {
        super.clearAttributes()
        multipartFiles.clear()
        multipartContentTypes.clear()
        httpHeaders.clear()
    }

    String getMultipartContentType(String paramOrFileName) {
        return multipartContentTypes.get(paramOrFileName)
    }

    String setMultipartContentType(String paramOrFileName, String contentType) {
        return multipartContentTypes.put(paramOrFileName, contentType)
    }

    HttpHeaders getMultipartHeaders(String paramOrFileName) {
        return multipartHeaders.get(paramOrFileName)
    }

    HttpHeaders setMultipartHeaders(String paramOrFileName, HttpHeaders headers) {
        return multipartHeaders.put(paramOrFileName, headers)
    }

    Collection getParts() {
        getMultiFileMap().values().flatten().collect {new MockPart(it)}
    }

    Part getPart(String name) {
        MultipartFile file = getFile(name)
        if (file) {
            return new MockPart(file)
        }
    }

    AsyncContext startAsync() {
        def webRequest = GrailsWebRequest.lookup()
        def response = webRequest?.currentResponse
        if (response == null) {
            response = new GrailsMockHttpServletResponse()
        }
        startAsync(this,response)
    }

    AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) {
        return new MockAsyncContext(servletRequest, servletResponse)
    }

    boolean isAsyncStarted() { asyncContext != null }

    boolean isAsyncSupported() { true }

    @Override
    public ServletInputStream getInputStream() {
        if (cachedInputStream == null) {
            cachedInputStream = super.getInputStream()
        }
        cachedInputStream
    }
}

class MockPart implements Part {

    MultipartFile file
    MultiValueMap headers = new LinkedMultiValueMap()

    MockPart(MultipartFile file) {
        this.file = file
    }

    @Override
    InputStream getInputStream() {
        file.inputStream
    }

    @Override
    String getContentType() {
        file.contentType
    }

    @Override
    String getName() {
        file.name
    }

    @Override
    String getSubmittedFileName() {
        "N/A"
    }

    @Override
    long getSize() {
        file.size
    }

    @Override
    void write(String fileName) {
        file.transferTo(new File(fileName))
    }

    @Override
    void delete() {
        // no-op
    }

    @Override
    String getHeader(String name) {
        return headers.getFirst(name)
    }

    @Override
    Collection getHeaders(String name) {
        return headers[name]
    }

    @Override
    Collection getHeaderNames() {
        return headers.keySet()
    }
}

class MockAsyncContext implements AsyncContext {

    ServletRequest request
    ServletResponse response
    String dispatchUri
    long timeout = -1
    List asyncListeners = []

    MockAsyncContext(HttpServletRequest request, HttpServletResponse response) {
        this.request = request
        this.response = response
    }

    boolean hasOriginalRequestAndResponse() {
        return true
    }

    void dispatch() {
        dispatchUri = request.requestURI
    }

    void dispatch(String path) {
        dispatchUri = path
    }

    void dispatch(ServletContext context, String path) {
        dispatchUri = path
    }

    void complete() {
        // no op
    }

    void start(Runnable run) {
        try {
            for (listener in asyncListeners) {
                AsyncListener al = listener.listener
                al.onStartAsync(listener.event)
            }
            run.run()
            for (listener in asyncListeners) {
                AsyncListener al = listener.listener
                al.onComplete(listener.event)
            }

        } catch (e) {
            for (listener in asyncListeners) {
                AsyncListener al = listener.listener
                al.onError(new AsyncEvent(this, e))
            }
        }
    }

    void addListener(AsyncListener listener) {
        asyncListeners << [listener:listener, event:new AsyncEvent(this, request, response)]
    }

    void addListener(AsyncListener listener, ServletRequest servletRequest, ServletResponse servletResponse) {
        asyncListeners << [listener:listener, event:new AsyncEvent(this, servletRequest, servletResponse)]
    }

    def  T createListener(Class clazz) {
        return clazz.newInstance()
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy