All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ciir.umass.edu.learning.neuralnet.ListNet Maven / Gradle / Ivy

There is a newer version: 2.10.1
Show newest version
/*===============================================================================
 * Copyright (c) 2010-2012 University of Massachusetts.  All Rights Reserved.
 *
 * Use of the RankLib package is subject to the terms of the software license set
 * forth in the LICENSE file included with this software, and also available at
 * http://people.cs.umass.edu/~vdang/ranklib_license.html
 *===============================================================================
 */

package ciir.umass.edu.learning.neuralnet;

import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;

public class ListNet extends RankNet {
    private static final Logger logger = Logger.getLogger(ListNet.class.getName());

    //Parameters
    public static int nIteration = 1500;
    public static double learningRate = 0.00001;
    public static int nHiddenLayer = 0;//FIXED, it doesn't work with hidden layer

    public ListNet() {
    }

    public ListNet(final List samples, final int[] features, final MetricScorer scorer) {
        super(samples, features, scorer);
    }

    protected float[] feedForward(final RankList rl) {
        final float[] labels = new float[rl.size()];
        for (int i = 0; i < rl.size(); i++) {
            addInput(rl.get(i));
            propagate(i);
            labels[i] = rl.get(i).getLabel();
        }
        return labels;
    }

    protected void backPropagate(final float[] labels) {
        //back-propagate
        final PropParameter p = new PropParameter(labels);
        outputLayer.computeDelta(p);//starting at the output layer

        //weight update
        outputLayer.updateWeight(p);
    }

    @Override
    protected void estimateLoss() {
        error = 0.0;
        double sumLabelExp = 0;
        double sumScoreExp = 0;
        for (int i = 0; i < samples.size(); i++) {
            final RankList rl = samples.get(i);
            final double[] scores = new double[rl.size()];
            double err = 0;
            for (int j = 0; j < rl.size(); j++) {
                scores[j] = eval(rl.get(j));
                sumLabelExp += Math.exp(rl.get(j).getLabel());
                sumScoreExp += Math.exp(scores[j]);
            }
            for (int j = 0; j < rl.size(); j++) {
                final double p1 = Math.exp(rl.get(j).getLabel()) / sumLabelExp;
                final double p2 = (Math.exp(scores[j]) / sumScoreExp);
                err += -p1 * SimpleMath.logBase2(p2);
            }
            error += err / rl.size();
        }
        lastError = error;
    }

    @Override
    public void init() {
        logger.info(() -> "Initializing... ");

        //Set up the network
        setInputOutput(features.length, 1, 1);
        wire();

        if (validationSamples != null) {
            for (int i = 0; i < layers.size(); i++) {
                bestModelOnValidation.add(new ArrayList());
            }
        }

        Neuron.learningRate = learningRate;
    }

    @Override
    public void learn() {
        logger.info(() -> "Training starts...");
        printLogLn(new int[] { 7, 14, 9, 9 }, new String[] { "#epoch", "C.E. Loss", scorer.name() + "-T", scorer.name() + "-V" });

        for (int i = 1; i <= nIteration; i++) {
            for (int j = 0; j < samples.size(); j++) {
                final float[] labels = feedForward(samples.get(j));
                backPropagate(labels);
                clearNeuronOutputs();
            }
            printLog(new int[] { 7, 14 }, new String[] { Integer.toString(i) , Double.toString(SimpleMath.round(error, 6)) });
            if (i % 1 == 0) {
                scoreOnTrainingData = scorer.score(rank(samples));
                printLog(new int[] { 9 }, new String[] {Double.toString(SimpleMath.round(scoreOnTrainingData, 4)) });
                if (validationSamples != null) {
                    final double score = scorer.score(rank(validationSamples));
                    if (score > bestScoreOnValidationData) {
                        bestScoreOnValidationData = score;
                        saveBestModelOnValidation();
                    }
                    printLog(new int[] { 9 }, new String[] { Double.toString(SimpleMath.round(score, 4)) });
                }
            }
            flushLog();
        }

        //if validation data is specified ==> best model on this data has been saved
        //we now restore the current model to that best model
        if (validationSamples != null) {
            restoreBestModelOnValidation();
        }

        scoreOnTrainingData = SimpleMath.round(scorer.score(rank(samples)), 4);
        logger.info(() -> "Finished sucessfully.");
        logger.info(() -> scorer.name() + " on training data: " + scoreOnTrainingData);
        if (validationSamples != null) {
            bestScoreOnValidationData = scorer.score(rank(validationSamples));
            logger.info(() -> scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4));
        }
    }

    @Override
    public double eval(final DataPoint p) {
        return super.eval(p);
    }

    @Override
    public Ranker createNew() {
        return new ListNet();
    }

    @Override
    public String toString() {
        return super.toString();
    }

    @Override
    public String model() {
        final StringBuilder output = new StringBuilder();
        output.append("## " + name() + "\n");
        output.append("## Epochs = " + nIteration + "\n");
        output.append("## No. of features = " + features.length + "\n");

        //print used features
        for (int i = 0; i < features.length; i++) {
            output.append(features[i] + ((i == features.length - 1) ? "" : " "));
        }
        output.append("\n");
        //print network information
        output.append("0\n");//[# hidden layers, *ALWAYS* 0 since we're using linear net]
        //print learned weights
        output.append(toString());
        return output.toString();
    }

    @Override
    public void loadFromString(final String fullText) {
        try (final BufferedReader in = new BufferedReader(new StringReader(fullText))) {
            String content = null;

            final List l = new ArrayList<>();
            while ((content = in.readLine()) != null) {
                content = content.trim();
                if (content.length() == 0 || content.indexOf("##") == 0) {
                    continue;
                }
                l.add(content);
            }
            //load the network
            //the first line contains features information
            final String[] tmp = l.get(0).split(" ");
            features = new int[tmp.length];
            for (int i = 0; i < tmp.length; i++) {
                features[i] = Integer.parseInt(tmp[i]);
            }
            //the 2nd line is a scalar indicating the number of hidden layers
            final int nHiddenLayer = Integer.parseInt(l.get(1));
            final int[] nn = new int[nHiddenLayer];
            //the next @nHiddenLayer lines contain the number of neurons in each layer
            int i = 2;
            for (; i < 2 + nHiddenLayer; i++) {
                nn[i - 2] = Integer.parseInt(l.get(i));
            }
            //create the network
            setInputOutput(features.length, 1);
            for (int j = 0; j < nHiddenLayer; j++) {
                addHiddenLayer(nn[j]);
            }
            wire();
            //fill in weights
            for (; i < l.size(); i++)//loop through all layers
            {
                final String[] s = l.get(i).split(" ");
                final int iLayer = Integer.parseInt(s[0]);//which layer?
                final int iNeuron = Integer.parseInt(s[1]);//which neuron?
                final Neuron n = layers.get(iLayer).get(iNeuron);
                for (int k = 0; k < n.getOutLinks().size(); k++) {
                    n.getOutLinks().get(k).setWeight(Double.parseDouble(s[k + 2]));
                }
            }
        } catch (final Exception ex) {
            throw RankLibError.create("Error in ListNet::load(): ", ex);
        }
    }

    @Override
    public void printParameters() {
        logger.info(() -> "No. of epochs: " + nIteration);
        logger.info(() -> "Learning rate: " + learningRate);
    }

    @Override
    public String name() {
        return "ListNet";
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy