All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wildfly.security.sasl.SaslMechanismPredicate Maven / Gradle / Ivy

/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2017 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.wildfly.security.sasl;

import static org.wildfly.common.math.HashMath.*;

import java.util.Arrays;
import java.util.function.Predicate;

import javax.net.ssl.SSLSession;

import org.wildfly.common.Assert;
import org.wildfly.security.sasl.util.SaslMechanismInformation;

/**
 * @author David M. Lloyd
 */
public abstract class SaslMechanismPredicate {
    private int hashCode;

    SaslMechanismPredicate() {
    }

    abstract boolean test(String mechName, final SSLSession sslSession);

    public String toString() {
        StringBuilder b = new StringBuilder();
        toString(b);
        return b.toString();
    }

    abstract void toString(StringBuilder b);

    public static SaslMechanismPredicate matchTrue() {
        return TRUE;
    }

    public static SaslMechanismPredicate matchFalse() {
        return FALSE;
    }

    public static SaslMechanismPredicate matchAll(SaslMechanismPredicate... predicates) {
        return new AllPredicate(predicates);
    }

    public static SaslMechanismPredicate matchAllOrNone(SaslMechanismPredicate... predicates) {
        return new AllOrNonePredicate(predicates);
    }

    public static SaslMechanismPredicate matchAny(SaslMechanismPredicate... predicates) {
        return new AnyPredicate(predicates);
    }

    public static SaslMechanismPredicate matchNot(SaslMechanismPredicate predicate) {
        Assert.checkNotNullParam("predicate", predicate);
        return predicate.not();
    }

    public static SaslMechanismPredicate matchIf(SaslMechanismPredicate conditionPredicate, SaslMechanismPredicate truePredicate, SaslMechanismPredicate falsePredicate) {
        Assert.checkNotNullParam("conditionPredicate", conditionPredicate);
        Assert.checkNotNullParam("truePredicate", truePredicate);
        Assert.checkNotNullParam("falsePredicate", falsePredicate);
        return new IfPredicate(conditionPredicate, truePredicate, falsePredicate);
    }

    public static SaslMechanismPredicate matchExact(String name) {
        Assert.checkNotNullParam("name", name);
        return new ExactPredicate(name);
    }

    public static SaslMechanismPredicate matchHashFunction(String digest) {
        Assert.checkNotNullParam("digest", digest);
        return new HashPredicate(digest);
    }

    public static SaslMechanismPredicate matchPlus() {
        return PLUS;
    }

    public static SaslMechanismPredicate matchMutual() {
        return MUTUAL;
    }

    public static SaslMechanismPredicate matchFamily(String name) {
        Assert.checkNotNullParam("name", name);
        final Predicate predicate;
        switch (name) {
            case "DIGEST": predicate = SaslMechanismInformation.DIGEST; break;
            case "EAP": predicate = SaslMechanismInformation.EAP; break;
            case "GS2": predicate = SaslMechanismInformation.GS2; break;
            case "SCRAM": predicate = SaslMechanismInformation.SCRAM; break;
            case "IEC-ISO-9798": predicate = SaslMechanismInformation.IEC_ISO_9798; break;
            default: predicate = s -> false; break;
        }
        return new FamilyPredicate(predicate, name);
    }

    public static SaslMechanismPredicate matchTLSActive() {
        return TLS_ACTIVE;
    }

    public final boolean equals(final Object obj) {
        return obj instanceof SaslMechanismPredicate && equals((SaslMechanismPredicate) obj);
    }

    public abstract boolean equals(SaslMechanismPredicate other);

    public final int hashCode() {
        int hashCode = this.hashCode;
        if (hashCode == 0) {
            hashCode = calcHashCode();
            if (hashCode == 0) {
                hashCode = 1;
            }
            return this.hashCode = hashCode;
        }
        return hashCode;
    }

    abstract int calcHashCode();

    SaslMechanismPredicate not() {
        return new NotPredicate(this);
    }

    static final BooleanPredicate TRUE = new BooleanPredicate(true);
    static final BooleanPredicate FALSE = new BooleanPredicate(false);

    static final SaslMechanismPredicate TLS_ACTIVE = new SaslMechanismPredicate() {
        boolean test(final String mechName, final SSLSession sslSession) {
            return sslSession != null;
        }

        void toString(final StringBuilder b) {
            b.append("#TLS");
        }

        @SuppressWarnings("checkstyle:equalshashcode")
        public boolean equals(final SaslMechanismPredicate other) {
            return this == other;
        }

        int calcHashCode() {
            return getClass().hashCode();
        }
    };

    static final SaslMechanismPredicate PLUS = new SaslMechanismPredicate() {
        boolean test(final String mechName, final SSLSession sslSession) {
            return sslSession != null && SaslMechanismInformation.BINDING.test(mechName);
        }

        void toString(final StringBuilder b) {
            b.append("#PLUS");
        }

        @SuppressWarnings("checkstyle:equalshashcode")
        public boolean equals(final SaslMechanismPredicate other) {
            return this == other;
        }

        int calcHashCode() {
            return getClass().hashCode();
        }
    };

    static final SaslMechanismPredicate MUTUAL = new SaslMechanismPredicate() {
        boolean test(final String mechName, final SSLSession sslSession) {
            return SaslMechanismInformation.MUTUAL.test(mechName);
        }

        void toString(final StringBuilder b) {
            b.append("#MUTUAL");
        }

        @SuppressWarnings("checkstyle:equalshashcode")
        public boolean equals(final SaslMechanismPredicate other) {
            return this == other;
        }

        int calcHashCode() {
            return getClass().hashCode();
        }
    };

    static final class BooleanPredicate extends SaslMechanismPredicate {
        private final boolean value;

        BooleanPredicate(final boolean value) {
            this.value = value;
        }

        boolean test(final String mechName, final SSLSession sslSession) {
            return value;
        }

        void toString(final StringBuilder b) {
            b.append(value);
        }

        SaslMechanismPredicate not() {
            return value ? FALSE : TRUE;
        }

        @SuppressWarnings("checkstyle:equalshashcode")
        public boolean equals(final SaslMechanismPredicate other) {
            return this == other;
        }

        int calcHashCode() {
            return getClass().hashCode() * 19 + (value ? 1 : 0);
        }
    }

    abstract static class MultiPredicate extends SaslMechanismPredicate {
        final SaslMechanismPredicate[] predicates;

        MultiPredicate(final SaslMechanismPredicate[] predicates) {
            for (int i = 0; i < predicates.length; i++) {
                SaslMechanismPredicate predicate = predicates[i];
                Assert.checkNotNullArrayParam("predicates", i, predicate);
            }
            this.predicates = predicates;
        }

        void toString(final StringBuilder b) {
            b.append('(');
            final int length = predicates.length;
            if (length > 0) {
                b.append(predicates[0]);
                for (int i = 1; i < length; i++) {
                    appendOperator(b);
                    b.append(predicates[i]);
                }
            }
            b.append(')');
        }

        @SuppressWarnings("checkstyle:equalshashcode")
        public boolean equals(final SaslMechanismPredicate other) {
            return other.getClass() == getClass() && Arrays.equals(predicates, ((MultiPredicate)other).predicates);
        }

        int calcHashCode() {
            int hc = getClass().hashCode() * 19;
            for (SaslMechanismPredicate predicate : predicates) {
                hc = multiHashUnordered(hc, predicate.calcHashCode());
            }
            return hc;
        }

        abstract void appendOperator(final StringBuilder b);
    }

    static final class AllPredicate extends MultiPredicate {
        AllPredicate(final SaslMechanismPredicate[] predicates) {
            super(predicates);
        }

        boolean test(final String mechName, final SSLSession sslSession) {
            for (SaslMechanismPredicate predicate : predicates) {
                if (! predicate.test(mechName, sslSession)) {
                    return false;
                }
            }
            return true;
        }

        void appendOperator(final StringBuilder b) {
            b.append('&').append('&');
        }
    }

    static final class AllOrNonePredicate extends MultiPredicate {
        AllOrNonePredicate(final SaslMechanismPredicate[] predicates) {
            super(predicates);
        }

        boolean test(final String mechName, final SSLSession sslSession) {
            final int length = predicates.length;
            if (length == 0) {
                return true;
            }
            boolean val = predicates[0].test(mechName, sslSession);
            for (int i = 1; i < length; i++) {
                final SaslMechanismPredicate predicate = predicates[i];
                if (val != predicate.test(mechName, sslSession)) {
                    return false;
                }
            }
            return true;
        }

        void appendOperator(final StringBuilder b) {
            b.append('=').append('=');
        }
    }

    static final class AnyPredicate extends MultiPredicate {
        AnyPredicate(final SaslMechanismPredicate[] predicates) {
            super(predicates);
        }

        boolean test(final String mechName, final SSLSession sslSession) {
            for (SaslMechanismPredicate predicate : predicates) {
                if (predicate.test(mechName, sslSession)) {
                    return true;
                }
            }
            return false;
        }

        void appendOperator(final StringBuilder b) {
            b.append('|').append('|');
        }
    }

    static class NotPredicate extends SaslMechanismPredicate {
        private final SaslMechanismPredicate predicate;

        NotPredicate(final SaslMechanismPredicate predicate) {
            this.predicate = predicate;
        }

        boolean test(final String mechName, final SSLSession sslSession) {
            return ! predicate.test(mechName, sslSession);
        }

        void toString(final StringBuilder b) {
            b.append('!');
            predicate.toString(b);
        }

        @SuppressWarnings("checkstyle:equalshashcode")
        public boolean equals(final SaslMechanismPredicate other) {
            return other instanceof NotPredicate && predicate.equals(((NotPredicate) other).predicate);
        }

        int calcHashCode() {
            return getClass().hashCode() * 19 + predicate.calcHashCode();
        }

        SaslMechanismPredicate not() {
            return predicate;
        }
    }

    static class ExactPredicate extends SaslMechanismPredicate {
        private final String mechName;

        ExactPredicate(final String mechName) {
            this.mechName = mechName;
        }

        boolean test(final String mechName, final SSLSession sslSession) {
            return this.mechName.equals(mechName);
        }

        void toString(final StringBuilder b) {
            b.append(mechName);
        }

        @SuppressWarnings("checkstyle:equalshashcode")
        public boolean equals(final SaslMechanismPredicate other) {
            return other instanceof ExactPredicate && this.mechName.equals(((ExactPredicate) other).mechName);
        }

        int calcHashCode() {
            return getClass().hashCode() * 19 + mechName.hashCode();
        }
    }

    static class FamilyPredicate extends SaslMechanismPredicate {
        private final Predicate predicate;
        private final String name;

        FamilyPredicate(final Predicate predicate, final String name) {
            this.predicate = predicate;
            this.name = name;
        }

        boolean test(final String mechName, final SSLSession sslSession) {
            return predicate.test(mechName);
        }

        void toString(final StringBuilder b) {
            b.append("#FAMILY(").append(name).append(')');
        }

        @SuppressWarnings("checkstyle:equalshashcode")
        public boolean equals(final SaslMechanismPredicate other) {
            return other instanceof FamilyPredicate && equals((FamilyPredicate) other);
        }

        @SuppressWarnings("checkstyle:equalshashcode")
        private boolean equals(final FamilyPredicate other) {
            return predicate.equals(other.predicate) && name.equals(other.name);
        }

        int calcHashCode() {
            return multiHashOrdered(multiHashOrdered(getClass().hashCode(), predicate.hashCode()), name.hashCode());
        }
    }

    static class IfPredicate extends SaslMechanismPredicate {
        private final SaslMechanismPredicate conditionPredicate;
        private final SaslMechanismPredicate truePredicate;
        private final SaslMechanismPredicate falsePredicate;

        IfPredicate(final SaslMechanismPredicate conditionPredicate, final SaslMechanismPredicate truePredicate, final SaslMechanismPredicate falsePredicate) {
            this.conditionPredicate = conditionPredicate;
            this.truePredicate = truePredicate;
            this.falsePredicate = falsePredicate;
        }

        boolean test(final String mechName, final SSLSession sslSession) {
            return conditionPredicate.test(mechName, sslSession) ? truePredicate.test(mechName, sslSession) : falsePredicate.test(mechName, sslSession);
        }

        void toString(final StringBuilder b) {
            b.append('(').append(conditionPredicate).append('?').append(truePredicate).append(':').append(falsePredicate).append(')');
        }

        @SuppressWarnings("checkstyle:equalshashcode")
        public boolean equals(final SaslMechanismPredicate other) {
            return this == other || other instanceof IfPredicate && equals((IfPredicate) other);
        }

        @SuppressWarnings("checkstyle:equalshashcode")
        private boolean equals(final IfPredicate other) {
            return conditionPredicate.equals(other.conditionPredicate) && truePredicate.equals(other.truePredicate) && falsePredicate.equals(other.falsePredicate);
        }

        int calcHashCode() {
            return multiHashOrdered(multiHashOrdered(multiHashOrdered(getClass().hashCode(), conditionPredicate.hashCode()), truePredicate.hashCode()), falsePredicate.hashCode());
        }
    }

    static class HashPredicate extends SaslMechanismPredicate {
        private final String digest;

        HashPredicate(final String digest) {
            this.digest = digest;
        }

        boolean test(final String mechName, final SSLSession sslSession) {
            switch (digest) {
                case "MD5": return SaslMechanismInformation.HASH_MD5.test(mechName);
                case "SHA-1": return SaslMechanismInformation.HASH_SHA.test(mechName);
                case "SHA-256": return SaslMechanismInformation.HASH_SHA_256.test(mechName);
                case "SHA-384": return SaslMechanismInformation.HASH_SHA_384.test(mechName);
                case "SHA-512": return SaslMechanismInformation.HASH_SHA_512.test(mechName);
                case "SHA-512-256": return SaslMechanismInformation.HASH_SHA_512_256.test(mechName);
                default: return false;
            }
        }

        void toString(final StringBuilder b) {
            b.append("#HASH(").append(digest).append(')');
        }

        @SuppressWarnings("checkstyle:equalshashcode")
        public boolean equals(final SaslMechanismPredicate other) {
            return other instanceof HashPredicate && digest.equals(((HashPredicate) other).digest);
        }

        int calcHashCode() {
            return 0;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy