
com.flyfish.oauth.client.RestClient Maven / Gradle / Ivy
package com.flyfish.oauth.client;
import com.fasterxml.jackson.core.type.TypeReference;
import com.flyfish.oauth.builder.TypedMapBuilder;
import com.flyfish.oauth.common.Consumer;
import com.flyfish.oauth.common.OAuthContext;
import com.flyfish.oauth.domain.OAuthSSOToken;
import com.flyfish.oauth.entry.AbstractAuthenticationEntryPoint;
import com.flyfish.oauth.utils.JacksonUtil;
import lombok.AllArgsConstructor;
import lombok.SneakyThrows;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpEntity;
import org.apache.http.NameValuePair;
import org.apache.http.StatusLine;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.entity.UrlEncodedFormEntity;
import org.apache.http.client.methods.*;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.message.BasicNameValuePair;
import org.apache.http.ssl.SSLContextBuilder;
import org.apache.http.ssl.TrustStrategy;
import org.apache.http.util.EntityUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLContext;
import java.io.IOException;
import java.io.InputStream;
import java.net.UnknownHostException;
import java.nio.charset.Charset;
import java.security.KeyManagementException;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Rest请求客户端
*
* @author Mr.Wang
*
* 1. 全builder调用,用户系统内部相互通信
* 2. 支持异步回调
* 3. 多样性组合
* 4. 解耦实现
*/
public final class RestClient {
private static final Logger logger = LoggerFactory.getLogger(RestClient.class);
// 内置的客户端
private static final CloseableHttpClient client = HttpClients.createDefault();
// 内置ssl客户端
private static final CloseableHttpClient sslClient = createSSLClient();
// 请求配置
private static final RequestConfig requestConfig = RequestConfig.custom().setConnectTimeout(3000).setAuthenticationEnabled(true).build();
// 请求解析器
private static final Map resolverMap = resolverBuilder()
.with(HttpMethod.GET, new HttpGetResolver())
.with(HttpMethod.POST, new HttpPostResolver())
.with(HttpMethod.PUT, new HttpPutResolver())
.build();
private final HttpRequestBase request;
private Consumer consumer;
private Consumer errorConsumer;
private TypeReference> typeReference;
private ResponseType responseType = ResponseType.NORMAL;
/**
* 内部构造方法,不对外公开
*
* @param request 请求信息
*/
private RestClient(HttpRequestBase request) {
this.request = request;
}
public static RestClientBuilder create() {
return new RestClientBuilder();
}
private static TypedMapBuilder resolverBuilder() {
return TypedMapBuilder.builder();
}
/**
* 不信任的证书请求客户端
*
* @return 结果
*/
private static CloseableHttpClient createSSLClient() {
//信任所有
try {
SSLContext context = SSLContextBuilder.create().loadTrustMaterial(null, new TrustStrategy() {
@Override
public boolean isTrusted(X509Certificate[] x509Certificates, String s) {
return true;
}
}).build();
SSLConnectionSocketFactory factory = new SSLConnectionSocketFactory(context);
return HttpClients.custom().setSSLSocketFactory(factory).build();
} catch (NoSuchAlgorithmException | KeyManagementException | KeyStoreException e) {
e.printStackTrace();
}
return null;
}
/**
* 销毁单例资源
*/
@SneakyThrows
public static void destroy() {
if (null != client) {
client.close();
}
if (null != sslClient) {
sslClient.close();
}
}
/**
* 设置请求失败时的回调
*
* @param errorConsumer 错误回调
* @return 结果
*/
public RestClient onError(Consumer errorConsumer) {
this.errorConsumer = errorConsumer;
return this;
}
/**
* 错误处理
*
* @param e 异常
*/
private void handleError(RestClientException e) {
if (null != errorConsumer) {
errorConsumer.accept(e);
} else {
throw e;
}
}
/**
* 设置响应类型
*
* @param responseType 响应类型
* @return 结果
*/
public RestClient responseType(ResponseType responseType) {
this.responseType = responseType;
return this;
}
/**
* 异步执行,接收结果
*
* @param consumer 结果
*/
public void execute(Consumer consumer) {
this.consumer = consumer;
try {
execute();
} catch (IOException e) {
handleError(new RestClientException(e.getMessage(), e, null));
}
}
/**
* 执行请求,返回Map
*
* @return map
* @throws IOException 异常
*/
public Map executeForMap() throws IOException {
this.responseType = ResponseType.JSON;
return innerExecute();
}
/**
* 执行请求,返回字符串
*
* @return 字符串
* @throws IOException 异常
*/
public String executeForString() throws IOException {
this.responseType = ResponseType.TEXT;
return innerExecute();
}
/**
* 执行请求,返回响应实体,自行处理
*
* @return 响应实体
* @throws IOException 异常
*/
public T execute() throws IOException {
return innerExecute();
}
/**
* 执行请求,根据type引用实例化
*
* @param typeReference 类型引用
* @param 泛型
* @return 结果
*/
public T execute(TypeReference typeReference) {
this.responseType = ResponseType.JSON;
this.typeReference = typeReference;
try {
return innerExecute();
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
/**
* 内部执行方法,预处理结果
*
* @param 泛型
* @return 结果
*/
private T innerExecute() throws IOException {
CloseableHttpClient client = getClient();
try (CloseableHttpResponse response = client.execute(request)) {
StatusLine statusLine = response.getStatusLine();
if (200 == statusLine.getStatusCode()) {
HttpEntity entity = response.getEntity();
if (consumer != null) {
consumer.accept(entity);
}
return resolveResponse(entity);
} else {
int requestCode = response.getStatusLine().getStatusCode();
logger.error(request.getURI() + "接口调用失败,code:" + requestCode);
handleError(new RestClientException("网络请求状态异常!代码:" + requestCode, null, statusLine));
}
} catch (UnknownHostException e) {
handleError(new RestClientException(e.getMessage(), e, null));
logger.error("未知的请求地址!");
} finally {
request.releaseConnection();
}
return null;
}
/**
* 解析结果
*
* @param entity 响应体
* @param 泛型
* @return 结果
*/
@SuppressWarnings("unchecked")
private T resolveResponse(HttpEntity entity) throws IOException {
switch (responseType) {
case NORMAL:
return (T) entity;
case TEXT:
return (T) EntityUtils.toString(entity);
case JSON:
if (typeReference != null) {
return (T) JacksonUtil.fromJson(EntityUtils.toString(entity), typeReference);
}
return (T) JacksonUtil.fromJson(EntityUtils.toString(entity));
case BINARY:
try (InputStream in = entity.getContent()) {
return (T) IOUtils.toByteArray(in);
}
default:
return (T) entity;
}
}
/**
* 获取请求客户端
*
* @return 结果
*/
private CloseableHttpClient getClient() {
return request.getURI().getScheme().equals("https") ? sslClient : client;
}
public enum HttpMethod {
GET, HEAD, POST, PUT, PATCH, DELETE, OPTIONS, TRACE;
}
/**
* 响应类型
*/
private enum ResponseType {
NORMAL, TEXT, JSON, BINARY
}
/**
* Http请求解析器
*/
private interface HttpMethodResolver {
HttpRequestBase resolve(RestClientBuilder builder);
}
/**
* 异常类,用于包装异常
*/
public static class RestClientException extends RuntimeException {
private static final long serialVersionUID = 4741281547788724661L;
private Exception nested;
private Object bind;
public RestClientException(String message, Exception nested) {
super(message);
this.nested = nested;
}
public RestClientException(String message, Exception nested, Object bind) {
super(message);
this.nested = nested;
this.bind = bind;
}
public RestClientException(String message) {
super(message);
}
@SuppressWarnings("unchecked")
public T getBind() {
return (T) bind;
}
public void setBind(Object bind) {
this.bind = bind;
}
public Exception getNested() {
return nested;
}
}
/**
* 主要的builder,核心构建
*/
public static class RestClientBuilder {
private String url;
private HttpMethod method = HttpMethod.GET;
private Map params;
private String body;
private Map headers;
private List multipartList;
private boolean multipart;
private String charset;
public String getUrl() {
return url;
}
public RestClientBuilder url(String url) {
this.url = url;
return this;
}
public HttpMethod getMethod() {
return method;
}
public RestClientBuilder method(HttpMethod method) {
this.method = method;
return this;
}
public RestClientBuilder get() {
this.method = HttpMethod.GET;
return this;
}
public RestClientBuilder withCredentials() {
this.addHeader(AbstractAuthenticationEntryPoint.AUTH_HEADER,
OAuthSSOToken.BEARER_TYPE + " " + OAuthContext.clientToken().getAccessToken());
return this;
}
public RestClientBuilder post() {
this.method = HttpMethod.POST;
return this;
}
public RestClientBuilder multipart() {
this.multipart = true;
return this;
}
public boolean isMultipart() {
return multipart;
}
public Map getParams() {
if (null == params) {
params = new HashMap<>();
}
return params;
}
public RestClientBuilder queryParams(Map params) {
this.params = params;
return this;
}
public RestClientBuilder addParam(String key, Object value) {
if (null == this.params) {
this.params = new HashMap<>();
}
this.params.put(key, value);
return this;
}
public RestClientBuilder charset(String charset) {
this.charset = charset;
return this;
}
public Charset getCharset() {
return StringUtils.isBlank(charset) ? Charset.defaultCharset() : Charset.forName(charset);
}
public RestClientBuilder addMultipartBody(String name, String filename, Object data) {
if (null == this.multipartList) {
this.multipartList = new ArrayList<>();
}
this.multipartList.add(new Multipart(name, filename, data));
return this;
}
public List getMultipartList() {
if (null == multipartList) {
multipartList = new ArrayList<>();
}
return multipartList;
}
public String getBody() {
return body;
}
public RestClientBuilder body(String body) {
this.body = body;
return this;
}
public RestClientBuilder body(Object body) {
this.body = JacksonUtil.toJson(body);
return this;
}
public Map getHeaders() {
if (null == headers) {
headers = new HashMap<>();
}
return headers;
}
public RestClientBuilder headers(Map headers) {
this.headers = headers;
return this;
}
public RestClientBuilder addHeader(String key, String value) {
if (null == this.headers) {
this.headers = new HashMap<>();
}
this.headers.put(key, value);
return this;
}
/**
* 匹配解析器
*
* @return 结果
*/
private HttpRequestBase buildRequest() {
HttpRequestBase request = MapUtils.getObject(resolverMap, this.method, resolverMap.get(HttpMethod.GET))
.resolve(this);
// 添加头
for (Map.Entry header : getHeaders().entrySet()) {
request.addHeader(header.getKey(), header.getValue());
}
// 设置公共设置
request.setConfig(requestConfig);
// 返回
return request;
}
/**
* 构建client
*
* @return 结果
*/
public RestClient build() {
// 创建请求
HttpRequestBase request = buildRequest();
return new RestClient(request);
}
/**
* 存储文件上传的part
*/
@AllArgsConstructor
private static class Multipart {
private final String name;
private final String filename;
private final Object data;
}
}
/**
* Get方法解析参数的解析器
*/
private static class HttpGetResolver implements HttpMethodResolver {
@Override
public HttpRequestBase resolve(RestClientBuilder builder) {
if (MapUtils.isNotEmpty(builder.getParams())) {
StringBuilder paramBuilder = new StringBuilder(builder.getUrl().contains("?") ? "&" : "?");
for (Map.Entry entry : builder.getParams().entrySet()) {
paramBuilder.append(entry.getKey()).append("=").append(entry.getValue());
}
builder.url(builder.getUrl() + paramBuilder.toString());
}
return new HttpGet(builder.getUrl());
}
}
/**
* Post方法解析参数的解析器,包括上传
*/
private static class HttpPostResolver implements HttpMethodResolver {
@Override
public HttpRequestBase resolve(RestClientBuilder builder) {
HttpPost post = new HttpPost(builder.getUrl());
HttpEntity entity = StringUtils.isNotBlank(builder.getBody()) ? buildJson(builder)
: buildFormData(builder);
post.setEntity(entity);
return post;
}
/**
* 构建JSON方式的POST
*
* @param clientBuilder builder
* @return 结果
*/
protected HttpEntity buildJson(RestClientBuilder clientBuilder) {
clientBuilder.addHeader("Content-Type", "application/json;charset=UTF-8");
Charset charset = clientBuilder.getCharset();
StringEntity entity = new StringEntity(clientBuilder.getBody(), charset);
entity.setContentEncoding(charset.toString());
entity.setContentType("application/json");
return entity;
}
/**
* 构建formdata
*
* @param clientBuilder builder
* @return 结果
*/
protected HttpEntity buildFormData(RestClientBuilder clientBuilder) {
// 设置参数
Map params = clientBuilder.getParams();
List list = new ArrayList<>();
for (Map.Entry entry : params.entrySet()) {
if (null != entry.getValue()) {
list.add(new BasicNameValuePair(entry.getKey(), String.valueOf(entry.getValue())));
}
}
if (CollectionUtils.isNotEmpty(list)) {
return new UrlEncodedFormEntity(list, clientBuilder.getCharset());
}
return null;
}
}
/**
* Put方法解析参数的解析器,包括上传
*/
private static class HttpPutResolver extends HttpPostResolver {
@Override
public HttpRequestBase resolve(RestClientBuilder builder) {
HttpPut put = new HttpPut(builder.getUrl());
HttpEntity entity = StringUtils.isNotBlank(builder.getBody()) ? buildJson(builder)
: buildFormData(builder);
put.setEntity(entity);
return put;
}
}
}