Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.grpc.alts.internal.AltsTsiFrameProtector Maven / Gradle / Ivy
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.internal;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import com.google.common.primitives.Ints;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
/** Frame protector that uses the ALTS framing. */
public final class AltsTsiFrameProtector implements TsiFrameProtector {
private static final int HEADER_LEN_FIELD_BYTES = 4;
private static final int HEADER_TYPE_FIELD_BYTES = 4;
private static final int HEADER_BYTES = HEADER_LEN_FIELD_BYTES + HEADER_TYPE_FIELD_BYTES;
private static final int HEADER_TYPE_DEFAULT = 6;
private static final int LIMIT_MAX_ALLOWED_FRAME_SIZE = 1024 * 1024;
// Frame size negotiation extends frame size range to [MIN_FRAME_SIZE, MAX_FRAME_SIZE].
private static final int MIN_FRAME_SIZE = 16 * 1024;
private static final int MAX_FRAME_SIZE = 128 * 1024;
private final Protector protector;
private final Unprotector unprotector;
/** Create a new AltsTsiFrameProtector. */
public AltsTsiFrameProtector(
int maxProtectedFrameBytes, ChannelCrypterNetty crypter, ByteBufAllocator alloc) {
checkArgument(maxProtectedFrameBytes > HEADER_BYTES + crypter.getSuffixLength());
maxProtectedFrameBytes = Math.min(LIMIT_MAX_ALLOWED_FRAME_SIZE, maxProtectedFrameBytes);
protector = new Protector(maxProtectedFrameBytes, crypter);
unprotector = new Unprotector(crypter, alloc);
}
static int getHeaderLenFieldBytes() {
return HEADER_LEN_FIELD_BYTES;
}
static int getHeaderTypeFieldBytes() {
return HEADER_TYPE_FIELD_BYTES;
}
public static int getHeaderBytes() {
return HEADER_BYTES;
}
static int getHeaderTypeDefault() {
return HEADER_TYPE_DEFAULT;
}
static int getLimitMaxAllowedFrameSize() {
return LIMIT_MAX_ALLOWED_FRAME_SIZE;
}
public static int getMinFrameSize() {
return MIN_FRAME_SIZE;
}
public static int getMaxFrameSize() {
return MAX_FRAME_SIZE;
}
@Override
public void protectFlush(
List unprotectedBufs, Consumer ctxWrite, ByteBufAllocator alloc)
throws GeneralSecurityException {
protector.protectFlush(unprotectedBufs, ctxWrite, alloc);
}
@Override
public void unprotect(ByteBuf in, List out, ByteBufAllocator alloc)
throws GeneralSecurityException {
unprotector.unprotect(in, out, alloc);
}
@Override
public void destroy() {
try {
unprotector.destroy();
} finally {
protector.destroy();
}
}
static final class Protector {
private final int maxUnprotectedBytesPerFrame;
private final int suffixBytes;
private ChannelCrypterNetty crypter;
Protector(int maxProtectedFrameBytes, ChannelCrypterNetty crypter) {
this.suffixBytes = crypter.getSuffixLength();
this.maxUnprotectedBytesPerFrame = maxProtectedFrameBytes - HEADER_BYTES - suffixBytes;
this.crypter = crypter;
}
void destroy() {
// Shared with Unprotector and destroyed there.
crypter = null;
}
void protectFlush(
List unprotectedBufs, Consumer ctxWrite, ByteBufAllocator alloc)
throws GeneralSecurityException {
checkState(crypter != null, "Cannot protectFlush after destroy.");
ByteBuf protectedBuf;
try {
protectedBuf = handleUnprotected(unprotectedBufs, alloc);
} finally {
for (ByteBuf buf : unprotectedBufs) {
buf.release();
}
}
if (protectedBuf != null) {
ctxWrite.accept(protectedBuf);
}
}
private ByteBuf handleUnprotected(List unprotectedBufs, ByteBufAllocator alloc)
throws GeneralSecurityException {
long unprotectedBytes = 0;
for (ByteBuf buf : unprotectedBufs) {
unprotectedBytes += buf.readableBytes();
}
// Empty plaintext not allowed since this should be handled as no-op in layer above.
checkArgument(unprotectedBytes > 0);
// Compute number of frames and allocate a single buffer for all frames.
long frameNum = unprotectedBytes / maxUnprotectedBytesPerFrame + 1;
int lastFrameUnprotectedBytes = (int) (unprotectedBytes % maxUnprotectedBytesPerFrame);
if (lastFrameUnprotectedBytes == 0) {
frameNum--;
lastFrameUnprotectedBytes = maxUnprotectedBytesPerFrame;
}
long protectedBytes = frameNum * (HEADER_BYTES + suffixBytes) + unprotectedBytes;
ByteBuf protectedBuf = alloc.directBuffer(Ints.checkedCast(protectedBytes));
try {
int bufferIdx = 0;
for (int frameIdx = 0; frameIdx < frameNum; ++frameIdx) {
int unprotectedBytesLeft =
(frameIdx == frameNum - 1) ? lastFrameUnprotectedBytes : maxUnprotectedBytesPerFrame;
// Write header (at most LIMIT_MAX_ALLOWED_FRAME_BYTES).
protectedBuf.writeIntLE(unprotectedBytesLeft + HEADER_TYPE_FIELD_BYTES + suffixBytes);
protectedBuf.writeIntLE(HEADER_TYPE_DEFAULT);
// Ownership of the backing buffer remains with protectedBuf.
ByteBuf frameOut = writeSlice(protectedBuf, unprotectedBytesLeft + suffixBytes);
List framePlain = new ArrayList<>();
while (unprotectedBytesLeft > 0) {
// Ownership of the buffer backing in remains with unprotectedBufs.
ByteBuf in = unprotectedBufs.get(bufferIdx);
if (in.readableBytes() <= unprotectedBytesLeft) {
// The complete buffer belongs to this frame.
framePlain.add(in);
unprotectedBytesLeft -= in.readableBytes();
bufferIdx++;
} else {
// The remainder of in will be part of the next frame.
framePlain.add(in.readSlice(unprotectedBytesLeft));
unprotectedBytesLeft = 0;
}
}
crypter.encrypt(frameOut, framePlain);
verify(!frameOut.isWritable());
}
protectedBuf.readerIndex(0);
protectedBuf.writerIndex(protectedBuf.capacity());
return protectedBuf.retain();
} finally {
protectedBuf.release();
}
}
}
static final class Unprotector {
private final int suffixBytes;
private final ChannelCrypterNetty crypter;
private DeframerState state = DeframerState.READ_HEADER;
private int requiredProtectedBytes;
private ByteBuf header;
private ByteBuf firstFrameTag;
private int unhandledIdx = 0;
private long unhandledBytes = 0;
private List unhandledBufs = new ArrayList<>(16);
Unprotector(ChannelCrypterNetty crypter, ByteBufAllocator alloc) {
this.crypter = crypter;
this.suffixBytes = crypter.getSuffixLength();
this.header = alloc.directBuffer(HEADER_BYTES);
this.firstFrameTag = alloc.directBuffer(suffixBytes);
}
private void addUnhandled(ByteBuf in) {
if (in.isReadable()) {
ByteBuf buf = in.readRetainedSlice(in.readableBytes());
unhandledBufs.add(buf);
unhandledBytes += buf.readableBytes();
}
}
void unprotect(ByteBuf in, List out, ByteBufAllocator alloc)
throws GeneralSecurityException {
checkState(header != null, "Cannot unprotect after destroy.");
addUnhandled(in);
decodeFrame(alloc, out);
}
@SuppressWarnings("fallthrough")
private void decodeFrame(ByteBufAllocator alloc, List out)
throws GeneralSecurityException {
switch (state) {
case READ_HEADER:
if (unhandledBytes < HEADER_BYTES) {
return;
}
handleHeader();
// fall through
case READ_PROTECTED_PAYLOAD:
if (unhandledBytes < requiredProtectedBytes) {
return;
}
ByteBuf unprotectedBuf;
try {
unprotectedBuf = handlePayload(alloc);
} finally {
clearState();
}
if (unprotectedBuf != null) {
out.add(unprotectedBuf);
}
break;
default:
throw new AssertionError("impossible enum value");
}
}
private void handleHeader() {
while (header.isWritable()) {
ByteBuf in = unhandledBufs.get(unhandledIdx);
int headerBytesToRead = Math.min(in.readableBytes(), header.writableBytes());
header.writeBytes(in, headerBytesToRead);
unhandledBytes -= headerBytesToRead;
if (!in.isReadable()) {
unhandledIdx++;
}
}
requiredProtectedBytes = header.readIntLE() - HEADER_TYPE_FIELD_BYTES;
checkArgument(
requiredProtectedBytes >= suffixBytes, "Invalid header field: frame size too small");
checkArgument(
requiredProtectedBytes <= LIMIT_MAX_ALLOWED_FRAME_SIZE - HEADER_BYTES,
"Invalid header field: frame size too large");
int frameType = header.readIntLE();
checkArgument(frameType == HEADER_TYPE_DEFAULT, "Invalid header field: frame type");
state = DeframerState.READ_PROTECTED_PAYLOAD;
}
private ByteBuf handlePayload(ByteBufAllocator alloc) throws GeneralSecurityException {
int requiredCiphertextBytes = requiredProtectedBytes - suffixBytes;
int firstFrameUnprotectedLen = requiredCiphertextBytes;
// We get the ciphertexts of the first frame and copy over the tag into a single buffer.
List firstFrameCiphertext = new ArrayList<>();
while (requiredCiphertextBytes > 0) {
ByteBuf buf = unhandledBufs.get(unhandledIdx);
if (buf.readableBytes() <= requiredCiphertextBytes) {
// We use the whole buffer.
firstFrameCiphertext.add(buf);
requiredCiphertextBytes -= buf.readableBytes();
unhandledIdx++;
} else {
firstFrameCiphertext.add(buf.readSlice(requiredCiphertextBytes));
requiredCiphertextBytes = 0;
}
}
int requiredSuffixBytes = suffixBytes;
while (true) {
ByteBuf buf = unhandledBufs.get(unhandledIdx);
if (buf.readableBytes() <= requiredSuffixBytes) {
// We use the whole buffer.
requiredSuffixBytes -= buf.readableBytes();
firstFrameTag.writeBytes(buf);
if (requiredSuffixBytes == 0) {
break;
}
unhandledIdx++;
} else {
firstFrameTag.writeBytes(buf, requiredSuffixBytes);
break;
}
}
verify(unhandledIdx == unhandledBufs.size() - 1);
ByteBuf lastBuf = unhandledBufs.get(unhandledIdx);
// We get the remaining ciphertexts and tags contained in the last buffer.
List ciphertextsAndTags = new ArrayList<>();
List unprotectedLens = new ArrayList<>();
long requiredUnprotectedBytesCompleteFrames = firstFrameUnprotectedLen;
while (lastBuf.readableBytes() >= HEADER_BYTES + suffixBytes) {
// Read frame size.
int frameSize = lastBuf.readIntLE();
int payloadSize = frameSize - HEADER_TYPE_FIELD_BYTES - suffixBytes;
// Break and undo read if we don't have the complete frame yet.
if (lastBuf.readableBytes() < frameSize) {
lastBuf.readerIndex(lastBuf.readerIndex() - HEADER_LEN_FIELD_BYTES);
break;
}
// Check the type header.
checkArgument(lastBuf.readIntLE() == 6);
// Create a new frame (except for out buffer).
ciphertextsAndTags.add(lastBuf.readSlice(payloadSize + suffixBytes));
// Update sizes for frame.
requiredUnprotectedBytesCompleteFrames += payloadSize;
unprotectedLens.add(payloadSize);
}
// We leave space for suffixBytes to allow for in-place encryption. This allows for calling
// doFinal in the JCE implementation which can be optimized better than update and doFinal.
ByteBuf unprotectedBuf =
alloc.directBuffer(
Ints.checkedCast(requiredUnprotectedBytesCompleteFrames + suffixBytes));
try {
ByteBuf out = writeSlice(unprotectedBuf, firstFrameUnprotectedLen + suffixBytes);
crypter.decrypt(out, firstFrameTag, firstFrameCiphertext);
verify(out.writableBytes() == suffixBytes);
unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes);
for (int frameIdx = 0; frameIdx < ciphertextsAndTags.size(); ++frameIdx) {
out = writeSlice(unprotectedBuf, unprotectedLens.get(frameIdx) + suffixBytes);
crypter.decrypt(out, ciphertextsAndTags.get(frameIdx));
verify(out.writableBytes() == suffixBytes);
unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes);
}
return unprotectedBuf.retain();
} finally {
unprotectedBuf.release();
}
}
private void clearState() {
int bufsSize = unhandledBufs.size();
ByteBuf lastBuf = unhandledBufs.get(bufsSize - 1);
boolean keepLast = lastBuf.isReadable();
for (int bufIdx = 0; bufIdx < (keepLast ? bufsSize - 1 : bufsSize); ++bufIdx) {
unhandledBufs.get(bufIdx).release();
}
unhandledBufs.clear();
unhandledBytes = 0;
unhandledIdx = 0;
if (keepLast) {
unhandledBufs.add(lastBuf);
unhandledBytes = lastBuf.readableBytes();
}
state = DeframerState.READ_HEADER;
requiredProtectedBytes = 0;
header.clear();
firstFrameTag.clear();
}
void destroy() {
for (ByteBuf unhandledBuf : unhandledBufs) {
unhandledBuf.release();
}
unhandledBufs.clear();
if (header != null) {
header.release();
header = null;
}
if (firstFrameTag != null) {
firstFrameTag.release();
firstFrameTag = null;
}
crypter.destroy();
}
}
private enum DeframerState {
READ_HEADER,
READ_PROTECTED_PAYLOAD
}
private static ByteBuf writeSlice(ByteBuf in, int len) {
checkArgument(len <= in.writableBytes());
ByteBuf out = in.slice(in.writerIndex(), len);
in.writerIndex(in.writerIndex() + len);
return out.writerIndex(0);
}
}