netflix.ocelli.util.StateMachine Maven / Gradle / Ivy
package netflix.ocelli.util;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.Observable;
import rx.Observable.OnSubscribe;
import rx.Subscriber;
import rx.functions.Action1;
import rx.functions.Action2;
import rx.functions.Func0;
import rx.functions.Func1;
import rx.subjects.PublishSubject;
public class StateMachine implements Action1 {
private static final Logger LOG = LoggerFactory.getLogger(StateMachine.class);
public static class State {
private String name;
private Func1> enter;
private Func1> exit;
private Map> transitions = new HashMap>();
private Set ignore = new HashSet();
public static State create(String name) {
return new State(name);
}
public State(String name) {
this.name = name;
}
public State onEnter(Func1> func) {
this.enter = func;
return this;
}
public State onExit(Func1> func) {
this.exit = func;
return this;
}
public State transition(E event, State state) {
transitions.put(event, state);
return this;
}
public State ignore(E event) {
ignore.add(event);
return this;
}
Observable enter(T context) {
if (enter != null)
return enter.call(context);
return Observable.empty();
}
Observable exit(T context) {
if (exit != null)
exit.call(context);
return Observable.empty();
}
State next(E event) {
return transitions.get(event);
}
public String toString() {
return name;
}
}
private volatile State state;
private final T context;
private final PublishSubject events = PublishSubject.create();
public static StateMachine create(T context, State initial) {
return new StateMachine(context, initial);
}
public StateMachine(T context, State initial) {
this.state = initial;
this.context = context;
}
public Observable start() {
return Observable.create(new OnSubscribe() {
@Override
public void call(Subscriber super Void> sub) {
sub.add(events.collect(new Func0() {
@Override
public T call() {
return context;
}
}, new Action2() {
@Override
public void call(T context, E event) {
LOG.trace("{} : {}({})", context, state, event);
final State next = state.next(event);
if (next != null) {
state.exit(context);
state = next;
next.enter(context).subscribe(StateMachine.this);
}
else if (!state.ignore.contains(event)) {
LOG.warn("Unexpected event {} in state {} for {} ", event, state, context);
}
}
})
.subscribe());
state.enter(context);
}
});
}
@Override
public void call(E event) {
events.onNext(event);
}
public State getState() {
return state;
}
public T getContext() {
return context;
}
}