org.wildfly.security.sasl.SaslMechanismPredicate Maven / Gradle / Ivy
The newest version!
/*
* JBoss, Home of Professional Open Source.
* Copyright 2017 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.wildfly.security.sasl;
import static org.wildfly.common.math.HashMath.*;
import java.util.Arrays;
import java.util.function.Predicate;
import javax.net.ssl.SSLSession;
import org.wildfly.common.Assert;
import org.wildfly.security.sasl.util.SaslMechanismInformation;
/**
* @author David M. Lloyd
*/
public abstract class SaslMechanismPredicate {
private int hashCode;
SaslMechanismPredicate() {
}
abstract boolean test(String mechName, final SSLSession sslSession);
public String toString() {
StringBuilder b = new StringBuilder();
toString(b);
return b.toString();
}
abstract void toString(StringBuilder b);
public static SaslMechanismPredicate matchTrue() {
return TRUE;
}
public static SaslMechanismPredicate matchFalse() {
return FALSE;
}
public static SaslMechanismPredicate matchAll(SaslMechanismPredicate... predicates) {
return new AllPredicate(predicates);
}
public static SaslMechanismPredicate matchAllOrNone(SaslMechanismPredicate... predicates) {
return new AllOrNonePredicate(predicates);
}
public static SaslMechanismPredicate matchAny(SaslMechanismPredicate... predicates) {
return new AnyPredicate(predicates);
}
public static SaslMechanismPredicate matchNot(SaslMechanismPredicate predicate) {
Assert.checkNotNullParam("predicate", predicate);
return predicate.not();
}
public static SaslMechanismPredicate matchIf(SaslMechanismPredicate conditionPredicate, SaslMechanismPredicate truePredicate, SaslMechanismPredicate falsePredicate) {
Assert.checkNotNullParam("conditionPredicate", conditionPredicate);
Assert.checkNotNullParam("truePredicate", truePredicate);
Assert.checkNotNullParam("falsePredicate", falsePredicate);
return new IfPredicate(conditionPredicate, truePredicate, falsePredicate);
}
public static SaslMechanismPredicate matchExact(String name) {
Assert.checkNotNullParam("name", name);
return new ExactPredicate(name);
}
public static SaslMechanismPredicate matchHashFunction(String digest) {
Assert.checkNotNullParam("digest", digest);
return new HashPredicate(digest);
}
public static SaslMechanismPredicate matchPlus() {
return PLUS;
}
public static SaslMechanismPredicate matchMutual() {
return MUTUAL;
}
public static SaslMechanismPredicate matchFamily(String name) {
Assert.checkNotNullParam("name", name);
final Predicate predicate;
switch (name) {
case "DIGEST": predicate = SaslMechanismInformation.DIGEST; break;
case "EAP": predicate = SaslMechanismInformation.EAP; break;
case "GS2": predicate = SaslMechanismInformation.GS2; break;
case "SCRAM": predicate = SaslMechanismInformation.SCRAM; break;
case "IEC-ISO-9798": predicate = SaslMechanismInformation.IEC_ISO_9798; break;
default: predicate = s -> false; break;
}
return new FamilyPredicate(predicate, name);
}
public static SaslMechanismPredicate matchTLSActive() {
return TLS_ACTIVE;
}
public final boolean equals(final Object obj) {
return obj instanceof SaslMechanismPredicate && equals((SaslMechanismPredicate) obj);
}
public abstract boolean equals(SaslMechanismPredicate other);
public final int hashCode() {
int hashCode = this.hashCode;
if (hashCode == 0) {
hashCode = calcHashCode();
if (hashCode == 0) {
hashCode = 1;
}
return this.hashCode = hashCode;
}
return hashCode;
}
abstract int calcHashCode();
SaslMechanismPredicate not() {
return new NotPredicate(this);
}
static final BooleanPredicate TRUE = new BooleanPredicate(true);
static final BooleanPredicate FALSE = new BooleanPredicate(false);
static final SaslMechanismPredicate TLS_ACTIVE = new SaslMechanismPredicate() {
boolean test(final String mechName, final SSLSession sslSession) {
return sslSession != null;
}
void toString(final StringBuilder b) {
b.append("#TLS");
}
@SuppressWarnings("checkstyle:equalshashcode")
public boolean equals(final SaslMechanismPredicate other) {
return this == other;
}
int calcHashCode() {
return getClass().hashCode();
}
};
static final SaslMechanismPredicate PLUS = new SaslMechanismPredicate() {
boolean test(final String mechName, final SSLSession sslSession) {
return sslSession != null && SaslMechanismInformation.BINDING.test(mechName);
}
void toString(final StringBuilder b) {
b.append("#PLUS");
}
@SuppressWarnings("checkstyle:equalshashcode")
public boolean equals(final SaslMechanismPredicate other) {
return this == other;
}
int calcHashCode() {
return getClass().hashCode();
}
};
static final SaslMechanismPredicate MUTUAL = new SaslMechanismPredicate() {
boolean test(final String mechName, final SSLSession sslSession) {
return SaslMechanismInformation.MUTUAL.test(mechName);
}
void toString(final StringBuilder b) {
b.append("#MUTUAL");
}
@SuppressWarnings("checkstyle:equalshashcode")
public boolean equals(final SaslMechanismPredicate other) {
return this == other;
}
int calcHashCode() {
return getClass().hashCode();
}
};
static final class BooleanPredicate extends SaslMechanismPredicate {
private final boolean value;
BooleanPredicate(final boolean value) {
this.value = value;
}
boolean test(final String mechName, final SSLSession sslSession) {
return value;
}
void toString(final StringBuilder b) {
b.append(value);
}
SaslMechanismPredicate not() {
return value ? FALSE : TRUE;
}
@SuppressWarnings("checkstyle:equalshashcode")
public boolean equals(final SaslMechanismPredicate other) {
return this == other;
}
int calcHashCode() {
return getClass().hashCode() * 19 + (value ? 1 : 0);
}
}
abstract static class MultiPredicate extends SaslMechanismPredicate {
final SaslMechanismPredicate[] predicates;
MultiPredicate(final SaslMechanismPredicate[] predicates) {
for (int i = 0; i < predicates.length; i++) {
SaslMechanismPredicate predicate = predicates[i];
Assert.checkNotNullArrayParam("predicates", i, predicate);
}
this.predicates = predicates;
}
void toString(final StringBuilder b) {
b.append('(');
final int length = predicates.length;
if (length > 0) {
b.append(predicates[0]);
for (int i = 1; i < length; i++) {
appendOperator(b);
b.append(predicates[i]);
}
}
b.append(')');
}
@SuppressWarnings("checkstyle:equalshashcode")
public boolean equals(final SaslMechanismPredicate other) {
return other.getClass() == getClass() && Arrays.equals(predicates, ((MultiPredicate)other).predicates);
}
int calcHashCode() {
int hc = getClass().hashCode() * 19;
for (SaslMechanismPredicate predicate : predicates) {
hc = multiHashUnordered(hc, predicate.calcHashCode());
}
return hc;
}
abstract void appendOperator(final StringBuilder b);
}
static final class AllPredicate extends MultiPredicate {
AllPredicate(final SaslMechanismPredicate[] predicates) {
super(predicates);
}
boolean test(final String mechName, final SSLSession sslSession) {
for (SaslMechanismPredicate predicate : predicates) {
if (! predicate.test(mechName, sslSession)) {
return false;
}
}
return true;
}
void appendOperator(final StringBuilder b) {
b.append('&').append('&');
}
}
static final class AllOrNonePredicate extends MultiPredicate {
AllOrNonePredicate(final SaslMechanismPredicate[] predicates) {
super(predicates);
}
boolean test(final String mechName, final SSLSession sslSession) {
final int length = predicates.length;
if (length == 0) {
return true;
}
boolean val = predicates[0].test(mechName, sslSession);
for (int i = 1; i < length; i++) {
final SaslMechanismPredicate predicate = predicates[i];
if (val != predicate.test(mechName, sslSession)) {
return false;
}
}
return true;
}
void appendOperator(final StringBuilder b) {
b.append('=').append('=');
}
}
static final class AnyPredicate extends MultiPredicate {
AnyPredicate(final SaslMechanismPredicate[] predicates) {
super(predicates);
}
boolean test(final String mechName, final SSLSession sslSession) {
for (SaslMechanismPredicate predicate : predicates) {
if (predicate.test(mechName, sslSession)) {
return true;
}
}
return false;
}
void appendOperator(final StringBuilder b) {
b.append('|').append('|');
}
}
static class NotPredicate extends SaslMechanismPredicate {
private final SaslMechanismPredicate predicate;
NotPredicate(final SaslMechanismPredicate predicate) {
this.predicate = predicate;
}
boolean test(final String mechName, final SSLSession sslSession) {
return ! predicate.test(mechName, sslSession);
}
void toString(final StringBuilder b) {
b.append('!');
predicate.toString(b);
}
@SuppressWarnings("checkstyle:equalshashcode")
public boolean equals(final SaslMechanismPredicate other) {
return other instanceof NotPredicate && predicate.equals(((NotPredicate) other).predicate);
}
int calcHashCode() {
return getClass().hashCode() * 19 + predicate.calcHashCode();
}
SaslMechanismPredicate not() {
return predicate;
}
}
static class ExactPredicate extends SaslMechanismPredicate {
private final String mechName;
ExactPredicate(final String mechName) {
this.mechName = mechName;
}
boolean test(final String mechName, final SSLSession sslSession) {
return this.mechName.equals(mechName);
}
void toString(final StringBuilder b) {
b.append(mechName);
}
@SuppressWarnings("checkstyle:equalshashcode")
public boolean equals(final SaslMechanismPredicate other) {
return other instanceof ExactPredicate && this.mechName.equals(((ExactPredicate) other).mechName);
}
int calcHashCode() {
return getClass().hashCode() * 19 + mechName.hashCode();
}
}
static class FamilyPredicate extends SaslMechanismPredicate {
private final Predicate predicate;
private final String name;
FamilyPredicate(final Predicate predicate, final String name) {
this.predicate = predicate;
this.name = name;
}
boolean test(final String mechName, final SSLSession sslSession) {
return predicate.test(mechName);
}
void toString(final StringBuilder b) {
b.append("#FAMILY(").append(name).append(')');
}
@SuppressWarnings("checkstyle:equalshashcode")
public boolean equals(final SaslMechanismPredicate other) {
return other instanceof FamilyPredicate && equals((FamilyPredicate) other);
}
@SuppressWarnings("checkstyle:equalshashcode")
private boolean equals(final FamilyPredicate other) {
return predicate.equals(other.predicate) && name.equals(other.name);
}
int calcHashCode() {
return multiHashOrdered(multiHashOrdered(getClass().hashCode(), predicate.hashCode()), name.hashCode());
}
}
static class IfPredicate extends SaslMechanismPredicate {
private final SaslMechanismPredicate conditionPredicate;
private final SaslMechanismPredicate truePredicate;
private final SaslMechanismPredicate falsePredicate;
IfPredicate(final SaslMechanismPredicate conditionPredicate, final SaslMechanismPredicate truePredicate, final SaslMechanismPredicate falsePredicate) {
this.conditionPredicate = conditionPredicate;
this.truePredicate = truePredicate;
this.falsePredicate = falsePredicate;
}
boolean test(final String mechName, final SSLSession sslSession) {
return conditionPredicate.test(mechName, sslSession) ? truePredicate.test(mechName, sslSession) : falsePredicate.test(mechName, sslSession);
}
void toString(final StringBuilder b) {
b.append('(').append(conditionPredicate).append('?').append(truePredicate).append(':').append(falsePredicate).append(')');
}
@SuppressWarnings("checkstyle:equalshashcode")
public boolean equals(final SaslMechanismPredicate other) {
return this == other || other instanceof IfPredicate && equals((IfPredicate) other);
}
@SuppressWarnings("checkstyle:equalshashcode")
private boolean equals(final IfPredicate other) {
return conditionPredicate.equals(other.conditionPredicate) && truePredicate.equals(other.truePredicate) && falsePredicate.equals(other.falsePredicate);
}
int calcHashCode() {
return multiHashOrdered(multiHashOrdered(multiHashOrdered(getClass().hashCode(), conditionPredicate.hashCode()), truePredicate.hashCode()), falsePredicate.hashCode());
}
}
static class HashPredicate extends SaslMechanismPredicate {
private final String digest;
HashPredicate(final String digest) {
this.digest = digest;
}
boolean test(final String mechName, final SSLSession sslSession) {
switch (digest) {
case "MD5": return SaslMechanismInformation.HASH_MD5.test(mechName);
case "SHA-1": return SaslMechanismInformation.HASH_SHA.test(mechName);
case "SHA-256": return SaslMechanismInformation.HASH_SHA_256.test(mechName);
case "SHA-384": return SaslMechanismInformation.HASH_SHA_384.test(mechName);
case "SHA-512": return SaslMechanismInformation.HASH_SHA_512.test(mechName);
case "SHA-512-256": return SaslMechanismInformation.HASH_SHA_512_256.test(mechName);
default: return false;
}
}
void toString(final StringBuilder b) {
b.append("#HASH(").append(digest).append(')');
}
@SuppressWarnings("checkstyle:equalshashcode")
public boolean equals(final SaslMechanismPredicate other) {
return other instanceof HashPredicate && digest.equals(((HashPredicate) other).digest);
}
int calcHashCode() {
return 0;
}
}
}