All Downloads are FREE. Search and download functionalities are using the official Maven repository.

aima.core.learning.reinforcement.agent.QLearningAgent Maven / Gradle / Ivy

Go to download

AIMA-Java Core Algorithms from the book Artificial Intelligence a Modern Approach 3rd Ed.

The newest version!
package aima.core.learning.reinforcement.agent;

import java.util.HashMap;
import java.util.Map;

import aima.core.agent.Action;
import aima.core.learning.reinforcement.PerceptStateReward;
import aima.core.probability.mdp.ActionsFunction;
import aima.core.util.FrequencyCounter;
import aima.core.util.datastructure.Pair;

/**
 * Artificial Intelligence A Modern Approach (3rd Edition): page 844.
*
* *
 * function Q-LEARNING-AGENT(percept) returns an action
 *   inputs: percept, a percept indicating the current state s' and reward signal r'
 *   persistent: Q, a table of action values indexed by state and action, initially zero
 *               Nsa, a table of frequencies for state-action pairs, initially zero
 *               s,a,r, the previous state, action, and reward, initially null
 *               
 *   if TERMAINAL?(s) then Q[s,None] <- r'
 *   if s is not null then
 *       increment Nsa[s,a]
 *       Q[s,a] <- Q[s,a] + α(Nsa[s,a])(r + γmaxa'Q[s',a'] - Q[s,a])
 *   s,a,r <- s',argmaxa'f(Q[s',a'],Nsa[s',a']),r'
 *   return a
 * 
* * Figure 21.8 An exploratory Q-learning agent. It is an active learner that * learns the value Q(s,a) of each action in each situation. It uses the same * exploration function f as the exploratory ADP agent, but avoids having to * learn the transition model because the Q-value of a state can be related * directly to those of its neighbors.
*
* Note: There appears to be two minor defects in the algorithm outlined * in the book:
* if TERMAINAL?(s) then Q[s,None] <- r'
* should be:
* if TERMAINAL?(s') then Q[s',None] <- r'
* so that the correct value for Q[s',a'] is used in the Q[s,a] update rule when * a terminal state is reached.
*
* s,a,r <- s',argmaxa'f(Q[s',a'],Nsa[s',a']),r'
* should be: * *
 * if s'.TERMINAL? then s,a,r <- null else s,a,r <- s',argmaxa'f(Q[s',a'],Nsa[s',a']),r'
 * 
* * otherwise at the beginning of a consecutive trial, s will be the prior * terminal state and is what will be updated in Q[s,a], which appears not to be * correct as you did not perform an action in the terminal state and the * initial state is not reachable from the prior terminal state. Comments * welcome. * * @param * the state type. * @param * the action type. * * @author Ciaran O'Reilly * @author Ravi Mohan * */ public class QLearningAgent extends ReinforcementAgent { // persistent: Q, a table of action values indexed by state and action, // initially zero Map, Double> Q = new HashMap, Double>(); // Nsa, a table of frequencies for state-action pairs, initially // zero private FrequencyCounter> Nsa = new FrequencyCounter>(); // s,a,r, the previous state, action, and reward, initially null private S s = null; private A a = null; private Double r = null; // private ActionsFunction actionsFunction = null; private A noneAction = null; private double alpha = 0.0; private double gamma = 0.0; private int Ne = 0; private double Rplus = 0.0; /** * Constructor. * * @param actionsFunction * a function that lists the legal actions from a state. * @param noneAction * an action representing None, i.e. a NoOp. * @param alpha * a fixed learning rate. * @param gamma * discount to be used. * @param Ne * is fixed parameter for use in the method f(u, n). * @param Rplus * R+ is an optimistic estimate of the best possible reward * obtainable in any state, which is used in the method f(u, n). */ public QLearningAgent(ActionsFunction actionsFunction, A noneAction, double alpha, double gamma, int Ne, double Rplus) { this.actionsFunction = actionsFunction; this.noneAction = noneAction; this.alpha = alpha; this.gamma = gamma; this.Ne = Ne; this.Rplus = Rplus; } /** * An exploratory Q-learning agent. It is an active learner that learns the * value Q(s,a) of each action in each situation. It uses the same * exploration function f as the exploratory ADP agent, but avoids having to * learn the transition model because the Q-value of a state can be related * directly to those of its neighbors. * * @param percept * a percept indicating the current state s' and reward signal * r'. * @return an action */ @Override public A execute(PerceptStateReward percept) { S sPrime = percept.state(); double rPrime = percept.reward(); // if TERMAINAL?(s') then Q[s',None] <- r' if (isTerminal(sPrime)) { Q.put(new Pair(sPrime, noneAction), rPrime); } // if s is not null then if (null != s) { // increment Nsa[s,a] Pair sa = new Pair(s, a); Nsa.incrementFor(sa); // Q[s,a] <- Q[s,a] + α(Nsa[s,a])(r + // γmaxa'Q[s',a'] - Q[s,a]) Double Q_sa = Q.get(sa); if (null == Q_sa) { Q_sa = 0.0; } Q.put(sa, Q_sa + alpha(Nsa, s, a) * (r + gamma * maxAPrime(sPrime) - Q_sa)); } // if s'.TERMINAL? then s,a,r <- null else // s,a,r <- s',argmaxa'f(Q[s',a'],Nsa[s',a']),r' if (isTerminal(sPrime)) { s = null; a = null; r = null; } else { s = sPrime; a = argmaxAPrime(sPrime); r = rPrime; } // return a return a; } @Override public void reset() { Q.clear(); Nsa.clear(); s = null; a = null; r = null; } @Override public Map getUtility() { // Q-values are directly related to utility values as follows // (AIMA3e pg. 843 - 21.6) : // U(s) = maxaQ(s,a). Map U = new HashMap(); for (Pair sa : Q.keySet()) { Double q = Q.get(sa); Double u = U.get(sa.getFirst()); if (null == u || u < q) { U.put(sa.getFirst(), q); } } return U; } // // PROTECTED METHODS // /** * AIMA3e pg. 836 'if we change α from a fixed parameter to a function * that decreases as the number of times a state action has been observed * increases, then Uπ(s) itself will converge to the correct * value.
*
* Note: override this method to obtain the desired behavior. * * @param Nsa * a frequency counter of observed state action pairs. * @param s * the current state. * @param a the current action. * @return the learning rate to use based on the frequency of the state * passed in. */ protected double alpha(FrequencyCounter> Nsa, S s, A a) { // Default implementation is just to return a fixed parameter value // irrespective of the # of times a state action has been encountered return alpha; } /** * AIMA3e pg. 842 'f(u, n) is called the exploration function. It * determines how greed (preferences for high values of u) is traded off * against curiosity (preferences for actions that have not been tried often * and have low n). The function f(u, n) should be increasing in u and * decreasing in n. * * * Note: Override this method to obtain desired behavior. * * @param u * the currently estimated utility. * @param n * the number of times this situation has been encountered. * @return the exploration value. */ protected double f(Double u, int n) { // A Simple definition of f(u, n): if (null == u || n < Ne) { return Rplus; } return u; } // // PRIVATE METHODS // private boolean isTerminal(S s) { boolean terminal = false; if (null != s && actionsFunction.actions(s).size() == 0) { // No actions possible in state is considered terminal. terminal = true; } return terminal; } private double maxAPrime(S sPrime) { double max = Double.NEGATIVE_INFINITY; if (actionsFunction.actions(sPrime).size() == 0) { // a terminal state max = Q.get(new Pair(sPrime, noneAction)); } else { for (A aPrime : actionsFunction.actions(sPrime)) { Double Q_sPrimeAPrime = Q.get(new Pair(sPrime, aPrime)); if (null != Q_sPrimeAPrime && Q_sPrimeAPrime > max) { max = Q_sPrimeAPrime; } } } if (max == Double.NEGATIVE_INFINITY) { // Assign 0 as the mimics Q being initialized to 0 up front. max = 0.0; } return max; } // argmaxa'f(Q[s',a'],Nsa[s',a']) private A argmaxAPrime(S sPrime) { A a = null; double max = Double.NEGATIVE_INFINITY; for (A aPrime : actionsFunction.actions(sPrime)) { Pair sPrimeAPrime = new Pair(sPrime, aPrime); double explorationValue = f(Q.get(sPrimeAPrime), Nsa .getCount(sPrimeAPrime)); if (explorationValue > max) { max = explorationValue; a = aPrime; } } return a; } }