All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jgroups.protocols.Encrypt Maven / Gradle / Ivy

Go to download

This artifact provides a single jar that contains all classes required to use remote EJB and JMS, including all dependencies. It is intended for use by those not using maven, maven users should just import the EJB and JMS BOM's instead (shaded JAR's cause lots of problems with maven, as it is very easy to inadvertently end up with different versions on classes on the class path).

There is a newer version: 34.0.0.Final
Show newest version
package org.jgroups.protocols;

import org.jgroups.*;
import org.jgroups.annotations.ManagedAttribute;
import org.jgroups.annotations.ManagedOperation;
import org.jgroups.annotations.Property;
import org.jgroups.stack.Protocol;
import org.jgroups.util.*;

import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import java.security.Key;
import java.security.KeyStore;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.stream.Collectors;

/**
 * Super class of symmetric ({@link SYM_ENCRYPT}) and asymmetric ({@link ASYM_ENCRYPT}) encryption protocols.
 * @author Bela Ban
 */
public abstract class Encrypt extends Protocol {
    protected static final String DEFAULT_SYM_ALGO="AES";


    /* -----------------------------------------    Properties     -------------------------------------------------- */
    @Property(description="Cryptographic Service Provider")
    protected String                        provider;

    @Property(description="Cipher engine transformation for asymmetric algorithm. Default is RSA")
    protected String                        asym_algorithm="RSA";

    @Property(description="Cipher engine transformation for symmetric algorithm. Default is AES")
    protected String                        sym_algorithm=DEFAULT_SYM_ALGO;

    @Property(description="Initialization vector length for symmetric encryption. A value must be specified here " +
      "if the configured sym_algorithm requires an initialization vector.")
    protected int                           sym_iv_length;

    @Property(description="Initial public/private key length. Default is 2048")
    protected int                           asym_keylength=2048;

    @Property(description="Initial key length for matching symmetric algorithm. Default is 128")
    protected int                           sym_keylength=128;

    @Property(description="Number of ciphers in the pool to parallelize encrypt and decrypt requests",writable=false)
    protected int                           cipher_pool_size=8;

    @Property(description="Max number of keys in key_map")
    protected int                           key_map_max_size=20;

    protected volatile View                 view;

    // Cipher pools used for encryption and decryption. Size is cipher_pool_size
    protected volatile BlockingQueue encoding_ciphers, decoding_ciphers;

    // version filed for secret key
    protected volatile byte[]               sym_version;

    // shared secret key to encrypt/decrypt messages
    protected volatile Key                  secret_key;

    // map to hold previous keys so we can decrypt some earlier messages if we need to
    protected Map          key_map;

    // SecureRandom instance for generating IV's
    protected SecureRandom                  secure_random = new SecureRandom();

    protected MessageFactory                msg_factory;


    /**
     * Sets the key store entry used to configure this protocol.
     * @param entry a key store entry
     */
    public abstract > T setKeyStoreEntry(E entry);


    public int                      asymKeylength()                 {return asym_keylength;}
    public > T asymKeylength(int len)          {this.asym_keylength=len; return (T)this;}
    public int                      symKeylength()                  {return sym_keylength;}
    public > T symKeylength(int len)           {this.sym_keylength=len; return (T)this;}
    public Key                      secretKey()                     {return secret_key;}
    public String                   symAlgorithm()                  {return sym_algorithm;}
    public > T symAlgorithm(String alg)        {this.sym_algorithm=alg; return (T)this;}
    public String                   symKeyAlgorithm()               {return getAlgorithm(sym_algorithm);}
    public int                      simIvLength()                   {return sym_iv_length;}
    public > T symIvLength(int len)            {this.sym_iv_length=len; return (T)this;}
    public String                   asymAlgorithm()                 {return asym_algorithm;}
    public > T asymAlgorithm(String alg)       {this.asym_algorithm=alg; return (T)this;}
    public byte[]                   symVersion()                    {return sym_version;}
    public SecureRandom             secureRandom()                  {return this.secure_random;}
    /** Allows callers to replace secure_random with impl of their choice, e.g. for performance reasons. */
    public > T secureRandom(SecureRandom sr)   {this.secure_random = sr; return (T)this;}
    public > T msgFactory(MessageFactory f)    {this.msg_factory=f; return (T)this;}
    @ManagedAttribute public String version()                       {return Util.byteArrayToHexString(sym_version);}


    @ManagedOperation(description="Prints the versions of the shared group keys cached in the key map")
    public String printCachedGroupKeys() {
        return key_map.keySet().stream().map(v -> Util.byteArrayToHexString(v.chars()))
          .collect(Collectors.joining(", "));
    }


    public void init() throws Exception {
        int tmp=Util.getNextHigherPowerOfTwo(cipher_pool_size);
        if(tmp != cipher_pool_size) {
            log.warn("%s: setting cipher_pool_size (%d) to %d (power of 2) for faster modulo operation", local_addr, cipher_pool_size, tmp);
            cipher_pool_size=tmp;
        }
        key_map=new BoundedHashMap<>(key_map_max_size);
        initSymCiphers(sym_algorithm, secret_key);
        TP transport=getTransport();
        if(transport != null)
            msg_factory=transport.getMessageFactory();
    }


    public Object down(Event evt) {
        switch(evt.getType()) {
            case Event.VIEW_CHANGE:
                Object retval=down_prot.down(evt); // Start keyserver socket in SSL_KEY_EXCHANGE, for instance
                handleView(evt.getArg());
                return retval;
        }
        return down_prot.down(evt);
    }


    public Object down(Message msg) {
        try {
            if(secret_key == null) {
                log.trace("%s: discarded %s message to %s as secret key is null, hdrs: %s",
                          local_addr, msg.dest() == null? "mcast" : "unicast", msg.dest(), msg.printHeaders());
                return null;
            }
            down_prot.down(encrypt(msg));
        }
        catch(Exception e) {
            log.warn("%s: unable to send message down", local_addr, e);
        }
        return null;
    }


    public Object up(Event evt) {
        switch(evt.getType()) {
            case Event.VIEW_CHANGE:
                handleView(evt.getArg());
                break;
        }
        return up_prot.up(evt);
    }

    public Object up(Message msg) {
        EncryptHeader hdr=msg.getHeader(this.id);
        if(hdr == null) {
            log.error("%s: received message without encrypt header from %s; dropping it", local_addr, msg.src());
            return null;
        }
        try {
            return handleEncryptedMessage(msg);
        }
        catch(Exception e) {
            log.warn("%s: exception occurred decrypting message", local_addr, e);
        }
        return null;
    }

    public void up(MessageBatch batch) {
        if(secret_key == null) {
            log.trace("%s: discarded %s batch from %s as secret key is null",
                      local_addr, batch.dest() == null? "mcast" : "unicast", batch.sender());
            return;
        }
        BlockingQueue cipherQueue=decoding_ciphers;
        if(cipherQueue == null)
            return;
        Cipher cipher=null;
        try {
            cipher=cipherQueue.take();
            FastArray.FastIterator it=(FastArray.FastIterator)batch.iterator();
            while(it.hasNext()) {
                Message msg=it.next();
                if(msg.getHeader(id) == null) {
                    log.error("%s: received message without encrypt header from %s; dropping it",
                              local_addr, batch.sender());
                    it.remove(); // remove from batch to prevent passing the message further up as part of the batch
                    continue;
                }
                try {
                    Message tmpMsg=decrypt(cipher, msg.copy(true, true)); // need to copy for possible xmits
                    if(tmpMsg != null)
                        it.replace(tmpMsg);
                    else
                        it.remove();
                }
                catch(Exception e) {
                    log.error("%s: failed decrypting message from %s (offset=%d, length=%d, buf.length=%d): %s, headers are %s",
                              local_addr, msg.getSrc(), msg.getOffset(), msg.getLength(), msg.getArray().length, e, msg.printHeaders());
                    it.remove();
                }
            }
        }
        catch(InterruptedException e) {
            log.error("%s: failed processing batch; discarding batch", local_addr, e);
            // we need to drop the batch if we for example have a failure fetching a cipher, or else other messages
            // in the batch might make it up the stack, bypassing decryption! This is not an issue because encryption
            // is below NAKACK2 or UNICAST3, so messages will get retransmitted
            return;
        }
        finally {
            if(cipher != null)
                cipherQueue.offer(cipher);
        }
        if(!batch.isEmpty())
            up_prot.up(batch);
    }


    /** Initialises the ciphers for both encryption and decryption using the generated or supplied secret key */
    protected void initSymCiphers(String algorithm, Key secret) throws Exception {
        if(secret == null)
            return;

        BlockingQueue tmp_encoding_ciphers=new ArrayBlockingQueue<>(cipher_pool_size);
        BlockingQueue tmp_decoding_ciphers=new ArrayBlockingQueue<>(cipher_pool_size);
        for(int i=0; i < cipher_pool_size; i++ ) {
            tmp_encoding_ciphers.offer(createCipher(algorithm));
            tmp_decoding_ciphers.offer(createCipher(algorithm));
        }

        // set the version
        MessageDigest digest=MessageDigest.getInstance("MD5");
        byte[] tmp_sym_version=digest.digest(secret.getEncoded());

        this.encoding_ciphers = tmp_encoding_ciphers;
        this.decoding_ciphers = tmp_decoding_ciphers;
        this.sym_version = tmp_sym_version;
    }


    protected Cipher createCipher(String algorithm) throws Exception {
        return provider != null && !provider.trim().isEmpty()?
          Cipher.getInstance(algorithm, provider) : Cipher.getInstance(algorithm);
    }

    protected static void initCipher(Cipher cipher, int mode, Key secret_key, byte[] iv) throws Exception {
        if(iv != null)
            cipher.init(mode, secret_key, new IvParameterSpec(iv));
        else
            cipher.init(mode, secret_key);
    }

    protected byte[] makeIv() {
        if(sym_iv_length > 0) {
            byte[] iv=new byte[sym_iv_length];
            secure_random.nextBytes(iv);
            return iv;
        }
        return null;
    }

    protected Object handleEncryptedMessage(Message msg) throws Exception {
        // decrypt the message; we need to copy msg as we modify its buffer (https://issues.redhat.com/browse/JGRP-538)
        Message tmpMsg=decrypt(null, msg.copy(true, true)); // need to copy for possible xmits
        return tmpMsg != null? up_prot.up(tmpMsg) : null;
    }


    protected void handleView(View view) {
        this.view=view;
    }

    protected boolean inView(Address sender, String error_msg) {
        View curr_view=this.view;
        if(curr_view == null || curr_view.containsMember(sender))
            return true;
        log.error(error_msg, sender, curr_view);
        return false;
    }


    /** Does the actual work for decrypting - if version does not match current cipher then tries the previous cipher */
    protected Message decrypt(Cipher cipher, Message msg) throws Exception {
        EncryptHeader hdr=msg.getHeader(this.id);
        // If the versions of the group keys don't match, we only try to use a previous version if the sender is in
        // the current view
        if(!Arrays.equals(hdr.version(), sym_version)) {
            if(!inView(msg.src(),
                       String.format("%s: rejected decryption of %s message from non-member %s",
                                     local_addr, msg.dest() == null? "multicast" : "unicast", msg.getSrc())))
                return null;
            Key key=key_map.get(new AsciiString(hdr.version()));
            if(key == null) {
                log.trace("%s: message from %s (version: %s) dropped, as a key matching that version wasn't found " +
                            "(current version: %s)",
                          local_addr, msg.src(), Util.byteArrayToHexString(hdr.version()), Util.byteArrayToHexString(sym_version));
                return null;
            }
            log.trace("%s: decrypting msg from %s using previous key version %s",
                      local_addr, msg.src(), Util.byteArrayToHexString(hdr.version()));
            return _decrypt(cipher, key, msg, hdr);
        }
        return _decrypt(cipher, secret_key, msg, hdr);
    }

    protected Message _decrypt(final Cipher cipher, Key key, Message msg, EncryptHeader hdr) throws Exception {
        if(!msg.hasPayload())
            return msg;

        byte[] decrypted_msg;
        if(cipher == null)
            decrypted_msg=code(msg.getArray(), msg.getOffset(), msg.getLength(), hdr.iv(), true);
        else {
            initCipher(cipher, Cipher.DECRYPT_MODE, key, hdr.iv());
            decrypted_msg=cipher.doFinal(msg.getArray(), msg.getOffset(), msg.getLength());
        }
        if(hdr.needsDeserialization())
            return Util.messageFromBuffer(decrypted_msg, 0, decrypted_msg.length, msg_factory);
        else
            return msg.setArray(decrypted_msg, 0, decrypted_msg.length);
    }

    protected Message encrypt(Message msg) throws Exception {
        // copy needed because same message (object) may be retransmitted -> prevent double encryption
        if(!msg.hasPayload())
            return msg.putHeader(this.id, new EncryptHeader((byte)0, symVersion(), makeIv()));
        boolean serialize=!msg.hasArray();
        ByteArray tmp=null;
        byte[] payload=serialize? (tmp=Util.messageToBuffer(msg)).getArray() : msg.getArray();
        int offset=serialize? tmp.getOffset() : msg.getOffset();
        int length=serialize? tmp.getLength() : msg.getLength();
        byte[] iv=makeIv();
        Message encrypted=(serialize? new BytesMessage(msg.dest()) : msg.copy(false, true))
          .putHeader(this.id, new EncryptHeader((byte)0, symVersion(), iv).needsDeserialization(serialize));
        if(length > 0)
            encrypted.setArray(code(payload, offset, length, iv, false));
        else // length is 0, but buffer may be "" (empty, but *not null* buffer)! [JGRP-2153]
            encrypted.setArray(payload, offset, length);
        return encrypted;
    }


    protected byte[] code(byte[] buf, int offset, int length, byte[] iv, boolean decode) throws Exception {
        BlockingQueue queue=decode? decoding_ciphers : encoding_ciphers;
        Cipher cipher=queue.take();
        try {
            initCipher(cipher, decode ? Cipher.DECRYPT_MODE : Cipher.ENCRYPT_MODE, secret_key, iv);
            return cipher.doFinal(buf, offset, length);
        }
        finally {
            queue.offer(cipher);
        }
    }


    /* Get the algorithm name from "algorithm/mode/padding"  taken from original ENCRYPT */
    protected static String getAlgorithm(String s) {
        int index=s.indexOf('/');
        return index == -1? s : s.substring(0, index);
    }

    /* Get the mode/padding part of the transformation, if present */
    protected static String getModeAndPadding(String s) {
        int index=s.indexOf('/');
        String modeAndPadding = index == -1? null : s.substring(index + 1);
        if (modeAndPadding == null || modeAndPadding.isEmpty())
            return null;
        return modeAndPadding;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy