All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fr.acinq.eclair.crypto.ShaChain.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2019 ACINQ SAS
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package fr.acinq.eclair.crypto

import fr.acinq.bitcoin._
import fr.acinq.eclair.wire.CommonCodecs
import scodec.Codec

import scala.annotation.tailrec

/**
  * see https://github.com/rustyrussell/lightning-rfc/blob/master/early-drafts/shachain.txt
  */
object ShaChain {

  case class Node(value: ByteVector32, height: Int, parent: Option[Node])

  def flip(in: ByteVector32, index: Int): ByteVector32 = ByteVector32(in.update(index / 8, (in(index / 8) ^ (1 << index % 8)).toByte))

  /**
    *
    * @param index 64-bit integer
    * @return a binary representation of index as a sequence of 64 booleans. Each bool represents a move down the tree
    */
  def moves(index: Long): Vector[Boolean] = (for (i <- 63 to 0 by -1) yield (index & (1L << i)) != 0).toVector

  /**
    *
    * @param node      initial node
    * @param direction false means left, true means right
    * @return the child of our node in the specified direction
    */
  def derive(node: Node, direction: Boolean) = direction match {
    case false => Node(node.value, node.height + 1, Some(node))
    case true => Node(Crypto.sha256(flip(node.value, 63 - node.height)), node.height + 1, Some(node))
  }

  def derive(node: Node, directions: Seq[Boolean]): Node = directions.foldLeft(node)(derive)

  def derive(node: Node, directions: Long): Node = derive(node, moves(directions))

  def shaChainFromSeed(hash: ByteVector32, index: Long) = derive(Node(hash, 0, None), index).value

  type Index = Vector[Boolean]

  val empty = ShaChain(Map.empty[Index, ByteVector32])

  val init = empty

  @tailrec
  def addHash(receiver: ShaChain, hash: ByteVector32, index: Index): ShaChain = {
    index.last match {
      case true => ShaChain(receiver.knownHashes + (index -> hash))
      case false =>
        val parentIndex = index.dropRight(1)
        // hashes are supposed to be received in reverse order so we already have parent :+ true
        // which we should be able to recompute (it's a left node so its hash is the same as its parent's hash)
        require(getHash(receiver, parentIndex :+ true) == Some(derive(Node(hash, parentIndex.length, None), true).value), "invalid hash")
        val nodes1 = receiver.knownHashes - (parentIndex :+ false) - (parentIndex :+ true)
        addHash(receiver.copy(knownHashes = nodes1), hash, parentIndex)
    }
  }

  def addHash(receiver: ShaChain, hash: ByteVector32, index: Long): ShaChain = {
    receiver.lastIndex.map(value => require(index == value - 1L))
    addHash(receiver, hash, moves(index)).copy(lastIndex = Some(index))
  }

  def getHash(receiver: ShaChain, index: Index): Option[ByteVector32] = {
    receiver.knownHashes.keys.find(key => index.startsWith(key)).map(key => {
      val root = Node(receiver.knownHashes(key), key.length, None)
      derive(root, index.drop(key.length)).value
    })
  }

  def getHash(receiver: ShaChain, index: Long): Option[ByteVector32] = {
    receiver.lastIndex match {
      case None => None
      case Some(value) if value > index => None
      case _ => getHash(receiver, moves(index))
    }
  }

  def iterator(chain: ShaChain): Iterator[ByteVector32] = chain.lastIndex match {
    case None => Iterator.empty
    case Some(index) => new Iterator[ByteVector32] {
      var pos = index

      override def hasNext: Boolean = pos >= index && pos <= 0xffffffffffffffffL

      override def next(): ByteVector32 = {
        val value = chain.getHash(pos).get
        pos = pos + 1
        value
      }
    }
  }


  val shaChainCodec: Codec[ShaChain] = {
    import scodec.Codec
    import scodec.bits.BitVector
    import scodec.codecs._

    // codec for a single map entry (i.e. Vector[Boolean] -> ByteVector
    val entryCodec = vectorOfN(uint16, bool) ~ variableSizeBytes(uint16, CommonCodecs.bytes32)

    // codec for a Map[Vector[Boolean], ByteVector]: write all k -> v pairs using the codec defined above
    val mapCodec: Codec[Map[Vector[Boolean], ByteVector32]] = Codec[Map[Vector[Boolean], ByteVector32]](
      (m: Map[Vector[Boolean], ByteVector32]) => vectorOfN(uint16, entryCodec).encode(m.toVector),
      (b: BitVector) => vectorOfN(uint16, entryCodec).decode(b).map(_.map(_.toMap))
    )

    // our shachain codec
    (("knownHashes" | mapCodec) :: ("lastIndex" | optional(bool, int64))).as[ShaChain]
  }

}

/**
  * Structure used to intelligently store unguessable hashes.
  *
  * @param knownHashes know hashes
  * @param lastIndex   index of the last known hash. Hashes are supposed to be added in reverse order i.e.
  *                    from 0xFFFFFFFFFFFFFFFF down to 0
  */
case class ShaChain(knownHashes: Map[Vector[Boolean], ByteVector32], lastIndex: Option[Long] = None) {
  def addHash(hash: ByteVector32, index: Long): ShaChain = ShaChain.addHash(this, hash, index)

  def getHash(index: Long) = ShaChain.getHash(this, index)

  def iterator = ShaChain.iterator(this)

  override def toString = s"ShaChain(lastIndex = $lastIndex)"
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy