All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.util.package.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2015 Skymind,Inc.
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

package org.deeplearning4j.spark

import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import org.apache.spark.mllib.linalg.{Vectors, Vector}

package object util {
  object conversions {
    /**
     * Convert a vector to an ndarray
     * @param vector the vector
     * @return an ndarray
     */
    implicit def toINDArray(vector: Vector): INDArray = {
      Nd4j.create(Nd4j.createBuffer(vector.toArray))
    }

    /**
     * Convert an ndarray to a vector
     * @param array the array
     * @return an mllib vector
     */
    implicit def toVector(array: INDArray): Vector = {
      if (!array.isVector) {
        throw new IllegalArgumentException("implicit array must be a vector")
      }
      val ret = new Array[Double](array.length)
      for(i <- 0 to array.length - 1)
        ret(i) = array.getDouble(i)

      return Vectors.dense(ret)
    }

    implicit def toMatrix(array: INDArray): Array[Vector] = {
      if (!array.isRowVector && !array.isMatrix) {
        throw new IllegalArgumentException("implicit array must be a matrix")
      }

      val matrix = new Array[Vector](array.rows())
      for(i <- 0 to array.rows() - 1) {
        val row = array.getRow(i)
        val v = new Array[Double](row.length)
        for(j <- 0 to row.length - 1)
          v(j) = row.getDouble(j)
        matrix(i) = Vectors.dense(v)
      }

      matrix
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy