All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sshd.common.cipher.ChaCha20Cipher Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sshd.common.cipher;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;

import javax.crypto.AEADBadTagException;

import org.apache.sshd.common.mac.Mac;
import org.apache.sshd.common.mac.Poly1305Mac;
import org.apache.sshd.common.util.NumberUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.BufferUtils;

/**
 * AEAD cipher based on the
 * OpenSSH
 * ChaCha20-Poly1305 cipher extension.
 */
public class ChaCha20Cipher implements Cipher {
    protected final ChaChaEngine headerEngine = new ChaChaEngine();
    protected final ChaChaEngine bodyEngine = new ChaChaEngine();
    protected final Mac mac = new Poly1305Mac();
    protected Mode mode;

    public ChaCha20Cipher() {
        // empty
    }

    @Override
    public String getAlgorithm() {
        return "ChaCha20";
    }

    @Override
    public void init(Mode mode, byte[] key, byte[] iv) throws Exception {
        this.mode = mode;

        bodyEngine.initKey(Arrays.copyOfRange(key, 0, 32));
        bodyEngine.initNonce(iv);
        mac.init(bodyEngine.polyKey());

        headerEngine.initKey(Arrays.copyOfRange(key, 32, 64));
        headerEngine.initNonce(iv);
        headerEngine.initCounter(0);
    }

    @Override
    public void updateAAD(byte[] data, int offset, int length) throws Exception {
        ValidateUtils.checkState(mode != null, "Cipher not initialized");
        ValidateUtils.checkTrue(length == 4, "AAD only supported for encrypted packet length");

        if (mode == Mode.Decrypt) {
            mac.update(data, offset, length);
        }

        headerEngine.crypt(data, offset, length, data, offset);

        if (mode == Mode.Encrypt) {
            mac.update(data, offset, length);
        }
    }

    @Override
    public void update(byte[] input, int inputOffset, int inputLen) throws Exception {
        ValidateUtils.checkState(mode != null, "Cipher not initialized");

        if (mode == Mode.Decrypt) {
            mac.update(input, inputOffset, inputLen);
            byte[] actual = mac.doFinal();
            if (!Mac.equals(input, inputOffset + inputLen, actual, 0, actual.length)) {
                throw new AEADBadTagException("Tag mismatch");
            }
        }

        bodyEngine.crypt(input, inputOffset, inputLen, input, inputOffset);

        if (mode == Mode.Encrypt) {
            mac.update(input, inputOffset, inputLen);
            mac.doFinal(input, inputOffset + inputLen);
        }

        headerEngine.advanceNonce();
        headerEngine.initCounter(0);
        bodyEngine.advanceNonce();
        mac.init(bodyEngine.polyKey());
    }

    @Override
    public String getTransformation() {
        return "ChaCha20";
    }

    @Override
    public int getIVSize() {
        return 8;
    }

    @Override
    public int getAuthenticationTagSize() {
        return 16;
    }

    @Override
    public int getCipherBlockSize() {
        return 8;
    }

    @Override
    public int getKdfSize() {
        return 64;
    }

    @Override
    public int getKeySize() {
        return 256;
    }

    protected static class ChaChaEngine {
        private static final int BLOCK_BYTES = 64;
        private static final int BLOCK_INTS = BLOCK_BYTES / Integer.BYTES;
        private static final int KEY_OFFSET = 4;
        private static final int KEY_BYTES = 32;
        private static final int KEY_INTS = KEY_BYTES / Integer.BYTES;
        private static final int COUNTER_OFFSET = 12;
        private static final int NONCE_OFFSET = 14;
        private static final int NONCE_BYTES = 8;
        private static final int NONCE_INTS = NONCE_BYTES / Integer.BYTES;
        private static final int[] ENGINE_STATE_HEADER
                = unpackSigmaString("expand 32-byte k".getBytes(StandardCharsets.US_ASCII));

        protected final int[] x = new int[BLOCK_INTS];
        protected final int[] engineState = new int[BLOCK_INTS];
        protected final byte[] nonce = new byte[NONCE_BYTES];
        protected long initialNonce;

        protected ChaChaEngine() {
            System.arraycopy(ENGINE_STATE_HEADER, 0, engineState, 0, 4);
        }

        protected void initKey(byte[] key) {
            unpackIntsLE(key, 0, KEY_INTS, engineState, KEY_OFFSET);
        }

        protected void initNonce(byte[] nonce) {
            initialNonce = BufferUtils.getLong(nonce, 0, NumberUtils.length(nonce));
            unpackIntsLE(nonce, 0, NONCE_INTS, engineState, NONCE_OFFSET);
            System.arraycopy(nonce, 0, this.nonce, 0, NONCE_BYTES);
        }

        protected void advanceNonce() {
            long counter = BufferUtils.getLong(nonce, 0, NONCE_BYTES) + 1;
            ValidateUtils.checkState(counter != initialNonce, "Packet sequence number cannot be reused with the same key");
            BufferUtils.putLong(counter, nonce, 0, NONCE_BYTES);
            unpackIntsLE(nonce, 0, NONCE_INTS, engineState, NONCE_OFFSET);
        }

        protected void initCounter(long counter) {
            engineState[COUNTER_OFFSET] = (int) counter;
            engineState[COUNTER_OFFSET + 1] = (int) (counter >>> Integer.SIZE);
        }

        // one-shot usage
        protected void crypt(byte[] in, int offset, int length, byte[] out, int outOffset) {
            while (length > 0) {
                System.arraycopy(engineState, 0, x, 0, BLOCK_INTS);
                permute(x);
                int want = Math.min(BLOCK_BYTES, length);
                for (int i = 0, j = 0; i < want; i += Integer.BYTES, j++) {
                    int keyStream = engineState[j] + x[j];
                    int take = Math.min(Integer.BYTES, length);
                    int input = unpackIntLE(in, offset, take);
                    int output = keyStream ^ input;
                    packIntLE(output, out, outOffset, take);
                    offset += take;
                    outOffset += take;
                    length -= take;
                }
                int lo = ++engineState[COUNTER_OFFSET];
                if (lo == 0) {
                    // overflow
                    ++engineState[COUNTER_OFFSET + 1];
                }
            }
        }

        protected byte[] polyKey() {
            byte[] block = new byte[Poly1305Mac.KEY_BYTES];
            initCounter(0);
            crypt(block, 0, block.length, block, 0);
            initCounter(1);
            return block;
        }

        protected static void permute(int[] state) {
            for (int i = 0; i < 10; i++) {
                columnRound(state);
                diagonalRound(state);
            }
        }

        protected static void columnRound(int[] state) {
            quarterRound(state, 0, 4, 8, 12);
            quarterRound(state, 1, 5, 9, 13);
            quarterRound(state, 2, 6, 10, 14);
            quarterRound(state, 3, 7, 11, 15);
        }

        protected static void diagonalRound(int[] state) {
            quarterRound(state, 0, 5, 10, 15);
            quarterRound(state, 1, 6, 11, 12);
            quarterRound(state, 2, 7, 8, 13);
            quarterRound(state, 3, 4, 9, 14);
        }

        protected static void quarterRound(int[] state, int a, int b, int c, int d) {
            state[a] += state[b];
            state[d] = Integer.rotateLeft(state[d] ^ state[a], 16);

            state[c] += state[d];
            state[b] = Integer.rotateLeft(state[b] ^ state[c], 12);

            state[a] += state[b];
            state[d] = Integer.rotateLeft(state[d] ^ state[a], 8);

            state[c] += state[d];
            state[b] = Integer.rotateLeft(state[b] ^ state[c], 7);
        }

        private static int unpackIntLE(byte[] buf, int off) {
            return unpackIntLE(buf, off, Integer.BYTES);
        }

        private static int unpackIntLE(byte[] buf, int off, int len) {
            int ret = 0;
            for (int i = 0; i < len; i++) {
                ret |= Byte.toUnsignedInt(buf[off + i]) << i * Byte.SIZE;
            }
            return ret;
        }

        private static void unpackIntsLE(byte[] buf, int off, int nrInts, int[] dst, int dstOff) {
            for (int i = 0; i < nrInts; i++) {
                dst[dstOff++] = unpackIntLE(buf, off);
                off += Integer.BYTES;
            }
        }

        private static int[] unpackSigmaString(byte[] buf) {
            int[] values = new int[4];
            unpackIntsLE(buf, 0, 4, values, 0);
            return values;
        }

        private static void packIntLE(int value, byte[] dst, int off, int len) {
            for (int i = 0; i < len; i++) {
                dst[off + i] = (byte) (value >>> i * Byte.SIZE);
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy