All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.data.DataFrame.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.data

import java.util.Optional
import java.util.stream.IntStream

import smile.data.measure.CategoricalMeasure
import smile.json._

/**
  * Pimped data frame with Scala style methods.
 *
  * @param data a data frame.
  */
case class DataFrameOps(data: DataFrame) {
  /** Selects a new DataFrame with given column indices. */
  def select(range: Range): DataFrame = data.select(range.toArray: _*)

  /** Returns a new DataFrame without given column indices. */
  def drop(range: Range): DataFrame = data.drop(range.toArray: _*)

  /** Returns a new data frame with row indexing. */
  def of(range: Range): DataFrame = data.of(range.toArray: _*)

  /** Finds the first row satisfying a predicate. */
  def find(p: Tuple => Boolean): Optional[Tuple] = data.stream().filter(t => p(t)).findAny()
  /** Tests if a predicate holds for at least one row of data frame. */
  def exists(p: Tuple => Boolean): Boolean = data.stream.anyMatch(t => p(t))
  /** Tests if a predicate holds for all rows of data frame. */
  def forall(p: Tuple => Boolean): Boolean = data.stream.allMatch(t => p(t))
  /** Applies a function for its side-effect to every row. */
  def foreach[U](p: Tuple => U): Unit = data.stream.forEach(t => p(t))

  /** Builds a new data collection by applying a function to all rows. */
  def map[U](p: Tuple => U): Iterable[U] = (0 until data.size).map(i => p(data(i)))

  /** Selects all rows which satisfy a predicate. */
  def filter(p: Tuple => Boolean): DataFrame = {
    val index = IntStream.range(0, data.size).filter(i => p(data(i))).toArray
    data.of(index: _*)
  }

  /** Partitions this DataFrame in two according to a predicate.
    *
    *  @param p the predicate on which to partition.
    *  @return  a pair of DataFrames: the first DataFrame consists of all elements that
    *           satisfy the predicate `p` and the second DataFrame consists of all elements
    *           that don't. The relative order of the elements in the resulting DataFramess
    *           is the same as in the original DataFrame.
    */
  def partition(p: Tuple => Boolean): (DataFrame, DataFrame) = {
    val l = new scala.collection.mutable.ArrayBuffer[Int]
    val r = new scala.collection.mutable.ArrayBuffer[Int]
    IntStream.range(0, data.size).forEach { i =>
      if (p(data(i))) l += i else r += i
    }
    (data.of(l.toArray: _*), data.of(r.toArray: _*))
  }

  /** Partitions the DataFrame into a map of DataFrames according to
    * some discriminator function.
    *
    * @param f the discriminator function.
    * @tparam K the type of keys returned by the discriminator function.
    * @return A map from keys to DataFrames
    */
  def groupBy[K](f: Tuple => K): scala.collection.immutable.Map[K, DataFrame] = {
    val groups = (0 until data.size).groupBy(i => f(data(i)))
    groups.view.mapValues(index => data.of(index: _*)).toMap
  }

  /** Converts the tuple to a JSON array. */
  def toJSON: JsArray = {
    JsArray(
      (0 until data.size).map(i => data(i).toJSON): _*
    )
  }
}

/**
  * Pimped tuple with additional methods.
  * @param t a tuple.
  */
case class TupleOps(t: Tuple) {
  /** Converts the tuple to a JSON object. */
  def toJSON: JsObject = {
    JsObject((0 until t.length()).map(valueOf): _*)
  }

  /** Returns the name value pair of a field. */
  private def valueOf(i: Int): (String, JsValue) = {
    val schema = t.schema()
    val field = schema.field(i)

    val value =
      if (field.measure != null && field.measure.isInstanceOf[CategoricalMeasure])
        if (t.isNullAt(i)) JsNull else JsString(t.getString(i))
      else
        t.get(i) match {
        case null => JsNull
        case x: java.lang.Boolean => JsBoolean(x)
        case x: java.lang.Byte => JsInt(x: Byte)
        case x: java.lang.Short => JsInt(x: Short)
        case x: java.lang.Integer => JsInt(x)
        case x: java.lang.Long => JsLong(x)
        case x: java.lang.Float => JsDouble(x: Float)
        case x: java.lang.Double => JsDouble(x)
        case _ => JsString(t.getString(i))
      }

    (field.name, value)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy