com.linkedin.photon.ml.hyperparameter.criteria.ExpectedImprovement.scala Maven / Gradle / Ivy
/*
* Copyright 2017 LinkedIn Corp. All rights reserved.
* Licensed under the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License. You may obtain a
* copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.linkedin.photon.ml.hyperparameter.criteria
import breeze.linalg.DenseVector
import breeze.numerics.sqrt
import breeze.stats.distributions.Gaussian
import com.linkedin.photon.ml.hyperparameter.estimators.PredictionTransformation
/**
* Expected improvement selection criterion. This transformation produces the expected improvement of the model
* predictions (over the current "best" value).
*
* @see "Practical Bayesian Optimization of Machine Learning Algorithms" (PBO),
* https://papers.nips.cc/paper/4522-practical-bayesian-optimization-of-machine-learning-algorithms.pdf
*
* @param bestEvaluation The current best evaluation
*/
class ExpectedImprovement(bestEvaluation: Double) extends PredictionTransformation {
// Maximize EI to minimize the evaluation value.
def isMaxOpt: Boolean = true
private val standardNormal = new Gaussian(0, 1)
/**
* Applies the expected improvement transformation to the model output.
*
* @param predictiveMeans Predictive mean output from the model
* @param predictiveVariances Predictive variance output from the model
* @return The expected improvement over the current best evaluation
*/
def apply(
predictiveMeans: DenseVector[Double],
predictiveVariances: DenseVector[Double]): DenseVector[Double] = {
val std = sqrt(predictiveVariances)
// PBO Eq. 1
val gamma = - (predictiveMeans - bestEvaluation) / std
// Eq. 2
std :* ((gamma :* gamma.map(standardNormal.cdf)) + gamma.map(standardNormal.pdf))
}
}