Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.digitalpetri.opcua.stack.server.handlers.UaTcpServerAsymmetricHandler Maven / Gradle / Ivy
package com.digitalpetri.opcua.stack.server.handlers;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteOrder;
import java.security.KeyPair;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import com.digitalpetri.opcua.stack.core.StatusCodes;
import com.digitalpetri.opcua.stack.core.UaException;
import com.digitalpetri.opcua.stack.core.application.CertificateManager;
import com.digitalpetri.opcua.stack.core.application.CertificateValidator;
import com.digitalpetri.opcua.stack.core.channel.ChannelSecurity;
import com.digitalpetri.opcua.stack.core.channel.ExceptionHandler;
import com.digitalpetri.opcua.stack.core.channel.SerializationQueue;
import com.digitalpetri.opcua.stack.core.channel.ServerSecureChannel;
import com.digitalpetri.opcua.stack.core.channel.headers.AsymmetricSecurityHeader;
import com.digitalpetri.opcua.stack.core.channel.headers.HeaderDecoder;
import com.digitalpetri.opcua.stack.core.channel.messages.ErrorMessage;
import com.digitalpetri.opcua.stack.core.channel.messages.MessageType;
import com.digitalpetri.opcua.stack.core.security.SecurityAlgorithm;
import com.digitalpetri.opcua.stack.core.security.SecurityPolicy;
import com.digitalpetri.opcua.stack.core.types.builtin.ByteString;
import com.digitalpetri.opcua.stack.core.types.builtin.DateTime;
import com.digitalpetri.opcua.stack.core.types.builtin.StatusCode;
import com.digitalpetri.opcua.stack.core.types.enumerated.SecurityTokenRequestType;
import com.digitalpetri.opcua.stack.core.types.structured.ChannelSecurityToken;
import com.digitalpetri.opcua.stack.core.types.structured.EndpointDescription;
import com.digitalpetri.opcua.stack.core.types.structured.OpenSecureChannelRequest;
import com.digitalpetri.opcua.stack.core.types.structured.OpenSecureChannelResponse;
import com.digitalpetri.opcua.stack.core.types.structured.ResponseHeader;
import com.digitalpetri.opcua.stack.core.util.BufferUtil;
import com.digitalpetri.opcua.stack.server.tcp.UaTcpStackServer;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static com.digitalpetri.opcua.stack.core.types.builtin.unsigned.Unsigned.uint;
import static com.digitalpetri.opcua.stack.core.util.NonceUtil.generateNonce;
import static com.digitalpetri.opcua.stack.core.util.NonceUtil.getNonceLength;
public class UaTcpServerAsymmetricHandler extends ByteToMessageDecoder implements HeaderDecoder {
private static final long SecureChannelLifetimeMin = 60000L * 60 ;
private static final long SecureChannelLifetimeMax = 60000L * 60 * 24 ;
private final Logger logger = LoggerFactory.getLogger(getClass());
private ServerSecureChannel secureChannel;
private volatile boolean symmetricHandlerAdded = false ;
private List chunkBuffers = new ArrayList<>();
private final AtomicReference headerRef = new AtomicReference<>();
private final int maxChunkCount;
private final int maxChunkSize;
private final UaTcpStackServer server;
private final SerializationQueue serializationQueue;
public UaTcpServerAsymmetricHandler (UaTcpStackServer server, SerializationQueue serializationQueue) {
this .server = server;
this .serializationQueue = serializationQueue;
maxChunkCount = serializationQueue.getParameters().getLocalMaxChunkCount();
maxChunkSize = serializationQueue.getParameters().getLocalReceiveBufferSize();
}
@Override
protected void decode (ChannelHandlerContext ctx, ByteBuf buffer, List out) throws Exception {
buffer = buffer.order(ByteOrder.LITTLE_ENDIAN);
while (buffer.readableBytes() >= HEADER_LENGTH &&
buffer.readableBytes() >= getMessageLength(buffer)) {
int messageLength = getMessageLength(buffer);
MessageType messageType = MessageType.fromMediumInt(buffer.getMedium(buffer.readerIndex()));
switch (messageType) {
case OpenSecureChannel:
onOpenSecureChannel(ctx, buffer.readSlice(messageLength));
break ;
case CloseSecureChannel:
logger.debug("Received CloseSecureChannelRequest" );
if (secureChannel != null ) {
server.closeSecureChannel(secureChannel);
}
buffer.skipBytes(messageLength);
break ;
default :
throw new UaException(StatusCodes.Bad_TcpMessageTypeInvalid,
"unexpected MessageType: " + messageType);
}
}
}
private void onOpenSecureChannel (ChannelHandlerContext ctx, ByteBuf buffer) throws UaException {
buffer.skipBytes(3 );
char chunkType = (char ) buffer.readByte();
if (chunkType == 'A' ) {
chunkBuffers.forEach(ByteBuf::release);
chunkBuffers.clear();
headerRef.set(null );
} else {
buffer.skipBytes(4 );
long secureChannelId = buffer.readUnsignedInt();
AsymmetricSecurityHeader securityHeader = AsymmetricSecurityHeader.decode(buffer);
if (secureChannelId == 0 ) {
String endpointUrl = ctx.channel().attr(UaTcpServerHelloHandler.ENDPOINT_URL_KEY).get();
String securityPolicyUri = securityHeader.getSecurityPolicyUri();
EndpointDescription endpointDescription = Arrays.stream(server.getEndpointDescriptions())
.filter(e -> {
String s1 = pathOrUrl(endpointUrl);
String s2 = pathOrUrl(e.getEndpointUrl());
boolean uriMatch = s1.equals(s2);
boolean policyMatch = e.getSecurityPolicyUri().equals(securityPolicyUri);
return uriMatch && policyMatch;
}).findFirst().orElse(null );
if (endpointDescription == null && !server.getConfig().isStrictEndpointUrlsEnabled()) {
endpointDescription = Arrays.stream(server.getEndpointDescriptions())
.filter(e -> e.getSecurityPolicyUri().equals(securityPolicyUri))
.findFirst().orElse(null );
}
if (endpointDescription == null ) {
throw new UaException(StatusCodes.Bad_SecurityChecksFailed, "SecurityPolicy URI did not match" );
}
secureChannel = server.openSecureChannel();
secureChannel.setEndpointDescription(endpointDescription);
} else {
secureChannel = server.getSecureChannel(secureChannelId);
if (secureChannel == null ) {
throw new UaException(StatusCodes.Bad_TcpSecureChannelUnknown,
"unknown secure channel id: " + secureChannelId);
}
if (!secureChannel.getRemoteCertificateBytes().equals(securityHeader.getSenderCertificate())) {
throw new UaException(StatusCodes.Bad_SecurityChecksFailed,
"certificate requesting renewal did not match existing certificate." );
}
Channel boundChannel = secureChannel.attr(UaTcpStackServer.BoundChannelKey).get();
if (boundChannel != null && boundChannel != ctx.channel()) {
throw new UaException(StatusCodes.Bad_SecurityChecksFailed,
"received a renewal request from channel other than the bound channel." );
}
}
if (!headerRef.compareAndSet(null , securityHeader)) {
if (!securityHeader.equals(headerRef.get())) {
throw new UaException(StatusCodes.Bad_SecurityChecksFailed,
"subsequent AsymmetricSecurityHeader did not match" );
}
}
SecurityPolicy securityPolicy = SecurityPolicy.fromUri(securityHeader.getSecurityPolicyUri());
secureChannel.setSecurityPolicy(securityPolicy);
if (!securityHeader.getSenderCertificate().isNull() && securityPolicy != SecurityPolicy.None) {
secureChannel.setRemoteCertificate(securityHeader.getSenderCertificate().bytes());
try {
CertificateValidator certificateValidator = server.getCertificateValidator();
certificateValidator.validate(secureChannel.getRemoteCertificate());
certificateValidator.verifyTrustChain(
secureChannel.getRemoteCertificate(),
secureChannel.getRemoteCertificateChain());
} catch (UaException e) {
try {
UaException cause = new UaException(e.getStatusCode(), "security checks failed" );
ErrorMessage errorMessage = ExceptionHandler.sendErrorMessage(ctx, cause);
logger.debug("[remote={}] {}." ,
ctx.channel().remoteAddress(), errorMessage.getReason(), cause);
} catch (Exception ignored) {
}
}
}
if (!securityHeader.getReceiverThumbprint().isNull()) {
CertificateManager certificateManager = server.getCertificateManager();
Optional localCertificate = certificateManager
.getCertificate(securityHeader.getReceiverThumbprint());
Optional keyPair = certificateManager
.getKeyPair(securityHeader.getReceiverThumbprint());
if (localCertificate.isPresent() && keyPair.isPresent()) {
secureChannel.setLocalCertificate(localCertificate.get());
secureChannel.setKeyPair(keyPair.get());
} else {
throw new UaException(StatusCodes.Bad_SecurityChecksFailed,
"no certificate for provided thumbprint" );
}
}
int chunkSize = buffer.readerIndex(0 ).readableBytes();
if (chunkSize > maxChunkSize) {
throw new UaException(StatusCodes.Bad_TcpMessageTooLarge,
String.format("max chunk size exceeded (%s)" , maxChunkSize));
}
chunkBuffers.add(buffer.retain());
if (chunkBuffers.size() > maxChunkCount) {
throw new UaException(StatusCodes.Bad_TcpMessageTooLarge,
String.format("max chunk count exceeded (%s)" , maxChunkCount));
}
if (chunkType == 'F' ) {
final List buffersToDecode = chunkBuffers;
chunkBuffers = new ArrayList<>(maxChunkCount);
headerRef.set(null );
serializationQueue.decode((binaryDecoder, chunkDecoder) -> {
ByteBuf messageBuffer = null ;
try {
messageBuffer = chunkDecoder.decodeAsymmetric(secureChannel, buffersToDecode);
OpenSecureChannelRequest request = binaryDecoder
.setBuffer(messageBuffer)
.decodeMessage(null );
logger.debug("Received OpenSecureChannelRequest ({}, id={})." ,
request.getRequestType(), secureChannelId);
long requestId = chunkDecoder.getLastRequestId();
installSecurityToken(ctx, request, requestId);
} catch (UaException e) {
logger.error("Error decoding asymmetric message: {}" , e.getMessage(), e);
ctx.close();
} finally {
if (messageBuffer != null ) {
messageBuffer.release();
}
buffersToDecode.clear();
}
});
}
}
}
private String pathOrUrl (String endpointUrl) {
try {
URI uri = new URI(endpointUrl).parseServerAuthority();
return uri.getPath();
} catch (Throwable e) {
logger.warn("Endpoint URL '{}' is not a valid URI: {}" , e.getMessage(), e);
return endpointUrl;
}
}
private void installSecurityToken (ChannelHandlerContext ctx,
OpenSecureChannelRequest request,
long requestId) throws UaException {
SecurityTokenRequestType requestType = request.getRequestType();
if (requestType == SecurityTokenRequestType.Issue) {
secureChannel.setMessageSecurityMode(request.getSecurityMode());
} else if (requestType == SecurityTokenRequestType.Renew &&
secureChannel.getMessageSecurityMode() != request.getSecurityMode()) {
throw new UaException(StatusCodes.Bad_SecurityChecksFailed,
"secure channel renewal requested a different MessageSecurityMode." );
}
long channelLifetime = request.getRequestedLifetime().longValue();
channelLifetime = Math.min(SecureChannelLifetimeMax, channelLifetime);
channelLifetime = Math.max(SecureChannelLifetimeMin, channelLifetime);
ChannelSecurityToken newToken = new ChannelSecurityToken(
uint(secureChannel.getChannelId()),
uint(server.nextTokenId()),
DateTime.now(),
uint(channelLifetime)
);
ChannelSecurity.SecuritySecrets newKeys = null ;
if (secureChannel.isSymmetricSigningEnabled()) {
SecurityAlgorithm algorithm = secureChannel.getSecurityPolicy().getSymmetricEncryptionAlgorithm();
ByteString remoteNonce = request.getClientNonce();
if (remoteNonce == null || remoteNonce.isNull()) {
throw new UaException(StatusCodes.Bad_SecurityChecksFailed, "remote nonce must be non-null" );
}
if (remoteNonce.length() < getNonceLength(algorithm)) {
String message = String.format(
"remote nonce length must be at least %d bytes" ,
getNonceLength(algorithm));
throw new UaException(StatusCodes.Bad_SecurityChecksFailed, message);
}
ByteString localNonce = generateNonce(getNonceLength(algorithm));
secureChannel.setLocalNonce(localNonce);
secureChannel.setRemoteNonce(remoteNonce);
newKeys = ChannelSecurity.generateKeyPair(
secureChannel,
secureChannel.getRemoteNonce(),
secureChannel.getLocalNonce()
);
}
ChannelSecurity oldSecrets = secureChannel.getChannelSecurity();
ChannelSecurity.SecuritySecrets oldKeys = oldSecrets != null ? oldSecrets.getCurrentKeys() : null ;
ChannelSecurityToken oldToken = oldSecrets != null ? oldSecrets.getCurrentToken() : null ;
ChannelSecurity newSecrets = new ChannelSecurity(
newKeys,
newToken,
oldKeys,
oldToken
);
secureChannel.setChannelSecurity(newSecrets);
ResponseHeader responseHeader = new ResponseHeader(
DateTime.now(),
request.getRequestHeader().getRequestHandle(),
StatusCode.GOOD,
null , null , null
);
OpenSecureChannelResponse response = new OpenSecureChannelResponse(
responseHeader,
uint(PROTOCOL_VERSION),
newToken,
secureChannel.getLocalNonce()
);
sendOpenSecureChannelResponse(ctx, requestId, response);
}
private void sendOpenSecureChannelResponse (ChannelHandlerContext ctx, long requestId, OpenSecureChannelResponse response) {
serializationQueue.encode((binaryEncoder, chunkEncoder) -> {
ByteBuf messageBuffer = BufferUtil.buffer();
try {
binaryEncoder.setBuffer(messageBuffer);
binaryEncoder.encodeMessage(null , response);
List chunks = chunkEncoder.encodeAsymmetric(
secureChannel,
MessageType.OpenSecureChannel,
messageBuffer,
requestId
);
if (!symmetricHandlerAdded) {
ctx.pipeline().addFirst(new UaTcpServerSymmetricHandler(server, serializationQueue, secureChannel));
symmetricHandlerAdded = true ;
}
chunks.forEach(c -> ctx.write(c, ctx.voidPromise()));
ctx.flush();
long lifetime = response.getSecurityToken().getRevisedLifetime().longValue();
server.secureChannelIssuedOrRenewed(secureChannel, lifetime);
logger.debug("Sent OpenSecureChannelResponse." );
} catch (UaException e) {
logger.error("Error encoding OpenSecureChannelResponse: {}" , e.getMessage(), e);
ctx.close();
} finally {
messageBuffer.release();
}
});
}
@Override
public void exceptionCaught (ChannelHandlerContext ctx, Throwable cause) throws Exception {
chunkBuffers.forEach(ByteBuf::release);
chunkBuffers.clear();
if (cause instanceof IOException) {
ctx.close();
logger.debug("[remote={}] IOException caught; channel closed" );
} else {
ErrorMessage errorMessage = ExceptionHandler.sendErrorMessage(ctx, cause);
if (cause instanceof UaException) {
logger.debug("[remote={}] UaException caught; sent {}" ,
ctx.channel().remoteAddress(), errorMessage, cause);
} else {
logger.error("[remote={}] Exception caught; sent {}" ,
ctx.channel().remoteAddress(), errorMessage, cause);
}
}
}
}