All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.linkedin.restli.client.GetRequestGroup Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * Copyright 2016 LinkedIn, Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package com.linkedin.restli.client;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.linkedin.common.callback.Callback;
import com.linkedin.data.DataMap;
import com.linkedin.data.schema.PathSpec;
import com.linkedin.data.template.RecordTemplate;
import com.linkedin.parseq.batching.Batch;
import com.linkedin.parseq.batching.BatchImpl.BatchEntry;
import com.linkedin.parseq.function.Tuple3;
import com.linkedin.parseq.function.Tuples;
import com.linkedin.r2.RemoteInvocationException;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.message.rest.RestResponseBuilder;
import com.linkedin.restli.client.response.BatchKVResponse;
import com.linkedin.restli.common.BatchResponse;
import com.linkedin.restli.common.EntityResponse;
import com.linkedin.restli.common.ErrorResponse;
import com.linkedin.restli.common.HttpStatus;
import com.linkedin.restli.common.ProtocolVersion;
import com.linkedin.restli.common.ResourceMethod;
import com.linkedin.restli.common.ResourceSpec;
import com.linkedin.restli.common.RestConstants;
import com.linkedin.restli.internal.client.ResponseImpl;
import com.linkedin.restli.internal.client.response.BatchEntityResponse;
import com.linkedin.restli.internal.common.ProtocolVersionUtil;
import com.linkedin.restli.internal.common.ResponseUtils;

class GetRequestGroup implements RequestGroup {

  private static final Logger LOGGER = LoggerFactory.getLogger(GetRequestGroup.class);
  private static final RestLiResponseException NOT_FOUND_EXCEPTION =
      new RestLiResponseException(new RestResponseBuilder().setStatus(HttpStatus.S_404_NOT_FOUND.getCode()).build(),
          null, new ErrorResponse().setStatus(HttpStatus.S_404_NOT_FOUND.getCode()));

  private final String _baseUriTemplate; //taken from first request, used to differentiate between groups
  private final ResourceSpec _resourceSpec;  //taken from first request
  private final Map _headers; //taken from first request, used to differentiate between groups
  private final RestliRequestOptions _requestOptions; //taken from first request, used to differentiate between groups
  private final Map _queryParams; //taken from first request, used to differentiate between groups
  private final Map _pathKeys; //taken from first request, used to differentiate between groups
  private final int _maxBatchSize;

  @SuppressWarnings("deprecation")
  public GetRequestGroup(Request request, int maxBatchSize) {
    _baseUriTemplate = request.getBaseUriTemplate();
    _headers = request.getHeaders();
    _queryParams = getQueryParamsForBatchingKey(request);
    _resourceSpec = request.getResourceSpec();
    _requestOptions = request.getRequestOptions();
    _pathKeys = request.getPathKeys();
    _maxBatchSize = maxBatchSize;
  }

  private static Map getQueryParamsForBatchingKey(Request request)
  {
    final Map params = new HashMap<>(request.getQueryParamsObjects());
    params.remove(RestConstants.QUERY_BATCH_IDS_PARAM);
    params.remove(RestConstants.FIELDS_PARAM);
    return params;
  }

  private static  Response unbatchResponse(BatchGetEntityRequest request,
      Response>> batchResponse, Object id) throws RemoteInvocationException {
    final BatchKVResponse> batchEntity = batchResponse.getEntity();
    final ErrorResponse errorResponse = batchEntity.getErrors().get(id);
    if (errorResponse != null) {
      throw new RestLiResponseException(errorResponse);
    }

    final EntityResponse entityResponse = batchEntity.getResults().get(id);
    if (entityResponse != null) {
      final RT entityResult = entityResponse.getEntity();
      if (entityResult != null) {
        return new ResponseImpl<>(batchResponse, entityResult);
      }
    }

    LOGGER.debug("No result or error for base URI : {}, id: {}. Verify that the batchGet endpoint returns response keys that match batchGet request IDs.",
        request.getBaseUriTemplate(), id);

    throw NOT_FOUND_EXCEPTION;
  }

  private DataMap filterIdsInBatchResult(DataMap data, Set ids) {
    DataMap dm = new DataMap(data.size());
    data.forEach((key, value) -> {
      switch(key) {
        case BatchResponse.ERRORS:
          dm.put(key, filterIds((DataMap)value, ids));
          break;
        case BatchResponse.RESULTS:
          dm.put(key, filterIds((DataMap)value, ids));
          break;
        case BatchResponse.STATUSES:
          dm.put(key, filterIds((DataMap)value, ids));
          break;
        default:
          dm.put(key, value);
          break;
      }
    });
    return dm;
  }

  private Object filterIds(DataMap data, Set ids) {
    DataMap dm = new DataMap(data.size());
    data.forEach((key, value) -> {
      if (ids.contains(key)) {
        dm.put(key, value);
      }
    });
    return dm;
  }


  //Tuple3: (keys, fields, contains-batch-get)
  private static Tuple3, Set, Boolean> reduceRequests(final Tuple3, Set, Boolean> state,
      final Request rq) {
    return reduceContainsBatch(reduceIds(reduceFields(state, rq), rq), rq);
  }

  //Tuple3: (keys, fields, contains-batch-get)
  private static Tuple3, Set, Boolean> reduceContainsBatch(Tuple3, Set, Boolean> state,
      Request request) {
    if (request instanceof GetRequest) {
      return state;
    } else if (request instanceof BatchRequest) {
      return Tuples.tuple(state._1(), state._2(), true);
    } else {
      throw unsupportedGetRequestType(request);
    }
  }

  //Tuple3: (keys, fields, contains-batch-get)
  private static Tuple3, Set, Boolean> reduceIds(Tuple3, Set, Boolean> state,
      Request request) {
    if (request instanceof GetRequest) {
      GetRequest getRequest = (GetRequest)request;
      state._1().add(getRequest.getObjectId());
      return state;
    } else if (request instanceof BatchRequest) {
      BatchRequest batchRequest = (BatchRequest)request;
      state._1().addAll(batchRequest.getObjectIds());
      return state;
    } else {
      throw unsupportedGetRequestType(request);
    }
  }

  //Tuple3: (keys, fields, contains-batch-get)
  private static Tuple3, Set, Boolean> reduceFields(Tuple3, Set, Boolean> state,
      Request request) {
    if (request instanceof GetRequest || request instanceof BatchRequest) {
      final Set requestFields = request.getFields();
      if (requestFields != null && !requestFields.isEmpty()) {
        if (state._2() != null) {
          state._2().addAll(requestFields);
        }
        return state;
      } else {
        return Tuples.tuple(state._1(), null, state._3());
      }
    } else {
      throw unsupportedGetRequestType(request);
    }
  }

  @SuppressWarnings({ "rawtypes", "unchecked" })
  private  void doExecuteBatchGet(final Client client,
    final Batch> batch, final Set ids, final Set fields,
    Function, RequestContext> requestContextProvider) {
    final BatchGetEntityRequestBuilder builder = new BatchGetEntityRequestBuilder<>(_baseUriTemplate, _resourceSpec, _requestOptions);
    builder.setHeaders(_headers);
    _queryParams.forEach((key, value) -> builder.setParam(key, value));
    _pathKeys.forEach((key, value) -> builder.pathKey(key, value));

    builder.ids((Set)ids);
    if (fields != null && !fields.isEmpty()) {
      builder.fields(fields.toArray(new PathSpec[fields.size()]));
    }

    final BatchGetEntityRequest batchGet = builder.build();

    client.sendRequest(batchGet, requestContextProvider.apply(batchGet), new Callback>>>() {

      @Override
      public void onSuccess(Response>> responseToBatch) {
        final ProtocolVersion version = ProtocolVersionUtil.extractProtocolVersion(responseToBatch.getHeaders());
        batch.entries().stream()
        .forEach(entry -> {
          try {
            RestRequestBatchKey rrbk = entry.getKey();
            Request request = rrbk.getRequest();
            if (request instanceof GetRequest) {
              successGet((GetRequest) request, responseToBatch, batchGet, entry, version);
            } else if (request instanceof BatchGetKVRequest) {
              successBatchGetKV((BatchGetKVRequest) request, responseToBatch, entry, version);
            } else if (request instanceof BatchGetRequest) {
              successBatchGet((BatchGetRequest) request, responseToBatch, entry, version);
            } else if (request instanceof BatchGetEntityRequest) {
              successBatchGetEntity((BatchGetEntityRequest) request, responseToBatch, entry, version);
            } else {
              entry.getValue().getPromise().fail(unsupportedGetRequestType(request));
            }
          } catch (RemoteInvocationException e) {
            entry.getValue().getPromise().fail(e);
          }
        });
      }

      @SuppressWarnings({ "deprecation" })
      private void successBatchGetEntity(BatchGetEntityRequest request,
          Response>> responseToBatch,
          Entry>> entry, final ProtocolVersion version) {
        Set ids = (Set) request.getObjectIds().stream()
            .map(o -> BatchResponse.keyToString(o, version))
            .collect(Collectors.toSet());
        DataMap dm = filterIdsInBatchResult(responseToBatch.getEntity().data(), ids);
        BatchKVResponse br = new BatchEntityResponse<>(dm, request.getResourceSpec().getKeyType(),
            request.getResourceSpec().getValueType(), request.getResourceSpec().getKeyParts(),
            request.getResourceSpec().getComplexKeyType(), version);
        Response rsp = new ResponseImpl(responseToBatch, br);
        entry.getValue().getPromise().done(rsp);
      }

      private void successBatchGet(BatchGetRequest request, Response>> responseToBatch,
          Entry>> entry, final ProtocolVersion version) {
        Set ids = (Set) request.getObjectIds().stream()
            .map(o -> BatchResponse.keyToString(o, version))
            .collect(Collectors.toSet());
        DataMap dm = filterIdsInBatchResult(responseToBatch.getEntity().data(), ids);
        BatchResponse br = new BatchResponse<>(dm, request.getResponseDecoder().getEntityClass());
        Response rsp = new ResponseImpl(responseToBatch, br);
        entry.getValue().getPromise().done(rsp);
      }

      @SuppressWarnings({ "deprecation" })
      private void successBatchGetKV(BatchGetKVRequest request, Response>> responseToBatch,
          Entry>> entry,
          final ProtocolVersion version) {
        Set ids = (Set) request.getObjectIds().stream()
            .map(o -> BatchResponse.keyToString(o, version))
            .collect(Collectors.toSet());
        DataMap dm = filterIdsInBatchResult(responseToBatch.getEntity().data(), ids);
        BatchKVResponse br = new BatchKVResponse(dm, request.getResourceSpec().getKeyType(),
            request.getResourceSpec().getValueType(), request.getResourceSpec().getKeyParts(),
            request.getResourceSpec().getComplexKeyType(), version);
        Response rsp = new ResponseImpl(responseToBatch, br);
        entry.getValue().getPromise().done(rsp);
      }

      @SuppressWarnings({ "deprecation" })
      private void successGet(GetRequest request,
          Response>> responseToBatch, final BatchGetEntityRequest batchGet,
          Entry>> entry, final ProtocolVersion version)
              throws RemoteInvocationException {
        String idString = BatchResponse.keyToString(request.getObjectId(), version);
        Object id = ResponseUtils.convertKey(idString, request.getResourceSpec().getKeyType(),
            request.getResourceSpec().getKeyParts(), request.getResourceSpec().getComplexKeyType(), version);
        Response rsp = unbatchResponse(batchGet, responseToBatch, id);
        entry.getValue().getPromise().done(rsp);
      }

      @Override
      public void onError(Throwable e) {
        batch.failAll(e);
      }

    });
  }

  private static RuntimeException unsupportedGetRequestType(Request request) {
    return new RuntimeException("ParSeqRestliClient could not handle this type of GET request: " + request.getClass().getName());
  }

  @SuppressWarnings({ "rawtypes", "unchecked" })
  private  void doExecuteGet(final Client client,
      final Batch> batch, final Set ids, final Set fields,
      Function, RequestContext> requestContextProvider) {

    final GetRequestBuilder builder = (GetRequestBuilder) new GetRequestBuilder<>(_baseUriTemplate,
        _resourceSpec.getValueClass(), _resourceSpec, _requestOptions);
    builder.setHeaders(_headers);
    _queryParams.forEach((key, value) -> builder.setParam(key, value));
    _pathKeys.forEach((key, value) -> builder.pathKey(key, value));

    builder.id((K) ids.iterator().next());
    if (fields != null && !fields.isEmpty()) {
      builder.fields(fields.toArray(new PathSpec[fields.size()]));
    }

    final GetRequest get = builder.build();

    client.sendRequest(get, requestContextProvider.apply(get), new Callback>() {

      @Override
      public void onError(Throwable e) {
        batch.failAll(e);
      }

      @Override
      public void onSuccess(Response responseToGet) {
        batch.entries().stream().forEach(entry -> {
          Request request = entry.getKey().getRequest();
          if (request instanceof GetRequest) {
            entry.getValue().getPromise().done(new ResponseImpl<>(responseToGet, responseToGet.getEntity()));
          } else {
            entry.getValue().getPromise().fail(unsupportedGetRequestType(request));
          }
        });
      }

    });
  }

  //Tuple3: (keys, fields, contains-batch-get)
  private Tuple3, Set, Boolean> reduceRequests(
      final Batch> batch) {
    return batch.entries().stream()
      .map(Entry::getKey)
      .map(RestRequestBatchKey::getRequest)
      .reduce(Tuples.tuple(new HashSet<>(), new HashSet<>(), false),
          GetRequestGroup::reduceRequests,
          GetRequestGroup::combine);
  }

  private static Tuple3, Set, Boolean> combine(Tuple3, Set, Boolean> a,
      Tuple3, Set, Boolean> b) {
    Set ids = a._1();
    ids.addAll(b._1());
    Set paths = a._2();
    paths.addAll(b._2());
    return Tuples.tuple(ids, paths, a._3() || b._3());
  }

  @Override
  public  void executeBatch(final Client client, final Batch> batch,
      Function, RequestContext> requestContextProvider) {
    final Tuple3, Set, Boolean> reductionResults = reduceRequests(batch);
    final Set ids = reductionResults._1();
    final Set fields = reductionResults._2();
    final boolean containsBatchGet = reductionResults._3();

    LOGGER.debug("executeBatch, ids: '{}', fields: {}", ids, fields);

    if (ids.size() == 1 && !containsBatchGet) {
      doExecuteGet(client, batch, ids, fields, requestContextProvider);
    } else {
      doExecuteBatchGet(client, batch, ids, fields, requestContextProvider);
    }
  }

  @Override
  public String getBaseUriTemplate() {
    return _baseUriTemplate;
  }

  public Map getHeaders() {
    return _headers;
  }

  public Map getQueryParams() {
    return _queryParams;
  }

  public Map getPathKeys() {
    return _pathKeys;
  }

  public ResourceSpec getResourceSpec() {
    return _resourceSpec;
  }

  public RestliRequestOptions getRequestOptions() {
    return _requestOptions;
  }


  @Override
  public int hashCode() {
    final int prime = 31;
    int result = 1;
    result = prime * result + ((_baseUriTemplate == null) ? 0 : _baseUriTemplate.hashCode());
    result = prime * result + ((_headers == null) ? 0 : _headers.hashCode());
    result = prime * result + ((_queryParams == null) ? 0 : _queryParams.hashCode());
    result = prime * result + ((_pathKeys == null) ? 0 : _pathKeys.hashCode());
    result = prime * result + ((_requestOptions == null) ? 0 : _requestOptions.hashCode());
    return result;
  }

  @Override
  public boolean equals(Object obj) {
    if (this == obj)
      return true;
    if (obj == null)
      return false;
    if (getClass() != obj.getClass())
      return false;
    GetRequestGroup other = (GetRequestGroup) obj;
    if (_baseUriTemplate == null) {
      if (other._baseUriTemplate != null)
        return false;
    } else if (!_baseUriTemplate.equals(other._baseUriTemplate))
      return false;
    if (_headers == null) {
      if (other._headers != null)
        return false;
    } else if (!_headers.equals(other._headers))
      return false;
    if (_queryParams == null) {
      if (other._queryParams != null)
        return false;
    } else if (!_queryParams.equals(other._queryParams))
      return false;
    if (_pathKeys == null) {
      if (other._pathKeys != null)
        return false;
    } else if (!_pathKeys.equals(other._pathKeys))
      return false;
    if (_requestOptions == null) {
      if (other._requestOptions != null)
        return false;
    } else if (!_requestOptions.equals(other._requestOptions))
      return false;
    if (_resourceSpec == null){
      if (other._resourceSpec != null) {
        return false;
      }
    } else if (_resourceSpec.getKeyClass() != other._resourceSpec.getKeyClass()) {
      return false;
    }
    return true;
  }

  @Override
  public String toString() {
    return "GetRequestGroup [_baseUriTemplate=" + _baseUriTemplate + ", _queryParams=" + _queryParams + ", _pathKeys=" + _pathKeys
        + ", _requestOptions=" + _requestOptions + ", _headers=" + _headers + ", _maxBatchSize=" + _maxBatchSize + "]";
  }

  @Override
  public  String getBatchName(final Batch batch) {
    return _baseUriTemplate + " " + (batch.batchSize() == 1 ? ResourceMethod.GET : (ResourceMethod.BATCH_GET +
        "(reqs: " + batch.keySize() + ", ids: " + batch.batchSize() + ")"));
  }

  @Override
  public int getMaxBatchSize() {
    return _maxBatchSize;
  }

  @Override
  public int keySize(RestRequestBatchKey key) {
    return key.ids().size();
  }

}