All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.samza.checkpoint.azure.AzureCheckpointManager Maven / Gradle / Ivy

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied.  See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.samza.checkpoint.azure;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.microsoft.azure.storage.StorageException;
import com.microsoft.azure.storage.table.*;
import org.apache.samza.AzureClient;
import org.apache.samza.AzureException;
import org.apache.samza.Partition;
import org.apache.samza.SamzaException;
import org.apache.samza.checkpoint.Checkpoint;
import org.apache.samza.checkpoint.CheckpointManager;
import org.apache.samza.checkpoint.CheckpointV1;
import org.apache.samza.config.AzureConfig;
import org.apache.samza.container.TaskName;
import org.apache.samza.serializers.JsonSerdeV2;
import org.apache.samza.system.SystemStreamPartition;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.URISyntaxException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.Set;

/**
 * Azure checkpoint manager is used to store checkpoints in a Azure Table.
 * All the task checkpoints are added to the a single table named "SamzaTaskCheckpoints".
 * The table entities take the following form:
 *
 * +-----------------+---------------------+-------------------+
 * |                 |     Serialized      |                   |
 * |   TaskName      |     JSON SSP        |     Offset        |
 * |                 |                     |                   |
 * +-----------------+---------------------+-------------------+
 *
 *  Each entity have a partitionKey set as the TaskName and the rowKey set as the SSP.
 */
public class AzureCheckpointManager implements CheckpointManager {
  private static final Logger LOG = LoggerFactory.getLogger(AzureCheckpointManager.class.getName());
  private static final String PARTITION_KEY = "PartitionKey";

  // Invalid characters in key field on Azure Table
  public static final String REGEX_INVALID_KEY = ".*[#?/\\\\].*";
  public static final String REGEX_TABLE_NAME = "[^A-Za-z0-9]";

  public static final int MAX_WRITE_BATCH_SIZE = 100;
  public static final String SYSTEM_PROP_NAME = "system";
  public static final String STREAM_PROP_NAME = "stream";
  public static final String PARTITION_PROP_NAME = "partition";

  private final String jobTableName;
  private final String storageConnectionString;
  private final AzureClient azureClient;
  private CloudTable cloudTable;

  private final Set taskNames = new HashSet<>();
  private final JsonSerdeV2> jsonSerde = new JsonSerdeV2<>();

  AzureCheckpointManager(AzureConfig azureConfig, Optional jobName) {
    if (!jobName.isPresent()) {
      throw new AzureException("Jobs must have a name to use Azure Checkpoint Manager");
    }
    // Remove invalid characters
    jobTableName = jobName.get().replaceAll(REGEX_TABLE_NAME, "");
    storageConnectionString = azureConfig.getAzureConnectionString();
    azureClient = new AzureClient(storageConnectionString);
  }

  @Override
  public void start() {
    try {
      // Create the table if it doesn't exist.
      cloudTable = azureClient.getTableClient().getTableReference(jobTableName);
      cloudTable.createIfNotExists();

    } catch (URISyntaxException e) {
      LOG.error("Connection string {} specifies an invalid URI while creating checkpoint table.",
              storageConnectionString);
      throw new AzureException(e);

    } catch (StorageException e) {
      LOG.error("Azure Storage failed when creating checkpoint table", e);
      throw new AzureException(e);
    }
  }

  @Override
  public void register(TaskName taskName) {
    taskNames.add(taskName);
  }

  @Override
  public void writeCheckpoint(TaskName taskName, Checkpoint checkpoint) {
    Preconditions.checkArgument(checkpoint instanceof CheckpointV1,
        "Only CheckpointV1 could be written to Azure");

    if (!taskNames.contains(taskName)) {
      throw new SamzaException("writing checkpoint of unregistered task");
    }

    TableBatchOperation batchOperation = new TableBatchOperation();

    Iterator> iterator =
        checkpoint.getOffsets().entrySet().iterator();
    while (iterator.hasNext()) {
      Map.Entry entry = iterator.next();
      SystemStreamPartition ssp = entry.getKey();
      String offset = entry.getValue();

      String partitionKey = taskName.toString();
      checkValidKey(partitionKey, "Taskname");
      String rowKey = serializeSystemStreamPartition(ssp);
      checkValidKey(rowKey, "SystemStreamPartition");

      // Create table entity
      TaskCheckpointEntity taskCheckpoint = new TaskCheckpointEntity(partitionKey, rowKey, offset);

      // Add to batch operation
      batchOperation.insertOrReplace(taskCheckpoint);

      // Execute when batch reaches capacity or this is the last item
      if (batchOperation.size() >= MAX_WRITE_BATCH_SIZE || !iterator.hasNext()) {
        try {
          cloudTable.execute(batchOperation);
        } catch (StorageException e) {
          LOG.error("Executing batch failed for task: {}", taskName);
          throw new AzureException(e);
        }
        batchOperation.clear();
      }
    }
  }

  private void checkValidKey(String key, String fieldUsed) {
    if (key == null || key.matches(REGEX_INVALID_KEY)) {
      throw new AzureException(String.format("Cannot insert to Azure Checkpoint Manager; %s %s contains invalid characters [*, /, \\\\, ?]",
      fieldUsed, key));
    }
  }

  private String serializeSystemStreamPartition(SystemStreamPartition ssp) {
    // Create the Json string for SystemStreamPartition
    Map sspMap = new HashMap<>();

    sspMap.put(SYSTEM_PROP_NAME, ssp.getSystem());
    sspMap.put(STREAM_PROP_NAME, ssp.getStream());
    sspMap.put(PARTITION_PROP_NAME, String.valueOf(ssp.getPartition().getPartitionId()));

    return new String(jsonSerde.toBytes(sspMap));
  }

  private SystemStreamPartition deserializeSystemStreamPartition(String serializedSSP) {
    Map sspPropertiesMap = jsonSerde.fromBytes(serializedSSP.getBytes());

    String systemName = sspPropertiesMap.get(SYSTEM_PROP_NAME);
    String streamName = sspPropertiesMap.get(STREAM_PROP_NAME);
    Partition partition = new Partition(Integer.parseInt(sspPropertiesMap.get("partition")));

    return new SystemStreamPartition(systemName, streamName, partition);
  }

  @Override
  public Checkpoint readLastCheckpoint(TaskName taskName) {
    if (!taskNames.contains(taskName)) {
      throw new SamzaException("reading checkpoint of unregistered/unwritten task");
    }

    // Create the query for taskName
    String partitionQueryKey = taskName.toString();
    String partitionFilter = TableQuery.generateFilterCondition(
            PARTITION_KEY,
            TableQuery.QueryComparisons.EQUAL,
            partitionQueryKey);
    TableQuery query = TableQuery.from(TaskCheckpointEntity.class).where(partitionFilter);

    ImmutableMap.Builder builder = ImmutableMap.builder();
    try {
      for (TaskCheckpointEntity taskCheckpointEntity : cloudTable.execute(query)) {
        // Recreate the SSP offset
        String serializedSSP = taskCheckpointEntity.getRowKey();
        builder.put(deserializeSystemStreamPartition(serializedSSP), taskCheckpointEntity.getOffset());
      }

    } catch (NoSuchElementException e) {
      LOG.warn("No checkpoints found found for registered taskName={}", taskName);
      // Return null if not entity elements are not found
      return null;
    }
    LOG.debug("Received checkpoint state for taskName=%s", taskName);
    return new CheckpointV1(builder.build());
  }

  @Override
  public void stop() {
    // Nothing to do here
  }

  @Override
  public void clearCheckpoints() {
    LOG.debug("Clearing all checkpoints in Azure table");

    for (TaskName taskName : taskNames) {
      String partitionQueryKey = taskName.toString();

      // Generate table query
      String partitionFilter = TableQuery.generateFilterCondition(
              PARTITION_KEY,
              TableQuery.QueryComparisons.EQUAL,
              partitionQueryKey);
      TableQuery partitionQuery = TableQuery.from(TaskCheckpointEntity.class)
              .where(partitionFilter);

      // All entities in a given batch must have the same partition key
      deleteEntities(cloudTable.execute(partitionQuery).iterator());
    }
  }

  private void deleteEntities(Iterator entitiesToDelete) {
    TableBatchOperation batchOperation = new TableBatchOperation();

    while (entitiesToDelete.hasNext()) {
      TaskCheckpointEntity entity = entitiesToDelete.next();

      // Add to batch operation
      batchOperation.delete(entity);

      // Execute when batch reaches capacity or when this is the last item
      if (batchOperation.size() >= MAX_WRITE_BATCH_SIZE || !entitiesToDelete.hasNext()) {
        try {
          cloudTable.execute(batchOperation);
        } catch (StorageException e) {
          LOG.error("Executing batch failed for deleting checkpoints");
          throw new AzureException(e);
        }
        batchOperation.clear();
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy