org.apache.mahout.ep.State Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-mr Show documentation
Show all versions of mahout-mr Show documentation
Scalable machine learning libraries
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.ep;
import com.google.common.collect.Lists;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.sgd.PolymorphicWritable;
import org.apache.mahout.common.RandomUtils;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Locale;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Records evolutionary state and provides a mutation operation for recorded-step meta-mutation.
*
* You provide the payload, this class provides the mutation operations. During mutation,
* the payload is copied and after the state variables are changed, they are passed to the
* payload.
*
* Parameters are internally mutated in a state space that spans all of R^n, but parameters
* passed to the payload are transformed as specified by a call to setMap(). The default
* mapping is the identity map, but uniform-ish or exponential-ish coverage of a range are
* also supported.
*
* More information on the underlying algorithm can be found in the following paper
*
* http://arxiv.org/abs/0803.3838
*
* @see Mapping
*/
public class State, U> implements Comparable>, Writable {
// object count is kept to break ties in comparison.
private static final AtomicInteger OBJECT_COUNT = new AtomicInteger();
private int id = OBJECT_COUNT.getAndIncrement();
private Random gen = RandomUtils.getRandom();
// current state
private double[] params;
// mappers to transform state
private Mapping[] maps;
// omni-directional mutation
private double omni;
// directional mutation
private double[] step;
// current fitness value
private double value;
private T payload;
public State() {
}
/**
* Invent a new state with no momentum (yet).
*/
public State(double[] x0, double omni) {
params = Arrays.copyOf(x0, x0.length);
this.omni = omni;
step = new double[params.length];
maps = new Mapping[params.length];
}
/**
* Deep copies a state, useful in mutation.
*/
public State copy() {
State r = new State<>();
r.params = Arrays.copyOf(this.params, this.params.length);
r.omni = this.omni;
r.step = Arrays.copyOf(this.step, this.step.length);
r.maps = Arrays.copyOf(this.maps, this.maps.length);
if (this.payload != null) {
r.payload = (T) this.payload.copy();
}
r.gen = this.gen;
return r;
}
/**
* Clones this state with a random change in position. Copies the payload and
* lets it know about the change.
*
* @return A new state.
*/
public State mutate() {
double sum = 0;
for (double v : step) {
sum += v * v;
}
sum = Math.sqrt(sum);
double lambda = 1 + gen.nextGaussian();
State r = this.copy();
double magnitude = 0.9 * omni + sum / 10;
r.omni = magnitude * -Math.log1p(-gen.nextDouble());
for (int i = 0; i < step.length; i++) {
r.step[i] = lambda * step[i] + r.omni * gen.nextGaussian();
r.params[i] += r.step[i];
}
if (this.payload != null) {
r.payload.update(r.getMappedParams());
}
return r;
}
/**
* Defines the transformation for a parameter.
* @param i Which parameter's mapping to define.
* @param m The mapping to use.
* @see org.apache.mahout.ep.Mapping
*/
public void setMap(int i, Mapping m) {
maps[i] = m;
}
/**
* Returns a transformed parameter.
* @param i The parameter to return.
* @return The value of the parameter.
*/
public double get(int i) {
Mapping m = maps[i];
return m == null ? params[i] : m.apply(params[i]);
}
public int getId() {
return id;
}
public double[] getParams() {
return params;
}
public Mapping[] getMaps() {
return maps;
}
/**
* Returns all the parameters in mapped form.
* @return An array of parameters.
*/
public double[] getMappedParams() {
double[] r = Arrays.copyOf(params, params.length);
for (int i = 0; i < params.length; i++) {
r[i] = get(i);
}
return r;
}
public double getOmni() {
return omni;
}
public double[] getStep() {
return step;
}
public T getPayload() {
return payload;
}
public double getValue() {
return value;
}
public void setOmni(double omni) {
this.omni = omni;
}
public void setId(int id) {
this.id = id;
}
public void setStep(double[] step) {
this.step = step;
}
public void setMaps(Mapping[] maps) {
this.maps = maps;
}
public void setMaps(Iterable maps) {
Collection list = Lists.newArrayList(maps);
this.maps = list.toArray(new Mapping[list.size()]);
}
public void setValue(double v) {
value = v;
}
public void setPayload(T payload) {
this.payload = payload;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof State)) {
return false;
}
State,?> other = (State,?>) o;
return id == other.id && value == other.value;
}
@Override
public int hashCode() {
return RandomUtils.hashDouble(value) ^ id;
}
/**
* Natural order is to sort in descending order of score. Creation order is used as a
* tie-breaker.
*
* @param other The state to compare with.
* @return -1, 0, 1 if the other state is better, identical or worse than this one.
*/
@Override
public int compareTo(State other) {
int r = Double.compare(other.value, this.value);
if (r != 0) {
return r;
}
if (this.id < other.id) {
return -1;
}
if (this.id > other.id) {
return 1;
}
return 0;
}
@Override
public String toString() {
double sum = 0;
for (double v : step) {
sum += v * v;
}
return String.format(Locale.ENGLISH, "", payload, omni + Math.sqrt(sum), value);
}
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(id);
out.writeInt(params.length);
for (double v : params) {
out.writeDouble(v);
}
for (Mapping map : maps) {
PolymorphicWritable.write(out, map);
}
out.writeDouble(omni);
for (double v : step) {
out.writeDouble(v);
}
out.writeDouble(value);
PolymorphicWritable.write(out, payload);
}
@Override
public void readFields(DataInput input) throws IOException {
id = input.readInt();
int n = input.readInt();
params = new double[n];
for (int i = 0; i < n; i++) {
params[i] = input.readDouble();
}
maps = new Mapping[n];
for (int i = 0; i < n; i++) {
maps[i] = PolymorphicWritable.read(input, Mapping.class);
}
omni = input.readDouble();
step = new double[n];
for (int i = 0; i < n; i++) {
step[i] = input.readDouble();
}
value = input.readDouble();
payload = (T) PolymorphicWritable.read(input, Payload.class);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy