aima.core.learning.reinforcement.agent.QLearningAgent Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of aima-core Show documentation
Show all versions of aima-core Show documentation
AIMA-Java Core Algorithms from the book Artificial Intelligence a Modern Approach 3rd Ed.
The newest version!
package aima.core.learning.reinforcement.agent;
import java.util.HashMap;
import java.util.Map;
import aima.core.agent.Action;
import aima.core.learning.reinforcement.PerceptStateReward;
import aima.core.probability.mdp.ActionsFunction;
import aima.core.util.FrequencyCounter;
import aima.core.util.datastructure.Pair;
/**
* Artificial Intelligence A Modern Approach (3rd Edition): page 844.
*
*
*
* function Q-LEARNING-AGENT(percept) returns an action
* inputs: percept, a percept indicating the current state s' and reward signal r'
* persistent: Q, a table of action values indexed by state and action, initially zero
* Nsa, a table of frequencies for state-action pairs, initially zero
* s,a,r, the previous state, action, and reward, initially null
*
* if TERMAINAL?(s) then Q[s,None] <- r'
* if s is not null then
* increment Nsa[s,a]
* Q[s,a] <- Q[s,a] + α(Nsa[s,a])(r + γmaxa'Q[s',a'] - Q[s,a])
* s,a,r <- s',argmaxa'f(Q[s',a'],Nsa[s',a']),r'
* return a
*
*
* Figure 21.8 An exploratory Q-learning agent. It is an active learner that
* learns the value Q(s,a) of each action in each situation. It uses the same
* exploration function f as the exploratory ADP agent, but avoids having to
* learn the transition model because the Q-value of a state can be related
* directly to those of its neighbors.
*
* Note: There appears to be two minor defects in the algorithm outlined
* in the book:
* if TERMAINAL?(s) then Q[s,None] <- r'
* should be:
* if TERMAINAL?(s') then Q[s',None] <- r'
* so that the correct value for Q[s',a'] is used in the Q[s,a] update rule when
* a terminal state is reached.
*
* s,a,r <- s',argmaxa'f(Q[s',a'],Nsa[s',a']),r'
* should be:
*
*
* if s'.TERMINAL? then s,a,r <- null else s,a,r <- s',argmaxa'f(Q[s',a'],Nsa[s',a']),r'
*
*
* otherwise at the beginning of a consecutive trial, s will be the prior
* terminal state and is what will be updated in Q[s,a], which appears not to be
* correct as you did not perform an action in the terminal state and the
* initial state is not reachable from the prior terminal state. Comments
* welcome.
*
* @param
* the state type.
* @param
* the action type.
*
* @author Ciaran O'Reilly
* @author Ravi Mohan
*
*/
public class QLearningAgent extends
ReinforcementAgent {
// persistent: Q, a table of action values indexed by state and action,
// initially zero
Map, Double> Q = new HashMap, Double>();
// Nsa, a table of frequencies for state-action pairs, initially
// zero
private FrequencyCounter> Nsa = new FrequencyCounter>();
// s,a,r, the previous state, action, and reward, initially null
private S s = null;
private A a = null;
private Double r = null;
//
private ActionsFunction actionsFunction = null;
private A noneAction = null;
private double alpha = 0.0;
private double gamma = 0.0;
private int Ne = 0;
private double Rplus = 0.0;
/**
* Constructor.
*
* @param actionsFunction
* a function that lists the legal actions from a state.
* @param noneAction
* an action representing None, i.e. a NoOp.
* @param alpha
* a fixed learning rate.
* @param gamma
* discount to be used.
* @param Ne
* is fixed parameter for use in the method f(u, n).
* @param Rplus
* R+ is an optimistic estimate of the best possible reward
* obtainable in any state, which is used in the method f(u, n).
*/
public QLearningAgent(ActionsFunction actionsFunction,
A noneAction, double alpha,
double gamma, int Ne, double Rplus) {
this.actionsFunction = actionsFunction;
this.noneAction = noneAction;
this.alpha = alpha;
this.gamma = gamma;
this.Ne = Ne;
this.Rplus = Rplus;
}
/**
* An exploratory Q-learning agent. It is an active learner that learns the
* value Q(s,a) of each action in each situation. It uses the same
* exploration function f as the exploratory ADP agent, but avoids having to
* learn the transition model because the Q-value of a state can be related
* directly to those of its neighbors.
*
* @param percept
* a percept indicating the current state s' and reward signal
* r'.
* @return an action
*/
@Override
public A execute(PerceptStateReward percept) {
S sPrime = percept.state();
double rPrime = percept.reward();
// if TERMAINAL?(s') then Q[s',None] <- r'
if (isTerminal(sPrime)) {
Q.put(new Pair(sPrime, noneAction), rPrime);
}
// if s is not null then
if (null != s) {
// increment Nsa[s,a]
Pair sa = new Pair(s, a);
Nsa.incrementFor(sa);
// Q[s,a] <- Q[s,a] + α(Nsa[s,a])(r +
// γmaxa'Q[s',a'] - Q[s,a])
Double Q_sa = Q.get(sa);
if (null == Q_sa) {
Q_sa = 0.0;
}
Q.put(sa, Q_sa + alpha(Nsa, s, a)
* (r + gamma * maxAPrime(sPrime) - Q_sa));
}
// if s'.TERMINAL? then s,a,r <- null else
// s,a,r <- s',argmaxa'f(Q[s',a'],Nsa[s',a']),r'
if (isTerminal(sPrime)) {
s = null;
a = null;
r = null;
} else {
s = sPrime;
a = argmaxAPrime(sPrime);
r = rPrime;
}
// return a
return a;
}
@Override
public void reset() {
Q.clear();
Nsa.clear();
s = null;
a = null;
r = null;
}
@Override
public Map getUtility() {
// Q-values are directly related to utility values as follows
// (AIMA3e pg. 843 - 21.6) :
// U(s) = maxaQ(s,a).
Map U = new HashMap();
for (Pair sa : Q.keySet()) {
Double q = Q.get(sa);
Double u = U.get(sa.getFirst());
if (null == u || u < q) {
U.put(sa.getFirst(), q);
}
}
return U;
}
//
// PROTECTED METHODS
//
/**
* AIMA3e pg. 836 'if we change α from a fixed parameter to a function
* that decreases as the number of times a state action has been observed
* increases, then Uπ(s) itself will converge to the correct
* value.
*
* Note: override this method to obtain the desired behavior.
*
* @param Nsa
* a frequency counter of observed state action pairs.
* @param s
* the current state.
* @param a the current action.
* @return the learning rate to use based on the frequency of the state
* passed in.
*/
protected double alpha(FrequencyCounter> Nsa, S s, A a) {
// Default implementation is just to return a fixed parameter value
// irrespective of the # of times a state action has been encountered
return alpha;
}
/**
* AIMA3e pg. 842 'f(u, n) is called the exploration function. It
* determines how greed (preferences for high values of u) is traded off
* against curiosity (preferences for actions that have not been tried often
* and have low n). The function f(u, n) should be increasing in u and
* decreasing in n.
*
*
* Note: Override this method to obtain desired behavior.
*
* @param u
* the currently estimated utility.
* @param n
* the number of times this situation has been encountered.
* @return the exploration value.
*/
protected double f(Double u, int n) {
// A Simple definition of f(u, n):
if (null == u || n < Ne) {
return Rplus;
}
return u;
}
//
// PRIVATE METHODS
//
private boolean isTerminal(S s) {
boolean terminal = false;
if (null != s && actionsFunction.actions(s).size() == 0) {
// No actions possible in state is considered terminal.
terminal = true;
}
return terminal;
}
private double maxAPrime(S sPrime) {
double max = Double.NEGATIVE_INFINITY;
if (actionsFunction.actions(sPrime).size() == 0) {
// a terminal state
max = Q.get(new Pair(sPrime, noneAction));
} else {
for (A aPrime : actionsFunction.actions(sPrime)) {
Double Q_sPrimeAPrime = Q.get(new Pair(sPrime, aPrime));
if (null != Q_sPrimeAPrime && Q_sPrimeAPrime > max) {
max = Q_sPrimeAPrime;
}
}
}
if (max == Double.NEGATIVE_INFINITY) {
// Assign 0 as the mimics Q being initialized to 0 up front.
max = 0.0;
}
return max;
}
// argmaxa'f(Q[s',a'],Nsa[s',a'])
private A argmaxAPrime(S sPrime) {
A a = null;
double max = Double.NEGATIVE_INFINITY;
for (A aPrime : actionsFunction.actions(sPrime)) {
Pair sPrimeAPrime = new Pair(sPrime, aPrime);
double explorationValue = f(Q.get(sPrimeAPrime), Nsa
.getCount(sPrimeAPrime));
if (explorationValue > max) {
max = explorationValue;
a = aPrime;
}
}
return a;
}
}