com.netflix.zuul.context.RequestContext Maven / Gradle / Ivy
/*
* Copyright 2013 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.netflix.zuul.context;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.Mockito.mock;
import java.io.InputStream;
import java.io.NotSerializableException;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.netflix.util.Pair;
import com.netflix.zuul.constants.ZuulHeaders;
import com.netflix.zuul.util.DeepCopy;
/**
* The Request Context holds request, response, state information and data for ZuulFilters to access and share.
* The RequestContext lives for the duration of the request and is ThreadLocal.
* extensions of RequestContext can be substituted by setting the contextClass.
* Most methods here are convenience wrapper methods; the RequestContext is an extension of a ConcurrentHashMap
*
* @author Mikey Cohen
* Date: 10/13/11
* Time: 10:21 AM
*/
public class RequestContext extends ConcurrentHashMap {
private static final Logger LOG = LoggerFactory.getLogger(RequestContext.class);
protected static Class extends RequestContext> contextClass = RequestContext.class;
private static RequestContext testContext = null;
protected static final ThreadLocal extends RequestContext> threadLocal = new ThreadLocal() {
@Override
protected RequestContext initialValue() {
try {
return contextClass.newInstance();
} catch (Throwable e) {
throw new RuntimeException(e);
}
}
};
public RequestContext() {
super();
}
/**
* Override the default RequestContext
*
* @param clazz
*/
public static void setContextClass(Class extends RequestContext> clazz) {
contextClass = clazz;
}
/**
* set an overriden "test" context
*
* @param context
*/
public static void testSetCurrentContext(RequestContext context) {
testContext = context;
}
/**
* Get the current RequestContext
*
* @return the current RequestContext
*/
public static RequestContext getCurrentContext() {
if (testContext != null) return testContext;
RequestContext context = threadLocal.get();
return context;
}
/**
* Convenience method to return a boolean value for a given key
*
* @param key
* @return true or false depending what was set. default is false
*/
public boolean getBoolean(String key) {
return getBoolean(key, false);
}
/**
* Convenience method to return a boolean value for a given key
*
* @param key
* @param defaultResponse
* @return true or false depending what was set. default defaultResponse
*/
public boolean getBoolean(String key, boolean defaultResponse) {
Boolean b = (Boolean) get(key);
if (b != null) {
return b.booleanValue();
}
return defaultResponse;
}
/**
* sets a key value to Boolen.TRUE
*
* @param key
*/
public void set(String key) {
put(key, Boolean.TRUE);
}
/**
* puts the key, value into the map. a null value will remove the key from the map
*
* @param key
* @param value
*/
public void set(String key, Object value) {
if (value != null) put(key, value);
else remove(key);
}
/**
* true if zuulEngineRan
*
* @return
*/
public boolean getZuulEngineRan() {
return getBoolean("zuulEngineRan");
}
/**
* sets zuulEngineRan to true
*/
public void setZuulEngineRan() {
put("zuulEngineRan", true);
}
/**
* @return the HttpServletRequest from the "request" key
*/
public HttpServletRequest getRequest() {
return (HttpServletRequest) get("request");
}
/**
* sets the HttpServletRequest into the "request" key
*
* @param request
*/
public void setRequest(HttpServletRequest request) {
put("request", request);
}
/**
* @return the HttpServletResponse from the "response" key
*/
public HttpServletResponse getResponse() {
return (HttpServletResponse) get("response");
}
/**
* sets the "response" key to the HttpServletResponse passed in
*
* @param response
*/
public void setResponse(HttpServletResponse response) {
set("response", response);
}
/**
* returns a set throwable
*
* @return a set throwable
*/
public Throwable getThrowable() {
return (Throwable) get("throwable");
}
/**
* sets a throwable
*
* @param th
*/
public void setThrowable(Throwable th) {
put("throwable", th);
}
/**
* sets debugRouting
*
* @param bDebug
*/
public void setDebugRouting(boolean bDebug) {
set("debugRouting", bDebug);
}
/**
* @return "debugRouting"
*/
public boolean debugRouting() {
return getBoolean("debugRouting");
}
/**
* sets "debugRequestHeadersOnly" to bHeadersOnly
*
* @param bHeadersOnly
*/
public void setDebugRequestHeadersOnly(boolean bHeadersOnly) {
set("debugRequestHeadersOnly", bHeadersOnly);
}
/**
* @return "debugRequestHeadersOnly"
*/
public boolean debugRequestHeadersOnly() {
return getBoolean("debugRequestHeadersOnly");
}
/**
* sets "debugRequest"
*
* @param bDebug
*/
public void setDebugRequest(boolean bDebug) {
set("debugRequest", bDebug);
}
/**
* gets debugRequest
*
* @return debugRequest
*/
public boolean debugRequest() {
return getBoolean("debugRequest");
}
/**
* removes "routeHost" key
*/
public void removeRouteHost() {
remove("routeHost");
}
/**
* sets routeHost
*
* @param routeHost a URL
*/
public void setRouteHost(URL routeHost) {
set("routeHost", routeHost);
}
/**
* @return "routeHost" URL
*/
public URL getRouteHost() {
return (URL) get("routeHost");
}
/**
* appends filter name and status to the filter execution history for the
* current request
*
* @param executedFilters - name of the filter
*/
public void addFilterExecutionSummary(String name, String status, long time) {
StringBuilder sb = getFilterExecutionSummary();
if (sb.length() > 0) sb.append(", ");
sb.append(name).append('[').append(status).append(']').append('[').append(time).append("ms]");
}
/**
* @return String that represents the filter execution history for the current request
*/
public StringBuilder getFilterExecutionSummary() {
if (get("executedFilters") == null) {
putIfAbsent("executedFilters", new StringBuilder());
}
return (StringBuilder) get("executedFilters");
}
/**
* sets the "responseBody" value as a String. This is the response sent back to the client.
*
* @param body
*/
public void setResponseBody(String body) {
set("responseBody", body);
}
/**
* @return the String response body to be snt back to the requesting client
*/
public String getResponseBody() {
return (String) get("responseBody");
}
/**
* sets the InputStream of the response into the responseDataStream
*
* @param responseDataStream
*/
public void setResponseDataStream(InputStream responseDataStream) {
set("responseDataStream", responseDataStream);
}
/**
* sets the flag responseGZipped if the response is gzipped
*
* @param gzipped
*/
public void setResponseGZipped(boolean gzipped) {
put("responseGZipped", gzipped);
}
/**
* @return true if responseGZipped is true (the response is gzipped)
*/
public boolean getResponseGZipped() {
return getBoolean("responseGZipped", true);
}
/**
* @return the InputStream Response
*/
public InputStream getResponseDataStream() {
return (InputStream) get("responseDataStream");
}
/**
* If this value is true then the response should be sent to the client.
*
* @return
*/
public boolean sendZuulResponse() {
return getBoolean("sendZuulResponse", true);
}
/**
* sets the sendZuulResponse boolean
*
* @param bSend
*/
public void setSendZuulResponse(boolean bSend) {
set("sendZuulResponse", Boolean.valueOf(bSend));
}
/**
* returns the response status code. Default is 200
*
* @return
*/
public int getResponseStatusCode() {
return get("responseStatusCode") != null ? (Integer) get("responseStatusCode") : 500;
}
/**
* Use this instead of response.setStatusCode()
*
* @param nStatusCode
*/
public void setResponseStatusCode(int nStatusCode) {
getResponse().setStatus(nStatusCode);
set("responseStatusCode", nStatusCode);
}
/**
* add a header to be sent to the origin
*
* @param name
* @param value
*/
public void addZuulRequestHeader(String name, String value) {
getZuulRequestHeaders().put(name.toLowerCase(), value);
}
/**
* return the list of requestHeaders to be sent to the origin
*
* @return the list of requestHeaders to be sent to the origin
*/
public Map getZuulRequestHeaders() {
if (get("zuulRequestHeaders") == null) {
HashMap zuulRequestHeaders = new HashMap();
putIfAbsent("zuulRequestHeaders", zuulRequestHeaders);
}
return (Map) get("zuulRequestHeaders");
}
/**
* add a header to be sent to the response
*
* @param name
* @param value
*/
public void addZuulResponseHeader(String name, String value) {
getZuulResponseHeaders().add(new Pair(name, value));
}
/**
* returns the current response header list
*
* @return a List> of response headers
*/
public List> getZuulResponseHeaders() {
if (get("zuulResponseHeaders") == null) {
List> zuulRequestHeaders = new ArrayList>();
putIfAbsent("zuulResponseHeaders", zuulRequestHeaders);
}
return (List>) get("zuulResponseHeaders");
}
/**
* the Origin response headers
*
* @return the List> of headers sent back from the origin
*/
public List> getOriginResponseHeaders() {
if (get("originResponseHeaders") == null) {
List> originResponseHeaders = new ArrayList>();
putIfAbsent("originResponseHeaders", originResponseHeaders);
}
return (List>) get("originResponseHeaders");
}
/**
* adds a header to the origin response headers
*
* @param name
* @param value
*/
public void addOriginResponseHeader(String name, String value) {
getOriginResponseHeaders().add(new Pair(name, value));
}
/**
* returns the content-length of the origin response
*
* @return the content-length of the origin response
*/
public Long getOriginContentLength() {
return (Long) get("originContentLength");
}
/**
* sets the content-length from the origin response
*
* @param v
*/
public void setOriginContentLength(Long v) {
set("originContentLength", v);
}
/**
* sets the content-length from the origin response
*
* @param v parses the string into an int
*/
public void setOriginContentLength(String v) {
try {
final Long i = Long.valueOf(v);
set("originContentLength", i);
} catch (NumberFormatException e) {
LOG.warn("error parsing origin content length", e);
}
}
/**
* @return true if the request body is chunked
*/
public boolean isChunkedRequestBody() {
final Object v = get("chunkedRequestBody");
return (v != null) ? (Boolean) v : false;
}
/**
* sets chunkedRequestBody to true
*/
public void setChunkedRequestBody() {
this.set("chunkedRequestBody", Boolean.TRUE);
}
/**
* @return true is the client request can accept gzip encoding. Checks the "accept-encoding" header
*/
public boolean isGzipRequested() {
final String requestEncoding = this.getRequest().getHeader(ZuulHeaders.ACCEPT_ENCODING);
return requestEncoding != null && requestEncoding.toLowerCase().contains("gzip");
}
/**
* unsets the threadLocal context. Done at the end of the request.
*/
public void unset() {
threadLocal.remove();
}
/**
* Mkaes a copy of the RequestContext. This is used for debugging.
*
* @return
*/
public RequestContext copy() {
RequestContext copy = new RequestContext();
Iterator it = keySet().iterator();
String key = it.next();
while (key != null) {
Object orig = get(key);
try {
Object copyValue = DeepCopy.copy(orig);
if (copyValue != null) {
copy.set(key, copyValue);
} else {
copy.set(key, orig);
}
} catch (NotSerializableException e) {
copy.set(key, orig);
}
if (it.hasNext()) {
key = it.next();
} else {
key = null;
}
}
return copy;
}
/**
* @return Map> of the request Query Parameters
*/
public Map> getRequestQueryParams() {
return (Map>) get("requestQueryParams");
}
/**
* sets the request query params list
*
* @param qp Map> qp
*/
public void setRequestQueryParams(Map> qp) {
put("requestQueryParams", qp);
}
@RunWith(MockitoJUnitRunner.class)
public static class UnitTest {
@Mock
HttpServletRequest request;
@Mock
HttpServletResponse response;
@Test
public void testGetContext() {
RequestContext context = RequestContext.getCurrentContext();
assertNotNull(context);
}
@Test
public void testSetContextVariable() {
RequestContext context = RequestContext.getCurrentContext();
assertNotNull(context);
context.set("test", "moo");
assertEquals(context.get("test"), "moo");
}
@Test
public void testSet() {
RequestContext context = RequestContext.getCurrentContext();
assertNotNull(context);
context.set("test");
assertEquals(context.get("test"), Boolean.TRUE);
}
@Test
public void testBoolean() {
RequestContext context = RequestContext.getCurrentContext();
assertEquals(context.getBoolean("boolean_test"), Boolean.FALSE);
assertEquals(context.getBoolean("boolean_test", true), true);
}
@Test
public void testCopy() {
RequestContext context = RequestContext.getCurrentContext();
context.put("test", "test");
context.put("test1", "test1");
context.put("test2", "test2");
RequestContext copy = context.copy();
assertEquals(copy.get("test"), "test");
assertEquals(copy.get("test1"), "test1");
assertEquals(copy.get("test2"), "test2");
// assertFalse(copy.get("test").hashCode() == context.get("test").hashCode());
}
@Test
public void testResponseHeaders() {
RequestContext context = RequestContext.getCurrentContext();
context.addZuulRequestHeader("header", "test");
Map headerMap = context.getZuulRequestHeaders();
assertNotNull(headerMap);
assertEquals(headerMap.get("header"), "test");
}
@Test
public void testAccessors() {
RequestContext context = new RequestContext();
RequestContext.testSetCurrentContext(context);
context.setRequest(request);
context.setResponse(response);
Throwable th = new Throwable();
context.setThrowable(th);
assertEquals(context.getThrowable(), th);
assertEquals(context.debugRouting(), false);
context.setDebugRouting(true);
assertEquals(context.debugRouting(), true);
assertEquals(context.debugRequest(), false);
context.setDebugRequest(true);
assertEquals(context.debugRequest(), true);
context.setDebugRequest(false);
assertEquals(context.debugRequest(), false);
context.setDebugRouting(false);
assertEquals(context.debugRouting(), false);
try {
URL url = new URL("http://www.moldfarm.com");
context.setRouteHost(url);
assertEquals(context.getRouteHost(), url);
} catch (MalformedURLException e) {
e.printStackTrace();
}
InputStream in = mock(InputStream.class);
context.setResponseDataStream(in);
assertEquals(context.getResponseDataStream(), in);
assertEquals(context.sendZuulResponse(), true);
context.setSendZuulResponse(false);
assertEquals(context.sendZuulResponse(), false);
context.setResponseStatusCode(100);
assertEquals(context.getResponseStatusCode(), 100);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy