All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.vw.VowpalWabbitSyncSchedule.scala Maven / Gradle / Ivy

There is a newer version: 1.0.9
Show newest version
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.vw

import com.microsoft.azure.synapse.ml.core.utils.ClusterUtil
import org.apache.spark.TaskContext
import org.apache.spark.sql.{DataFrame, Row, functions => F}

import java.io.Serializable

/**
  * Defines when VW needs to synchronize across partitions.
  */
trait VowpalWabbitSyncSchedule extends Serializable {
  /**
    * Implementations must guarantee to trigger the same number of times across all partitions.
    *
    * @param row passed in to enable content passed schedules (e.g. temporal)
    * @return True if this partition needs to synchronize, false otherwise.
    */
  def shouldTriggerAllReduce(row: Row): Boolean
}

class VowpalWabbitSyncScheduleDisabled extends VowpalWabbitSyncSchedule {
  override def shouldTriggerAllReduce(row: Row): Boolean = false
}

object VowpalWabbitSyncSchedule {
  lazy val Disabled = new VowpalWabbitSyncScheduleDisabled
}

/**
  * Row-count based splitting.
  */
class VowpalWabbitSyncScheduleSplits(df: DataFrame,
                                     numSplits: Integer)
  extends VowpalWabbitSyncSchedule {

  assert(numSplits > 0, "Number of splits must be greater than zero")

  private val rowsPerPartitions = ClusterUtil.getNumRowsPerPartition(df, F.lit(0))

  private val stepSizePerPartition = rowsPerPartitions.map { c => c / numSplits.toDouble }

  private lazy val rowCount = rowsPerPartitions(TaskContext.getPartitionId())

  @transient
  private lazy val stepSize = {
    val s = stepSizePerPartition(TaskContext.getPartitionId())

    assert (s > 1, s"Number of splits $numSplits > $rowCount")

    Math.ceil(s).toLong
  }

  private lazy val needToSyncOnLastRow = stepSize * numSplits != rowCount

  @transient
  private var i = 0

  override def shouldTriggerAllReduce(row: Row): Boolean = {
    i += 1

    if (i % stepSize == 0)
      true
    else {
      // let's make sure even and odd partitions have the same number synchronizations
      needToSyncOnLastRow && i == rowCount
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy