All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.google.cloud.genomics.dataflow.readers.VariantStreamer Maven / Gradle / Ivy

/*
 * Copyright (C) 2015 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 */
package com.google.cloud.genomics.dataflow.readers;

import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.TimeUnit;

import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.cloud.dataflow.sdk.transforms.Aggregator;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.transforms.Max;
import com.google.cloud.dataflow.sdk.transforms.PTransform;
import com.google.cloud.dataflow.sdk.transforms.ParDo;
import com.google.cloud.dataflow.sdk.transforms.Sum;
import com.google.cloud.dataflow.sdk.values.PCollection;
import com.google.cloud.genomics.utils.OfflineAuth;
import com.google.cloud.genomics.utils.ShardBoundary;
import com.google.cloud.genomics.utils.grpc.VariantStreamIterator;
import com.google.common.base.Stopwatch;
import com.google.genomics.v1.StreamVariantsRequest;
import com.google.genomics.v1.StreamVariantsResponse;
import com.google.genomics.v1.Variant;

/**
 * PTransform for streaming variants via gRPC.
 */
public class VariantStreamer extends
PTransform, PCollection> {

  private static final Logger LOG = LoggerFactory.getLogger(VariantStreamer.class);
  protected final OfflineAuth auth;
  protected final ShardBoundary.Requirement shardBoundary;
  protected final String fields;

  /**
   * Create a streamer that can enforce shard boundary semantics.
   * 
   * @param auth The OfflineAuth to use for the request.
   * @param shardBoundary The shard boundary semantics to enforce.
   * @param fields Which fields to include in a partial response or null for all.
   */
  public VariantStreamer(OfflineAuth auth, ShardBoundary.Requirement shardBoundary, String fields) {
    this.auth = auth;
    this.shardBoundary = shardBoundary;
    this.fields = fields;
  }

  @Override
  public PCollection apply(PCollection input) {
    return input.apply(ParDo.of(new RetrieveVariants()))
        .apply(ParDo.of(new ConvergeVariantsList()));
  }

  private class RetrieveVariants extends DoFn> {

    protected Aggregator initializedShardCount;
    protected Aggregator finishedShardCount;
    protected Aggregator shardTimeMaxSec;
    DescriptiveStatistics stats;

    public RetrieveVariants() {
      initializedShardCount = createAggregator("Initialized Shard Count", new Sum.SumIntegerFn());
      finishedShardCount = createAggregator("Finished Shard Count", new Sum.SumIntegerFn());
      shardTimeMaxSec = createAggregator("Maximum Shard Processing Time (sec)", new Max.MaxLongFn());
      stats = new DescriptiveStatistics(500);
    }

    @Override
    public void processElement(ProcessContext c) throws IOException, GeneralSecurityException, InterruptedException {
      initializedShardCount.addValue(1);
      shardTimeMaxSec.addValue(0L);
      Stopwatch stopWatch = Stopwatch.createStarted();
      Iterator iter = VariantStreamIterator.enforceShardBoundary(auth, c.element(), shardBoundary, fields);
      while (iter.hasNext()) {
        StreamVariantsResponse variantResponse = iter.next();
        c.output(variantResponse.getVariantsList());
      }
      stopWatch.stop();
      shardTimeMaxSec.addValue(stopWatch.elapsed(TimeUnit.SECONDS));
      stats.addValue(stopWatch.elapsed(TimeUnit.SECONDS));
      finishedShardCount.addValue(1);
      LOG.info("Shard Duration in Seconds - Min: " + stats.getMin() + " Max: " + stats.getMax() +
          " Avg: " + stats.getMean() + " StdDev: " + stats.getStandardDeviation());      
    }
  }
  
  /**
   * This step exists to emit the individual variants in a parallel step to the StreamVariants step
   * in order to increase throughput.
   */
  private class ConvergeVariantsList extends DoFn, Variant> {

    protected Aggregator itemCount;

    public ConvergeVariantsList() {
      itemCount = createAggregator("Number of variants", new Sum.SumLongFn());
    }

    @Override
    public void processElement(ProcessContext c) {
      for (Variant v : c.element()) {
        c.output(v);
        itemCount.addValue(1L);
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy