io.grpc.alts.internal.AltsHandshakerClient Maven / Gradle / Ivy
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.internal;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.protobuf.ByteString;
import io.grpc.ChannelLogger;
import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.Status;
import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceStub;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
/** An API for conducting handshakes via ALTS handshaker service. */
class AltsHandshakerClient {
private static final String APPLICATION_PROTOCOL = "grpc";
private static final String RECORD_PROTOCOL = "ALTSRP_GCM_AES128_REKEY";
private static final int KEY_LENGTH = AltsChannelCrypter.getKeyLength();
private final AltsHandshakerStub handshakerStub;
private final AltsHandshakerOptions handshakerOptions;
private HandshakerResult result;
private HandshakerStatus status;
private final ChannelLogger logger;
/** Starts a new handshake interacting with the handshaker service. */
AltsHandshakerClient(
HandshakerServiceStub stub, AltsHandshakerOptions options, ChannelLogger logger) {
handshakerStub = new AltsHandshakerStub(stub);
handshakerOptions = options;
this.logger = logger;
}
@VisibleForTesting
AltsHandshakerClient(
AltsHandshakerStub handshakerStub, AltsHandshakerOptions options, ChannelLogger logger) {
this.handshakerStub = handshakerStub;
handshakerOptions = options;
this.logger = logger;
}
static String getApplicationProtocol() {
return APPLICATION_PROTOCOL;
}
static String getRecordProtocol() {
return RECORD_PROTOCOL;
}
/** Sets the start client fields for the passed handshake request. */
private void setStartClientFields(HandshakerReq.Builder req) {
// Sets the default values.
StartClientHandshakeReq.Builder startClientReq =
StartClientHandshakeReq.newBuilder()
.setHandshakeSecurityProtocol(HandshakeProtocol.ALTS)
.addApplicationProtocols(APPLICATION_PROTOCOL)
.addRecordProtocols(RECORD_PROTOCOL);
// Sets handshaker options.
if (handshakerOptions.getRpcProtocolVersions() != null) {
startClientReq.setRpcVersions(handshakerOptions.getRpcProtocolVersions());
}
if (handshakerOptions instanceof AltsClientOptions) {
AltsClientOptions clientOptions = (AltsClientOptions) handshakerOptions;
if (!Strings.isNullOrEmpty(clientOptions.getTargetName())) {
startClientReq.setTargetName(clientOptions.getTargetName());
}
for (String serviceAccount : clientOptions.getTargetServiceAccounts()) {
startClientReq.addTargetIdentitiesBuilder().setServiceAccount(serviceAccount);
}
}
startClientReq.setMaxFrameSize(AltsTsiFrameProtector.getMaxFrameSize());
req.setClientStart(startClientReq);
}
/** Sets the start server fields for the passed handshake request. */
private void setStartServerFields(HandshakerReq.Builder req, ByteBuffer inBytes) {
ServerHandshakeParameters serverParameters =
ServerHandshakeParameters.newBuilder().addRecordProtocols(RECORD_PROTOCOL).build();
StartServerHandshakeReq.Builder startServerReq =
StartServerHandshakeReq.newBuilder()
.addApplicationProtocols(APPLICATION_PROTOCOL)
.putHandshakeParameters(HandshakeProtocol.ALTS.getNumber(), serverParameters)
.setInBytes(ByteString.copyFrom(inBytes.duplicate()));
if (handshakerOptions.getRpcProtocolVersions() != null) {
startServerReq.setRpcVersions(handshakerOptions.getRpcProtocolVersions());
}
startServerReq.setMaxFrameSize(AltsTsiFrameProtector.getMaxFrameSize());
req.setServerStart(startServerReq);
}
/** Returns true if the handshake is complete. */
public boolean isFinished() {
// If we have a HandshakeResult, we are done.
if (result != null) {
return true;
}
// If we have an error status, we are done.
if (status != null && status.getCode() != Status.Code.OK.value()) {
return true;
}
return false;
}
/** Returns the handshake status. */
public HandshakerStatus getStatus() {
return status;
}
/** Returns the result data of the handshake, if the handshake is completed. */
public HandshakerResult getResult() {
return result;
}
/**
* Returns the resulting key of the handshake, if the handshake is completed. Note that the key
* data returned from the handshake may be more than the key length required for the record
* protocol, thus we need to truncate to the right size.
*/
public byte[] getKey() {
if (result == null) {
return null;
}
if (result.getKeyData().size() < KEY_LENGTH) {
throw new IllegalStateException("Could not get enough key data from the handshake.");
}
byte[] key = new byte[KEY_LENGTH];
result.getKeyData().substring(0, KEY_LENGTH).copyTo(key, 0);
return key;
}
/**
* Parses a handshake response, setting the status, result, and closing the handshaker, as needed.
*/
private void handleResponse(HandshakerResp resp) throws GeneralSecurityException {
status = resp.getStatus();
if (resp.hasResult()) {
result = resp.getResult();
close();
}
if (status.getCode() != Status.Code.OK.value()) {
String error = "Handshaker service error: " + status.getDetails();
logger.log(ChannelLogLevel.DEBUG, error);
close();
throw new GeneralSecurityException(error);
}
}
/**
* Starts a client handshake. A GeneralSecurityException is thrown if the handshaker service is
* interrupted or fails. Note that isFinished() must be false before this function is called.
*
* @return the frame to give to the peer.
* @throws GeneralSecurityException or IllegalStateException
*/
public ByteBuffer startClientHandshake() throws GeneralSecurityException {
Preconditions.checkState(!isFinished(), "Handshake has already finished.");
HandshakerReq.Builder req = HandshakerReq.newBuilder();
setStartClientFields(req);
HandshakerResp resp;
try {
logger.log(ChannelLogLevel.DEBUG, "Send ALTS handshake request to upstream");
resp = handshakerStub.send(req.build());
logger.log(ChannelLogLevel.DEBUG, "Receive ALTS handshake response from upstream");
} catch (IOException | InterruptedException e) {
throw new GeneralSecurityException(e);
}
handleResponse(resp);
return resp.getOutFrames().asReadOnlyByteBuffer();
}
/**
* Starts a server handshake. A GeneralSecurityException is thrown if the handshaker service is
* interrupted or fails. Note that isFinished() must be false before this function is called.
*
* @param inBytes the bytes received from the peer.
* @return the frame to give to the peer.
* @throws GeneralSecurityException or IllegalStateException
*/
public ByteBuffer startServerHandshake(ByteBuffer inBytes) throws GeneralSecurityException {
Preconditions.checkState(!isFinished(), "Handshake has already finished.");
HandshakerReq.Builder req = HandshakerReq.newBuilder();
setStartServerFields(req, inBytes);
HandshakerResp resp;
try {
resp = handshakerStub.send(req.build());
} catch (IOException | InterruptedException e) {
throw new GeneralSecurityException(e);
}
handleResponse(resp);
((Buffer) inBytes).position(inBytes.position() + resp.getBytesConsumed());
return resp.getOutFrames().asReadOnlyByteBuffer();
}
/**
* Processes the next bytes in a handshake. A GeneralSecurityException is thrown if the handshaker
* service is interrupted or fails. Note that isFinished() must be false before this function is
* called.
*
* @param inBytes the bytes received from the peer.
* @return the frame to give to the peer.
* @throws GeneralSecurityException or IllegalStateException
*/
public ByteBuffer next(ByteBuffer inBytes) throws GeneralSecurityException {
Preconditions.checkState(!isFinished(), "Handshake has already finished.");
HandshakerReq.Builder req =
HandshakerReq.newBuilder()
.setNext(
NextHandshakeMessageReq.newBuilder()
.setInBytes(ByteString.copyFrom(inBytes.duplicate()))
.build());
HandshakerResp resp;
try {
logger.log(ChannelLogLevel.DEBUG, "Send ALTS handshake request to upstream");
resp = handshakerStub.send(req.build());
logger.log(ChannelLogLevel.DEBUG, "Receive ALTS handshake response from upstream");
} catch (IOException | InterruptedException e) {
throw new GeneralSecurityException(e);
}
handleResponse(resp);
((Buffer) inBytes).position(inBytes.position() + resp.getBytesConsumed());
return resp.getOutFrames().asReadOnlyByteBuffer();
}
private boolean closed = false;
/** Closes the connection. */
public void close() {
if (closed) {
return;
}
closed = true;
handshakerStub.close();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy