Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.digitalpetri.opcua.stack.client.handlers.UaTcpClientMessageHandler Maven / Gradle / Ivy
/*
* Copyright 2016 Kevin Herron
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.digitalpetri.opcua.stack.client.handlers;
import java.nio.ByteOrder;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import com.digitalpetri.opcua.stack.client.UaTcpStackClient;
import com.digitalpetri.opcua.stack.core.StatusCodes;
import com.digitalpetri.opcua.stack.core.UaException;
import com.digitalpetri.opcua.stack.core.UaRuntimeException;
import com.digitalpetri.opcua.stack.core.UaServiceFaultException;
import com.digitalpetri.opcua.stack.core.channel.ChannelSecurity;
import com.digitalpetri.opcua.stack.core.channel.ClientSecureChannel;
import com.digitalpetri.opcua.stack.core.channel.MessageAbortedException;
import com.digitalpetri.opcua.stack.core.channel.SerializationQueue;
import com.digitalpetri.opcua.stack.core.channel.headers.AsymmetricSecurityHeader;
import com.digitalpetri.opcua.stack.core.channel.headers.HeaderDecoder;
import com.digitalpetri.opcua.stack.core.channel.messages.ErrorMessage;
import com.digitalpetri.opcua.stack.core.channel.messages.MessageType;
import com.digitalpetri.opcua.stack.core.channel.messages.TcpMessageDecoder;
import com.digitalpetri.opcua.stack.core.security.SecurityAlgorithm;
import com.digitalpetri.opcua.stack.core.serialization.UaRequestMessage;
import com.digitalpetri.opcua.stack.core.serialization.UaResponseMessage;
import com.digitalpetri.opcua.stack.core.types.builtin.ByteString;
import com.digitalpetri.opcua.stack.core.types.builtin.DateTime;
import com.digitalpetri.opcua.stack.core.types.builtin.StatusCode;
import com.digitalpetri.opcua.stack.core.types.builtin.unsigned.UInteger;
import com.digitalpetri.opcua.stack.core.types.enumerated.SecurityTokenRequestType;
import com.digitalpetri.opcua.stack.core.types.structured.ChannelSecurityToken;
import com.digitalpetri.opcua.stack.core.types.structured.CloseSecureChannelRequest;
import com.digitalpetri.opcua.stack.core.types.structured.OpenSecureChannelRequest;
import com.digitalpetri.opcua.stack.core.types.structured.OpenSecureChannelResponse;
import com.digitalpetri.opcua.stack.core.types.structured.RequestHeader;
import com.digitalpetri.opcua.stack.core.types.structured.ServiceFault;
import com.digitalpetri.opcua.stack.core.util.BufferUtil;
import com.digitalpetri.opcua.stack.core.util.LongSequence;
import com.digitalpetri.opcua.stack.core.util.NonceUtil;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageCodec;
import io.netty.util.AttributeKey;
import io.netty.util.Timeout;
import org.jooq.lambda.tuple.Tuple2;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static com.digitalpetri.opcua.stack.core.types.builtin.unsigned.Unsigned.uint;
public class UaTcpClientMessageHandler extends ByteToMessageCodec implements HeaderDecoder {
public static final AttributeKey> KEY_PENDING_REQUEST_FUTURES =
AttributeKey.valueOf("pending-request-futures");
private final Logger logger = LoggerFactory.getLogger(getClass());
private List chunkBuffers = new LinkedList<>();
private final AtomicReference headerRef = new AtomicReference<>();
private ScheduledFuture renewFuture;
private Timeout secureChannelTimeout;
private final Map pending;
private final LongSequence requestIdSequence;
private final UaTcpStackClient client;
private final ClientSecureChannel secureChannel;
private final SerializationQueue serializationQueue;
private final CompletableFuture handshakeFuture;
public UaTcpClientMessageHandler(
UaTcpStackClient client,
ClientSecureChannel secureChannel,
SerializationQueue serializationQueue,
CompletableFuture handshakeFuture) {
this.client = client;
this.secureChannel = secureChannel;
this.serializationQueue = serializationQueue;
this.handshakeFuture = handshakeFuture;
secureChannel
.attr(KEY_PENDING_REQUEST_FUTURES)
.setIfAbsent(Maps.newConcurrentMap());
pending = secureChannel.attr(KEY_PENDING_REQUEST_FUTURES).get();
secureChannel
.attr(ClientSecureChannel.KEY_REQUEST_ID_SEQUENCE)
.setIfAbsent(new LongSequence(1L, UInteger.MAX_VALUE));
requestIdSequence = secureChannel.attr(ClientSecureChannel.KEY_REQUEST_ID_SEQUENCE).get();
handshakeFuture.thenAccept(sc -> {
Channel channel = sc.getChannel();
channel.eventLoop().execute(() -> {
List awaitingHandshake = channel.attr(
UaTcpClientAcknowledgeHandler.KEY_AWAITING_HANDSHAKE).get();
if (awaitingHandshake != null) {
channel.attr(UaTcpClientAcknowledgeHandler.KEY_AWAITING_HANDSHAKE).remove();
logger.debug("{} message(s) queued before handshake completed; sending now.", awaitingHandshake.size());
awaitingHandshake.forEach(channel::writeAndFlush);
}
});
});
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
SecurityTokenRequestType requestType = secureChannel.getChannelId() == 0 ?
SecurityTokenRequestType.Issue : SecurityTokenRequestType.Renew;
secureChannelTimeout = client.getConfig().getWheelTimer().newTimeout(
timeout -> {
if (!timeout.isCancelled()) {
handshakeFuture.completeExceptionally(
new UaException(
StatusCodes.Bad_Timeout,
"timed out waiting for secure channel"));
ctx.close();
}
},
5, TimeUnit.SECONDS
);
logger.debug("OpenSecureChannel timeout scheduled for +5s");
sendOpenSecureChannelRequest(ctx, requestType);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
if (renewFuture != null) renewFuture.cancel(false);
handshakeFuture.completeExceptionally(
new UaException(StatusCodes.Bad_ConnectionClosed, "connection closed"));
super.channelInactive(ctx);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.error("Exception caught: {}", cause.getMessage(), cause);
ctx.close();
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof CloseSecureChannelRequest) {
sendCloseSecureChannelRequest(ctx, (CloseSecureChannelRequest) evt);
}
}
private void sendOpenSecureChannelRequest(ChannelHandlerContext ctx, SecurityTokenRequestType requestType) {
SecurityAlgorithm algorithm = secureChannel.getSecurityPolicy().getSymmetricEncryptionAlgorithm();
int nonceLength = NonceUtil.getNonceLength(algorithm);
ByteString clientNonce = secureChannel.isSymmetricSigningEnabled() ?
NonceUtil.generateNonce(nonceLength) :
ByteString.NULL_VALUE;
secureChannel.setLocalNonce(clientNonce);
OpenSecureChannelRequest request = new OpenSecureChannelRequest(
new RequestHeader(null, DateTime.now(), uint(0), uint(0), null, uint(0), null),
uint(PROTOCOL_VERSION),
requestType,
secureChannel.getMessageSecurityMode(),
secureChannel.getLocalNonce(),
client.getChannelLifetime());
encodeMessage(request, MessageType.OpenSecureChannel).whenComplete((t2, ex) -> {
if (ex != null) {
ctx.close();
return;
}
List chunks = t2.v2();
ctx.executor().execute(() -> {
chunks.forEach(c -> ctx.write(c, ctx.voidPromise()));
ctx.flush();
});
ChannelSecurity channelSecurity = secureChannel.getChannelSecurity();
long currentTokenId = channelSecurity != null ?
channelSecurity.getCurrentToken().getTokenId().longValue() : -1L;
long previousTokenId = channelSecurity != null ?
channelSecurity.getPreviousToken().map(token -> token.getTokenId().longValue()).orElse(-1L) : -1L;
logger.debug(
"Sent OpenSecureChannelRequest ({}, id={}, currentToken={}, previousToken={}).",
request.getRequestType(), secureChannel.getChannelId(), currentTokenId, previousTokenId);
});
}
private void sendCloseSecureChannelRequest(ChannelHandlerContext ctx, CloseSecureChannelRequest request) {
encodeMessage(request, MessageType.CloseSecureChannel).whenComplete((t2, ex) -> {
if (ex != null) {
ctx.close();
return;
}
List chunks = t2.v2();
ctx.executor().execute(() -> {
chunks.forEach(c -> ctx.write(c, ctx.voidPromise()));
ctx.flush();
ctx.close();
});
secureChannel.setChannelId(0);
});
}
@Override
protected void encode(ChannelHandlerContext ctx, UaRequestFuture request, ByteBuf buffer) throws Exception {
encodeMessage(request.getRequest(), MessageType.SecureMessage).whenComplete((t2, ex) -> {
if (ex != null) {
ctx.close();
return;
}
long requestId = t2.v1();
List chunks = t2.v2();
pending.put(requestId, request);
// No matter how we complete, make sure the entry in pending is removed.
// This covers the case where the request fails due to a timeout in the
// upper layers as well as normal completion.
request.getFuture().whenComplete((r, x) -> pending.remove(requestId));
ctx.executor().execute(() -> {
chunks.forEach(c -> ctx.write(c, ctx.voidPromise()));
ctx.flush();
});
});
}
private CompletableFuture>> encodeMessage(
UaRequestMessage request,
MessageType messageType) {
CompletableFuture>> future = new CompletableFuture<>();
serializationQueue.encode((binaryEncoder, chunkEncoder) -> {
ByteBuf messageBuffer = null;
try {
messageBuffer = BufferUtil.buffer();
binaryEncoder.setBuffer(messageBuffer);
binaryEncoder.encodeMessage(null, request);
List chunks;
if (messageType == MessageType.OpenSecureChannel) {
chunks = chunkEncoder.encodeAsymmetric(
secureChannel,
messageType,
messageBuffer,
requestIdSequence.getAndIncrement()
);
} else {
chunks = chunkEncoder.encodeSymmetric(
secureChannel,
messageType,
messageBuffer,
requestIdSequence.getAndIncrement()
);
}
future.complete(new Tuple2<>(chunkEncoder.getLastRequestId(), chunks));
} catch (UaException ex) {
logger.error("Error encoding {}: {}", request, ex.getMessage(), ex);
future.completeExceptionally(ex);
} finally {
if (messageBuffer != null) {
messageBuffer.release();
}
}
});
return future;
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List out) throws Exception {
if (buffer.readableBytes() >= HEADER_LENGTH) {
buffer = buffer.order(ByteOrder.LITTLE_ENDIAN);
if (buffer.readableBytes() >= getMessageLength(buffer)) {
decodeMessage(ctx, buffer);
}
}
}
private void decodeMessage(ChannelHandlerContext ctx, ByteBuf buffer) throws UaException {
int messageLength = getMessageLength(buffer);
MessageType messageType = MessageType.fromMediumInt(buffer.getMedium(buffer.readerIndex()));
switch (messageType) {
case OpenSecureChannel:
onOpenSecureChannel(ctx, buffer.readSlice(messageLength));
break;
case SecureMessage:
onSecureMessage(ctx, buffer.readSlice(messageLength));
break;
case Error:
onError(ctx, buffer.readSlice(messageLength));
break;
default:
throw new UaException(
StatusCodes.Bad_TcpMessageTypeInvalid,
"unexpected MessageType: " + messageType);
}
}
private boolean accumulateChunk(ByteBuf buffer) throws UaException {
int maxChunkCount = serializationQueue.getParameters().getLocalMaxChunkCount();
int maxChunkSize = serializationQueue.getParameters().getLocalReceiveBufferSize();
int chunkSize = buffer.readerIndex(0).readableBytes();
if (chunkSize > maxChunkSize) {
throw new UaException(StatusCodes.Bad_TcpMessageTooLarge,
String.format("max chunk size exceeded (%s)", maxChunkSize));
}
chunkBuffers.add(buffer.retain());
if (chunkBuffers.size() > maxChunkCount) {
throw new UaException(StatusCodes.Bad_TcpMessageTooLarge,
String.format("max chunk count exceeded (%s)", maxChunkCount));
}
char chunkType = (char) buffer.getByte(3);
return (chunkType == 'A' || chunkType == 'F');
}
private void onOpenSecureChannel(ChannelHandlerContext ctx, ByteBuf buffer) throws UaException {
if (secureChannelTimeout != null) {
if (secureChannelTimeout.cancel()) {
logger.debug("OpenSecureChannel timeout canceled");
secureChannelTimeout = null;
} else {
logger.warn("timed out waiting for secure channel");
handshakeFuture.completeExceptionally(
new UaException(StatusCodes.Bad_Timeout,
"timed out waiting for secure channel"));
ctx.close();
return;
}
}
buffer.skipBytes(3 + 1 + 4 + 4); // skip messageType, chunkType, messageSize, secureChannelId
AsymmetricSecurityHeader securityHeader = AsymmetricSecurityHeader.decode(buffer);
if (!headerRef.compareAndSet(null, securityHeader)) {
if (!securityHeader.equals(headerRef.get())) {
throw new UaException(StatusCodes.Bad_SecurityChecksFailed,
"subsequent AsymmetricSecurityHeader did not match");
}
}
if (accumulateChunk(buffer)) {
final List buffersToDecode = ImmutableList.copyOf(chunkBuffers);
chunkBuffers = new LinkedList<>();
serializationQueue.decode((binaryDecoder, chunkDecoder) -> {
ByteBuf decodedBuffer = null;
try {
decodedBuffer = chunkDecoder.decodeAsymmetric(secureChannel, buffersToDecode);
UaResponseMessage responseMessage = binaryDecoder
.setBuffer(decodedBuffer)
.decodeMessage(null);
StatusCode serviceResult = responseMessage.getResponseHeader().getServiceResult();
if (serviceResult.isGood()) {
OpenSecureChannelResponse response = (OpenSecureChannelResponse) responseMessage;
secureChannel.setChannelId(response.getSecurityToken().getChannelId().longValue());
logger.debug("Received OpenSecureChannelResponse.");
installSecurityToken(ctx, response);
handshakeFuture.complete(secureChannel);
} else {
ServiceFault serviceFault = (responseMessage instanceof ServiceFault) ?
(ServiceFault) responseMessage :
new ServiceFault(responseMessage.getResponseHeader());
throw new UaServiceFaultException(serviceFault);
}
} catch (MessageAbortedException e) {
logger.error("Received message abort chunk; error={}, reason={}", e.getStatusCode(), e.getMessage());
ctx.close();
} catch (Throwable t) {
logger.error("Error decoding OpenSecureChannelResponse: {}", t.getMessage(), t);
ctx.close();
} finally {
if (decodedBuffer != null) {
decodedBuffer.release();
}
}
});
}
}
private void installSecurityToken(ChannelHandlerContext ctx, OpenSecureChannelResponse response) {
ChannelSecurity.SecuritySecrets newKeys = null;
if (response.getServerProtocolVersion().longValue() < PROTOCOL_VERSION) {
throw new UaRuntimeException(StatusCodes.Bad_ProtocolVersionUnsupported,
"server protocol version unsupported: " + response.getServerProtocolVersion());
}
ChannelSecurityToken newToken = response.getSecurityToken();
if (secureChannel.isSymmetricSigningEnabled()) {
secureChannel.setRemoteNonce(response.getServerNonce());
newKeys = ChannelSecurity.generateKeyPair(
secureChannel,
secureChannel.getLocalNonce(),
secureChannel.getRemoteNonce()
);
}
ChannelSecurity oldSecrets = secureChannel.getChannelSecurity();
ChannelSecurity.SecuritySecrets oldKeys = oldSecrets != null ? oldSecrets.getCurrentKeys() : null;
ChannelSecurityToken oldToken = oldSecrets != null ? oldSecrets.getCurrentToken() : null;
secureChannel.setChannelSecurity(new ChannelSecurity(newKeys, newToken, oldKeys, oldToken));
DateTime createdAt = response.getSecurityToken().getCreatedAt();
long revisedLifetime = response.getSecurityToken().getRevisedLifetime().longValue();
if (revisedLifetime > 0) {
long renewAt = (long) (revisedLifetime * 0.75);
renewFuture = ctx.executor().schedule(
() -> sendOpenSecureChannelRequest(ctx, SecurityTokenRequestType.Renew),
renewAt, TimeUnit.MILLISECONDS);
} else {
logger.warn("Server revised secure channel lifetime to 0; renewal will not occur.");
}
ctx.executor().execute(() -> {
// SecureChannel is ready; remove the acknowledge handler.
if (ctx.pipeline().get(UaTcpClientAcknowledgeHandler.class) != null) {
ctx.pipeline().remove(UaTcpClientAcknowledgeHandler.class);
}
});
ChannelSecurity channelSecurity = secureChannel.getChannelSecurity();
long currentTokenId = channelSecurity.getCurrentToken().getTokenId().longValue();
long previousTokenId = channelSecurity.getPreviousToken()
.map(t -> t.getTokenId().longValue()).orElse(-1L);
logger.debug(
"SecureChannel id={}, currentTokenId={}, previousTokenId={}, lifetime={}ms, createdAt={}",
secureChannel.getChannelId(), currentTokenId, previousTokenId, revisedLifetime, createdAt);
}
private void onSecureMessage(ChannelHandlerContext ctx, ByteBuf buffer) throws UaException {
buffer.skipBytes(3 + 1 + 4); // skip messageType, chunkType, messageSize
long secureChannelId = buffer.readUnsignedInt();
if (secureChannelId != secureChannel.getChannelId()) {
throw new UaException(StatusCodes.Bad_SecureChannelIdInvalid,
"invalid secure channel id: " + secureChannelId);
}
if (accumulateChunk(buffer)) {
final List buffersToDecode = ImmutableList.copyOf(chunkBuffers);
chunkBuffers = new LinkedList<>();
serializationQueue.decode((binaryDecoder, chunkDecoder) -> {
ByteBuf decodedBuffer = null;
try {
decodedBuffer = chunkDecoder.decodeSymmetric(secureChannel, buffersToDecode);
binaryDecoder.setBuffer(decodedBuffer);
UaResponseMessage response = binaryDecoder.decodeMessage(null);
UaRequestFuture request = pending.remove(chunkDecoder.getLastRequestId());
if (request != null) {
client.getExecutorService().execute(
() -> request.getFuture().complete(response));
} else {
logger.warn("No UaRequestFuture for requestId={}", chunkDecoder.getLastRequestId());
}
} catch (MessageAbortedException e) {
logger.debug("Received message abort chunk; error={}, reason={}", e.getStatusCode(), e.getMessage());
UaRequestFuture request = pending.remove(chunkDecoder.getLastRequestId());
if (request != null) {
client.getExecutorService().execute(
() -> request.getFuture().completeExceptionally(e));
} else {
logger.warn("No UaRequestFuture for requestId={}", chunkDecoder.getLastRequestId());
}
} catch (Throwable t) {
logger.error("Error decoding symmetric message: {}", t.getMessage(), t);
serializationQueue.pause();
ctx.close();
} finally {
if (decodedBuffer != null) {
decodedBuffer.release();
}
}
});
}
}
private void onError(ChannelHandlerContext ctx, ByteBuf buffer) {
try {
ErrorMessage errorMessage = TcpMessageDecoder.decodeError(buffer);
StatusCode statusCode = errorMessage.getError();
long errorCode = statusCode.getValue();
boolean secureChannelError =
errorCode == StatusCodes.Bad_SecurityChecksFailed ||
errorCode == StatusCodes.Bad_TcpSecureChannelUnknown ||
errorCode == StatusCodes.Bad_SecureChannelIdInvalid;
if (secureChannelError) {
secureChannel.setChannelId(0);
}
logger.error("Received error message: " + errorMessage);
handshakeFuture.completeExceptionally(new UaException(statusCode, errorMessage.getReason()));
} catch (UaException e) {
logger.error("An exception occurred while decoding an error message: {}", e.getMessage(), e);
handshakeFuture.completeExceptionally(e);
} finally {
ctx.close();
}
}
}