All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.combust.bundle.tree.decision.TreeSerializer.scala Maven / Gradle / Ivy

package ml.combust.bundle.tree.decision

import java.io._
import java.nio.file.{Files, Path}

import ml.bundle.dtree.Node
import ml.combust.bundle.BundleContext
import ml.combust.bundle.serializer.SerializationFormat
import ml.combust.bundle.tree.JsonSupport._
import spray.json._
import resource._

import scala.util.Try

/**
  * Created by hollinwilkins on 8/22/16.
  */
object FormatTreeSerializer {
  def writer(format: SerializationFormat,
             out: OutputStream): FormatTreeWriter = format match {
    case SerializationFormat.Json => JsonFormatTreeWriter(new BufferedWriter(new OutputStreamWriter(out)))
    case SerializationFormat.Protobuf => ProtoFormatTreeWriter(new DataOutputStream(out))
  }

  def reader(format: SerializationFormat,
             in: InputStream): FormatTreeReader = format match {
    case SerializationFormat.Json => JsonFormatTreeReader(new BufferedReader(new InputStreamReader(in)))
    case SerializationFormat.Protobuf => ProtoFormatTreeReader(new DataInputStream(in))
  }
}

trait FormatTreeWriter extends Closeable {
  def write(node: Node): Unit
}

trait FormatTreeReader extends Closeable {
  def read(): Node
}

case class JsonFormatTreeWriter(out: BufferedWriter) extends FormatTreeWriter {
  override def write(node: Node): Unit = {
    out.write(node.toJson.compactPrint + "\n")
  }

  override def close(): Unit = out.close()
}

case class JsonFormatTreeReader(in: BufferedReader) extends FormatTreeReader {
  override def read(): Node = {
    in.readLine().parseJson.convertTo[Node]
  }

  override def close(): Unit = in.close()
}

case class ProtoFormatTreeWriter(out: DataOutputStream) extends FormatTreeWriter {
  override def write(node: Node): Unit = {
    val size = node.serializedSize
    for(writer <- managed(new ByteArrayOutputStream(size))) {
      node.writeTo(writer)
      out.writeInt(size)
      out.write(writer.toByteArray)
    }
  }

  override def close(): Unit = out.close()
}

case class ProtoFormatTreeReader(in: DataInputStream) extends FormatTreeReader {
  override def read(): Node = {
    val size = in.readInt()
    val bytes = new Array[Byte](size)
    in.readFully(bytes)
    Node.parseFrom(bytes)
  }

  override def close(): Unit = in.close()
}

case class TreeSerializer[N: NodeWrapper](path: Path,
                                          withImpurities: Boolean)
                                         (implicit bundleContext: BundleContext[_]) {
  val extension = bundleContext.format match {
    case SerializationFormat.Json => "json"
    case SerializationFormat.Protobuf => "pb"
  }
  val ntc = implicitly[NodeWrapper[N]]

  def write(node: N): Unit = {
    val open = () => Files.newOutputStream(path.getFileSystem.getPath(s"${path.toString}.$extension"))
    for(writer <- managed(FormatTreeSerializer.writer(bundleContext.format, open()))) {
      write(node, writer)
    }
  }

  def write(node: N, writer: FormatTreeWriter): Unit = {
    val n = ntc.node(node, withImpurities)
    writer.write(n)

    if(ntc.isInternal(node)) {
      write(ntc.left(node), writer)
      write(ntc.right(node), writer)
    }
  }

  def read(): Try[N] = {
    (for(in <- managed(Files.newInputStream(path.getFileSystem.getPath(s"${path.toString}.$extension")))) yield {
      val reader = FormatTreeSerializer.reader(bundleContext.format, in)
      read(reader)
    }).tried
  }

  def read(reader: FormatTreeReader): N = {
    val node = reader.read()

    if(node.n.isInternal) {
      ntc.internal(node.getInternal,
        read(reader),
        read(reader))
    } else if(node.n.isLeaf) {
      ntc.leaf(node.getLeaf, withImpurities)
    } else { throw new IllegalArgumentException("invalid tree") }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy