All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.MTA Maven / Gradle / Ivy

package com.expleague.ml.methods;

import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.MxTools;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;

import static com.expleague.commons.math.MathTools.sqr;

/**
 * Created by noxoomo on 23/10/14.
 * multitask averaging — http://static.googleusercontent.com/media/research.google.com/ru//pubs/archive/42935.pdf
 * stein — http://en.wikipedia.org/wiki/James%E2%80%93Stein_estimator
 */
public class MTA {
  private final double[][] tasks;
  private final double[] sum;
  private final double[] sum2;
  private final double[] sigma;

  public MTA(final double[][] samples) {
    this.tasks = samples;
    sum = new double[samples.length];
    sum2 = new double[samples.length];
    for (int task = 0; task < tasks.length; ++task) {
      for (int i = 0; i < samples[task].length; ++i) {
        final double val = samples[task][i];
        sum[task] += val;
        sum2[task] += val * val;
      }
    }
    sigma = new double[tasks.length];
    for (int i = 0; i < tasks.length; ++i) {
      final double n = tasks[i].length;
      sigma[i] = n > 1 ? (sum2[i] - sum[i] * sum[i] / n) / (n - 1) : 0;
      sigma[i] /= n;
    }
  }

  public double[] stein(final double[] prior) {
    final double[] means = new double[sum.length];
    for (int i = 0; i < means.length; ++i) {
      means[i] = sum[i] / tasks[i].length;
    }
    final double[] sigma = new double[sum.length];
    double sigmaMax = Double.NEGATIVE_INFINITY;
    double sigmaSum = 0;
    for (int i = 0; i < sigma.length; ++i) {
      final int n = tasks[i].length;
      sigma[i] = n > 1 ? (sum2[i] - sum[i] * sum[i] / n) / (n - 1) : sum2[i] - sum[i] * sum[i];
      sigmaMax = Math.max(sigmaMax, sigma[i]);
      sigmaSum += sigma[i];
    }
    double lambda = sigmaSum / sigmaMax - 2;
    double denum = 0;
    for (int i = 0; i < sum.length; ++i) {
      final double diff = (means[i] - prior[i]);
      denum += diff * diff / sigma[i];
    }
    lambda /= denum;
    lambda = 1 - lambda;
    lambda = lambda > 0 ? lambda : 0;
    for (int i = 0; i < means.length; ++i) {
      means[i] = prior[i] + lambda * (means[i] - prior[i]);
    }
    return means;
  }

  public static Vec naiveStein(final Vec means) {
    final Vec js = VecTools.copy(means);
    VecTools.scale(js, (1 - (js.dim() - 2) / MathTools.sqr(VecTools.norm(means))));
    return js;
  }

  public double[] stein() {
    double prior = 0;
    final double[] means = new double[sum.length];
    for (int i = 0; i < means.length; ++i) {
      means[i] = sum[i] / tasks[i].length;
      prior += means[i];
    }
    prior /= tasks.length;
    final double[] sigma = new double[sum.length];
    double sigmaMax = Double.NEGATIVE_INFINITY;
    double sigmaSum = 0;
    for (int i = 0; i < sigma.length; ++i) {
      final int n = tasks[i].length;
      sigma[i] = n > 1 ? (sum2[i] - sum[i] * sum[i] / n) / (n - 1) : sum2[i] - sum[i] * sum[i];
      sigmaMax = Math.max(sigmaMax, sigma[i]);
      sigmaSum += sigma[i];
    }
    double lambda = sigmaSum / sigmaMax - 3;
    double denum = 0;
    for (int i = 0; i < sum.length; ++i) {
      final double diff = (means[i] - prior) * tasks[i].length;
      denum += diff * diff / sigma[i];
    }
    lambda = denum > 0 ? 1 - lambda / denum : 0;
    lambda = lambda > 0 ? lambda : 0;
    for (int i = 0; i < means.length; ++i) {
      means[i] = prior + lambda * (means[i] - prior);
    }
    return means;
  }


  public double[] stein(final double prior) {
    final double[] means = new double[sum.length];
    for (int i = 0; i < means.length; ++i) {
      means[i] = sum[i] / tasks[i].length;
    }
    if (tasks.length < 4) {
      return means;
    }
    final double[] sigma = new double[sum.length];
    double sigmaMax = Double.NEGATIVE_INFINITY;
    double sigmaSum = 0;
    for (int i = 0; i < sigma.length; ++i) {
      final int n = tasks[i].length;
      sigma[i] = n > 1 ? (sum2[i] - sum[i] * sum[i] / n) / (n - 1) : sum2[i] - sum[i] * sum[i];
      sigmaMax = Math.max(sigmaMax, sigma[i]);
      sigmaSum += sigma[i];
    }
    double lambda = sigmaSum / sigmaMax - 3;
    double denum = 0;
    for (int i = 0; i < sum.length; ++i) {
      final double diff = (means[i] - prior) * tasks[i].length;
      denum += diff * diff / sigma[i];
    }
    lambda = denum > 0 ? 1 - lambda / denum : 0;
    lambda = lambda > 0 ? lambda : 0;
    for (int i = 0; i < means.length; ++i) {
      means[i] = prior + lambda * (means[i] - prior);
    }
    return means;
  }

  public double[] steinBernoulli() {
    final double[] means = new double[sum.length];
    for (int i = 0; i < means.length; ++i) {
      means[i] = sum[i] / tasks[i].length;
    }
    double norm = 0;
    double tr = 0;
    double lambdaMax = 0;
    for (int i = 0; i < sum.length; ++i) {
      final double diff = means[i];
      lambdaMax = Math.max(lambdaMax, 1.0 / 4 / tasks[i].length);
      norm += diff * diff * tasks[i].length * 4;
      tr += 1.0 / 4 / tasks[i].length;
    }
    double lambda = norm > 0 ? 1 - (tr / lambdaMax - 2) / (norm) : 1.0;
    lambda = lambda > 0 ? lambda : 0;
    lambda = lambda < 1 ? lambda : 1;
    for (int i = 0; i < means.length; ++i) {
      means[i] = lambda * means[i];
    }
    return means;
  }


  public static double[] bernoulliMTA(final double[] sum, final double[] counts) {
    final Vec means = new ArrayVec(sum.length);
    for (int i = 0; i < sum.length; ++i) {
      means.set(i, sum[i] / counts[i]);
    }
    final Mx A = new VecBasedMx(sum.length, sum.length);
    for (int i = 0; i < sum.length; ++i) {
      for (int j = i + 1; j < sum.length; ++j) {
        final double dist = sum[i] / counts[i] - sum[j] / counts[j];
        A.set(i, j, dist > 1e-9 ? 2.0 / dist * dist : 2.0 * 1e18);
        A.set(j, i, dist > 1e-9 ? 2.0 / dist * dist : 2.0 * 1e18);
      }
    }

    final double[] sigma = new double[sum.length];
    for (int i = 0; i < sigma.length; ++i) {
      final double p = means.get(i);
      sigma[i] = counts[i] > 1 ? p * (1 - p) / (counts[i] - 1) : p * (1 - p);
    }

    final Mx L = MxTools.laplacian(A);
    final Mx W = new VecBasedMx(L.rows(), L.columns());
    for (int row = 0; row < L.rows(); ++row) {
      for (int col = 0; col < L.columns(); ++col) {
        W.set(row, col, (row == col ? 1 : 0) + sigma[row] * L.get(row, col) / sum.length);
      }
    }
    final Mx inverse = MxTools.inverse(W);
    return MxTools.multiply(inverse, means).toArray();
  }


  public static double[] bernoulliStein(final double[] sum, final double[] counts) {
    final double[] means = new double[sum.length];
    for (int i = 0; i < means.length; ++i) {
      means[i] = sum[i] / counts[i];
    }
    double norm = 0;
    double tr = 0;
    double lambdaMax = 0;
    for (int i = 0; i < sum.length; ++i) {
      final double diff = means[i];
      lambdaMax = Math.max(lambdaMax, 1.0 / 4 / counts[i]);
      norm += diff * diff * counts[i] * 4;
      tr += 1.0 / 4 / counts[i];
    }
    double lambda = norm > 0 ? 1 - (tr / lambdaMax - 2) / (norm) : 1.0;
    lambda = lambda > 0 ? lambda : 0;
    lambda = lambda < 1 ? lambda : 1;
    for (int i = 0; i < means.length; ++i) {
      means[i] = lambda * means[i];
    }
    return means;
  }


  public static Vec stein(final Vec means) {
    final double sigma = 1;
    final Vec js = VecTools.copy(means);
    VecTools.scale(js, (1 - (js.dim() - 2) * sigma * sigma / MathTools.sqr(VecTools.norm(js))));
    return js;
  }


  public static double[] bernoulliConst(final double[] weights) {
    double min = Double.POSITIVE_INFINITY;
    double max = Double.NEGATIVE_INFINITY;
    for (int i = 0; i < weights.length; ++i) {
      final double p = weights[i];
      min = p < min ? p : min;
      max = p > max ? p : max;
    }
    final double a = Math.abs(min - max) > 1e-9 ? 2.0 / (weights.length * (min - max) * (min - max)) : 2e18;
    final double[] Z = new double[weights.length];
    final double[] z = new double[weights.length];
    for (int i = 0; i < weights.length; ++i) {
      Z[i] = 1 + a * weights.length;
      Z[i] = 1.0 / Z[i];
      z[i] = a * (weights[i] * (1 - weights[i])) / weights.length;
    }

    final double[] Zz = new double[weights.length];
    final double[] result = new double[weights.length];
    double denum = 1;
    double num = 0;
    for (int i = 0; i < weights.length; ++i) {
      Zz[i] = Z[i] * z[i];
      denum -= Zz[i];
      result[i] = Z[i] * weights[i];
      num += result[i];
    }
    for (int i = 0; i < result.length; ++i) {
      result[i] += num * Zz[i] / denum;
    }

    return result;
  }


  public Mx bernoulliSimilarity() {
    final double[] upperBounds = new double[sum.length];
    final double[] lowerBounds = new double[sum.length];
    final double[] means = new double[sum.length];
    for (int i = 0; i < sum.length; ++i) {
      final double n = tasks[i].length;
      final double p = sum[i] / n;
      means[i] = p;
      final double z = 0;
//      upperBounds[i] = p + 1.96 * Math.sqrt( p * (1-p) / n);
      upperBounds[i] = 1 / (1 + z * z / n);
      upperBounds[i] *= (p + z * z / (2 * n) + z * Math.sqrt(p * (1 - p) / n + z * z / (4 * n * n)));
      upperBounds[i] = Math.min(upperBounds[i], 1);
//      lowerBounds[i] =Math.max(p - 1.96 * Math.sqrt( p * (1-p) / n), 0);
      lowerBounds[i] = 1 / (1 + z * z / n);
      lowerBounds[i] *= (p + z * z / (2 * n) - z * Math.sqrt(p * (1 - p) / n + z * z / (4 * n * n)));
      lowerBounds[i] = Math.max(lowerBounds[i], 0);
    }
    final Mx A = new VecBasedMx(sum.length, sum.length);
    for (int i = 0; i < sum.length; ++i) {
      for (int j = i + 1; j < sum.length; ++j) {
        final double upDist = Math.max(Math.abs(upperBounds[j] - lowerBounds[i]), Math.abs(upperBounds[j] - lowerBounds[i]));
//        final double upDist =0.5*(Math.abs(upperBounds[j] - lowerBounds[i])+Math.abs(upperBounds[j] - lowerBounds[i])));
        A.set(i, j, upDist > 0 ? 2.0 / upDist * upDist : 2.0 * 1e12);
        A.set(j, i, upDist > 0 ? 2.0 / upDist * upDist : 2.0 * 1e12);
      }
    }
    return A;
  }

  public Vec oracle(final Mx A) {
    final Mx L = MxTools.laplacian(A);
    final Mx W = new VecBasedMx(L.rows(), L.columns());
    for (int row = 0; row < L.rows(); ++row) {
      for (int col = 0; col < L.columns(); ++col) {
        W.set(row, col, (row == col ? 1 : 0) + sigma[row] * L.get(row, col) / tasks.length);
      }
    }
    final Vec means = new ArrayVec(sum.length);
    for (int i = 0; i < sum.length; ++i) {
      means.set(i, sum[i] / tasks[i].length);
    }
    final Mx inverse = MxTools.inverse(W);
    return MxTools.multiply(inverse, means);
  }

  public double[] classic() {
    final double[] result = new double[tasks.length];
    for (int i = 0; i < tasks.length; ++i) {
      result[i] = sum[i] / tasks[i].length;
    }
    return result;
  }

  public double[] mtaConst() {
    double a = 0;
    for (int i = 0; i < sum.length; ++i) {
      for (int j = i + 1; j < sum.length; ++j) {
        final double val = sum[i] / tasks[i].length - sum[j] / tasks[j].length;
        a += val * val;
      }
    }
    a = a > 0 ? tasks.length * (tasks.length - 1) / a : tasks.length * (tasks.length - 1) / 1e-12;
    final double[] Z = new double[tasks.length];
    final double[] z = new double[tasks.length];
    for (int i = 0; i < tasks.length; ++i) {
      Z[i] = 1 + a * sigma[i];
      Z[i] = 1.0 / Z[i];
      z[i] = a * sigma[i] / tasks.length;
    }

    final double[] Zz = new double[tasks.length];
    final double[] result = new double[tasks.length];
    double denum = 1;
    double num = 0;
    for (int i = 0; i < tasks.length; ++i) {
      Zz[i] = Z[i] * z[i];
      denum -= Zz[i];
      result[i] = Z[i] * (sum[i] / tasks[i].length);
      num += result[i];
    }
    for (int i = 0; i < result.length; ++i) {
      result[i] += num * Zz[i] / denum;
    }

    return result;
  }

  public double[] mtaMiniMax() {
    double min = Double.POSITIVE_INFINITY;
    double max = Double.NEGATIVE_INFINITY;
    for (int i = 0; i < sum.length; ++i) {
      final double p = sum[i] / tasks[i].length;
      min = p < min ? p : min;
      max = p > max ? p : max;
    }
    final double a = Math.abs(min - max) > 1e-9 ? 2.0 / (tasks.length * (min - max) * (min - max)) : 2e18;
    final double[] Z = new double[tasks.length];
    final double[] z = new double[tasks.length];
    for (int i = 0; i < tasks.length; ++i) {
      Z[i] = 1 + a * sigma[i];
      Z[i] = 1.0 / Z[i];
      z[i] = a * sigma[i] / tasks.length;
    }

    final double[] Zz = new double[tasks.length];
    final double[] result = new double[tasks.length];
    double denum = 1;
    double num = 0;
    for (int i = 0; i < tasks.length; ++i) {
      Zz[i] = Z[i] * z[i];
      denum -= Zz[i];
      result[i] = Z[i] * (sum[i] / tasks[i].length);
      num += result[i];
    }
    for (int i = 0; i < result.length; ++i) {
      result[i] += num * Zz[i] / denum;
    }
    return result;
  }


  public double[] mtaMiniMaxBernoulli() {
    final double min = 0;
    final double max = 1;
    final double a = Math.abs(min - max) > 1e-9 ? 2.0 / (tasks.length * (min - max) * (min - max)) : 2e18;
    final double[] Z = new double[tasks.length];
    final double[] z = new double[tasks.length];
    for (int i = 0; i < tasks.length; ++i) {
      Z[i] = 1 + a * sigma[i];
      Z[i] = 1.0 / Z[i];
      z[i] = a * sigma[i] / tasks.length;
    }

    final double[] Zz = new double[tasks.length];
    final double[] result = new double[tasks.length];
    double denum = 1;
    double num = 0;
    for (int i = 0; i < tasks.length; ++i) {
      Zz[i] = Z[i] * z[i];
      denum -= Zz[i];
      result[i] = Z[i] * (sum[i] / tasks[i].length);
      num += result[i];
    }
    for (int i = 0; i < result.length; ++i) {
      result[i] += num * Zz[i] / denum;
    }
    return result;
  }
}






© 2015 - 2024 Weber Informatics LLC | Privacy Policy