org.wildfly.security.asn1.DEREncoder Maven / Gradle / Ivy
Go to download
This artifact provides a single jar that contains all classes required to use remote EJB and JMS, including
all dependencies. It is intended for use by those not using maven, maven users should just import the EJB and
JMS BOM's instead (shaded JAR's cause lots of problems with maven, as it is very easy to inadvertently end up
with different versions on classes on the class path).
The newest version!
/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.wildfly.security.asn1;
import static org.wildfly.security.asn1.ElytronMessages.log;
import static org.wildfly.security.asn1.ASN1.*;
import java.math.BigInteger;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.LinkedList;
import org.wildfly.common.bytes.ByteStringBuilder;
import org.wildfly.common.iteration.ByteIterator;
/**
* A class used to encode ASN.1 values using the Distinguished Encoding Rules (DER), as specified
* in ITU-T X.690.
*
* @author Farah Juma
*/
public class DEREncoder implements ASN1Encoder {
private static final int[] BITS = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80};
private static final long LARGEST_UNSHIFTED_LONG = Long.MAX_VALUE / 10L;
private static final byte[] NULL_CONTENTS = new byte[0];
private static final TagComparator TAG_COMPARATOR = new TagComparator();
private static final LexicographicComparator LEXICOGRAPHIC_COMPARATOR = new LexicographicComparator();
private static final byte[] BOOLEAN_TRUE_AS_BYTES = new byte[] { ~0 };
private static final byte[] BOOLEAN_FALSE_AS_BYTES = new byte[] { 0 };
private final ArrayDeque states = new ArrayDeque();
private final ArrayList buffers = new ArrayList();
private ByteStringBuilder currentBuffer;
private int currentBufferPos = -1;
private final ByteStringBuilder target;
private int implicitTag = -1;
/**
* Create a DER encoder.
*/
public DEREncoder() {
this(new ByteStringBuilder());
}
/**
* Create a DER encoder that writes its output to the given {@code ByteStringBuilder}.
*
* @param target the {@code ByteStringBuilder} to which the DER encoded values are written
*/
DEREncoder(ByteStringBuilder target) {
this.target = target;
currentBuffer = target;
}
@Override
public void startSequence() {
startConstructedElement(SEQUENCE_TYPE);
}
@Override
public void startSet() {
startConstructedElement(SET_TYPE);
}
@Override
public void startSetOf() {
startSet();
}
@Override
public void startExplicit(int number) {
startExplicit(CONTEXT_SPECIFIC_MASK, number);
}
@Override
public void startExplicit(int clazz, int number) {
int explicitTag = clazz | CONSTRUCTED_MASK | number;
startConstructedElement(explicitTag);
}
private void startConstructedElement(int tag) {
EncoderState lastState = states.peekLast();
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
updateCurrentBuffer();
lastState.addChildElement(tag, currentBufferPos);
}
writeTag(tag, currentBuffer);
if (tag != SET_TYPE) {
updateCurrentBuffer();
}
states.add(new EncoderState(tag, currentBufferPos));
}
@Override
public void endSequence() throws IllegalStateException {
EncoderState lastState = states.peekLast();
if ((lastState == null) || (lastState.getTag() != SEQUENCE_TYPE)) {
throw log.noSequenceToEnd();
}
endConstructedElement();
}
@Override
public void endExplicit() throws IllegalStateException {
EncoderState lastState = states.peekLast();
if ((lastState == null) || (lastState.getTag() == SEQUENCE_TYPE)
|| (lastState.getTag() == SET_TYPE) || ((lastState.getTag() & CONSTRUCTED_MASK) == 0)) {
throw log.noExplicitlyTaggedElementToEnd();
}
endConstructedElement();
}
private void endConstructedElement() {
ByteStringBuilder dest;
if (currentBufferPos > 0) {
// Output the element to its parent buffer
dest = buffers.get(currentBufferPos - 1);
} else {
// Output the element directly to the target destination
dest = target;
}
int length = currentBuffer.length();
int numLengthOctets = writeLength(length, dest);
dest.append(currentBuffer);
currentBuffer.setLength(0);
currentBuffer = dest;
currentBufferPos -= 1;
states.removeLast();
// If this element's parent element is a set element, update the parent's accumulated length
EncoderState lastState = states.peekLast();
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
lastState.addChildLength(1 + numLengthOctets + length);
}
}
@Override
public void endSet() throws IllegalStateException {
endSet(TAG_COMPARATOR);
}
@Override
public void endSetOf() throws IllegalStateException {
endSet(LEXICOGRAPHIC_COMPARATOR);
}
private void endSet(Comparator comparator) {
EncoderState lastState = states.peekLast();
if ((lastState == null) || (lastState.getTag() != SET_TYPE)) {
throw log.noSetToEnd();
}
// The child elements of a set must be encoded in ascending order by tag
LinkedList childElements = lastState.getSortedChildElements(comparator);
int setBufferPos = lastState.getBufferPos();
ByteStringBuilder dest;
if (setBufferPos >= 0) {
dest = buffers.get(setBufferPos);
} else {
dest = target;
}
ByteStringBuilder contents;
int childLength = lastState.getChildLength();
int numLengthOctets = writeLength(lastState.getChildLength(), dest);
for (EncoderState element : childElements) {
contents = buffers.get(element.getBufferPos());
dest.append(contents);
contents.setLength(0);
}
currentBuffer = dest;
currentBufferPos = setBufferPos;
states.removeLast();
// If this set's parent element is a set element, update the parent's accumulated length
lastState = states.peekLast();
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
lastState.addChildLength(1 + numLengthOctets + childLength);
}
}
@Override
public void encodeOctetString(String str) {
encodeOctetString(str.getBytes(StandardCharsets.UTF_8));
}
@Override
public void encodeOctetString(byte[] str) {
writeElement(OCTET_STRING_TYPE, str);
}
void encodeOctetString(ByteStringBuilder str) {
writeElement(OCTET_STRING_TYPE, str);
}
@Override
public void encodeIA5String(String str) {
encodeIA5String(str.getBytes(StandardCharsets.US_ASCII));
}
@Override
public void encodeIA5String(byte[] str) {
writeElement(IA5_STRING_TYPE, str);
}
void encodeIA5String(ByteStringBuilder str) {
writeElement(IA5_STRING_TYPE, str);
}
@Override
public void encodePrintableString(final byte[] str) {
for (byte b : str) {
validatePrintableByte(b & 0xff);
}
writeElement(PRINTABLE_STRING_TYPE, str);
}
@Override
public void encodePrintableString(final String str) {
for (int i = 0; i < str.length(); i = str.offsetByCodePoints(i, 1)) {
validatePrintableByte(str.codePointAt(i));
}
writeElement(PRINTABLE_STRING_TYPE, str.getBytes(StandardCharsets.US_ASCII));
}
@Override
public void encodeUTF8String(final String str) {
writeElement(UTF8_STRING_TYPE, str.getBytes(StandardCharsets.UTF_8));
}
@Override
public void encodeBMPString(final String str) {
// technically this may fail if str contains a code point outside of the BMP
writeElement(BMP_STRING_TYPE, str.getBytes(StandardCharsets.UTF_16BE));
}
private static final Charset UTF_32BE = Charset.forName("UTF-32BE");
@Override
public void encodeUniversalString(final String str) {
writeElement(UNIVERSAL_STRING_TYPE, str.getBytes(UTF_32BE));
}
@Override
public void encodeBitString(byte[] str) {
encodeBitString(str, 0); // All bits will be used
}
@Override
public void encodeBitString(byte[] str, int numUnusedBits) {
byte[] contents = new byte[str.length + 1];
contents[0] = (byte) numUnusedBits;
System.arraycopy(str, 0, contents, 1, str.length);
writeElement(BIT_STRING_TYPE, contents);
}
@Override
public void encodeBitString(final EnumSet> enumSet) {
int ord;
final BitSet bitSet = new BitSet();
for (Enum> anEnum : enumSet) {
ord = anEnum.ordinal();
bitSet.set(ord);
}
encodeBitString(bitSet);
}
@Override
public void encodeBitString(final BitSet bitSet) {
final byte[] array = bitSet.toByteArray();
final int unusedBits = - bitSet.length() & 0b111;
for (int i = 0; i < array.length; i++) {
array[i] = (byte) (Integer.reverse(array[i]) >> 24);
}
encodeBitString(array, unusedBits);
}
@Override
public void encodeBitString(String binaryStr) {
int numBits = binaryStr.length();
int numBytes = numBits >> 3;
int remainder = numBits % 8;
int numUnusedBits = 0;
if (remainder != 0) {
numBytes = numBytes + 1;
numUnusedBits = 8 - remainder;
}
byte[] contents = new byte[numBytes + 1];
contents[0] = (byte) numUnusedBits;
for (int i = 1; i <= numBytes; i++) {
contents[i] = (byte) 0;
}
char[] binaryStrChars = binaryStr.toCharArray();
int index = 0;
for (int i = 1; i <= numBytes && index < numBits; i++) {
for (int bit = 7; bit >= 0 && index < numBits; bit--) {
if ((i == numBytes) && (bit < numUnusedBits)) {
continue;
}
if (binaryStrChars[index++] == '1') {
contents[i] |= BITS[bit];
}
}
}
writeElement(BIT_STRING_TYPE, contents);
}
@Override
public void encodeBitString(BigInteger integer) {
ByteStringBuilder target = new ByteStringBuilder();
new DEREncoder(target).encodeInteger(integer);
encodeBitString(target.toArray());
}
private static final DateTimeFormatter GENERALIZED_TIME_FORMAT = DateTimeFormatter.ofPattern("yyyyMMddHHmmssX");
@Override
public void encodeGeneralizedTime(final ZonedDateTime time) {
writeElement(GENERALIZED_TIME_TYPE, GENERALIZED_TIME_FORMAT.format(time).getBytes(StandardCharsets.UTF_8));
}
@Override
public void encodeObjectIdentifier(String objectIdentifier) throws ASN1Exception {
if (objectIdentifier == null || objectIdentifier.length() == 0) {
throw log.asnOidMustHaveAtLeast2Components();
}
int len = objectIdentifier.length();
int offs = 0;
int idx = 0;
long t = 0L;
char c;
int numComponents = 0;
int first = -1;
ByteStringBuilder contents = new ByteStringBuilder();
a: for (;;) {
c = objectIdentifier.charAt(offs + idx ++);
if (Character.isDigit(c)) {
int digit = Character.digit(c, 10);
if (t > LARGEST_UNSHIFTED_LONG) {
BigInteger bi = BigInteger.valueOf(t).multiply(BigInteger.TEN).add(digits[digit]);
t = 0L;
for (;;) {
c = objectIdentifier.charAt(offs + idx ++);
if (Character.isDigit(c)) {
digit = Character.digit(c, 10);
bi = bi.multiply(BigInteger.TEN).add(digits[digit]);
} else if (c == '.') {
if (numComponents == 0) {
first = validateFirstOIDComponent(bi);
} else {
encodeOIDComponent(bi, contents, numComponents, first);
}
numComponents++;
continue a;
} else {
throw log.asnInvalidOidCharacter();
}
if (idx == len) {
if (numComponents == 0) {
throw log.asnOidMustHaveAtLeast2Components();
}
encodeOIDComponent(bi, contents, numComponents, first);
writeElement(OBJECT_IDENTIFIER_TYPE, contents);
return;
}
}
} else {
t = 10L * t + (long) digit;
}
} else if (c == '.') {
if (numComponents == 0) {
first = validateFirstOIDComponent(t);
} else {
encodeOIDComponent(t, contents, numComponents, first);
}
numComponents++;
t = 0L;
} else {
throw log.asnInvalidOidCharacter();
}
if (idx == len) {
if (c == '.') {
throw log.asnInvalidOidCharacter();
}
if (numComponents == 0) {
throw log.asnOidMustHaveAtLeast2Components();
}
encodeOIDComponent(t, contents, numComponents, first);
writeElement(OBJECT_IDENTIFIER_TYPE, contents);
return;
}
}
}
@Override
public void encodeNull() {
writeElement(NULL_TYPE, NULL_CONTENTS);
}
@Override
public void encodeImplicit(int number) {
encodeImplicit(CONTEXT_SPECIFIC_MASK, number);
}
@Override
public void encodeImplicit(int clazz, int number) {
if (implicitTag == -1) {
implicitTag = clazz | number;
}
}
@Override
public void encodeBoolean(final boolean value) {
writeElement(BOOLEAN_TYPE, value ? BOOLEAN_TRUE_AS_BYTES : BOOLEAN_FALSE_AS_BYTES);
}
@Override
public void encodeInteger(BigInteger integer) {
writeElement(INTEGER_TYPE, integer.toByteArray());
}
@Override
public void writeEncoded(byte[] encoded) {
EncoderState lastState = states.peekLast();
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
updateCurrentBuffer();
lastState.addChildElement(encoded[0], currentBufferPos);
}
if (implicitTag != -1) {
writeTag(encoded[0], currentBuffer);
currentBuffer.append(encoded, 1, encoded.length - 1);
} else {
currentBuffer.append(encoded);
}
// If this element's parent element is a set element, update the parent's accumulated length
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
lastState.addChildLength(currentBuffer.length());
}
}
@Override
public void flush() {
while (states.size() != 0) {
EncoderState lastState = states.peekLast();
if (lastState.getTag() == SEQUENCE_TYPE) {
endSequence();
} else if (lastState.getTag() == SET_TYPE) {
endSet();
}
}
}
@Override
public byte[] getEncoded() {
return target.toArray();
}
private int validateFirstOIDComponent(long value) throws ASN1Exception {
if (value < 0 || value > 2) {
throw log.asnInvalidValueForFirstOidComponent();
}
return (int) value;
}
private int validateFirstOIDComponent(BigInteger value) throws ASN1Exception {
if ((value.compareTo(BigInteger.valueOf(0)) == -1)
|| (value.compareTo(BigInteger.valueOf(2)) == 1)) {
throw log.asnInvalidValueForFirstOidComponent();
}
return value.intValue();
}
private void validateSecondOIDComponent(long second, int first) throws ASN1Exception {
if ((first < 2) && (second > 39)) {
throw log.asnInvalidValueForSecondOidComponent();
}
}
private void validateSecondOIDComponent(BigInteger second, int first) throws ASN1Exception {
if ((first < 2) && (second.compareTo(BigInteger.valueOf(39)) == 1)) {
throw log.asnInvalidValueForSecondOidComponent();
}
}
private void encodeOIDComponent(long value, ByteStringBuilder contents,
int numComponents, int first) throws ASN1Exception {
if (numComponents == 1) {
validateSecondOIDComponent(value, first);
encodeOIDComponent(value + (40 * first), contents);
} else {
encodeOIDComponent(value, contents);
}
}
private void encodeOIDComponent(BigInteger value, ByteStringBuilder contents,
int numComponents, int first) throws ASN1Exception {
if (numComponents == 1) {
validateSecondOIDComponent(value, first);
encodeOIDComponent(value.add(BigInteger.valueOf(40 * first)), contents);
} else {
encodeOIDComponent(value, contents);
}
}
private void encodeOIDComponent(long value, ByteStringBuilder contents) {
int shift = 56;
int octet;
while (shift > 0) {
if (value >= (1L << shift)) {
octet = (int) ((value >>> shift) | 0x80);
contents.append((byte) octet);
}
shift = shift - 7;
}
octet = (int) (value & 0x7f);
contents.append((byte) octet);
}
private void encodeOIDComponent(BigInteger value, ByteStringBuilder contents) {
int numBytes = (value.bitLength() + 6) / 7;
if (numBytes == 0) {
contents.append((byte) 0);
} else {
byte[] result = new byte[numBytes];
BigInteger currValue = value;
for (int i = numBytes - 1; i >= 0; i--) {
result[i] = (byte) ((currValue.intValue() & 0x7f) | 0x80);
currValue = currValue.shiftRight(7);
}
result[numBytes - 1] &= 0x7f;
contents.append(result);
}
}
private static final BigInteger[] digits = {
BigInteger.ZERO,
BigInteger.ONE,
BigInteger.valueOf(2),
BigInteger.valueOf(3),
BigInteger.valueOf(4),
BigInteger.valueOf(5),
BigInteger.valueOf(6),
BigInteger.valueOf(7),
BigInteger.valueOf(8),
BigInteger.valueOf(9),
};
private void writeTag(int tag, ByteStringBuilder dest) {
int constructed = tag & CONSTRUCTED_MASK;
if (implicitTag != -1) {
tag = implicitTag | constructed;
implicitTag = -1;
}
int tagClass = tag & CLASS_MASK;
int tagNumber = tag & TAG_NUMBER_MASK;
if (tagNumber < 31) {
dest.append((byte) (tagClass | constructed | tagNumber));
} else {
// High-tag-number-form
dest.append((byte) (tagClass | constructed | 0x1f));
if (tagNumber < 128) {
dest.append((byte) tagNumber);
} else {
int shift = 28;
int octet;
while (shift > 0) {
if (tagNumber >= (1 << shift)) {
octet = (tagNumber >>> shift) | 0x80;
dest.append((byte) octet);
}
shift = shift - 7;
}
octet = tagNumber & 0x7f;
dest.append((byte) octet);
}
}
}
private int writeLength(int length, ByteStringBuilder dest) throws ASN1Exception {
int numLengthOctets;
if (length < 0) {
throw log.asnInvalidLength();
} else if (length <= 127) {
// Short form
numLengthOctets = 1;
} else {
// Long form
numLengthOctets = 1;
int value = length;
while ((value >>>= 8) != 0) {
numLengthOctets += 1;
}
}
if (length > 127) {
// bit 8 of the first octet has value "1" and bits 7-1 give the number of additional length octets
dest.append((byte) (numLengthOctets | 0x80));
}
for (int i = (numLengthOctets - 1) * 8; i >= 0; i -= 8) {
dest.append((byte) (length >> i));
}
if (length > 127) {
// include the first octet
return 1 + numLengthOctets;
} else {
return numLengthOctets;
}
}
private void updateCurrentBuffer() {
currentBufferPos += 1;
if (currentBufferPos < buffers.size()) {
currentBuffer = buffers.get(currentBufferPos);
} else {
ByteStringBuilder buffer = new ByteStringBuilder();
buffers.add(buffer);
currentBuffer = buffer;
}
}
private void writeElement(int tag, byte[] contents) {
EncoderState lastState = states.peekLast();
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
updateCurrentBuffer();
lastState.addChildElement(tag, currentBufferPos);
}
writeTag(tag, currentBuffer);
writeLength(contents.length, currentBuffer);
currentBuffer.append(contents);
// If this element's parent element is a set element, update the parent's accumulated length
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
lastState.addChildLength(currentBuffer.length());
}
}
private void writeElement(int tag, ByteStringBuilder contents) {
EncoderState lastState = states.peekLast();
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
updateCurrentBuffer();
lastState.addChildElement(tag, currentBufferPos);
}
writeTag(tag, currentBuffer);
writeLength(contents.length(), currentBuffer);
currentBuffer.append(contents);
// If this element's parent element is a set element, update the parent's accumulated length
if ((lastState != null) && (lastState.getTag() == SET_TYPE)) {
lastState.addChildLength(currentBuffer.length());
}
}
/**
* A class used to maintain state information during DER encoding.
*/
private class EncoderState {
private final int tag;
private final int bufferPos;
private LinkedList childElements = new LinkedList();
private int childLength = 0;
public EncoderState(int tag, int bufferPos) {
this.tag = tag;
this.bufferPos = bufferPos;
}
public int getTag() {
return tag;
}
public int getBufferPos() {
return bufferPos;
}
public ByteStringBuilder getBuffer() {
return buffers.get(getBufferPos());
}
public int getChildLength() {
return childLength;
}
public LinkedList getSortedChildElements(Comparator comparator) {
Collections.sort(childElements, comparator);
return childElements;
}
public void addChildElement(int tag, int bufferPos) {
childElements.add(new EncoderState(tag, bufferPos));
}
public void addChildLength(int length) {
childLength += length;
}
}
/**
* A class that compares DER encodings based on their tags.
*/
private static class TagComparator implements Comparator {
@Override
public int compare(EncoderState state1, EncoderState state2) {
// Ignore the constructed bit when comparing tags
return (state1.getTag() | CONSTRUCTED_MASK) - (state2.getTag() | CONSTRUCTED_MASK);
}
}
/**
* A class that compares DER encodings using lexicographic order.
*/
private static class LexicographicComparator implements Comparator {
@Override
public int compare(EncoderState state1, EncoderState state2) {
ByteStringBuilder bytes1 = state1.getBuffer();
ByteStringBuilder bytes2 = state2.getBuffer();
ByteIterator bi1 = bytes1.iterate();
ByteIterator bi2 = bytes2.iterate();
// Scan the two encodings from left to right until a difference is found
int diff;
while (bi1.hasNext() && bi2.hasNext()) {
diff = (bi1.next() & 0xff) - (bi2.next() & 0xff);
if (diff != 0) {
return diff;
}
}
// The longer encoding is considered to be the bigger-valued encoding
return bytes1.length() - bytes2.length();
}
}
}