All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.grpc.alts.internal.AltsChannelCrypter Maven / Gradle / Ivy

There is a newer version: 1.69.1
Show newest version
/*
 * Copyright 2018 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.alts.internal;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;

import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBuf;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.List;

/** Performs encryption and decryption with AES-GCM using JCE. All methods are thread-compatible. */
final class AltsChannelCrypter implements ChannelCrypterNetty {
  private static final int KEY_LENGTH = AesGcmHkdfAeadCrypter.getKeyLength();
  private static final int COUNTER_LENGTH = 12;
  // The counter will overflow after 2^64 operations and encryption/decryption will stop working.
  private static final int COUNTER_OVERFLOW_LENGTH = 8;
  private static final int TAG_LENGTH = 16;

  private final AeadCrypter aeadCrypter;

  private final byte[] outCounter = new byte[COUNTER_LENGTH];
  private final byte[] inCounter = new byte[COUNTER_LENGTH];
  private final byte[] oldCounter = new byte[COUNTER_LENGTH];

  AltsChannelCrypter(byte[] key, boolean isClient) {
    checkArgument(key.length == KEY_LENGTH);
    byte[] counter = isClient ? inCounter : outCounter;
    counter[counter.length - 1] = (byte) 0x80;
    this.aeadCrypter = new AesGcmHkdfAeadCrypter(key);
  }

  static int getKeyLength() {
    return KEY_LENGTH;
  }

  static int getCounterLength() {
    return COUNTER_LENGTH;
  }

  @SuppressWarnings("BetaApi") // verify is stable in Guava
  @Override
  public void encrypt(ByteBuf outBuf, List plainBufs) throws GeneralSecurityException {
    checkArgument(outBuf.nioBufferCount() == 1);
    // Copy plaintext buffers into outBuf for in-place encryption on single direct buffer.
    ByteBuf plainBuf = outBuf.slice(outBuf.writerIndex(), outBuf.writableBytes());
    plainBuf.writerIndex(0);
    for (ByteBuf inBuf : plainBufs) {
      plainBuf.writeBytes(inBuf);
    }

    verify(outBuf.writableBytes() == plainBuf.readableBytes() + TAG_LENGTH);
    ByteBuffer out = outBuf.internalNioBuffer(outBuf.writerIndex(), outBuf.writableBytes());
    ByteBuffer plain = out.duplicate();
    plain.limit(out.limit() - TAG_LENGTH);

    byte[] counter = incrementOutCounter();
    int outPosition = out.position();
    aeadCrypter.encrypt(out, plain, counter);
    int bytesWritten = out.position() - outPosition;
    outBuf.writerIndex(outBuf.writerIndex() + bytesWritten);
    verify(!outBuf.isWritable());
  }

  @Override
  public void decrypt(ByteBuf out, ByteBuf tag, List ciphertextBufs)
      throws GeneralSecurityException {

    ByteBuf cipherTextAndTag = out.slice(out.writerIndex(), out.writableBytes());
    cipherTextAndTag.writerIndex(0);

    for (ByteBuf inBuf : ciphertextBufs) {
      cipherTextAndTag.writeBytes(inBuf);
    }
    cipherTextAndTag.writeBytes(tag);

    decrypt(out, cipherTextAndTag);
  }

  @SuppressWarnings("BetaApi") // verify is stable in Guava
  @Override
  public void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException {
    int bytesRead = ciphertextAndTag.readableBytes();
    checkArgument(bytesRead == out.writableBytes());

    checkArgument(out.nioBufferCount() == 1);
    ByteBuffer outBuffer = out.internalNioBuffer(out.writerIndex(), out.writableBytes());

    checkArgument(ciphertextAndTag.nioBufferCount() == 1);
    ByteBuffer ciphertextAndTagBuffer =
        ciphertextAndTag.nioBuffer(ciphertextAndTag.readerIndex(), bytesRead);

    byte[] counter = incrementInCounter();
    int outPosition = outBuffer.position();
    aeadCrypter.decrypt(outBuffer, ciphertextAndTagBuffer, counter);
    int bytesWritten = outBuffer.position() - outPosition;
    out.writerIndex(out.writerIndex() + bytesWritten);
    ciphertextAndTag.readerIndex(out.readerIndex() + bytesRead);
    verify(out.writableBytes() == TAG_LENGTH);
  }

  @Override
  public int getSuffixLength() {
    return TAG_LENGTH;
  }

  @Override
  public void destroy() {
    // no destroy required
  }

  /** Increments {@code counter}, store the unincremented value in {@code oldCounter}. */
  static void incrementCounter(byte[] counter, byte[] oldCounter) throws GeneralSecurityException {
    System.arraycopy(counter, 0, oldCounter, 0, counter.length);
    int i = 0;
    for (; i < COUNTER_OVERFLOW_LENGTH; i++) {
      counter[i]++;
      if (counter[i] != (byte) 0x00) {
        break;
      }
    }

    if (i == COUNTER_OVERFLOW_LENGTH) {
      // Restore old counter value to ensure that encrypt and decrypt keep failing.
      System.arraycopy(oldCounter, 0, counter, 0, counter.length);
      throw new GeneralSecurityException("Counter has overflowed.");
    }
  }

  /** Increments the input counter, returning the previous (unincremented) value. */
  private byte[] incrementInCounter() throws GeneralSecurityException {
    incrementCounter(inCounter, oldCounter);
    return oldCounter;
  }

  /** Increments the output counter, returning the previous (unincremented) value. */
  private byte[] incrementOutCounter() throws GeneralSecurityException {
    incrementCounter(outCounter, oldCounter);
    return oldCounter;
  }

  @VisibleForTesting
  void incrementInCounterForTesting(int n) throws GeneralSecurityException {
    for (int i = 0; i < n; i++) {
      incrementInCounter();
    }
  }

  @VisibleForTesting
  void incrementOutCounterForTesting(int n) throws GeneralSecurityException {
    for (int i = 0; i < n; i++) {
      incrementOutCounter();
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy