All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.google.crypto.tink.subtle.StreamingAeadSeekableDecryptingChannel Maven / Gradle / Ivy

Go to download

Tink is a small cryptographic library that provides a safe, simple, agile and fast way to accomplish some common cryptographic tasks.

The newest version!
// Copyright 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////

package com.google.crypto.tink.subtle;

import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.NonWritableChannelException;
import java.nio.channels.SeekableByteChannel;
import java.security.GeneralSecurityException;
import java.util.Arrays;

/**
 * An instance of {@link SeekableByteChannel} that allows random access to the plaintext of some
 * ciphertext.
 */
class StreamingAeadSeekableDecryptingChannel implements SeekableByteChannel {
  // Each plaintext segment has 16 bytes more of memory than the actual plaintext that it contains.
  // This is a workaround for an incompatibility between Conscrypt and OpenJDK in their
  // AES-GCM implementations, see b/67416642, b/31574439, and cr/170969008 for more information.
  // Conscrypt refused to fix this issue, but even if they fixed it, there are always Android phones
  // running old versions of Conscrypt, so we decided to take matters into our own hands.
  // Why 16? Actually any number larger than 16 should work. 16 is the lower bound because it's the
  // size of the tags of each AES-GCM ciphertext segment.
  private static final int PLAINTEXT_SEGMENT_EXTRA_SIZE = 16;

  private final SeekableByteChannel ciphertextChannel;
  private final ByteBuffer ciphertextSegment;
  private final ByteBuffer plaintextSegment;
  private final ByteBuffer header;
  private final long ciphertextChannelSize;  // unverified size of the ciphertext
  private final int numberOfSegments;  // unverified number of segments
  private final int lastCiphertextSegmentSize;  // unverified size of the last segment.
  private final byte[] aad;
  private final StreamSegmentDecrypter decrypter;
  private long plaintextPosition;
  private long plaintextSize;
  private boolean headerRead;
  private boolean isCurrentSegmentDecrypted;
  private int currentSegmentNr;
  private boolean isopen;
  private final int plaintextSegmentSize;
  private final int ciphertextSegmentSize;
  private final int ciphertextOffset;
  private final int firstSegmentOffset;

  public StreamingAeadSeekableDecryptingChannel(
      NonceBasedStreamingAead streamAead,
      SeekableByteChannel ciphertext,
      byte[] associatedData) throws IOException, GeneralSecurityException {
    decrypter = streamAead.newStreamSegmentDecrypter();
    ciphertextChannel = ciphertext;
    header = ByteBuffer.allocate(streamAead.getHeaderLength());
    ciphertextSegmentSize = streamAead.getCiphertextSegmentSize();
    ciphertextSegment = ByteBuffer.allocate(ciphertextSegmentSize);
    plaintextSegmentSize = streamAead.getPlaintextSegmentSize();
    plaintextSegment = ByteBuffer.allocate(plaintextSegmentSize + PLAINTEXT_SEGMENT_EXTRA_SIZE);
    plaintextPosition = 0;
    headerRead = false;
    currentSegmentNr = -1;
    isCurrentSegmentDecrypted = false;
    ciphertextChannelSize = ciphertextChannel.size();
    aad = Arrays.copyOf(associatedData, associatedData.length);
    isopen = ciphertextChannel.isOpen();
    int  fullSegments = (int) (ciphertextChannelSize / ciphertextSegmentSize);
    int remainder = (int) (ciphertextChannelSize % ciphertextSegmentSize);
    int ciphertextOverhead = streamAead.getCiphertextOverhead();
    if (remainder > 0) {
      numberOfSegments = fullSegments + 1;
      if (remainder < ciphertextOverhead) {
        throw new IOException("Invalid ciphertext size");
      }
      lastCiphertextSegmentSize = remainder;
    } else {
      numberOfSegments = fullSegments;
      lastCiphertextSegmentSize = ciphertextSegmentSize;
    }
    ciphertextOffset = streamAead.getCiphertextOffset();
    firstSegmentOffset = ciphertextOffset - streamAead.getHeaderLength();
    if (firstSegmentOffset < 0) {
      throw new IOException("Invalid ciphertext offset or header length");
    }
    long overhead = (long) numberOfSegments * ciphertextOverhead + ciphertextOffset;
    if (overhead > ciphertextChannelSize) {
      throw new IOException("Ciphertext is too short");
    }
    plaintextSize = ciphertextChannelSize - overhead;
  }

  /**
   * A description of the state of this StreamingAeadSeekableDecryptingChannel.
   * While this description does not contain plaintext or key material
   * it contains length information that might leak some information.
   */
  @Override
  public synchronized String toString() {
    StringBuilder res =
      new StringBuilder();
    String ctChannel;
    try {
      ctChannel = "position:" + ciphertextChannel.position();
    } catch (IOException ex) {
      ctChannel = "position: n/a";
    }
    res.append("StreamingAeadSeekableDecryptingChannel")
       .append("\nciphertextChannel").append(ctChannel)
       .append("\nciphertextChannelSize:").append(ciphertextChannelSize)
       .append("\nplaintextSize:").append(plaintextSize)
       .append("\nciphertextSegmentSize:").append(ciphertextSegmentSize)
       .append("\nnumberOfSegments:").append(numberOfSegments)
       .append("\nheaderRead:").append(headerRead)
       .append("\nplaintextPosition:").append(plaintextPosition)
       .append("\nHeader")
       .append(" position:").append(header.position())
       .append(" limit:").append(header.position())
       .append("\ncurrentSegmentNr:").append(currentSegmentNr)
       .append("\nciphertextSgement")
       .append(" position:").append(ciphertextSegment.position())
       .append(" limit:").append(ciphertextSegment.limit())
       .append("\nisCurrentSegmentDecrypted:").append(isCurrentSegmentDecrypted)
       .append("\nplaintextSegment")
       .append(" position:").append(plaintextSegment.position())
       .append(" limit:").append(plaintextSegment.limit());
    return res.toString();
  }

  /**
   * Returns the position of this channel.
   * The position is relative to the plaintext.
   */
  @Override
  public synchronized long position() {
    return plaintextPosition;
  }

  /**
   * Sets the position in the plaintext. Setting the position to a value greater than the plaintext
   * size is legal. A later attempt to read byte will throw an IOException.
   */
  @CanIgnoreReturnValue
  @Override
  public synchronized SeekableByteChannel position(long newPosition) {
    plaintextPosition = newPosition;
    return this;
  }

  /**
   * Tries to read the header of the ciphertext and derive the key used for the ciphertext from the
   * information in the header.
   *
   * @return true if the header was fully read and has a correct format. Returns false if the header
   *     could not be read.
   * @throws IOException if the header was incorrectly formatted or if there was an exception during
   *     the key derivation.
   */
  private boolean tryReadHeader() throws IOException {
    ciphertextChannel.position(header.position() + firstSegmentOffset);
    ciphertextChannel.read(header);
    if (header.remaining() > 0) {
      return false;
    } else {
      header.flip();
      try {
        decrypter.init(header, aad);
        headerRead = true;
      } catch (GeneralSecurityException ex) {
        // TODO(bleichen): Define the state of this.
        throw new IOException(ex);
      }
      return true;
    }
  }

  private int getSegmentNr(long plaintextPosition) {
    return (int) ((plaintextPosition + ciphertextOffset) / plaintextSegmentSize);
  }

  /**
   * Tries to read and decrypt a ciphertext segment.
   * @param segmentNr the number of the segment
   * @return true if the segment was read and correctly decrypted.
   *          Returns false if the segment could not be fully read.
   * @throws IOException if there was an exception reading the ciphertext,
   *         if the segment number was incorrect, or
   *         if there was an exception trying to decrypt the ciphertext segment.
   */
  private boolean tryLoadSegment(int segmentNr) throws IOException {
    if (segmentNr < 0 || segmentNr >= numberOfSegments) {
      throw new IOException("Invalid position");
    }
    boolean isLast = segmentNr == numberOfSegments - 1;
    if (segmentNr == currentSegmentNr) {
      if (isCurrentSegmentDecrypted) {
        return true;
      }
    } else {
      // segmentNr != currentSegmentNr
      long ciphertextPosition = (long) segmentNr * ciphertextSegmentSize;
      int segmentSize = ciphertextSegmentSize;
      if (isLast) {
        segmentSize = lastCiphertextSegmentSize;
      }
      if (segmentNr == 0) {
        segmentSize -= ciphertextOffset;
        ciphertextPosition = ciphertextOffset;
      }
      ciphertextChannel.position(ciphertextPosition);
      ciphertextSegment.clear();
      ciphertextSegment.limit(segmentSize);
      currentSegmentNr = segmentNr;
      isCurrentSegmentDecrypted = false;
    }
    if (ciphertextSegment.remaining() > 0) {
      ciphertextChannel.read(ciphertextSegment);
    }
    if (ciphertextSegment.remaining() > 0) {
      return false;
    }
    ciphertextSegment.flip();
    plaintextSegment.clear();
    try {
      decrypter.decryptSegment(ciphertextSegment, segmentNr, isLast, plaintextSegment);
    } catch (GeneralSecurityException ex) {
      // The current segment did not validate. Ensure that this instance remains
      // in a valid state.
      currentSegmentNr = -1;
      throw new IOException("Failed to decrypt", ex);
    }
    plaintextSegment.flip();
    isCurrentSegmentDecrypted = true;
    return true;
  }

  /**
   * Returns true if plaintextPositon is at the end of the file
   * and this has been verified, by decrypting the last segment.
   */
  private boolean reachedEnd() {
    return (plaintextPosition == plaintextSize
        && isCurrentSegmentDecrypted
        && currentSegmentNr == numberOfSegments - 1
        && plaintextSegment.remaining() == 0);
  }

  /**
   * Atomic read from a given position.
   *
   * This method works in the same way as read(ByteBuffer), except that it starts at the given
   * position and does not modify the channel's position.
   */
  public synchronized int read(ByteBuffer dst, long start) throws IOException {
    long oldPosition = position();
    try {
      position(start);
      return read(dst);
    } finally {
      position(oldPosition);
    }
  }

  @Override
  public synchronized int read(ByteBuffer dst) throws IOException {
    if (!isopen) {
      throw new ClosedChannelException();
    }
    if (!headerRead) {
      if (!tryReadHeader()) {
        return 0;
      }
    }
    int startPos = dst.position();
    while (dst.remaining() > 0 && plaintextPosition < plaintextSize) {
      // Determine segmentNr for the plaintext to read and the offset in
      // the plaintext, where reading should start.
      int segmentNr = getSegmentNr(plaintextPosition);
      int segmentOffset;
      if (segmentNr == 0) {
         segmentOffset = (int) plaintextPosition;
      } else {
         segmentOffset = (int) ((plaintextPosition +  ciphertextOffset) % plaintextSegmentSize);
      }

      if (tryLoadSegment(segmentNr)) {
        plaintextSegment.position(segmentOffset);
        if (plaintextSegment.remaining() <= dst.remaining()) {
          plaintextPosition += plaintextSegment.remaining();
          dst.put(plaintextSegment);
        } else {
          int sliceSize = dst.remaining();
          ByteBuffer slice = plaintextSegment.duplicate();
          slice.limit(slice.position() + sliceSize);
          dst.put(slice);
          plaintextPosition += sliceSize;
          plaintextSegment.position(plaintextSegment.position() + sliceSize);
        }
      } else {
        break;
      }
    }
    int read = dst.position() - startPos;
    if (read == 0 && reachedEnd()) {
      return -1;
    }
    return read;
  }

  /**
   * Returns the expected size of the plaintext.
   * Note that this implementation does not perform an integrity check on the size.
   * I.e. if the file has been truncated then size() will return the wrong
   * result. Reading the last block of the ciphertext will verify whether size()
   * is correct.
   */
  @Override
  public long size() {
    return plaintextSize;
  }

  public synchronized long verifiedSize() throws IOException {
    if (tryLoadSegment(numberOfSegments - 1)) {
      return plaintextSize;
    } else {
      throw new IOException("could not verify the size");
    }
  }

  @Override
  public SeekableByteChannel truncate(long size) throws NonWritableChannelException {
    throw new NonWritableChannelException();
  }

  @Override
  public int write(ByteBuffer src) throws NonWritableChannelException {
    throw new NonWritableChannelException();
  }

  @Override
  public synchronized void close() throws IOException {
    ciphertextChannel.close();
    isopen = false;
  }

  @Override
  public synchronized boolean isOpen() {
    return isopen;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy