All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.openimaj.workinprogress.sgdsvm.SvmSgd Maven / Gradle / Ivy

/**
 * Copyright (c) 2011, The University of Southampton and the individual contributors.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without modification,
 * are permitted provided that the following conditions are met:
 *
 *   * 	Redistributions of source code must retain the above copyright notice,
 * 	this list of conditions and the following disclaimer.
 *
 *   *	Redistributions in binary form must reproduce the above copyright notice,
 * 	this list of conditions and the following disclaimer in the documentation
 * 	and/or other materials provided with the distribution.
 *
 *   *	Neither the name of the University of Southampton nor the names of its
 * 	contributors may be used to endorse or promote products derived from this
 * 	software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
package org.openimaj.workinprogress.sgdsvm;

import java.util.List;

import org.apache.commons.math.random.MersenneTwister;
import org.openimaj.feature.FloatFV;
import org.openimaj.feature.FloatFVComparison;
import org.openimaj.util.array.ArrayUtils;
import org.openimaj.util.array.SparseFloatArray;
import org.openimaj.util.array.SparseFloatArray.Entry;
import org.openimaj.util.array.SparseHashedFloatArray;

import gnu.trove.list.array.TDoubleArrayList;

public class SvmSgd implements Cloneable {
	Loss LOSS = LossFunctions.HingeLoss;
	boolean BIAS = true;
	boolean REGULARIZED_BIAS = false;

	public double lambda;
	public double eta0;
	FloatFV w;
	double wDivisor;
	double wBias;
	double t;

	public SvmSgd(int dim, double lambda) {
		this(dim, lambda, 0);
	}

	public SvmSgd(int dim, double lambda, double eta0) {
		this.lambda = lambda;
		this.eta0 = eta0;
		this.w = new FloatFV(dim);
		this.wDivisor = 1;
		this.wBias = 0;
		this.t = 0;
	}

	private double dot(FloatFV v1, SparseFloatArray v2) {
		double d = 0;
		for (final Entry e : v2.entries()) {
			d += e.value * v1.values[e.index];
		}

		return d;
	}

	private double dot(FloatFV v1, FloatFV v2) {
		return FloatFVComparison.INNER_PRODUCT.compare(v1, v2);
	}

	private void add(FloatFV y, SparseFloatArray x, double d) {
		// w2 = w2 + x*w1

		for (final Entry e : x.entries()) {
			y.values[e.index] += e.value * d;
		}
	}

	/// Renormalize the weights
	public void renorm() {
		if (wDivisor != 1.0) {
			ArrayUtils.multiply(w.values, (float) (1.0 / wDivisor));
			// w.scale(1.0 / wDivisor);
			wDivisor = 1.0;
		}
	}

	/// Compute the norm of the weights
	public double wnorm() {
		double norm = dot(w, w) / wDivisor / wDivisor;

		if (REGULARIZED_BIAS)
			norm += wBias * wBias;
		return norm;
	}

	/// Compute the output for one example.
	public double testOne(final SparseFloatArray x, double y, double[] ploss, double[] pnerr) {
		final double s = dot(w, x) / wDivisor + wBias;
		if (ploss != null)
			ploss[0] += LOSS.loss(s, y);
		if (pnerr != null)
			pnerr[0] += (s * y <= 0) ? 1 : 0;
		return s;
	}

	/// Perform one iteration of the SGD algorithm with specified gains
	public void trainOne(final SparseFloatArray x, double y, double eta) {
		final double s = dot(w, x) / wDivisor + wBias;
		// update for regularization term
		wDivisor = wDivisor / (1 - eta * lambda);
		if (wDivisor > 1e5)
			renorm();
		// update for loss term
		final double d = LOSS.dloss(s, y);
		if (d != 0)
			add(w, x, eta * d * wDivisor);

		// same for the bias
		if (BIAS) {
			final double etab = eta * 0.01;
			if (REGULARIZED_BIAS) {
				wBias *= (1 - etab * lambda);
			}
			wBias += etab * d;
		}
	}

	@Override
	protected SvmSgd clone() {
		SvmSgd clone;
		try {
			clone = (SvmSgd) super.clone();
		} catch (final CloneNotSupportedException e) {
			throw new RuntimeException(e);
		}
		clone.w = clone.w.clone();
		return clone;
	}

	/// Perform a training epoch
	public void train(int imin, int imax, SparseFloatArray[] xp, double[] yp) {
		System.out.println("Training on [" + imin + ", " + imax + "].");
		assert (imin <= imax);
		assert (eta0 > 0);
		for (int i = imin; i <= imax; i++) {
			final double eta = eta0 / (1 + lambda * eta0 * t);
			trainOne(xp[i], yp[i], eta);
			t += 1;
		}
		// cout << prefix << setprecision(6) << "wNorm=" << wnorm();
		System.out.format("wNorm=%.6f", wnorm());
		if (BIAS) {
			// cout << " wBias=" << wBias;
			System.out.format(" wBias=%.6f", wBias);
		}
		System.out.println();
		// cout << endl;
	}

	/// Perform a training epoch
	public void train(int imin, int imax, List xp, TDoubleArrayList yp) {
		System.out.println("Training on [" + imin + ", " + imax + "].");
		assert (imin <= imax);
		assert (eta0 > 0);
		for (int i = imin; i <= imax; i++) {
			final double eta = eta0 / (1 + lambda * eta0 * t);
			trainOne(xp.get(i), yp.get(i), eta);
			t += 1;
		}
		// cout << prefix << setprecision(6) << "wNorm=" << wnorm();
		System.out.format("wNorm=%.6f", wnorm());
		if (BIAS) {
			// cout << " wBias=" << wBias;
			System.out.format(" wBias=%.6f", wBias);
		}
		System.out.println();
		// cout << endl;
	}

	/// Perform a test pass
	public void test(int imin, int imax, SparseFloatArray[] xp, double[] yp, String prefix) {
		// cout << prefix << "Testing on [" << imin << ", " << imax << "]." <<
		// endl;
		System.out.println(prefix + "Testing on [" + imin + ", " + imax + "].");
		assert (imin <= imax);
		final double nerr[] = { 0 };
		final double loss[] = { 0 };
		for (int i = imin; i <= imax; i++)
			testOne(xp[i], yp[i], loss, nerr);
		nerr[0] = nerr[0] / (imax - imin + 1);
		loss[0] = loss[0] / (imax - imin + 1);
		final double cost = loss[0] + 0.5 * lambda * wnorm();
		// cout << prefix
		// << "Loss=" << setprecision(12) << loss
		// << " Cost=" << setprecision(12) << cost
		// << " Misclassification=" << setprecision(4) << 100 * nerr << "%."
		// << endl;
		System.out.println(prefix + "Loss=" + loss[0] + " Cost=" + cost + " Misclassification="
				+ String.format("%2.4f", 100 * nerr[0]) + "%");
	}

	/// Perform a test pass
	public void test(int imin, int imax, List xp, TDoubleArrayList yp, String prefix) {
		// cout << prefix << "Testing on [" << imin << ", " << imax << "]." <<
		// endl;
		System.out.println(prefix + "Testing on [" + imin + ", " + imax + "].");
		assert (imin <= imax);
		final double nerr[] = { 0 };
		final double loss[] = { 0 };
		for (int i = imin; i <= imax; i++)
			testOne(xp.get(i), yp.get(i), loss, nerr);
		nerr[0] = nerr[0] / (imax - imin + 1);
		loss[0] = loss[0] / (imax - imin + 1);
		final double cost = loss[0] + 0.5 * lambda * wnorm();
		// cout << prefix
		// << "Loss=" << setprecision(12) << loss
		// << " Cost=" << setprecision(12) << cost
		// << " Misclassification=" << setprecision(4) << 100 * nerr << "%."
		// << endl;
		System.out.println(prefix + "Loss=" + loss[0] + " Cost=" + cost + " Misclassification="
				+ String.format("%2.4f", 100 * nerr[0]) + "%");
	}

	/// Perform one epoch with fixed eta and return cost
	public double evaluateEta(int imin, int imax, SparseFloatArray[] xp, double[] yp, double eta) {
		final SvmSgd clone = this.clone(); // take a copy of the current state
		assert (imin <= imax);
		for (int i = imin; i <= imax; i++)
			clone.trainOne(xp[i], yp[i], eta);
		final double loss[] = { 0 };
		double cost = 0;
		for (int i = imin; i <= imax; i++)
			clone.testOne(xp[i], yp[i], loss, null);
		loss[0] = loss[0] / (imax - imin + 1);
		cost = loss[0] + 0.5 * lambda * clone.wnorm();
		// cout << "Trying eta=" << eta << " yields cost " << cost << endl;
		System.out.println("Trying eta=" + eta + " yields cost " + cost);
		return cost;
	}

	/// Perform one epoch with fixed eta and return cost
	public double evaluateEta(int imin, int imax, List xp, TDoubleArrayList yp, double eta) {
		final SvmSgd clone = this.clone(); // take a copy of the current state
		assert (imin <= imax);
		for (int i = imin; i <= imax; i++)
			clone.trainOne(xp.get(i), yp.get(i), eta);
		final double loss[] = { 0 };
		double cost = 0;
		for (int i = imin; i <= imax; i++)
			clone.testOne(xp.get(i), yp.get(i), loss, null);
		loss[0] = loss[0] / (imax - imin + 1);
		cost = loss[0] + 0.5 * lambda * clone.wnorm();
		// cout << "Trying eta=" << eta << " yields cost " << cost << endl;
		System.out.println("Trying eta=" + eta + " yields cost " + cost);
		return cost;
	}

	public void determineEta0(int imin, int imax, SparseFloatArray[] xp, double[] yp) {
		final double factor = 2.0;
		double loEta = 1;
		double loCost = evaluateEta(imin, imax, xp, yp, loEta);
		double hiEta = loEta * factor;
		double hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
		if (loCost < hiCost)
			while (loCost < hiCost) {
				hiEta = loEta;
				hiCost = loCost;
				loEta = hiEta / factor;
				loCost = evaluateEta(imin, imax, xp, yp, loEta);
			}
		else if (hiCost < loCost)
			while (hiCost < loCost) {
				loEta = hiEta;
				loCost = hiCost;
				hiEta = loEta * factor;
				hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
			}
		eta0 = loEta;
		// cout << "# Using eta0=" << eta0 << endl;
		System.out.println("# Using eta0=" + eta0 + "\n");
	}

	public void determineEta0(int imin, int imax, List xp, TDoubleArrayList yp) {
		final double factor = 2.0;
		double loEta = 1;
		double loCost = evaluateEta(imin, imax, xp, yp, loEta);
		double hiEta = loEta * factor;
		double hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
		if (loCost < hiCost)
			while (loCost < hiCost) {
				hiEta = loEta;
				hiCost = loCost;
				loEta = hiEta / factor;
				loCost = evaluateEta(imin, imax, xp, yp, loEta);
			}
		else if (hiCost < loCost)
			while (hiCost < loCost) {
				loEta = hiEta;
				loCost = hiCost;
				hiEta = loEta * factor;
				hiCost = evaluateEta(imin, imax, xp, yp, hiEta);
			}
		eta0 = loEta;
		// cout << "# Using eta0=" << eta0 << endl;
		System.out.println("# Using eta0=" + eta0 + "\n");
	}

	public static void main(String[] args) {
		final MersenneTwister mt = new MersenneTwister();
		final SparseFloatArray[] tr = new SparseFloatArray[10000];
		final double[] clz = new double[tr.length];
		for (int i = 0; i < tr.length; i++) {
			tr[i] = new SparseHashedFloatArray(2);

			if (i < tr.length / 2) {
				tr[i].set(0, (float) (mt.nextGaussian() - 2));
				tr[i].set(1, (float) (mt.nextGaussian() - 2));
				clz[i] = -1;
			} else {
				tr[i].set(0, (float) (mt.nextGaussian() + 2));
				tr[i].set(1, (float) (mt.nextGaussian() + 2));
				clz[i] = 1;
			}
			System.out.println(tr[i].values()[0] + " " + clz[i]);
		}

		final SvmSgd svm = new SvmSgd(2, 1e-5);
		svm.BIAS = true;
		svm.REGULARIZED_BIAS = false;
		svm.determineEta0(0, tr.length - 1, tr, clz);
		for (int i = 0; i < 10; i++) {
			System.out.println();
			svm.train(0, tr.length - 1, tr, clz);
			svm.test(0, tr.length - 1, tr, clz, "training ");
			System.out.println(svm.w);
			System.out.println(svm.wBias);
			System.out.println(svm.wDivisor);
		}

		// svm.w.values[0] = 1f;
		// svm.w.values[1] = 1f;
		// svm.wDivisor = 1;
		// svm.wBias = 0;
		// svm.test(0, 999, tr, clz, "training ");
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy