All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wildfly.security.asn1.DEREncoder Maven / Gradle / Ivy

Go to download

This artifact provides a single jar that contains all classes required to use remote EJB and JMS, including all dependencies. It is intended for use by those not using maven, maven users should just import the EJB and JMS BOM's instead (shaded JAR's cause lots of problems with maven, as it is very easy to inadvertently end up with different versions on classes on the class path).

The newest version!
/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.wildfly.security.asn1;

import static org.wildfly.security.asn1.ElytronMessages.log;
import static org.wildfly.security.asn1.ASN1.*;

import java.math.BigInteger;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.LinkedList;

import org.wildfly.common.bytes.ByteStringBuilder;
import org.wildfly.common.iteration.ByteIterator;

/**
 * A class used to encode ASN.1 values using the Distinguished Encoding Rules (DER), as specified
 * in ITU-T X.690.
 *
 * @author Farah Juma
 */
public class DEREncoder implements ASN1Encoder {
    private static final int[] BITS = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80};
    private static final long LARGEST_UNSHIFTED_LONG = Long.MAX_VALUE / 10L;
    private static final byte[] NULL_CONTENTS = new byte[0];
    private static final TagComparator TAG_COMPARATOR = new TagComparator();
    private static final LexicographicComparator LEXICOGRAPHIC_COMPARATOR = new LexicographicComparator();
    private static final byte[] BOOLEAN_TRUE_AS_BYTES = new byte[] { ~0 };
    private static final byte[] BOOLEAN_FALSE_AS_BYTES = new byte[] { 0 };

    private final ArrayDeque states = new ArrayDeque();
    private final ArrayList buffers = new ArrayList();
    private ByteStringBuilder currentBuffer;
    private int currentBufferPos = -1;
    private final ByteStringBuilder target;
    private int implicitTag = -1;

    /**
     * Create a DER encoder.
     */
    public DEREncoder() {
        this(new ByteStringBuilder());
    }

    /**
     * Create a DER encoder that writes its output to the given {@code ByteStringBuilder}.
     *
     * @param target the {@code ByteStringBuilder} to which the DER encoded values are written
     */
    DEREncoder(ByteStringBuilder target) {
        this.target = target;
        currentBuffer = target;
    }

    @Override
    public void startSequence() {
        startConstructedElement(SEQUENCE_TYPE);
    }

    @Override
    public void startSet() {
        startConstructedElement(SET_TYPE);
    }

    @Override
    public void startSetOf() {
        startSet();
    }

    @Override
    public void startExplicit(int number) {
        startExplicit(CONTEXT_SPECIFIC_MASK, number);
    }

    @Override
    public void startExplicit(int clazz, int number) {
        int explicitTag = clazz | CONSTRUCTED_MASK | number;
        startConstructedElement(explicitTag);
    }

    private void startConstructedElement(int tag) {
        EncoderState lastState = states.peekLast();
        if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
            updateCurrentBuffer();
            lastState.addChildElement(tag, currentBufferPos);
        }
        writeTag(tag, currentBuffer);
        if (tag != SET_TYPE) {
            updateCurrentBuffer();
        }
        states.add(new EncoderState(tag, currentBufferPos));
    }

    @Override
    public void endSequence() throws IllegalStateException {
        EncoderState lastState = states.peekLast();
        if ((lastState == null) || (lastState.getTag() != SEQUENCE_TYPE)) {
            throw log.noSequenceToEnd();
        }
        endConstructedElement();
    }

    @Override
    public void endExplicit() throws IllegalStateException {
        EncoderState lastState = states.peekLast();
        if ((lastState == null) || (lastState.getTag() == SEQUENCE_TYPE)
                || (lastState.getTag() == SET_TYPE) || ((lastState.getTag() & CONSTRUCTED_MASK) == 0)) {
            throw log.noExplicitlyTaggedElementToEnd();
        }
        endConstructedElement();
    }

    private void endConstructedElement() {
        ByteStringBuilder dest;
        if (currentBufferPos > 0) {
            // Output the element to its parent buffer
            dest = buffers.get(currentBufferPos - 1);
        } else {
            // Output the element directly to the target destination
            dest = target;
        }
        int length = currentBuffer.length();
        int numLengthOctets = writeLength(length, dest);
        dest.append(currentBuffer);
        currentBuffer.setLength(0);
        currentBuffer = dest;
        currentBufferPos -= 1;
        states.removeLast();

        // If this element's parent element is a set element, update the parent's accumulated length
        EncoderState lastState = states.peekLast();
        if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
            lastState.addChildLength(1 + numLengthOctets + length);
        }
    }

    @Override
    public void endSet() throws IllegalStateException {
        endSet(TAG_COMPARATOR);
    }

    @Override
    public void endSetOf() throws IllegalStateException {
        endSet(LEXICOGRAPHIC_COMPARATOR);
    }

    private void endSet(Comparator comparator) {
        EncoderState lastState = states.peekLast();
        if ((lastState == null) || (lastState.getTag() != SET_TYPE)) {
            throw log.noSetToEnd();
        }

        // The child elements of a set must be encoded in ascending order by tag
        LinkedList childElements = lastState.getSortedChildElements(comparator);
        int setBufferPos = lastState.getBufferPos();
        ByteStringBuilder dest;
        if (setBufferPos >= 0) {
            dest = buffers.get(setBufferPos);
        } else {
            dest = target;
        }

        ByteStringBuilder contents;
        int childLength = lastState.getChildLength();
        int numLengthOctets = writeLength(lastState.getChildLength(), dest);
        for (EncoderState element : childElements) {
            contents = buffers.get(element.getBufferPos());
            dest.append(contents);
            contents.setLength(0);
        }
        currentBuffer = dest;
        currentBufferPos = setBufferPos;
        states.removeLast();

        // If this set's parent element is a set element, update the parent's accumulated length
        lastState = states.peekLast();
        if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
            lastState.addChildLength(1 + numLengthOctets + childLength);
        }
    }

    @Override
    public void encodeOctetString(String str) {
        encodeOctetString(str.getBytes(StandardCharsets.UTF_8));
    }

    @Override
    public void encodeOctetString(byte[] str) {
        writeElement(OCTET_STRING_TYPE, str);
    }

    void encodeOctetString(ByteStringBuilder str) {
        writeElement(OCTET_STRING_TYPE, str);
    }

    @Override
    public void encodeIA5String(String str) {
        encodeIA5String(str.getBytes(StandardCharsets.US_ASCII));
    }

    @Override
    public void encodeIA5String(byte[] str) {
        writeElement(IA5_STRING_TYPE, str);
    }

    void encodeIA5String(ByteStringBuilder str) {
        writeElement(IA5_STRING_TYPE, str);
    }

    @Override
    public void encodePrintableString(final byte[] str) {
        for (byte b : str) {
            validatePrintableByte(b & 0xff);
        }
        writeElement(PRINTABLE_STRING_TYPE, str);
    }

    @Override
    public void encodePrintableString(final String str) {
        for (int i = 0; i < str.length(); i = str.offsetByCodePoints(i, 1)) {
            validatePrintableByte(str.codePointAt(i));
        }
        writeElement(PRINTABLE_STRING_TYPE, str.getBytes(StandardCharsets.US_ASCII));
    }

    @Override
    public void encodeUTF8String(final String str) {
        writeElement(UTF8_STRING_TYPE, str.getBytes(StandardCharsets.UTF_8));
    }

    @Override
    public void encodeBMPString(final String str) {
        // technically this may fail if str contains a code point outside of the BMP
        writeElement(BMP_STRING_TYPE, str.getBytes(StandardCharsets.UTF_16BE));
    }

    private static final Charset UTF_32BE = Charset.forName("UTF-32BE");

    @Override
    public void encodeUniversalString(final String str) {
        writeElement(UNIVERSAL_STRING_TYPE, str.getBytes(UTF_32BE));
    }

    @Override
    public void encodeBitString(byte[] str) {
        encodeBitString(str, 0); // All bits will be used
    }

    @Override
    public void encodeBitString(byte[] str, int numUnusedBits) {
        byte[] contents = new byte[str.length + 1];
        contents[0] = (byte) numUnusedBits;
        System.arraycopy(str, 0, contents, 1, str.length);
        writeElement(BIT_STRING_TYPE, contents);
    }

    @Override
    public void encodeBitString(final EnumSet enumSet) {
        int ord;
        final BitSet bitSet = new BitSet();
        for (Enum anEnum : enumSet) {
            ord = anEnum.ordinal();
            bitSet.set(ord);
        }
        encodeBitString(bitSet);
    }

    @Override
    public void encodeBitString(final BitSet bitSet) {
        final byte[] array = bitSet.toByteArray();
        final int unusedBits = - bitSet.length() & 0b111;
        for (int i = 0; i < array.length; i++) {
            array[i] = (byte) (Integer.reverse(array[i]) >> 24);
        }
        encodeBitString(array, unusedBits);
    }

    @Override
    public void encodeBitString(String binaryStr) {
        int numBits = binaryStr.length();
        int numBytes = numBits >> 3;
        int remainder = numBits % 8;
        int numUnusedBits = 0;

        if (remainder != 0) {
            numBytes = numBytes + 1;
            numUnusedBits = 8 - remainder;
        }

        byte[] contents = new byte[numBytes + 1];
        contents[0] = (byte) numUnusedBits;
        for (int i = 1; i <= numBytes; i++) {
            contents[i] = (byte) 0;
        }

        char[] binaryStrChars = binaryStr.toCharArray();
        int index = 0;
        for (int i = 1; i <= numBytes && index < numBits; i++) {
            for (int bit = 7; bit >= 0 && index < numBits; bit--) {
                if ((i == numBytes) && (bit < numUnusedBits)) {
                    continue;
                }
                if (binaryStrChars[index++] == '1') {
                    contents[i] |= BITS[bit];
                }
            }
        }
        writeElement(BIT_STRING_TYPE, contents);
    }

    @Override
    public void encodeBitString(BigInteger integer) {
        ByteStringBuilder target = new ByteStringBuilder();
        new DEREncoder(target).encodeInteger(integer);
        encodeBitString(target.toArray());
    }

    private static final DateTimeFormatter GENERALIZED_TIME_FORMAT = DateTimeFormatter.ofPattern("yyyyMMddHHmmssX");

    @Override
    public void encodeGeneralizedTime(final ZonedDateTime time) {
        writeElement(GENERALIZED_TIME_TYPE, GENERALIZED_TIME_FORMAT.format(time).getBytes(StandardCharsets.UTF_8));
    }

    @Override
    public void encodeObjectIdentifier(String objectIdentifier) throws ASN1Exception {
        if (objectIdentifier == null || objectIdentifier.length() == 0) {
            throw log.asnOidMustHaveAtLeast2Components();
        }
        int len = objectIdentifier.length();
        int offs = 0;
        int idx = 0;
        long t = 0L;
        char c;
        int numComponents = 0;
        int first = -1;
        ByteStringBuilder contents = new ByteStringBuilder();

        a: for (;;) {
            c = objectIdentifier.charAt(offs + idx ++);
            if (Character.isDigit(c)) {
                int digit = Character.digit(c, 10);
                if (t > LARGEST_UNSHIFTED_LONG) {
                    BigInteger bi = BigInteger.valueOf(t).multiply(BigInteger.TEN).add(digits[digit]);
                    t = 0L;
                    for (;;) {
                        c = objectIdentifier.charAt(offs + idx ++);
                        if (Character.isDigit(c)) {
                            digit = Character.digit(c, 10);
                            bi = bi.multiply(BigInteger.TEN).add(digits[digit]);
                        } else if (c == '.') {
                            if (numComponents == 0) {
                                first = validateFirstOIDComponent(bi);
                            } else {
                                encodeOIDComponent(bi, contents, numComponents, first);
                            }
                            numComponents++;
                            continue a;
                        } else {
                            throw log.asnInvalidOidCharacter();
                        }
                        if (idx == len) {
                            if (numComponents == 0) {
                                throw log.asnOidMustHaveAtLeast2Components();
                            }
                            encodeOIDComponent(bi, contents, numComponents, first);
                            writeElement(OBJECT_IDENTIFIER_TYPE, contents);
                            return;
                        }
                    }
                } else {
                    t = 10L * t + (long) digit;
                }
            } else if (c == '.') {
                if (numComponents == 0) {
                    first = validateFirstOIDComponent(t);
                } else {
                    encodeOIDComponent(t, contents, numComponents, first);
                }
                numComponents++;
                t = 0L;
            } else {
                throw log.asnInvalidOidCharacter();
            }
            if (idx == len) {
                if (c == '.') {
                    throw log.asnInvalidOidCharacter();
                }
                if (numComponents == 0) {
                    throw log.asnOidMustHaveAtLeast2Components();
                }
                encodeOIDComponent(t, contents, numComponents, first);
                writeElement(OBJECT_IDENTIFIER_TYPE, contents);
                return;
            }
        }
    }

    @Override
    public void encodeNull() {
        writeElement(NULL_TYPE, NULL_CONTENTS);
    }

    @Override
    public void encodeImplicit(int number) {
        encodeImplicit(CONTEXT_SPECIFIC_MASK, number);
    }

    @Override
    public void encodeImplicit(int clazz, int number) {
        if (implicitTag == -1) {
            implicitTag = clazz | number;
        }
    }

    @Override
    public void encodeBoolean(final boolean value) {
        writeElement(BOOLEAN_TYPE, value ? BOOLEAN_TRUE_AS_BYTES : BOOLEAN_FALSE_AS_BYTES);
    }

    @Override
    public void encodeInteger(BigInteger integer) {
        writeElement(INTEGER_TYPE, integer.toByteArray());
    }

    @Override
    public void writeEncoded(byte[] encoded) {
        EncoderState lastState = states.peekLast();
        if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
            updateCurrentBuffer();
            lastState.addChildElement(encoded[0], currentBufferPos);
        }

        if (implicitTag != -1) {
            writeTag(encoded[0], currentBuffer);
            currentBuffer.append(encoded, 1, encoded.length - 1);
        } else {
            currentBuffer.append(encoded);
        }

        // If this element's parent element is a set element, update the parent's accumulated length
        if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
            lastState.addChildLength(currentBuffer.length());
        }
    }

    @Override
    public void flush() {
        while (states.size() != 0) {
            EncoderState lastState = states.peekLast();
            if (lastState.getTag() == SEQUENCE_TYPE) {
                endSequence();
            } else if (lastState.getTag() == SET_TYPE) {
                endSet();
            }
        }
    }

    @Override
    public byte[] getEncoded() {
        return target.toArray();
    }

    private int validateFirstOIDComponent(long value) throws ASN1Exception {
        if (value < 0 || value > 2) {
            throw log.asnInvalidValueForFirstOidComponent();
        }
        return (int) value;
    }

    private int validateFirstOIDComponent(BigInteger value) throws ASN1Exception  {
        if ((value.compareTo(BigInteger.valueOf(0)) == -1)
                || (value.compareTo(BigInteger.valueOf(2)) == 1)) {
            throw log.asnInvalidValueForFirstOidComponent();
        }
        return value.intValue();
    }

    private void validateSecondOIDComponent(long second, int first) throws ASN1Exception  {
        if ((first < 2) && (second > 39)) {
            throw log.asnInvalidValueForSecondOidComponent();
        }
    }

    private void validateSecondOIDComponent(BigInteger second, int first) throws ASN1Exception {
        if ((first < 2) && (second.compareTo(BigInteger.valueOf(39)) == 1)) {
            throw log.asnInvalidValueForSecondOidComponent();
        }
    }

    private void encodeOIDComponent(long value, ByteStringBuilder contents,
            int numComponents, int first) throws ASN1Exception {
         if (numComponents == 1) {
            validateSecondOIDComponent(value, first);
            encodeOIDComponent(value + (40 * first), contents);
        } else {
            encodeOIDComponent(value, contents);
        }
    }

    private void encodeOIDComponent(BigInteger value, ByteStringBuilder contents,
            int numComponents, int first) throws ASN1Exception {
         if (numComponents == 1) {
            validateSecondOIDComponent(value, first);
            encodeOIDComponent(value.add(BigInteger.valueOf(40 * first)), contents);
        } else {
            encodeOIDComponent(value, contents);
        }
    }

    private void encodeOIDComponent(long value, ByteStringBuilder contents) {
        int shift = 56;
        int octet;
        while (shift > 0) {
            if (value >= (1L << shift)) {
                octet = (int) ((value >>> shift) | 0x80);
                contents.append((byte) octet);
            }
            shift = shift - 7;
        }
        octet = (int) (value & 0x7f);
        contents.append((byte) octet);
    }

    private void encodeOIDComponent(BigInteger value, ByteStringBuilder contents) {
        int numBytes = (value.bitLength() + 6) / 7;
        if (numBytes == 0) {
            contents.append((byte) 0);
        } else {
            byte[] result = new byte[numBytes];
            BigInteger currValue = value;
            for (int i = numBytes - 1; i >= 0; i--) {
                result[i] = (byte) ((currValue.intValue() & 0x7f) | 0x80);
                currValue = currValue.shiftRight(7);
            }
            result[numBytes - 1] &= 0x7f;
            contents.append(result);
        }
    }

    private static final BigInteger[] digits = {
        BigInteger.ZERO,
        BigInteger.ONE,
        BigInteger.valueOf(2),
        BigInteger.valueOf(3),
        BigInteger.valueOf(4),
        BigInteger.valueOf(5),
        BigInteger.valueOf(6),
        BigInteger.valueOf(7),
        BigInteger.valueOf(8),
        BigInteger.valueOf(9),
    };

    private void writeTag(int tag, ByteStringBuilder dest) {
        int constructed = tag & CONSTRUCTED_MASK;
        if (implicitTag != -1) {
            tag = implicitTag | constructed;
            implicitTag = -1;
        }
        int tagClass = tag & CLASS_MASK;
        int tagNumber = tag & TAG_NUMBER_MASK;
        if (tagNumber < 31) {
            dest.append((byte) (tagClass | constructed | tagNumber));
        } else {
            // High-tag-number-form
            dest.append((byte) (tagClass | constructed | 0x1f));
            if (tagNumber < 128) {
                dest.append((byte) tagNumber);
            } else {
                int shift = 28;
                int octet;
                while (shift > 0) {
                    if (tagNumber >= (1 << shift)) {
                        octet = (tagNumber >>> shift) | 0x80;
                        dest.append((byte) octet);
                    }
                    shift = shift - 7;
                }
                octet = tagNumber & 0x7f;
                dest.append((byte) octet);
            }
        }
    }

    private int writeLength(int length, ByteStringBuilder dest) throws ASN1Exception {
        int numLengthOctets;
        if (length < 0) {
            throw log.asnInvalidLength();
        } else if (length <= 127) {
            // Short form
            numLengthOctets = 1;
        } else {
            // Long form
            numLengthOctets = 1;
            int value = length;
            while ((value >>>= 8) != 0) {
                numLengthOctets += 1;
            }
        }
        if (length > 127) {
            // bit 8 of the first octet has value "1" and bits 7-1 give the number of additional length octets
            dest.append((byte) (numLengthOctets | 0x80));
        }
        for (int i = (numLengthOctets - 1) * 8; i >= 0; i -= 8) {
            dest.append((byte) (length >> i));
        }
        if (length > 127) {
            // include the first octet
            return 1 + numLengthOctets;
        } else {
            return numLengthOctets;
        }
    }

    private void updateCurrentBuffer() {
        currentBufferPos += 1;
        if (currentBufferPos < buffers.size()) {
            currentBuffer = buffers.get(currentBufferPos);
        } else {
            ByteStringBuilder buffer = new ByteStringBuilder();
            buffers.add(buffer);
            currentBuffer = buffer;
        }
    }

    private void writeElement(int tag, byte[] contents) {
        EncoderState lastState = states.peekLast();
        if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
            updateCurrentBuffer();
            lastState.addChildElement(tag, currentBufferPos);
        }

        writeTag(tag, currentBuffer);
        writeLength(contents.length, currentBuffer);
        currentBuffer.append(contents);

        // If this element's parent element is a set element, update the parent's accumulated length
        if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
            lastState.addChildLength(currentBuffer.length());
        }
    }

    private void writeElement(int tag, ByteStringBuilder contents) {
        EncoderState lastState = states.peekLast();
        if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
            updateCurrentBuffer();
            lastState.addChildElement(tag, currentBufferPos);
        }

        writeTag(tag, currentBuffer);
        writeLength(contents.length(), currentBuffer);
        currentBuffer.append(contents);

        // If this element's parent element is a set element, update the parent's accumulated length
        if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
            lastState.addChildLength(currentBuffer.length());
        }
    }

    /**
     * A class used to maintain state information during DER encoding.
     */
    private class EncoderState {
        private final int tag;
        private final int bufferPos;
        private LinkedList childElements = new LinkedList();
        private int childLength = 0;

        public EncoderState(int tag, int bufferPos) {
            this.tag = tag;
            this.bufferPos = bufferPos;
        }

        public int getTag() {
            return tag;
        }

        public int getBufferPos() {
            return bufferPos;
        }

        public ByteStringBuilder getBuffer() {
            return buffers.get(getBufferPos());
        }

        public int getChildLength() {
            return childLength;
        }

        public LinkedList getSortedChildElements(Comparator comparator) {
            Collections.sort(childElements, comparator);
            return childElements;
        }

        public void addChildElement(int tag, int bufferPos) {
            childElements.add(new EncoderState(tag, bufferPos));
        }

        public void addChildLength(int length) {
            childLength += length;
        }
    }

    /**
     * A class that compares DER encodings based on their tags.
     */
    private static class TagComparator implements Comparator {
        @Override
        public int compare(EncoderState state1, EncoderState state2) {
            // Ignore the constructed bit when comparing tags
            return (state1.getTag() | CONSTRUCTED_MASK) - (state2.getTag() | CONSTRUCTED_MASK);
        }
    }

    /**
     * A class that compares DER encodings using lexicographic order.
     */
    private static class LexicographicComparator implements Comparator {
        @Override
        public int compare(EncoderState state1, EncoderState state2) {
            ByteStringBuilder bytes1 = state1.getBuffer();
            ByteStringBuilder bytes2 = state2.getBuffer();
            ByteIterator bi1 = bytes1.iterate();
            ByteIterator bi2 = bytes2.iterate();

            // Scan the two encodings from left to right until a difference is found
            int diff;
            while (bi1.hasNext() && bi2.hasNext()) {
                diff = (bi1.next() & 0xff) - (bi2.next() & 0xff);
                if (diff != 0) {
                    return diff;
                }
            }

            // The longer encoding is considered to be the bigger-valued encoding
            return bytes1.length() - bytes2.length();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy