All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.spotify.featran.CrossingFeatureBuilder.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2017 Spotify AB.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.spotify.featran

import com.spotify.featran.transformers.Transformer

import scala.collection.immutable.SortedMap
import scala.collection.mutable

private object Crossings {
  type KEY = (String, String)
  type FN = (Double, Double) => Double
  type MAP = SortedMap[KEY, FN]
  def empty: Crossings = Crossings(SortedMap.empty, Set.empty)
  def name(n1: String, n2: String): String = "cross_%s_x_%s".format(n1, n2)
}

private case class Crossings(map: Crossings.MAP, keys: Set[String]) {

  def +(v: (Crossings.KEY, Crossings.FN)): Crossings = {
    val (k, f) = v
    require(!map.contains(k), s"Crossing $k already defined")
    this.copy(this.map + (k -> f), this.keys ++ Set(k._1, k._2))
  }

  def ++(that: Crossings): Crossings = {
    val k = this.map.keySet.intersect(that.map.keySet)
    require(k.isEmpty, s"Duplicate crossing ${k.mkString(", ")}")
    Crossings(this.map ++ that.map, this.keys ++ that.keys)
  }

  def filter[T](predicate: String => Boolean): Crossings = {
    val filteredKeys = keys.filter(predicate)
    val b = SortedMap.newBuilder[Crossings.KEY, Crossings.FN]
    b ++= map.filterKeys { case (k1, k2) =>
      filteredKeys.contains(k1) || filteredKeys.contains(k2)
    }
    val filteredMap = b.result()

    Crossings(filteredMap, filteredKeys)
  }

}

object CrossingFeatureBuilder {
  @inline def apply[F](fb: FeatureBuilder[F], crossings: Crossings): FeatureBuilder[F] =
    new CrossingFeatureBuilder[F](fb, crossings)
}

private class CrossingFeatureBuilder[F] private (
  private val fb: FeatureBuilder[F],
  private val crossings: Crossings
) extends FeatureBuilder[F] {
  private case class CrossValue(name: String, offset: Int, value: Double)
  private[this] var xEnabled = false // true if current transformer will be crossed
  // name, offset and values of the current transformer
  private[this] var xName: String = _
  private[this] var xOffset = 0
  private[this] var xQueue: mutable.Queue[CrossValue] = _
  // values and dimensions of transformers to be crossed
  private[this] val xValues = mutable.Map.empty[String, mutable.Queue[CrossValue]]
  private[this] val xDims = mutable.Map.empty[String, Int]

  override def prepare(transformer: Transformer[_, _, _]): Unit = {
    updateDim()
    val name = transformer.name
    if (crossings.keys.contains(name)) {
      xEnabled = true
      xName = name
      xOffset = 0
      xQueue = mutable.Queue.empty
      xValues(name) = xQueue
    } else {
      xEnabled = false
    }
  }

  // update dimension of the current transformer
  private def updateDim(): Unit =
    if (xEnabled) {
      xDims(xName) = xOffset
    }

  override def init(dimension: Int): Unit = fb.init(dimension)
  override def add(name: String, value: Double): Unit = {
    if (xEnabled) {
      xQueue.enqueue(CrossValue(name, xOffset, value))
      xOffset += 1
    }
    fb.add(name, value)
  }
  override def add[M[_]](names: Iterable[String], values: M[Double])(implicit
    ev: M[Double] => Seq[Double]
  ): Unit = {
    if (xEnabled) {
      val i = names.iterator
      val j = values.iterator
      while (i.hasNext && j.hasNext) {
        xQueue.enqueue(CrossValue(i.next(), xOffset, j.next()))
        xOffset += 1
      }
    }
    fb.add(names, values)
  }
  override def skip(): Unit = {
    xOffset += 1
    fb.skip()
  }
  override def skip(n: Int): Unit = {
    xOffset += n
    fb.skip(n)
  }
  override def result: F = {
    updateDim()
    crossings.map.foreach { case ((t1, t2), f) =>
      val d1 = xDims.getOrElse(t1, 0)
      val d2 = xDims.getOrElse(t2, 0)
      if (d1 > 0 && d2 > 0) {
        val q1 = xValues(t1)
        val q2 = xValues(t2)
        var prev = -1
        for (CrossValue(n1, o1, v1) <- q1) {
          for (CrossValue(n2, o2, v2) <- q2) {
            val offset = d2 * o1 + o2
            fb.skip(offset - prev - 1)
            fb.add(Crossings.name(n1, n2), f(v1, v2))
            prev = offset
          }
        }
        fb.skip(d1 * d2 - prev - 1)
      }
    }
    fb.result
  }

  override def reject(transformer: Transformer[_, _, _], reason: FeatureRejection): Unit =
    fb.reject(transformer, reason)
  override def rejections: Map[String, FeatureRejection] = fb.rejections

  override def newBuilder: FeatureBuilder[F] = CrossingFeatureBuilder(fb, crossings)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy