All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.amazonaws.services.stepfunctions.builder.internal.validation.StateMachineValidator Maven / Gradle / Ivy

/*
 * Copyright 2011-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * A copy of the License is located at
 *
 *  http://aws.amazon.com/apache2.0
 *
 * or in the "license" file accompanying this file. This file is distributed
 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */
package com.amazonaws.services.stepfunctions.builder.internal.validation;

import com.amazonaws.services.stepfunctions.builder.ErrorCodes;
import com.amazonaws.services.stepfunctions.builder.StateMachine;
import com.amazonaws.services.stepfunctions.builder.conditions.BinaryCondition;
import com.amazonaws.services.stepfunctions.builder.conditions.Condition;
import com.amazonaws.services.stepfunctions.builder.conditions.NAryCondition;
import com.amazonaws.services.stepfunctions.builder.conditions.NotCondition;
import com.amazonaws.services.stepfunctions.builder.internal.PropertyNames;
import com.amazonaws.services.stepfunctions.builder.states.Branch;
import com.amazonaws.services.stepfunctions.builder.states.Iterator;
import com.amazonaws.services.stepfunctions.builder.states.Catcher;
import com.amazonaws.services.stepfunctions.builder.states.Choice;
import com.amazonaws.services.stepfunctions.builder.states.ChoiceState;
import com.amazonaws.services.stepfunctions.builder.states.FailState;
import com.amazonaws.services.stepfunctions.builder.states.NextStateTransition;
import com.amazonaws.services.stepfunctions.builder.states.ParallelState;
import com.amazonaws.services.stepfunctions.builder.states.MapState;
import com.amazonaws.services.stepfunctions.builder.states.PassState;
import com.amazonaws.services.stepfunctions.builder.states.Retrier;
import com.amazonaws.services.stepfunctions.builder.states.State;
import com.amazonaws.services.stepfunctions.builder.states.StateVisitor;
import com.amazonaws.services.stepfunctions.builder.states.SucceedState;
import com.amazonaws.services.stepfunctions.builder.states.TaskState;
import com.amazonaws.services.stepfunctions.builder.states.Transition;
import com.amazonaws.services.stepfunctions.builder.states.TransitionState;
import com.amazonaws.services.stepfunctions.builder.states.WaitFor;
import com.amazonaws.services.stepfunctions.builder.states.WaitForSeconds;
import com.amazonaws.services.stepfunctions.builder.states.WaitForSecondsPath;
import com.amazonaws.services.stepfunctions.builder.states.WaitForTimestamp;
import com.amazonaws.services.stepfunctions.builder.states.WaitForTimestampPath;
import com.amazonaws.services.stepfunctions.builder.states.WaitState;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Validator for a {@link StateMachine} object.
 * // TODO Does not check max nesting.
 * // TODO Does not validate ARNs against a regex
 */
public class StateMachineValidator {

    private final ProblemReporter problemReporter = new ProblemReporter();
    private final StateMachine stateMachine;

    public StateMachineValidator(StateMachine stateMachine) {
        this.stateMachine = stateMachine;
    }

    public StateMachine validate() {
        ValidationContext context = ValidationContext.builder()
                .problemReporter(problemReporter)
                .parentContext(null)
                .identifier("Root")
                .location(Location.StateMachine)
                .build();
        context.assertStringNotEmpty(stateMachine.getStartAt(), PropertyNames.START_AT);
        context.assertIsPositiveIfPresent(stateMachine.getTimeoutSeconds(), PropertyNames.TIMEOUT_SECONDS);
        context.assertNotEmpty(stateMachine.getStates(), PropertyNames.STATES);

        validateStates(context, stateMachine.getStates());

        if (!stateMachine.getStates().containsKey(stateMachine.getStartAt())) {
            problemReporter.report(new Problem(context,
                                               String.format("%s state does not exist.", PropertyNames.START_AT)));
        }

        // If basic validation failed then the graph may not be in a good state to be able to validate
        if (!problemReporter.hasProblems()) {
            new GraphValidator(context, stateMachine).validate();
        }

        if (problemReporter.hasProblems()) {
            throw problemReporter.getException();
        }

        return stateMachine;
    }

    /**
     * Validates the DFS does not contain unrecoverable cycles (i.e. cycles with no branching logic) or
     * does not contain a path to a terminal state.
     */
    private final class GraphValidator {
        private final Map parentVisited;
        private final String initialState;
        private final Map states;
        private final Map visited = new HashMap();
        private final ValidationContext currentContext;

        public GraphValidator(ValidationContext context, StateMachine stateMachine) {
            this(context,
                 Collections.emptyMap(),
                 stateMachine.getStartAt(),
                 stateMachine.getStates());
        }

        private GraphValidator(ValidationContext context,
                               Map parentVisited,
                               String initialState,
                               Map states) {
            this.currentContext = context;
            this.parentVisited = parentVisited;
            this.initialState = initialState;
            this.states = states;
        }

        public boolean validate() {
            boolean pathToTerminal = visit(initialState);
            if (parentVisited.isEmpty() && !pathToTerminal) {
                problemReporter.report(new Problem(currentContext, "No path to a terminal state exists."));
            }
            return pathToTerminal;
        }

        private boolean visit(String stateName) {
            ValidationContext stateContext = currentContext.state(stateName);
            final State state = states.get(stateName);
            if (!parentVisited.containsKey(stateName) && visited.containsKey(stateName)) {
                problemReporter.report(new Problem(stateContext, "Cycle detected."));
                return false;
            } else if (parentVisited.containsKey(stateName)) {
                // Cycle but to parent so we may be okay
                return false;
            }
            visited.put(stateName, state);
            if (state instanceof ParallelState) {
                validateParallelState(stateContext, (ParallelState) state);
            }
            if (state instanceof MapState) {
                validateMapState(stateContext, (MapState) state);
            }
            if (state.isTerminalState()) {
                return true;
            } else if (state instanceof TransitionState) {
                final Transition transition = ((TransitionState) state).getTransition();
                return visit(((NextStateTransition) transition).getNextStateName());
            } else if (state instanceof ChoiceState) {
                return validateChoiceState(stateContext, (ChoiceState) state);
            } else {
                throw new RuntimeException("Unexpected state type: " + state.getClass().getName());
            }
        }

        private void validateParallelState(ValidationContext stateContext, ParallelState state) {
            int index = 0;
            for (Branch branch : state.getBranches()) {
                new GraphValidator(stateContext.branch(index),
                                   Collections.emptyMap(),
                                   branch.getStartAt(),
                                   branch.getStates()).validate();
                index++;
            }
        }

        private void validateMapState(ValidationContext stateContext, MapState state) {
            Iterator iterator = state.getIterator();
            new GraphValidator(stateContext.iterator(),
                               Collections.emptyMap(),
                               iterator.getStartAt(),
                               iterator.getStates()).validate();
        }

        private boolean validateChoiceState(ValidationContext stateContext, ChoiceState choiceState) {
            final Map merged = mergeParentVisited();
            boolean hasPathToTerminal = false;
            if (choiceState.getDefaultStateName() != null) {
                hasPathToTerminal = new GraphValidator(stateContext, merged, choiceState.getDefaultStateName(), states)
                        .validate();
            }
            int index = 0;
            for (Choice choice : choiceState.getChoices()) {
                final String nextStateName = ((NextStateTransition) choice.getTransition()).getNextStateName();
                // It's important hasPathToTerminal is last in the OR so it doesn't short circuit the choice validation
                hasPathToTerminal = new GraphValidator(stateContext.choice(index), merged, nextStateName, states).validate()
                                    || hasPathToTerminal;
                index++;
            }
            return hasPathToTerminal;
        }

        private Map mergeParentVisited() {
            final Map merged = new HashMap(parentVisited.size() + visited.size());
            merged.putAll(parentVisited);
            merged.putAll(visited);
            return merged;
        }

    }

    private void validateStates(ValidationContext parentContext, Map states) {
        for (Map.Entry entry : states.entrySet()) {
            parentContext.assertStringNotEmpty(entry.getKey(), "State Name");
            entry.getValue().accept(new StateValidationVisitor(states, parentContext.state(entry.getKey())));
        }
    }

    /**
     * Validates all the supported states and their nested properties.
     */
    private class StateValidationVisitor extends StateVisitor {

        private final ValidationContext currentContext;
        private final Map states;

        private StateValidationVisitor(Map states, ValidationContext context) {
            this.states = states;
            this.currentContext = context;
        }

        @Override
        public Void visit(ChoiceState choiceState) {
            currentContext.assertIsValidInputPath(choiceState.getInputPath());
            currentContext.assertIsValidOutputPath(choiceState.getOutputPath());
            if (choiceState.getDefaultStateName() != null) {
                currentContext.assertStringNotEmpty(choiceState.getDefaultStateName(), PropertyNames.DEFAULT_STATE);
                assertContainsState(choiceState.getDefaultStateName());
            }

            currentContext.assertNotEmpty(choiceState.getChoices(), PropertyNames.CHOICES);
            int index = 0;
            for (Choice choice : choiceState.getChoices()) {
                ValidationContext choiceContext = currentContext.choice(index);
                validateTransition(choiceContext, choice.getTransition());
                validateCondition(choiceContext, choice.getCondition());
                index++;
            }
            return null;
        }

        private void validateCondition(ValidationContext context, Condition condition) {
            context.assertNotNull(condition, "Condition");
            if (condition instanceof BinaryCondition) {
                validateBinaryCondition(context, (BinaryCondition) condition);
            } else if (condition instanceof NAryCondition) {
                validateNAryCondition(context, (NAryCondition) condition);
            } else if (condition instanceof NotCondition) {
                validateCondition(context, ((NotCondition) condition).getCondition());
            } else if (condition != null) {
                throw new RuntimeException("Unsupported condition type: " + condition.getClass());
            }
        }

        private void validateNAryCondition(ValidationContext context, NAryCondition condition) {
            context.assertNotEmpty(condition.getConditions(), "Conditions");
            for (Condition nestedCondition : condition.getConditions()) {
                validateCondition(context, nestedCondition);
            }
        }

        private void validateBinaryCondition(ValidationContext context, BinaryCondition condition) {
            context.assertStringNotEmpty(condition.getVariable(), PropertyNames.VARIABLE);
            context.assertIsValidJsonPath(condition.getVariable(), PropertyNames.VARIABLE);
            context.assertNotNull(condition.getExpectedValue(), "ExpectedValue");
        }

        @Override
        public Void visit(FailState failState) {
            currentContext.assertStringNotEmpty(failState.getCause(), PropertyNames.CAUSE);
            return null;
        }

        @Override
        public Void visit(ParallelState parallelState) {
            currentContext.assertIsValidInputPath(parallelState.getInputPath());
            currentContext.assertIsValidOutputPath(parallelState.getOutputPath());
            currentContext.assertIsValidResultPath(parallelState.getResultPath());
            validateTransition(parallelState.getTransition());
            validateRetriers(parallelState.getRetriers());
            validateCatchers(parallelState.getCatchers());
            validateBranches(parallelState);
            return null;
        }

        private void validateBranches(ParallelState parallelState) {
            currentContext.assertNotEmpty(parallelState.getBranches(), PropertyNames.BRANCHES);
            int index = 0;
            for (Branch branch : parallelState.getBranches()) {
                ValidationContext branchContext = currentContext.branch(index);
                validateStates(branchContext, branch.getStates());
                if (!branch.getStates().containsKey(branch.getStartAt())) {
                    problemReporter.report(new Problem(branchContext, String.format("%s references a non existent state.",
                                                                                    PropertyNames.START_AT)));
                }
                index++;
            }
        }

        @Override
        public Void visit(MapState mapState) {
            currentContext.assertIsValidItemsPath(mapState.getItemsPath());
            currentContext.assertIsValidInputPath(mapState.getInputPath());
            currentContext.assertIsValidOutputPath(mapState.getOutputPath());
            currentContext.assertIsValidResultPath(mapState.getResultPath());
            validateTransition(mapState.getTransition());
            validateRetriers(mapState.getRetriers());
            validateCatchers(mapState.getCatchers());
            validateIterator(mapState);
            return null;
        }

        private void validateIterator(MapState mapState) {
            currentContext.assertNotNull(mapState.getIterator(), PropertyNames.ITERATOR);
            Iterator iterator = mapState.getIterator();
            ValidationContext iteratorContext = currentContext.iterator();
            validateStates(iteratorContext, iterator.getStates());
            if (!iterator.getStates().containsKey(iterator.getStartAt())) {
                problemReporter.report(new Problem(iteratorContext, String.format("%s references a non existent state.",
                        PropertyNames.START_AT)));
            }
        }

        @Override
        public Void visit(PassState passState) {
            currentContext.assertIsValidInputPath(passState.getInputPath());
            currentContext.assertIsValidOutputPath(passState.getOutputPath());
            currentContext.assertIsValidResultPath(passState.getResultPath());
            validateTransition(passState.getTransition());
            return null;
        }

        @Override
        public Void visit(SucceedState succeedState) {
            currentContext.assertIsValidInputPath(succeedState.getInputPath());
            currentContext.assertIsValidOutputPath(succeedState.getOutputPath());
            return null;
        }

        @Override
        public Void visit(TaskState taskState) {
            currentContext.assertIsValidInputPath(taskState.getInputPath());
            currentContext.assertIsValidOutputPath(taskState.getOutputPath());
            currentContext.assertIsValidResultPath(taskState.getResultPath());
            currentContext.assertIsPositiveIfPresent(taskState.getTimeoutSeconds(), PropertyNames.TIMEOUT_SECONDS);
            currentContext.assertIsPositiveIfPresent(taskState.getHeartbeatSeconds(), PropertyNames.HEARTBEAT_SECONDS);
            if (taskState.getTimeoutSeconds() != null && taskState.getHeartbeatSeconds() != null) {
                if (taskState.getHeartbeatSeconds() >= taskState.getTimeoutSeconds()) {
                    problemReporter.report(new Problem(currentContext, String.format("%s must be smaller than %s",
                                                                                     PropertyNames.HEARTBEAT_SECONDS,
                                                                                     PropertyNames.TIMEOUT_SECONDS)));
                }
            }

            currentContext.assertStringNotEmpty(taskState.getResource(), PropertyNames.RESOURCE);
            validateRetriers(taskState.getRetriers());
            validateCatchers(taskState.getCatchers());
            validateTransition(taskState.getTransition());
            return null;
        }

        private void validateRetriers(List retriers) {
            boolean hasRetryAll = false;
            int index = 0;
            for (Retrier retrier : retriers) {
                ValidationContext retrierContext = currentContext.retrier(index);
                if (hasRetryAll) {
                    problemReporter.report(
                            new Problem(retrierContext,
                                        String.format("When %s is used in must be in the last Retrier", ErrorCodes.ALL)));
                }
                // MaxAttempts may be zero
                retrierContext.assertIsNotNegativeIfPresent(retrier.getMaxAttempts(), PropertyNames.MAX_ATTEMPTS);
                retrierContext.assertIsPositiveIfPresent(retrier.getIntervalSeconds(), PropertyNames.INTERVAL_SECONDS);
                if (retrier.getBackoffRate() != null && retrier.getBackoffRate() < 1.0) {
                    problemReporter.report(new Problem(retrierContext, String.format("%s must be greater than or equal to 1.0",
                                                                                     PropertyNames.BACKOFF_RATE)));
                }
                hasRetryAll = validateErrorEquals(retrierContext, retrier.getErrorEquals());
                index++;
            }
        }


        private void validateCatchers(List catchers) {
            boolean hasCatchAll = false;
            int index = 0;
            for (Catcher catcher : catchers) {
                ValidationContext catcherContext = currentContext.catcher(index);
                catcherContext.assertIsValidResultPath(catcher.getResultPath());
                if (hasCatchAll) {
                    problemReporter.report(
                            new Problem(catcherContext,
                                        String.format("When %s is used in must be in the last Catcher", ErrorCodes.ALL)));
                }
                validateTransition(catcherContext, catcher.getTransition());
                hasCatchAll = validateErrorEquals(catcherContext, catcher.getErrorEquals());
                index++;
            }
        }

        private boolean validateErrorEquals(ValidationContext currentContext, List errorEquals) {
            currentContext.assertNotEmpty(errorEquals, PropertyNames.ERROR_EQUALS);
            if (errorEquals.contains(ErrorCodes.ALL)) {
                if (errorEquals.size() != 1) {
                    problemReporter.report(new Problem(currentContext, String.format(
                            "When %s is used in %s, it must be the only error code in the array",
                            ErrorCodes.ALL, PropertyNames.ERROR_EQUALS)));
                }
                return true;
            }
            return false;
        }

        @Override
        public Void visit(WaitState waitState) {
            currentContext.assertIsValidInputPath(waitState.getInputPath());
            currentContext.assertIsValidOutputPath(waitState.getOutputPath());
            validateTransition(waitState.getTransition());
            validateWaitFor(waitState.getWaitFor());
            return null;
        }

        private void validateWaitFor(WaitFor waitFor) {
            currentContext.assertNotNull(waitFor, "WaitFor");
            if (waitFor instanceof WaitForSeconds) {
                currentContext.assertIsPositiveIfPresent(((WaitForSeconds) waitFor).getSeconds(), PropertyNames.SECONDS);
            } else if (waitFor instanceof WaitForSecondsPath) {
                assertWaitForPath(((WaitForSecondsPath) waitFor).getSecondsPath(), PropertyNames.SECONDS_PATH);
            } else if (waitFor instanceof WaitForTimestamp) {
                currentContext.assertNotNull(((WaitForTimestamp) waitFor).getTimestamp(), PropertyNames.TIMESTAMP);
            } else if (waitFor instanceof WaitForTimestampPath) {
                assertWaitForPath(((WaitForTimestampPath) waitFor).getTimestampPath(), PropertyNames.TIMESTAMP_PATH);
            } else if (waitFor != null) {
                throw new RuntimeException("Unsupported WaitFor strategy: " + waitFor.getClass());
            }
        }

        /**
         * TimestampPath and SecondsPath must have a valid reference path.
         */
        private void assertWaitForPath(String pathValue, String propertyName) {
            currentContext.assertNotNull(pathValue, propertyName);
            currentContext.assertIsValidReferencePath(pathValue, propertyName);
        }

        private void validateTransition(Transition transition) {
            validateTransition(currentContext, transition);
        }

        private void validateTransition(ValidationContext context, Transition transition) {
            context.assertNotNull(transition, "Transition");
            if (transition instanceof NextStateTransition) {
                final String nextStateName = ((NextStateTransition) transition).getNextStateName();
                context.assertNotNull(nextStateName, PropertyNames.NEXT);
                assertContainsState(context, nextStateName);
            }
        }

        private void assertContainsState(String nextStateName) {
            assertContainsState(currentContext, nextStateName);
        }

        private void assertContainsState(ValidationContext context, String nextStateName) {
            if (!states.containsKey(nextStateName)) {
                problemReporter.report(new Problem(context, String.format("%s is not a valid state", nextStateName)));
            }
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy