All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.sri.ai.util.experiment.Experiment Maven / Gradle / Ivy

/*
 * Copyright (c) 2013, SRI International
 * All rights reserved.
 * Licensed under the The BSD 3-Clause License;
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at:
 * 
 * http://opensource.org/licenses/BSD-3-Clause
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 
 * Redistributions of source code must retain the above copyright
 * notice, this list of conditions and the following disclaimer.
 * 
 * Redistributions in binary form must reproduce the above copyright
 * notice, this list of conditions and the following disclaimer in the
 * documentation and/or other materials provided with the distribution.
 * 
 * Neither the name of the aic-util nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
 * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 
 * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 
 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
 * OF THE POSSIBILITY OF SUCH DAMAGE.
 */
package com.sri.ai.util.experiment;

import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import com.google.common.annotations.Beta;
import com.sri.ai.util.Util;
import com.sri.ai.util.gnuplot.DataSeries;
import com.sri.ai.util.gnuplot.Gnuplot;
import com.sri.ai.util.rangeoperation.api.DAEFunction;
import com.sri.ai.util.rangeoperation.api.DependencyAwareEnvironment;
import com.sri.ai.util.rangeoperation.api.Range;
import com.sri.ai.util.rangeoperation.api.RangeOperation;
import com.sri.ai.util.rangeoperation.core.AbstractDAEFunction;
import com.sri.ai.util.rangeoperation.core.RangeOperationsInterpreter;
import com.sri.ai.util.rangeoperation.library.rangeoperations.Dimension;

@Beta
public class Experiment {

	public static String[] guaranteedPreCommands = {
		"set xlabel font 'Arial, 10",
		"set ylabel font 'Arial, 10'",
		"set title  font 'Arial, 20'"
	};

	/** 
	 * Runs an experiment and generates a gnuplot with its results.
	 * 

* The method works by generating a data matrix (according to {@link RangeOperationsInterpreter}) * with one or two {@link Dimension}s. * The plot may contain one or more data series (shown as lines on the plot). * A {@link DataSeriesSpec} object defines which {@link Dimension} is used to define multiple data series, * as well as data series titles and format. *

* There is a commented example at the end of this documentation block. *

* Arguments to this method can be of three possible types: *

    *
  • String values representing variable names * followed by another argument representing the variable's value. * These variables can either be gnuplot parameters (see reference below), * or fixed parameters stored in a {@link DependencyAwareEnvironment} to be used * by the experiment (via {@link DAEFunction}s -- see below). *
  • {@link RangeOperation}s, which work like for-loops (creating data matrix {@link Dimension}s) * or aggregate operations for a specified variable. *
  • a single occurrence of a {@link DAEFunction}; this is the main argument since * this function is responsible for computing the experiment's reported results. * It can run whatever code one wishes, including calls to * {@link DependencyAwareEnvironment#get(String)}, * {@link DependencyAwareEnvironment#getOrUseDefault(String, Object)} and * {@link DependencyAwareEnvironment#getResultOrRecompute(DAEFunction)} * to gain access to the variables defined by other arguments * (both fixed and iterated by range operations). * The interaction between range operations and a DAEFunction * produces a generalized (multidimensional) matrix (see {@link RangeOperationsInterpreter} for details). * In {@link Experiment}, the range operations and DAEFunction must be structured * so that the resulting matrix is a regular rows-and-columns matrix * in which each row is the data for an individual data series to appear in the plot. *
  • a single {@link DataSeriesSpec}, which declares gnuplot directives for each of * the data series (the several graph lines) in the graph ({@link DataSeries}). * We obtain one data series per value of a variable specified by * the {@link DataSeriesSpec}. *
* Arguments can be provided in any order, although the order of range operations does matter * (their iterations are nested in the order they are given). *

* The variable on the graph's x-axis is the one specified by the first {@link Dimension} argument * that is not the variable labeling the multiple data series. *

* The recognized gnuplot parameters are the following: *

    *
  • title: the graph's title *
  • xlabel: the label of the x axis (default is variable name in x-{@link Dimension}). *
  • ylabel: the label of the y axis (default is empty string). *
  • filename: the name of a file (without an extension) to which to record the graph; * the extension ".ps" is automatically added. *
  • file: boolean value indicating whether to record the graph, using title as filename. *
  • print: same as file *
* If file recording is disabled, gnuplot persists (keeps open) and shows the graph; * otherwise, it closes as soon as the graph is recorded. *

* Consider the example: *

	 * 		experiment(
	 *  		"file", "false",
	 *  		"title", "The more samples, the less variance",
	 *  		"xlabel", "Number of samples",
	 *  		"ylabel", "Average of Uniform[0,1] + mean - 0.5",
	 *  		"some unused variable that could be used if we wanted to", 10,
	 *  		new {@link Dimension}("mean", Util.list(2, 3)),
	 *  		new {@link Dimension}("numSamples", 1, 1000, 1),
	 *  		averageOfNumSamplesOfUniformPlusCurrentMeanMinusZeroPointFive,
	 *  		{@link DataSeriesSpec}("mean", Util.list(
	 *  				Util.list("title 'mean 2'", "w linespoints"),
	 *  				Util.list("title 'mean 3'", "w linespoints"))));
	 * 
* Here, some gnuplot parameters and a (unused) variable are introduced. * Then ranging operations {@link Dimension} is used to vary variables "mean" and "numSamples" * across a range of values. * The function averageOfNumSamplesOfUniformPlusCurrentMeanMinusZeroPointFive uses them to compute * the elements of a matrix. * The {@link DataSeriesSpec} specifies that the "mean" dimension is the one determining * the individual data series for the plot * (and as a consequence the "numSamples" dimension is selected for the x-axis), * and specifies their gnuplot labels and styles as well. * Note that the plot's y-axis does not correspond to any dimensions of the matrix, but to its values * (the ones computed by the {@link DAEFunction}). * * @param arguments * the experiments arguments. */ public static void experiment(Object ... arguments) { List> dimensions = getDimensions(arguments); DataSeriesSpec dataSeriesDimensionSpec = getDataSeriesDimensionSpec(arguments); Range xSeries = getXDimensionRange(dimensions, dataSeriesDimensionSpec); List data = (List) RangeOperationsInterpreter.apply(arguments); List> dataSeriesList = getDataSeriesList(data, dimensions, dataSeriesDimensionSpec); Map properties = Util.getMapWithStringKeys(arguments); if ( ! properties.containsKey("xlabel")) { final Dimension xDimension = getXDimension(dimensions, dataSeriesDimensionSpec); String name = xDimension == null? "x" : xDimension.getRange().getName(); properties.put("xlabel", name); } LinkedList preCommands = getPreCommands(properties); Gnuplot.plot(preCommands, xSeries, dataSeriesList); } private static LinkedList getPreCommands(Map properties) { LinkedList preCommands = new LinkedList(); for (String preCommand : guaranteedPreCommands) { preCommands.add(preCommand); } if (properties.containsKey("title")) { preCommands.add("set title '" + properties.get("title") + "'"); } if (properties.containsKey("xlabel")) { preCommands.add("set xlabel '" + properties.get("xlabel") + "'"); } if (properties.containsKey("ylabel")) { preCommands.add("set ylabel '" + properties.get("ylabel") + "'"); } if (writesToFile(properties)) { preCommands.add("set term postscript color"); preCommands.add("set output '" + filename(properties) + ".ps'"); } else { preCommands.add("persist"); } return preCommands; } private static boolean writesToFile(Map properties) { boolean result = properties.containsKey("filename") || Util.getOrUseDefault(properties, "print", "false").equals("true") || Util.getOrUseDefault(properties, "file", "false").equals("true"); return result; } private static String filename(Map properties) { String filename = (String) Util.getOrUseDefault(properties, (String) properties.get("filename"), properties.get("title")); if (filename == null) { filename = "unnamed"; } return filename; } private static List> getDimensions(Object ... arguments) { List> result = new LinkedList>(); for (Object object : arguments) { if (object instanceof Dimension) { @SuppressWarnings("unchecked") Dimension dimension = (Dimension) object; result.add(dimension); } } return result; } /** * A class indicating the variable corresponding to a data series in a graph, * as well as its directives (see {@link Gnuplot}). */ public static class DataSeriesSpec { private String variable; private List> directivesList; public DataSeriesSpec(String variable, List> directivesList) { this.variable = variable; this.directivesList = directivesList; } public String getName() { return variable; } } private static DataSeriesSpec getDataSeriesDimensionSpec(Object... arguments) { return (DataSeriesSpec) Util.getObjectOfClass(DataSeriesSpec.class, arguments); } private static Range getXDimensionRange(List> dimensions, DataSeriesSpec dataSeriesDimensionSpec) { Dimension xDimension = getXDimension(dimensions, dataSeriesDimensionSpec); final Range result = xDimension != null ? xDimension.getRange() : null; return result; } private static Dimension getXDimension(List> dimensions, DataSeriesSpec dataSeriesDimensionSpec) { for (Dimension dimension : dimensions) { if ( ! dimension.getRange().getName().equals(dataSeriesDimensionSpec.variable)) { return dimension; } } return null; } private static Dimension getDataSeriesDimension(List> dimensions, DataSeriesSpec dataSeriesSpec) { for (Dimension dimension : dimensions) { if (dimension.getRange().getName().equals(dataSeriesSpec.variable)) { return dimension; } } return null; } private static List> getDataSeriesList(List data, List> dimensions, DataSeriesSpec dataSeriesSpec) { List> dataSeriesList = new LinkedList>(); Dimension dataSeriesDimension = getDataSeriesDimension(dimensions, dataSeriesSpec); if (dataSeriesDimension != null) { int dimension = dimensions.indexOf(dataSeriesDimension); Iterator rangeIterator = dataSeriesDimension.getRange().apply(); Iterator> directiveIterator = dataSeriesSpec.directivesList.iterator(); int sliceIndex = 0; while(rangeIterator.hasNext()) { rangeIterator.next(); if ( ! directiveIterator.hasNext()) { throw new Error("DataSeriesSpec on '" + dataSeriesSpec.getName() + "' does not have enough directives (it needs one per value of '" + dataSeriesSpec.getName() + "')"); } List directives = directiveIterator.next(); @SuppressWarnings("unchecked") List dataSeriesData = Util.matrixSlice((List>) data, dimension, sliceIndex); dataSeriesList.add(new DataSeries(directives, dataSeriesData)); sliceIndex++; } } else { if (dimensions.size() > 1) { Util.fatalError("DataSeriesSpec " + dataSeriesSpec + " does not refer to any present dimension and data is multidimensional."); } List directives = Util.getFirst(dataSeriesSpec.directivesList); @SuppressWarnings("unchecked") List dataList = data; dataSeriesList.add(new DataSeries(directives, dataList)); } return dataSeriesList; } /** * An extension of {@link List} for keeping pre-commands for a {@link Gnuplot} graph. */ @SuppressWarnings("serial") private static class PreCommands extends LinkedList { public PreCommands(String ... preCommands) { addAll(Arrays.asList(preCommands)); } } public static PreCommands preCommands(String ... preCommands) { return new PreCommands(preCommands); } public static PreCommands getPreCommands(Object ... arguments) { PreCommands preCommands = (PreCommands) Util.getObjectOfClass(PreCommands.class, arguments); if (preCommands == null) { preCommands = new PreCommands(); } return preCommands; } private static class Title { public Title(String value) { buffer.append(value); } @Override public String toString() { return buffer.toString(); } private StringBuffer buffer = new StringBuffer(); } public static Title Title(String value) { return new Title(value); } public static String getTitle(Object ... args) { return Util.getObjectOfClass(Title.class, args).toString(); } private static DAEFunction averageOfNumSamplesOfUniformPlusCurrentMeanMinusZeroPointFive = new AbstractDAEFunction() { @Override public Object apply(DependencyAwareEnvironment environment) { int numberOfSamples = environment.getInt("numSamples"); int mean = environment.getInt("mean"); double sum = 0; for (int i = 0; i != numberOfSamples; i++) { sum += (mean - 0.5) + Math.random(); } double result = sum/numberOfSamples; return result; } }; @SuppressWarnings("unchecked") public static void main(String[] args) { experiment( "file", "false", "title", "The more samples, the less variance", "xlabel", "Number of samples", "ylabel", "Average of Uniform[0,1] + mean - 0.5", "some unused variable that could be used if we wanted to", 10, new Dimension("mean", Util.list(2, 3)), new Dimension("numSamples", 1, 200, 1), averageOfNumSamplesOfUniformPlusCurrentMeanMinusZeroPointFive, new DataSeriesSpec("mean", Util.list( Util.list("title 'mean 2'", "w linespoints"), Util.list("title 'mean 3'", "w linespoints")))); } }