
com.github.chen0040.sparkml.gp.SparkLGP Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of spark-ml-genetic-programming Show documentation
Show all versions of spark-ml-genetic-programming Show documentation
Genetic Programming algorithms implemented in Java and for Apache Spark
The newest version!
package com.github.chen0040.sparkml.gp;
import com.github.chen0040.gp.commons.BasicObservation;
import com.github.chen0040.gp.lgp.LGP;
import com.github.chen0040.gp.lgp.gp.Population;
import com.github.chen0040.gp.lgp.program.Program;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import scala.Tuple2;
/**
* Created by xschen on 6/6/2017.
*/
public class SparkLGP extends LGP {
private JavaRDD observationRdd;
private Function, Double> perObservationCostEvaluator;
public void setPerObservationCostEvaluator(Function, Double> perObservationCostEvaluator) {
this.perObservationCostEvaluator = perObservationCostEvaluator;
}
public void setObservationRdd(JavaRDD observationRdd) {
setObservationRdd(observationRdd, -1);
}
public void setObservationRdd(JavaRDD observationRdd, int partitionCount) {
if(partitionCount == -1) {
this.observationRdd = observationRdd.cache();
} else {
this.observationRdd = observationRdd.coalesce(partitionCount).cache();
}
}
@Override
public double evaluateCost(Program program) {
program.markStructuralIntrons(this);
program = program.makeEffectiveCopy();
JavaSparkContext context = JavaSparkContext.fromSparkContext(observationRdd.context());
Broadcast programBroadcast = context.broadcast(program);
Tuple2 result = observationRdd.map(observation -> {
Program p = programBroadcast.getValue();
return new Tuple2<>(p, observation);
}).map(perObservationCostEvaluator)
.map(a -> new Tuple2<>(a, 1))
.reduce((a, b) -> new Tuple2<>(a._1() + b._1(), a._2() + b._2()));
double cost = result._1() / result._2();
programBroadcast.destroy();
return cost;
}
public Program fit(JavaRDD trainingData) {
setObservationRdd(trainingData);
long startTime = System.currentTimeMillis();
Population pop = this.newPopulation();
pop.initialize();
while (!pop.isTerminated())
{
pop.evolve();
if(getDisplayEvery() > 0 && pop.getCurrentGeneration() % getDisplayEvery() == 0) {
long seconds = (System.currentTimeMillis() - startTime) / 1000;
System.out.println("Generation: " + pop.getCurrentGeneration() + " (Pop: " + pop.size() + "), elapsed: " + seconds + " seconds");
System.out.println("Global Cost: " + pop.getGlobalBestProgram().getCost() + "\tCurrent Cost: " + pop.getCostInCurrentGeneration());
}
}
return pop.getGlobalBestProgram();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy