All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.google.cloud.dataflow.sdk.util.state.WindmillStateReader Maven / Gradle / Ivy

/*
 * Copyright (C) 2015 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package com.google.cloud.dataflow.sdk.util.state;

import com.google.cloud.dataflow.sdk.coders.Coder;
import com.google.cloud.dataflow.sdk.runners.worker.MetricTrackingWindmillServerStub;
import com.google.cloud.dataflow.sdk.runners.worker.StreamingDataflowWorker;
import com.google.cloud.dataflow.sdk.runners.worker.windmill.Windmill;
import com.google.cloud.dataflow.sdk.runners.worker.windmill.Windmill.TagList;
import com.google.cloud.dataflow.sdk.runners.worker.windmill.Windmill.TagValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ForwardingFuture;
import com.google.common.util.concurrent.SettableFuture;
import com.google.protobuf.ByteString;

import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

/**
 * Reads persistent state from {@link Windmill}. Returns {@code Future}s containing the data that
 * has been read. Will not initiate a read until {@link Future#get} is called, at which point all
 * the pending futures will be read.
 */
public class WindmillStateReader {

  private static class StateTag {
    private enum Kind {
      VALUE,
      LIST,
      WATERMARK;
    }

    private final Kind kind;
    private final ByteString tag;
    private final String stateFamily;

    private StateTag(Kind kind, ByteString tag, String stateFamily) {
      this.kind = kind;
      this.tag = tag;
      this.stateFamily = Preconditions.checkNotNull(stateFamily);
    }

    @Override
    public boolean equals(Object obj) {
      if (this == obj) {
        return true;
      }

      if (!(obj instanceof StateTag)) {
        return false;
      }

      StateTag that = (StateTag) obj;
      return Objects.equal(this.kind, that.kind)
          && Objects.equal(this.tag, that.tag)
          && Objects.equal(this.stateFamily, that.stateFamily);
    }

    @Override
    public int hashCode() {
      return Objects.hashCode(kind, tag, stateFamily);
    }

    @Override
    public String toString() {
      return "Tag(" + kind + "," + tag.toStringUtf8() + "," + stateFamily + ")";
    }
  }

  private static final Logger LOG = LoggerFactory.getLogger(WindmillStateReader.class);

  private final String computation;
  private final ByteString key;
  private final long workToken;

  private final MetricTrackingWindmillServerStub metrics;

  public WindmillStateReader(
      MetricTrackingWindmillServerStub metrics,
      String computation, ByteString key, long workToken) {
    this.metrics = metrics;
    this.computation = computation;
    this.key = key;
    this.workToken = workToken;
  }

  @VisibleForTesting ConcurrentLinkedQueue pendingLookups = new ConcurrentLinkedQueue<>();
  private ConcurrentHashMap> coders = new ConcurrentHashMap<>();

  private ConcurrentHashMap> futures = new ConcurrentHashMap<>();

  private  Future stateFuture(StateTag tag, Coder coder) {
    SettableFuture wildcardFuture = futures.get(tag);
    if (wildcardFuture == null) {
      // If we don't yet have a future, try to create one.
      wildcardFuture = SettableFuture.create();
      SettableFuture old = futures.putIfAbsent(tag, wildcardFuture);

      if (old == null) {
        // We won the race, queue the lookup and coder.
        pendingLookups.add(tag);
        if (coder != null) {
          coders.putIfAbsent(tag, coder);
        }
      } else {
        // We lost the race, use the other future.
        wildcardFuture = old;
      }
    }

    @SuppressWarnings("unchecked")
    SettableFuture typedFuture = (SettableFuture) wildcardFuture;
    return wrappedFuture(typedFuture);
  }

  public Future watermarkFuture(ByteString encodedTag, String stateFamily) {
    return stateFuture(new StateTag(StateTag.Kind.WATERMARK, encodedTag, stateFamily), null);
  }

  public  Future valueFuture(ByteString encodedTag, String stateFamily, Coder coder) {
    return stateFuture(new StateTag(StateTag.Kind.VALUE, encodedTag, stateFamily), coder);
  }

  public  Future> listFuture(ByteString encodedTag, String stateFamily,
      Coder elemCoder) {
    return stateFuture(new StateTag(StateTag.Kind.LIST, encodedTag, stateFamily), elemCoder);
  }

  private  Future wrappedFuture(final Future future) {
    // If the underlying lookup is already complete, we don't need to create the wrapper.
    if (future.isDone()) {
      return future;
    }

    return new ForwardingFuture() {
      @Override
      protected Future delegate() {
        return future;
      }

      @Override
      public T get() throws InterruptedException, ExecutionException {
        if (!future.isDone()) {
          startBatchAndBlock();
        }
        return super.get();
      }

      @Override
      public T get(long timeout, TimeUnit unit)
          throws InterruptedException, ExecutionException, TimeoutException {
        if (!future.isDone()) {
          startBatchAndBlock();
        }
        return super.get(timeout, unit);
      }
    };
  }

  public void startBatchAndBlock() {
    // First, drain work out of the pending lookups into a set. These will be the items we fetch.
    HashSet toFetch = new HashSet<>();
    while (!pendingLookups.isEmpty()) {
      StateTag tag = pendingLookups.poll();
      if (tag == null) {
        break;
      }

      if (!toFetch.add(tag)) {
        throw new IllegalStateException("Duplicate tags being fetched.");
      }
    }

    // If we failed to drain anything, some other thread pulled it off the queue. We have no work
    // to do.
    if (toFetch.isEmpty()) {
      return;
    }

    Windmill.GetDataRequest request = createRequest(toFetch);
    Windmill.GetDataResponse response = metrics.getStateData(request);

    if (response == null) {
      throw new RuntimeException("Windmill unexpectedly returned null for request " + request);
    }

    consumeResponse(request, response, toFetch);
  }

  private Windmill.GetDataRequest createRequest(Iterable toFetch) {
    Windmill.GetDataRequest.Builder request = Windmill.GetDataRequest.newBuilder();
    Windmill.KeyedGetDataRequest.Builder keyedDataBuilder = request
        .addRequestsBuilder().setComputationId(computation)
        .addRequestsBuilder().setKey(key).setWorkToken(workToken);

    for (StateTag tag : toFetch) {
      switch (tag.kind) {
        case LIST:
          keyedDataBuilder
              .addListsToFetchBuilder()
              .setTag(tag.tag)
              .setStateFamily(tag.stateFamily)
              .setEndTimestamp(Long.MAX_VALUE);
          break;

        case WATERMARK:
          keyedDataBuilder
              .addWatermarkHoldsToFetchBuilder()
              .setTag(tag.tag)
              .setStateFamily(tag.stateFamily);
          break;

        case VALUE:
          keyedDataBuilder
              .addValuesToFetchBuilder()
              .setTag(tag.tag)
              .setStateFamily(tag.stateFamily);
          break;

        default:
          throw new RuntimeException("Unknown kind of tag requested: " + tag.kind);
      }
    }

    return request.build();
  }

  private void consumeResponse(Windmill.GetDataRequest request,
      Windmill.GetDataResponse getDataResponse, Set toFetch) {
    // Validate the response is for our computation/key.
    if (getDataResponse.getDataCount() == 0) {
      throw new RuntimeException(
          "No computation in response to request: " + request);
    } else if (getDataResponse.getDataCount() > 1) {
      throw new RuntimeException("Expected exactly one computation in response, but got: "
          + getDataResponse.getDataList());
    }

    Windmill.ComputationGetDataResponse computationResponse = getDataResponse.getData(0);

    if (!computation.equals(computationResponse.getComputationId())) {
      throw new RuntimeException("Expected data for computation " + computation
          + " but was " + computationResponse.getComputationId());
    }

    if (computationResponse.getDataCount() == 0) {
      throw new RuntimeException(
          "No key in response to request: " + request);
    } else if (computationResponse.getDataCount() > 1) {
      throw new RuntimeException(
          "Expected exactly one key in response, but was: " + computationResponse.getDataList());
    }

    Windmill.KeyedGetDataResponse response = computationResponse.getData(0);

    if (response.getFailed()) {
      // Set up all the futures for this key to throw an exception:
      StreamingDataflowWorker.KeyTokenInvalidException keyTokenInvalidException =
          new StreamingDataflowWorker.KeyTokenInvalidException(key.toStringUtf8());
      for (StateTag stateTag : toFetch) {
        futures.get(stateTag).setException(keyTokenInvalidException);
      }
      return;
    }

    if (!key.equals(response.getKey())) {
      throw new RuntimeException("Expected data for key " + key
          + " but was " + response.getKey());
    }


    for (Windmill.TagList list : response.getListsList()) {
      StateTag stateTag = new StateTag(
          StateTag.Kind.LIST, list.getTag(), list.getStateFamily());
      if (!toFetch.remove(stateTag)) {
        throw new IllegalStateException(
            "Received response for unrequested tag " + stateTag + ". Pending tags: " + toFetch);
      }
      consumeTagList(list, stateTag);
    }

    for (Windmill.WatermarkHold hold : response.getWatermarkHoldsList()) {
      StateTag stateTag = new StateTag(
          StateTag.Kind.WATERMARK, hold.getTag(), hold.getStateFamily());
      if (!toFetch.remove(stateTag)) {
        throw new IllegalStateException(
            "Received response for unrequested tag " + stateTag + ". Pending tags: " + toFetch);
      }
      consumeWatermark(hold, stateTag);
    }

    for (Windmill.TagValue value : response.getValuesList()) {
      StateTag stateTag = new StateTag(
          StateTag.Kind.VALUE, value.getTag(), value.getStateFamily());
      if (!toFetch.remove(stateTag)) {
        throw new IllegalStateException(
            "Received response for unrequested tag " + stateTag + ". Pending tags: " + toFetch);
      }
      consumeTagValue(value, stateTag);
    }

    if (!toFetch.isEmpty()) {
      throw new IllegalStateException(
          "Didn't receive responses for all pending fetches. Missing: " + toFetch);
    }
  }

  private  void consumeTagList(TagList list, StateTag stateTag) {
    @SuppressWarnings("unchecked")
    SettableFuture> future = (SettableFuture>) futures.get(stateTag);
    if (future == null) {
      throw new IllegalStateException("Missing future for " + stateTag);
    } else if (future.isDone()) {
      LOG.error("Future for {} is already done", stateTag);
    }

    if (list.getValuesCount() == 0) {
      future.set(Collections.emptyList());
      return;
    }

    @SuppressWarnings("unchecked")
    Coder elemCoder = (Coder) coders.remove(stateTag);
    if (elemCoder == null) {
      throw new IllegalStateException("Missing element coder for " + stateTag);
    }

    List valueList = new ArrayList<>(list.getValuesCount());
    for (Windmill.Value value : list.getValuesList()) {
      if (value.hasData() && !value.getData().isEmpty()) {
        // Drop the first byte of the data; it's the zero byte we prependend to avoid writing
        // empty data.
        InputStream inputStream = value.getData().substring(1).newInput();
        try {
          valueList.add(elemCoder.decode(inputStream, Coder.Context.OUTER));
        } catch (IOException e) {
          throw new IllegalStateException(
              "Unable to decode tag list using " + elemCoder, e);
        }
      }
    }

    future.set(Collections.unmodifiableList(valueList));
  }

  private void consumeWatermark(Windmill.WatermarkHold watermarkHold, StateTag stateTag) {
    @SuppressWarnings("unchecked")
    SettableFuture future = (SettableFuture) futures.get(stateTag);
    if (future == null) {
      throw new IllegalStateException("Missing future for " + stateTag);
    } else if (future.isDone()) {
      LOG.error("Future for {} is already done", stateTag);
    }

    Instant hold = null;
    for (long timestamp : watermarkHold.getTimestampsList()) {
      Instant instant = new Instant(TimeUnit.MICROSECONDS.toMillis(timestamp));
      if (hold == null || instant.isBefore(hold)) {
        hold = instant;
      }
    }

    future.set(hold);
  }

  private  void consumeTagValue(TagValue tagValue, StateTag stateTag) {
    @SuppressWarnings("unchecked")
    SettableFuture future = (SettableFuture) futures.get(stateTag);
    if (future == null) {
      throw new IllegalStateException("Missing future for " + stateTag);
    } else if (future.isDone()) {
      LOG.error("Future for {} is already done", stateTag);
    }

    @SuppressWarnings("unchecked")
    Coder coder = (Coder) coders.remove(stateTag);
    if (coder == null) {
      throw new IllegalStateException("Missing coder for " + stateTag);
    }

    if (tagValue.hasValue()
        && tagValue.getValue().hasData()
        && !tagValue.getValue().getData().isEmpty()) {
      InputStream inputStream = tagValue.getValue().getData().newInput();
      try {
        T value = coder.decode(inputStream, Coder.Context.OUTER);
        future.set(value);
      } catch (IOException e) {
        throw new IllegalStateException("Unable to decode value using " + coder, e);
      }
    } else {
      future.set(null);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy