All Downloads are FREE. Search and download functionalities are using the official Maven repository.

burlap.behavior.singleagent.learning.lspi.SARSCollector Maven / Gradle / Ivy

Go to download

The Brown-UMBC Reinforcement Learning and Planning (BURLAP) Java code library is for the use and development of single or multi-agent planning and learning algorithms and domains to accompany them. The library uses a highly flexible state/observation representation where you define states with your own Java classes, enabling support for domains that discrete, continuous, relational, or anything else. Planning and learning algorithms range from classic forward search planning to value-function-based stochastic planning and learning algorithms.

The newest version!
package burlap.behavior.singleagent.learning.lspi;

import burlap.debugtools.RandomFactory;
import burlap.mdp.auxiliary.StateGenerator;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.core.action.ActionType;
import burlap.mdp.core.action.ActionUtils;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.Environment;
import burlap.mdp.singleagent.environment.EnvironmentOutcome;
import burlap.mdp.singleagent.model.SampleModel;

import java.util.List;


/**
 * This object is used to collected {@link SARSData} (state-action-reard-state tuples) that can then be used by algorithms like LSPI for learning.
 * @author James MacGlashan
 *
 */
public abstract class SARSCollector {
	
	/**
	 * The actions used for collecting data.
	 */
	protected List actionTypes;
	
	
	/**
	 * Initializes the collector's action set using the actions that are part of the domain.
	 * @param domain the domain containing the actions to use
	 */
	public SARSCollector(SADomain domain){
		this.actionTypes = domain.getActionTypes();
	}
	
	/**
	 * Initializes this collector's action set to use for collecting data.
	 * @param actionTypes the action set to use for collecting data.
	 */
	public SARSCollector(List actionTypes){
		this.actionTypes = actionTypes;
	}
	
	/**
	 * Collects data from an initial state until either a terminal state is reached or until the maximum number of steps is taken.
	 * Data is stored into the dataset intoDataset and returned. If intoDataset is null, then it is first created.
	 * @param s the initial state from which data should be collected.
	 * @param model the model of the world to use
	 * @param maxSteps the maximum number of steps that can be taken.
	 * @param intoDataset the dataset into which data will be stored. If null, a dataset is created.
	 * @return the resulting dataset with the newly collected data.
	 */
	public abstract SARSData collectDataFrom(State s, SampleModel model, int maxSteps, SARSData intoDataset);

	/**
	 * Collects data from an {@link burlap.mdp.singleagent.environment.Environment}'s current state until either the maximum
	 * number of steps is taken or a terminal state is reached.
	 * Data is stored into the dataset intoDataset and returned. If intoDataset is null, then it is first created.
	 * @param env the {@link burlap.mdp.singleagent.environment.Environment} from which data will be collected.
	 * @param maxSteps the maximum number of steps to take in the environment.
	 * @param intoDataset the dataset into which data will be stored. If null, a dataset is created.
	 * @return the resulting dataset with the newly collected data.
	 */
	public abstract SARSData collectDataFrom(Environment env, int maxSteps, SARSData intoDataset);
	
	
	/**
	 * Collects nSamples of SARS tuples and returns it in a {@link SARSData} object.
	 * @param sg a state generator for finding initial state from which data can be collected.
	 * @param model the model of the world to use
	 * @param nSamples the number of SARS samples to collect.
	 * @param maxEpisodeSteps the maximum number of steps that can be taken when rolling out from a state generated by {@link StateGenerator} sg, before a new rollout is started.
	 * @param intoDataset the dataset into which the results will be collected. If null, a new dataset is created.
	 * @return the intoDataset object, which is created if it is input as null.
	 */
	public SARSData collectNInstances(StateGenerator sg, SampleModel model, int nSamples, int maxEpisodeSteps, SARSData intoDataset){
		
		if(intoDataset == null){
			intoDataset = new SARSData(nSamples);
		}
		
		while(nSamples > 0){
			int maxSteps = Math.min(nSamples, maxEpisodeSteps);
			int oldSize = intoDataset.size();
			this.collectDataFrom(sg.generateState(), model, maxSteps, intoDataset);
			int delta = intoDataset.size() - oldSize;
			nSamples -= delta;
		}
		
		return intoDataset;
		
	}


	/**
	 * Collects nSamples of SARS tuples from an {@link burlap.mdp.singleagent.environment.Environment} and returns it in a {@link burlap.behavior.singleagent.learning.lspi.SARSData} object.
	 * Each sequence of samples is no longer than maxEpisodeSteps and samples are collected using this object's {@link #collectDataFrom(burlap.mdp.singleagent.environment.Environment, int, SARSData)}
	 * method. After each call to {@link #collectDataFrom(burlap.mdp.singleagent.environment.Environment, int, SARSData)}, the provided {@link burlap.mdp.singleagent.environment.Environment}
	 * is sent the {@link burlap.mdp.singleagent.environment.Environment#resetEnvironment()} message.
	 * @param env The {@link burlap.mdp.singleagent.environment.Environment} from which samples should be collected.
	 * @param nSamples The number of samples to generate.
	 * @param maxEpisodeSteps the maximum number of steps to take from any initial state of the {@link burlap.mdp.singleagent.environment.Environment}.
	 * @param intoDataset the dataset into which the results will be collected. If null, a new dataset is created.
	 * @return the intoDataset object, which is created if it is input as null.
	 */
	public SARSData collectNInstances(Environment env, int nSamples, int maxEpisodeSteps, SARSData intoDataset){

		if(intoDataset == null){
			intoDataset = new SARSData(nSamples);
		}

		while(nSamples > 0 && !env.isInTerminalState()){
			int maxSteps = Math.min(nSamples, maxEpisodeSteps);
			int oldSize = intoDataset.size();
			this.collectDataFrom(env, maxSteps, intoDataset);
			int delta = intoDataset.size() - oldSize;
			nSamples -= delta;
			env.resetEnvironment();
		}

		return intoDataset;

	}
	
	
	
	/**
	 * Collects SARS data from source states generated by a {@link StateGenerator} by choosing actions uniformly at random.
	 * @author James MacGlashan
	 *
	 */
	public static class UniformRandomSARSCollector extends SARSCollector{

		/**
		 * Initializes the collector's action set using the actions that are part of the domain.
		 * @param domain the domain containing the actions to use
		 */
		public UniformRandomSARSCollector(SADomain domain) {
			super(domain);
		}
		
		/**
		 * Initializes this collector's action set to use for collecting data.
		 * @param actionTypes the action set to use for collecting data.
		 */
		public UniformRandomSARSCollector(List actionTypes) {
			super(actionTypes);
		}

		@Override
		public SARSData collectDataFrom(State s, SampleModel model, int maxSteps, SARSData intoDataset) {
			
			if(intoDataset == null){
				intoDataset = new SARSData();
			}
			
			State curState = s;
			int nsteps = 0;
			boolean terminated = model.terminal(s);
			while(!terminated && nsteps < maxSteps){
				
				List gas = ActionUtils.allApplicableActionsForTypes(this.actionTypes, curState);
				Action ga = gas.get(RandomFactory.getMapped(0).nextInt(gas.size()));
				EnvironmentOutcome eo = model.sample(curState, ga);
				intoDataset.add(curState, ga, eo.r, eo.op);
				curState = eo.op;
				terminated = eo.terminated;
				nsteps++;
				
			}
			
			
			return intoDataset;
			
		}

		@Override
		public SARSData collectDataFrom(Environment env, int maxSteps, SARSData intoDataset) {

			if(intoDataset == null){
				intoDataset = new SARSData();
			}

			int nsteps = 0;
			while(!env.isInTerminalState() && nsteps < maxSteps){
				List gas = ActionUtils.allApplicableActionsForTypes(this.actionTypes, env.currentObservation());
				Action ga = gas.get(RandomFactory.getMapped(0).nextInt(gas.size()));
				EnvironmentOutcome eo = env.executeAction(ga);
				intoDataset.add(eo.o, eo.a, eo.r, eo.op);

				nsteps++;
			}

			return intoDataset;
		}
	}
	
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy