All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.python.MLSerDe.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml.python

import java.io.OutputStream
import java.nio.{ByteBuffer, ByteOrder}

import net.razorvine.pickle._

import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.ml.linalg._
import org.apache.spark.mllib.api.python.SerDeBase

/**
 * SerDe utility functions for pyspark.ml.
 */
private[spark] object MLSerDe extends SerDeBase with Serializable {

  override val PYSPARK_PACKAGE = "pyspark.ml"

  // Pickler for DenseVector
  private[python] class DenseVectorPickler extends BasePickler[DenseVector] {

    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
      val vector: DenseVector = obj.asInstanceOf[DenseVector]
      val bytes = new Array[Byte](8 * vector.size)
      val bb = ByteBuffer.wrap(bytes)
      bb.order(ByteOrder.nativeOrder())
      val db = bb.asDoubleBuffer()
      db.put(vector.values)

      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(bytes.length))
      out.write(bytes)
      out.write(Opcodes.TUPLE1)
    }

    def construct(args: Array[Object]): Object = {
      if (args.length != 1) {
        throw new PickleException("length of args should be 1")
      }
      val bytes = getBytes(args(0))
      val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
      bb.order(ByteOrder.nativeOrder())
      val db = bb.asDoubleBuffer()
      val ans = new Array[Double](bytes.length / 8)
      db.get(ans)
      Vectors.dense(ans)
    }
  }

  // Pickler for DenseMatrix
  private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] {

    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
      val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
      val bytes = new Array[Byte](8 * m.values.length)
      val order = ByteOrder.nativeOrder()
      val isTransposed = if (m.isTransposed) 1 else 0
      ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)

      out.write(Opcodes.MARK)
      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(m.numRows))
      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(m.numCols))
      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(bytes.length))
      out.write(bytes)
      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(isTransposed))
      out.write(Opcodes.TUPLE)
    }

    def construct(args: Array[Object]): Object = {
      if (args.length != 4) {
        throw new PickleException("length of args should be 4")
      }
      val bytes = getBytes(args(2))
      val n = bytes.length / 8
      val values = new Array[Double](n)
      val order = ByteOrder.nativeOrder()
      ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
      val isTransposed = args(3).asInstanceOf[Int] == 1
      new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed)
    }
  }

  // Pickler for SparseMatrix
  private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {

    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
      val s = obj.asInstanceOf[SparseMatrix]
      val order = ByteOrder.nativeOrder()

      val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
      val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
      val valuesBytes = new Array[Byte](8 * s.values.length)
      val isTransposed = if (s.isTransposed) 1 else 0
      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
      ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)

      out.write(Opcodes.MARK)
      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(s.numRows))
      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(s.numCols))
      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
      out.write(colPtrsBytes)
      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
      out.write(indicesBytes)
      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
      out.write(valuesBytes)
      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(isTransposed))
      out.write(Opcodes.TUPLE)
    }

    def construct(args: Array[Object]): Object = {
      if (args.length != 6) {
        throw new PickleException("length of args should be 6")
      }
      val order = ByteOrder.nativeOrder()
      val colPtrsBytes = getBytes(args(2))
      val indicesBytes = getBytes(args(3))
      val valuesBytes = getBytes(args(4))
      val colPtrs = new Array[Int](colPtrsBytes.length / 4)
      val rowIndices = new Array[Int](indicesBytes.length / 4)
      val values = new Array[Double](valuesBytes.length / 8)
      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
      ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
      val isTransposed = args(5).asInstanceOf[Int] == 1
      new SparseMatrix(
        args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values,
        isTransposed)
    }
  }

  // Pickler for SparseVector
  private[python] class SparseVectorPickler extends BasePickler[SparseVector] {

    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
      val v: SparseVector = obj.asInstanceOf[SparseVector]
      val n = v.indices.length
      val indiceBytes = new Array[Byte](4 * n)
      val order = ByteOrder.nativeOrder()
      ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices)
      val valueBytes = new Array[Byte](8 * n)
      ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values)

      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(v.size))
      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(indiceBytes.length))
      out.write(indiceBytes)
      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(valueBytes.length))
      out.write(valueBytes)
      out.write(Opcodes.TUPLE3)
    }

    def construct(args: Array[Object]): Object = {
      if (args.length != 3) {
        throw new PickleException("length of args should be 3")
      }
      val size = args(0).asInstanceOf[Int]
      val indiceBytes = getBytes(args(1))
      val valueBytes = getBytes(args(2))
      val n = indiceBytes.length / 4
      val indices = new Array[Int](n)
      val values = new Array[Double](n)
      if (n > 0) {
        val order = ByteOrder.nativeOrder()
        ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices)
        ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values)
      }
      new SparseVector(size, indices, values)
    }
  }

  var initialized = false
  // This should be called before trying to serialize any above classes
  // In cluster mode, this should be put in the closure
  override def initialize(): Unit = {
    SerDeUtil.initialize()
    synchronized {
      if (!initialized) {
        new DenseVectorPickler().register()
        new DenseMatrixPickler().register()
        new SparseMatrixPickler().register()
        new SparseVectorPickler().register()
        initialized = true
      }
    }
  }
  // will not called in Executor automatically
  initialize()
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy