org.wildfly.security.asn1.DEREncoder Maven / Gradle / Ivy
The newest version!
/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.wildfly.security.asn1;
import static org.wildfly.security.asn1.ElytronMessages.log;
import static org.wildfly.security.asn1.ASN1.*;
import java.math.BigInteger;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.LinkedList;
import org.wildfly.common.bytes.ByteStringBuilder;
import org.wildfly.common.iteration.ByteIterator;
/**
* A class used to encode ASN.1 values using the Distinguished Encoding Rules (DER), as specified
* in ITU-T X.690.
*
* @author Farah Juma
*/
public class DEREncoder implements ASN1Encoder {
private static final int[] BITS = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80};
private static final long LARGEST_UNSHIFTED_LONG = Long.MAX_VALUE / 10L;
private static final byte[] NULL_CONTENTS = new byte[0];
private static final TagComparator TAG_COMPARATOR = new TagComparator();
private static final LexicographicComparator LEXICOGRAPHIC_COMPARATOR = new LexicographicComparator();
private static final byte[] BOOLEAN_TRUE_AS_BYTES = new byte[] { ~0 };
private static final byte[] BOOLEAN_FALSE_AS_BYTES = new byte[] { 0 };
private final ArrayDeque states = new ArrayDeque();
private final ArrayList buffers = new ArrayList();
private ByteStringBuilder currentBuffer;
private int currentBufferPos = -1;
private final ByteStringBuilder target;
private int implicitTag = -1;
/**
* Create a DER encoder.
*/
public DEREncoder() {
this(new ByteStringBuilder());
}
/**
* Create a DER encoder that writes its output to the given {@code ByteStringBuilder}.
*
* @param target the {@code ByteStringBuilder} to which the DER encoded values are written
*/
DEREncoder(ByteStringBuilder target) {
this.target = target;
currentBuffer = target;
}
@Override
public void startSequence() {
startConstructedElement(SEQUENCE_TYPE);
}
@Override
public void startSet() {
startConstructedElement(SET_TYPE);
}
@Override
public void startSetOf() {
startSet();
}
@Override
public void startExplicit(int number) {
startExplicit(CONTEXT_SPECIFIC_MASK, number);
}
@Override
public void startExplicit(int clazz, int number) {
int explicitTag = clazz | CONSTRUCTED_MASK | number;
startConstructedElement(explicitTag);
}
private void startConstructedElement(int tag) {
EncoderState lastState = states.peekLast();
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
updateCurrentBuffer();
lastState.addChildElement(tag, currentBufferPos);
}
writeTag(tag, currentBuffer);
if (tag != SET_TYPE) {
updateCurrentBuffer();
}
states.add(new EncoderState(tag, currentBufferPos));
}
@Override
public void endSequence() throws IllegalStateException {
EncoderState lastState = states.peekLast();
if ((lastState == null) || (lastState.getTag() != SEQUENCE_TYPE)) {
throw log.noSequenceToEnd();
}
endConstructedElement();
}
@Override
public void endExplicit() throws IllegalStateException {
EncoderState lastState = states.peekLast();
if ((lastState == null) || (lastState.getTag() == SEQUENCE_TYPE)
|| (lastState.getTag() == SET_TYPE) || ((lastState.getTag() & CONSTRUCTED_MASK) == 0)) {
throw log.noExplicitlyTaggedElementToEnd();
}
endConstructedElement();
}
private void endConstructedElement() {
ByteStringBuilder dest;
if (currentBufferPos > 0) {
// Output the element to its parent buffer
dest = buffers.get(currentBufferPos - 1);
} else {
// Output the element directly to the target destination
dest = target;
}
int length = currentBuffer.length();
int numLengthOctets = writeLength(length, dest);
dest.append(currentBuffer);
currentBuffer.setLength(0);
currentBuffer = dest;
currentBufferPos -= 1;
states.removeLast();
// If this element's parent element is a set element, update the parent's accumulated length
EncoderState lastState = states.peekLast();
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
lastState.addChildLength(1 + numLengthOctets + length);
}
}
@Override
public void endSet() throws IllegalStateException {
endSet(TAG_COMPARATOR);
}
@Override
public void endSetOf() throws IllegalStateException {
endSet(LEXICOGRAPHIC_COMPARATOR);
}
private void endSet(Comparator comparator) {
EncoderState lastState = states.peekLast();
if ((lastState == null) || (lastState.getTag() != SET_TYPE)) {
throw log.noSetToEnd();
}
// The child elements of a set must be encoded in ascending order by tag
LinkedList childElements = lastState.getSortedChildElements(comparator);
int setBufferPos = lastState.getBufferPos();
ByteStringBuilder dest;
if (setBufferPos >= 0) {
dest = buffers.get(setBufferPos);
} else {
dest = target;
}
ByteStringBuilder contents;
int childLength = lastState.getChildLength();
int numLengthOctets = writeLength(lastState.getChildLength(), dest);
for (EncoderState element : childElements) {
contents = buffers.get(element.getBufferPos());
dest.append(contents);
contents.setLength(0);
}
currentBuffer = dest;
currentBufferPos = setBufferPos;
states.removeLast();
// If this set's parent element is a set element, update the parent's accumulated length
lastState = states.peekLast();
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
lastState.addChildLength(1 + numLengthOctets + childLength);
}
}
@Override
public void encodeOctetString(String str) {
encodeOctetString(str.getBytes(StandardCharsets.UTF_8));
}
@Override
public void encodeOctetString(byte[] str) {
writeElement(OCTET_STRING_TYPE, str);
}
void encodeOctetString(ByteStringBuilder str) {
writeElement(OCTET_STRING_TYPE, str);
}
@Override
public void encodeIA5String(String str) {
encodeIA5String(str.getBytes(StandardCharsets.US_ASCII));
}
@Override
public void encodeIA5String(byte[] str) {
writeElement(IA5_STRING_TYPE, str);
}
void encodeIA5String(ByteStringBuilder str) {
writeElement(IA5_STRING_TYPE, str);
}
@Override
public void encodePrintableString(final byte[] str) {
for (byte b : str) {
validatePrintableByte(b & 0xff);
}
writeElement(PRINTABLE_STRING_TYPE, str);
}
@Override
public void encodePrintableString(final String str) {
for (int i = 0; i < str.length(); i = str.offsetByCodePoints(i, 1)) {
validatePrintableByte(str.codePointAt(i));
}
writeElement(PRINTABLE_STRING_TYPE, str.getBytes(StandardCharsets.US_ASCII));
}
@Override
public void encodeUTF8String(final String str) {
writeElement(UTF8_STRING_TYPE, str.getBytes(StandardCharsets.UTF_8));
}
@Override
public void encodeBMPString(final String str) {
// technically this may fail if str contains a code point outside of the BMP
writeElement(BMP_STRING_TYPE, str.getBytes(StandardCharsets.UTF_16BE));
}
private static final Charset UTF_32BE = Charset.forName("UTF-32BE");
@Override
public void encodeUniversalString(final String str) {
writeElement(UNIVERSAL_STRING_TYPE, str.getBytes(UTF_32BE));
}
@Override
public void encodeBitString(byte[] str) {
encodeBitString(str, 0); // All bits will be used
}
@Override
public void encodeBitString(byte[] str, int numUnusedBits) {
byte[] contents = new byte[str.length + 1];
contents[0] = (byte) numUnusedBits;
System.arraycopy(str, 0, contents, 1, str.length);
writeElement(BIT_STRING_TYPE, contents);
}
@Override
public void encodeBitString(final EnumSet> enumSet) {
int ord;
final BitSet bitSet = new BitSet();
for (Enum> anEnum : enumSet) {
ord = anEnum.ordinal();
bitSet.set(ord);
}
encodeBitString(bitSet);
}
@Override
public void encodeBitString(final BitSet bitSet) {
final byte[] array = bitSet.toByteArray();
final int unusedBits = - bitSet.length() & 0b111;
for (int i = 0; i < array.length; i++) {
array[i] = (byte) (Integer.reverse(array[i]) >> 24);
}
encodeBitString(array, unusedBits);
}
@Override
public void encodeBitString(String binaryStr) {
int numBits = binaryStr.length();
int numBytes = numBits >> 3;
int remainder = numBits % 8;
int numUnusedBits = 0;
if (remainder != 0) {
numBytes = numBytes + 1;
numUnusedBits = 8 - remainder;
}
byte[] contents = new byte[numBytes + 1];
contents[0] = (byte) numUnusedBits;
for (int i = 1; i <= numBytes; i++) {
contents[i] = (byte) 0;
}
char[] binaryStrChars = binaryStr.toCharArray();
int index = 0;
for (int i = 1; i <= numBytes && index < numBits; i++) {
for (int bit = 7; bit >= 0 && index < numBits; bit--) {
if ((i == numBytes) && (bit < numUnusedBits)) {
continue;
}
if (binaryStrChars[index++] == '1') {
contents[i] |= BITS[bit];
}
}
}
writeElement(BIT_STRING_TYPE, contents);
}
@Override
public void encodeBitString(BigInteger integer) {
ByteStringBuilder target = new ByteStringBuilder();
new DEREncoder(target).encodeInteger(integer);
encodeBitString(target.toArray());
}
private static final DateTimeFormatter GENERALIZED_TIME_FORMAT = DateTimeFormatter.ofPattern("yyyyMMddHHmmssX");
@Override
public void encodeGeneralizedTime(final ZonedDateTime time) {
writeElement(GENERALIZED_TIME_TYPE, GENERALIZED_TIME_FORMAT.format(time).getBytes(StandardCharsets.UTF_8));
}
@Override
public void encodeObjectIdentifier(String objectIdentifier) throws ASN1Exception {
if (objectIdentifier == null || objectIdentifier.length() == 0) {
throw log.asnOidMustHaveAtLeast2Components();
}
int len = objectIdentifier.length();
int offs = 0;
int idx = 0;
long t = 0L;
char c;
int numComponents = 0;
int first = -1;
ByteStringBuilder contents = new ByteStringBuilder();
a: for (;;) {
c = objectIdentifier.charAt(offs + idx ++);
if (Character.isDigit(c)) {
int digit = Character.digit(c, 10);
if (t > LARGEST_UNSHIFTED_LONG) {
BigInteger bi = BigInteger.valueOf(t).multiply(BigInteger.TEN).add(digits[digit]);
t = 0L;
for (;;) {
c = objectIdentifier.charAt(offs + idx ++);
if (Character.isDigit(c)) {
digit = Character.digit(c, 10);
bi = bi.multiply(BigInteger.TEN).add(digits[digit]);
} else if (c == '.') {
if (numComponents == 0) {
first = validateFirstOIDComponent(bi);
} else {
encodeOIDComponent(bi, contents, numComponents, first);
}
numComponents++;
continue a;
} else {
throw log.asnInvalidOidCharacter();
}
if (idx == len) {
if (numComponents == 0) {
throw log.asnOidMustHaveAtLeast2Components();
}
encodeOIDComponent(bi, contents, numComponents, first);
writeElement(OBJECT_IDENTIFIER_TYPE, contents);
return;
}
}
} else {
t = 10L * t + (long) digit;
}
} else if (c == '.') {
if (numComponents == 0) {
first = validateFirstOIDComponent(t);
} else {
encodeOIDComponent(t, contents, numComponents, first);
}
numComponents++;
t = 0L;
} else {
throw log.asnInvalidOidCharacter();
}
if (idx == len) {
if (c == '.') {
throw log.asnInvalidOidCharacter();
}
if (numComponents == 0) {
throw log.asnOidMustHaveAtLeast2Components();
}
encodeOIDComponent(t, contents, numComponents, first);
writeElement(OBJECT_IDENTIFIER_TYPE, contents);
return;
}
}
}
@Override
public void encodeNull() {
writeElement(NULL_TYPE, NULL_CONTENTS);
}
@Override
public void encodeImplicit(int number) {
encodeImplicit(CONTEXT_SPECIFIC_MASK, number);
}
@Override
public void encodeImplicit(int clazz, int number) {
if (implicitTag == -1) {
implicitTag = clazz | number;
}
}
@Override
public void encodeBoolean(final boolean value) {
writeElement(BOOLEAN_TYPE, value ? BOOLEAN_TRUE_AS_BYTES : BOOLEAN_FALSE_AS_BYTES);
}
@Override
public void encodeInteger(BigInteger integer) {
writeElement(INTEGER_TYPE, integer.toByteArray());
}
@Override
public void writeEncoded(byte[] encoded) {
EncoderState lastState = states.peekLast();
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
updateCurrentBuffer();
lastState.addChildElement(encoded[0], currentBufferPos);
}
if (implicitTag != -1) {
writeTag(encoded[0], currentBuffer);
currentBuffer.append(encoded, 1, encoded.length - 1);
} else {
currentBuffer.append(encoded);
}
// If this element's parent element is a set element, update the parent's accumulated length
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
lastState.addChildLength(currentBuffer.length());
}
}
@Override
public void flush() {
while (states.size() != 0) {
EncoderState lastState = states.peekLast();
if (lastState.getTag() == SEQUENCE_TYPE) {
endSequence();
} else if (lastState.getTag() == SET_TYPE) {
endSet();
}
}
}
@Override
public byte[] getEncoded() {
return target.toArray();
}
private int validateFirstOIDComponent(long value) throws ASN1Exception {
if (value < 0 || value > 2) {
throw log.asnInvalidValueForFirstOidComponent();
}
return (int) value;
}
private int validateFirstOIDComponent(BigInteger value) throws ASN1Exception {
if ((value.compareTo(BigInteger.valueOf(0)) == -1)
|| (value.compareTo(BigInteger.valueOf(2)) == 1)) {
throw log.asnInvalidValueForFirstOidComponent();
}
return value.intValue();
}
private void validateSecondOIDComponent(long second, int first) throws ASN1Exception {
if ((first < 2) && (second > 39)) {
throw log.asnInvalidValueForSecondOidComponent();
}
}
private void validateSecondOIDComponent(BigInteger second, int first) throws ASN1Exception {
if ((first < 2) && (second.compareTo(BigInteger.valueOf(39)) == 1)) {
throw log.asnInvalidValueForSecondOidComponent();
}
}
private void encodeOIDComponent(long value, ByteStringBuilder contents,
int numComponents, int first) throws ASN1Exception {
if (numComponents == 1) {
validateSecondOIDComponent(value, first);
encodeOIDComponent(value + (40 * first), contents);
} else {
encodeOIDComponent(value, contents);
}
}
private void encodeOIDComponent(BigInteger value, ByteStringBuilder contents,
int numComponents, int first) throws ASN1Exception {
if (numComponents == 1) {
validateSecondOIDComponent(value, first);
encodeOIDComponent(value.add(BigInteger.valueOf(40 * first)), contents);
} else {
encodeOIDComponent(value, contents);
}
}
private void encodeOIDComponent(long value, ByteStringBuilder contents) {
int shift = 56;
int octet;
while (shift > 0) {
if (value >= (1L << shift)) {
octet = (int) ((value >>> shift) | 0x80);
contents.append((byte) octet);
}
shift = shift - 7;
}
octet = (int) (value & 0x7f);
contents.append((byte) octet);
}
private void encodeOIDComponent(BigInteger value, ByteStringBuilder contents) {
int numBytes = (value.bitLength() + 6) / 7;
if (numBytes == 0) {
contents.append((byte) 0);
} else {
byte[] result = new byte[numBytes];
BigInteger currValue = value;
for (int i = numBytes - 1; i >= 0; i--) {
result[i] = (byte) ((currValue.intValue() & 0x7f) | 0x80);
currValue = currValue.shiftRight(7);
}
result[numBytes - 1] &= 0x7f;
contents.append(result);
}
}
private static final BigInteger[] digits = {
BigInteger.ZERO,
BigInteger.ONE,
BigInteger.valueOf(2),
BigInteger.valueOf(3),
BigInteger.valueOf(4),
BigInteger.valueOf(5),
BigInteger.valueOf(6),
BigInteger.valueOf(7),
BigInteger.valueOf(8),
BigInteger.valueOf(9),
};
private void writeTag(int tag, ByteStringBuilder dest) {
int constructed = tag & CONSTRUCTED_MASK;
if (implicitTag != -1) {
tag = implicitTag | constructed;
implicitTag = -1;
}
int tagClass = tag & CLASS_MASK;
int tagNumber = tag & TAG_NUMBER_MASK;
if (tagNumber < 31) {
dest.append((byte) (tagClass | constructed | tagNumber));
} else {
// High-tag-number-form
dest.append((byte) (tagClass | constructed | 0x1f));
if (tagNumber < 128) {
dest.append((byte) tagNumber);
} else {
int shift = 28;
int octet;
while (shift > 0) {
if (tagNumber >= (1 << shift)) {
octet = (tagNumber >>> shift) | 0x80;
dest.append((byte) octet);
}
shift = shift - 7;
}
octet = tagNumber & 0x7f;
dest.append((byte) octet);
}
}
}
private int writeLength(int length, ByteStringBuilder dest) throws ASN1Exception {
int numLengthOctets;
if (length < 0) {
throw log.asnInvalidLength();
} else if (length <= 127) {
// Short form
numLengthOctets = 1;
} else {
// Long form
numLengthOctets = 1;
int value = length;
while ((value >>>= 8) != 0) {
numLengthOctets += 1;
}
}
if (length > 127) {
// bit 8 of the first octet has value "1" and bits 7-1 give the number of additional length octets
dest.append((byte) (numLengthOctets | 0x80));
}
for (int i = (numLengthOctets - 1) * 8; i >= 0; i -= 8) {
dest.append((byte) (length >> i));
}
if (length > 127) {
// include the first octet
return 1 + numLengthOctets;
} else {
return numLengthOctets;
}
}
private void updateCurrentBuffer() {
currentBufferPos += 1;
if (currentBufferPos < buffers.size()) {
currentBuffer = buffers.get(currentBufferPos);
} else {
ByteStringBuilder buffer = new ByteStringBuilder();
buffers.add(buffer);
currentBuffer = buffer;
}
}
private void writeElement(int tag, byte[] contents) {
EncoderState lastState = states.peekLast();
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
updateCurrentBuffer();
lastState.addChildElement(tag, currentBufferPos);
}
writeTag(tag, currentBuffer);
writeLength(contents.length, currentBuffer);
currentBuffer.append(contents);
// If this element's parent element is a set element, update the parent's accumulated length
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
lastState.addChildLength(currentBuffer.length());
}
}
private void writeElement(int tag, ByteStringBuilder contents) {
EncoderState lastState = states.peekLast();
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
updateCurrentBuffer();
lastState.addChildElement(tag, currentBufferPos);
}
writeTag(tag, currentBuffer);
writeLength(contents.length(), currentBuffer);
currentBuffer.append(contents);
// If this element's parent element is a set element, update the parent's accumulated length
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
lastState.addChildLength(currentBuffer.length());
}
}
/**
* A class used to maintain state information during DER encoding.
*/
private class EncoderState {
private final int tag;
private final int bufferPos;
private LinkedList childElements = new LinkedList();
private int childLength = 0;
public EncoderState(int tag, int bufferPos) {
this.tag = tag;
this.bufferPos = bufferPos;
}
public int getTag() {
return tag;
}
public int getBufferPos() {
return bufferPos;
}
public ByteStringBuilder getBuffer() {
return buffers.get(getBufferPos());
}
public int getChildLength() {
return childLength;
}
public LinkedList getSortedChildElements(Comparator comparator) {
Collections.sort(childElements, comparator);
return childElements;
}
public void addChildElement(int tag, int bufferPos) {
childElements.add(new EncoderState(tag, bufferPos));
}
public void addChildLength(int length) {
childLength += length;
}
}
/**
* A class that compares DER encodings based on their tags.
*/
private static class TagComparator implements Comparator {
@Override
public int compare(EncoderState state1, EncoderState state2) {
// Ignore the constructed bit when comparing tags
return (state1.getTag() | CONSTRUCTED_MASK) - (state2.getTag() | CONSTRUCTED_MASK);
}
}
/**
* A class that compares DER encodings using lexicographic order.
*/
private static class LexicographicComparator implements Comparator {
@Override
public int compare(EncoderState state1, EncoderState state2) {
ByteStringBuilder bytes1 = state1.getBuffer();
ByteStringBuilder bytes2 = state2.getBuffer();
ByteIterator bi1 = bytes1.iterate();
ByteIterator bi2 = bytes2.iterate();
// Scan the two encodings from left to right until a difference is found
int diff;
while (bi1.hasNext() && bi2.hasNext()) {
diff = (bi1.next() & 0xff) - (bi2.next() & 0xff);
if (diff != 0) {
return diff;
}
}
// The longer encoding is considered to be the bigger-valued encoding
return bytes1.length() - bytes2.length();
}
}
}