All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ciir.umass.edu.learning.neuralnet.RankNet Maven / Gradle / Ivy

The newest version!
/*===============================================================================
 * Copyright (c) 2010-2012 University of Massachusetts.  All Rights Reserved.
 *
 * Use of the RankLib package is subject to the terms of the software license set
 * forth in the LICENSE file included with this software, and also available at
 * http://people.cs.umass.edu/~vdang/ranklib_license.html
 *===============================================================================
 */

package ciir.umass.edu.learning.neuralnet;

import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;

/**
 * @author vdang
 *
 *  This class implements RankNet.
 *  C.J.C. Burges, T. Shaked, E. Renshaw, A. Lazier, M. Deeds, N. Hamilton and G. Hullender. Learning to rank using gradient descent.
 *  In Proc. of ICML, pages 89-96, 2005.
 */
public class RankNet extends Ranker {
    private static final Logger logger = Logger.getLogger(RankNet.class.getName());

    //Parameters
    public static int nIteration = 100;
    public static int nHiddenLayer = 1;
    public static int nHiddenNodePerLayer = 10;
    public static double learningRate = 0.00005;

    //Variables
    protected List layers = new ArrayList<>();
    protected Layer inputLayer = null;
    protected Layer outputLayer = null;

    //to store the best model on validation data (if specified)
    protected List> bestModelOnValidation = new ArrayList<>();

    protected int totalPairs = 0;
    protected int misorderedPairs = 0;
    protected double error = 0.0;
    protected double lastError = Double.MAX_VALUE;
    protected int straightLoss = 0;

    public RankNet() {

    }

    public RankNet(final List samples, final int[] features, final MetricScorer scorer) {
        super(samples, features, scorer);
    }

    /**
     * Setting up the Neural Network
     */
    protected void setInputOutput(final int nInput, final int nOutput) {
        inputLayer = new Layer(nInput + 1);//plus the "bias" (output threshold)
        outputLayer = new Layer(nOutput);
        layers.clear();
        layers.add(inputLayer);
        layers.add(outputLayer);
    }

    protected void setInputOutput(final int nInput, final int nOutput, final int nType) {
        inputLayer = new Layer(nInput + 1, nType);//plus the "bias" (output threshold)
        outputLayer = new Layer(nOutput, nType);
        layers.clear();
        layers.add(inputLayer);
        layers.add(outputLayer);
    }

    protected void addHiddenLayer(final int size) {
        layers.add(layers.size() - 1, new Layer(size));
    }

    protected void wire() {
        //wire the input layer to the first hidden layer
        for (int i = 0; i < inputLayer.size() - 1; i++) {
            for (int j = 0; j < layers.get(1).size(); j++) {
                connect(0, i, 1, j);
            }
        }

        //wire one layer to the next, starting at layer 1 (the first hidden layer)
        for (int i = 1; i < layers.size() - 1; i++) {
            for (int j = 0; j < layers.get(i).size(); j++) {
                for (int k = 0; k < layers.get(i + 1).size(); k++) {
                    connect(i, j, i + 1, k);
                }
            }
        }

        //wire the "bias" neuron to all others (in all layers)
        for (int i = 1; i < layers.size(); i++) {
            for (int j = 0; j < layers.get(i).size(); j++) {
                connect(0, inputLayer.size() - 1, i, j);
            }
        }
    }

    protected void connect(final int sourceLayer, final int sourceNeuron, final int targetLayer, final int targetNeuron) {
        new Synapse(layers.get(sourceLayer).get(sourceNeuron), layers.get(targetLayer).get(targetNeuron));
    }

    /**
     *  Auxiliary functions for pair-wise preference network learning.
     */
    protected void addInput(final DataPoint p) {
        for (int k = 0; k < inputLayer.size() - 1; k++) {
            inputLayer.get(k).addOutput(p.getFeatureValue(features[k]));
        }
        //  and now the bias node with a fix "1.0"
        inputLayer.get(inputLayer.size() - 1).addOutput(1.0f);
    }

    protected void propagate(final int i) {
        for (int k = 1; k < layers.size(); k++) {
            layers.get(k).computeOutput(i);
        }
    }

    protected int[][] batchFeedForward(final RankList rl) {
        final int[][] pairMap = new int[rl.size()][];
        for (int i = 0; i < rl.size(); i++) {
            addInput(rl.get(i));
            propagate(i);

            int count = 0;
            for (int j = 0; j < rl.size(); j++) {
                if (rl.get(i).getLabel() > rl.get(j).getLabel()) {
                    count++;
                }
            }

            pairMap[i] = new int[count];
            int k = 0;
            for (int j = 0; j < rl.size(); j++) {
                if (rl.get(i).getLabel() > rl.get(j).getLabel()) {
                    pairMap[i][k++] = j;
                }
            }
        }
        return pairMap;
    }

    protected void batchBackPropagate(final int[][] pairMap, final float[][] pairWeight) {
        for (int i = 0; i < pairMap.length; i++) {
            //back-propagate
            final PropParameter p = new PropParameter(i, pairMap);
            outputLayer.computeDelta(p);//starting at the output layer
            for (int j = layers.size() - 2; j >= 1; j--) {
                layers.get(j).updateDelta(p);
            }

            //weight update
            outputLayer.updateWeight(p);
            for (int j = layers.size() - 2; j >= 1; j--) {
                layers.get(j).updateWeight(p);
            }
        }
    }

    protected void clearNeuronOutputs() {
        for (int k = 0; k < layers.size(); k++) {
            layers.get(k).clearOutputs();
        }
    }

    protected float[][] computePairWeight(final int[][] pairMap, final RankList rl) {
        return null;
    }

    protected RankList internalReorder(final RankList rl) {
        return rl;
    }

    /**
     * Model validation
     */
    protected void saveBestModelOnValidation() {
        for (int i = 0; i < layers.size() - 1; i++)//loop through all layers
        {
            final List l = bestModelOnValidation.get(i);
            l.clear();
            for (int j = 0; j < layers.get(i).size(); j++)//loop through all neurons on in the current layer
            {
                final Neuron n = layers.get(i).get(j);
                for (int k = 0; k < n.getOutLinks().size(); k++) {
                    l.add(n.getOutLinks().get(k).getWeight());
                }
            }
        }
    }

    protected void restoreBestModelOnValidation() {
        try {
            for (int i = 0; i < layers.size() - 1; i++)//loop through all layers
            {
                final List l = bestModelOnValidation.get(i);
                int c = 0;
                for (int j = 0; j < layers.get(i).size(); j++)//loop through all neurons on in the current layer
                {
                    final Neuron n = layers.get(i).get(j);
                    for (int k = 0; k < n.getOutLinks().size(); k++) {
                        n.getOutLinks().get(k).setWeight(l.get(c++));
                    }
                }
            }
        } catch (final Exception ex) {
            throw RankLibError.create("Error in NeuralNetwork.restoreBestModelOnValidation(): ", ex);
        }
    }

    protected double crossEntropy(final double o1, final double o2, final double targetValue) {
        final double oij = o1 - o2;
        return -targetValue * oij + SimpleMath.logBase2(1 + Math.exp(oij));
    }

    protected void estimateLoss() {
        misorderedPairs = 0;
        error = 0.0;
        for (int j = 0; j < samples.size(); j++) {
            final RankList rl = samples.get(j);
            for (int k = 0; k < rl.size() - 1; k++) {
                final double o1 = eval(rl.get(k));
                for (int l = k + 1; l < rl.size(); l++) {
                    if (rl.get(k).getLabel() > rl.get(l).getLabel()) {
                        final double o2 = eval(rl.get(l));
                        error += crossEntropy(o1, o2, 1.0f);
                        if (o1 < o2) {
                            misorderedPairs++;
                        }
                    }
                }
            }
        }
        error = SimpleMath.round(error / totalPairs, 4);

        //Neuron.learningRate *= 0.8;
        lastError = error;
    }

    /**
     * Main public functions
     */
    @Override
    public void init() {
        logger.info(() -> "Initializing... ");

        //Set up the network
        setInputOutput(features.length, 1);
        for (int i = 0; i < nHiddenLayer; i++) {
            addHiddenLayer(nHiddenNodePerLayer);
        }
        wire();

        totalPairs = 0;
        for (int i = 0; i < samples.size(); i++) {
            final RankList rl = samples.get(i).getCorrectRanking();
            for (int j = 0; j < rl.size() - 1; j++) {
                for (int k = j + 1; k < rl.size(); k++) {
                    if (rl.get(j).getLabel() > rl.get(k).getLabel()) {
                        totalPairs++;
                    }
                }
            }
        }

        if (validationSamples != null) {
            for (int i = 0; i < layers.size(); i++) {
                bestModelOnValidation.add(new ArrayList());
            }
        }

        Neuron.learningRate = learningRate;
    }

    @Override
    public void learn() {
        logger.info(() -> "Training starts...");
        printLogLn(new int[] { 7, 14, 9, 9 }, new String[] { "#epoch", "% mis-ordered", scorer.name() + "-T", scorer.name() + "-V" });
        printLogLn(new int[] { 7, 14, 9, 9 }, new String[] { " ", "  pairs", " ", " " });

        for (int i = 1; i <= nIteration; i++) {
            for (int j = 0; j < samples.size(); j++) {
                final RankList rl = internalReorder(samples.get(j));
                final int[][] pairMap = batchFeedForward(rl);
                final float[][] pairWeight = computePairWeight(pairMap, rl);
                batchBackPropagate(pairMap, pairWeight);
                clearNeuronOutputs();
            }

            scoreOnTrainingData = scorer.score(rank(samples));
            estimateLoss();
            printLog(new int[] { 7, 14 }, new String[] { Integer.toString(i), Double.toString(SimpleMath.round(((double) misorderedPairs) / totalPairs, 4)) });
            if (i % 1 == 0) {
                printLog(new int[] { 9 }, new String[] { Double.toString(SimpleMath.round(scoreOnTrainingData, 4)) });
                if (validationSamples != null) {
                    final double score = scorer.score(rank(validationSamples));
                    if (score > bestScoreOnValidationData) {
                        bestScoreOnValidationData = score;
                        saveBestModelOnValidation();
                    }
                    printLog(new int[] { 9 }, new String[] { Double.toString(SimpleMath.round(score, 4)) });
                }
            }
            flushLog();
        }

        //if validation data is specified ==> best model on this data has been saved
        //we now restore the current model to that best model
        if (validationSamples != null) {
            restoreBestModelOnValidation();
        }

        scoreOnTrainingData = SimpleMath.round(scorer.score(rank(samples)), 4);
        logger.info(() -> "Finished sucessfully.");
        logger.info(() -> scorer.name() + " on training data: " + scoreOnTrainingData);
        if (validationSamples != null) {
            bestScoreOnValidationData = scorer.score(rank(validationSamples));
            logger.info(() -> scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4));
        }
    }

    @Override
    public double eval(final DataPoint p) {
        //feed input
        for (int k = 0; k < inputLayer.size() - 1; k++) {
            inputLayer.get(k).setOutput(p.getFeatureValue(features[k]));
        }
        //and now the bias node with a fix "1.0"
        inputLayer.get(inputLayer.size() - 1).setOutput(1.0f);
        //propagate
        for (int k = 1; k < layers.size(); k++) {
            layers.get(k).computeOutput();
        }
        return outputLayer.get(0).getOutput();
    }

    @Override
    public Ranker createNew() {
        return new RankNet();
    }

    @Override
    public String toString() {
        final StringBuilder output = new StringBuilder();
        for (int i = 0; i < layers.size() - 1; i++)//loop through all layers
        {
            for (int j = 0; j < layers.get(i).size(); j++)//loop through all neurons on in the current layer
            {
                output.append(i + " " + j + " ");
                final Neuron n = layers.get(i).get(j);
                for (int k = 0; k < n.getOutLinks().size(); k++) {
                    output.append(n.getOutLinks().get(k).getWeight() + ((k == n.getOutLinks().size() - 1) ? "" : " "));
                }
                output.append('\n');
            }
        }
        return output.toString();
    }

    @Override
    public String model() {
        final StringBuilder output = new StringBuilder();
        output.append("## " + name() + "\n");
        output.append("## Epochs = " + nIteration + "\n");
        output.append("## No. of features = " + features.length + "\n");
        output.append("## No. of hidden layers = " + (layers.size() - 2) + "\n");
        for (int i = 1; i < layers.size() - 1; i++) {
            output.append("## Layer " + i + ": " + layers.get(i).size() + " neurons\n");
        }

        //print used features
        for (int i = 0; i < features.length; i++) {
            output.append(features[i] + ((i == features.length - 1) ? "" : " "));
        }
        output.append('\n');
        //print network information
        output.append(layers.size() - 2 + "\n");//[# hidden layers]
        for (int i = 1; i < layers.size() - 1; i++) {
            output.append(layers.get(i).size() + "\n");//[#neurons]
        }
        //print learned weights
        output.append(toString());
        return output.toString();
    }

    @Override
    public void loadFromString(final String fullText) {
        try (final BufferedReader in = new BufferedReader(new StringReader(fullText))) {
            String content = null;

            final List l = new ArrayList<>();
            while ((content = in.readLine()) != null) {
                content = content.trim();
                if (content.length() == 0 || content.indexOf("##") == 0) {
                    continue;
                }
                l.add(content);
            }
            //load the network
            //the first line contains features information
            final String[] tmp = l.get(0).split(" ");
            features = new int[tmp.length];
            for (int i = 0; i < tmp.length; i++) {
                features[i] = Integer.parseInt(tmp[i]);
            }
            //the 2nd line is a scalar indicating the number of hidden layers
            final int nhl = Integer.parseInt(l.get(1));
            final int[] nn = new int[nhl];
            //the next @nHiddenLayer lines contain the number of neurons in each layer
            int i = 2;
            for (; i < 2 + nhl; i++) {
                nn[i - 2] = Integer.parseInt(l.get(i));
            }
            //create the network
            setInputOutput(features.length, 1);
            for (int j = 0; j < nhl; j++) {
                addHiddenLayer(nn[j]);
            }
            wire();
            //fill in weights
            for (; i < l.size(); i++)//loop through all layers
            {
                final String[] s = l.get(i).split(" ");
                final int iLayer = Integer.parseInt(s[0]);//which layer?
                final int iNeuron = Integer.parseInt(s[1]);//which neuron?
                final Neuron n = layers.get(iLayer).get(iNeuron);
                for (int k = 0; k < n.getOutLinks().size(); k++) {
                    n.getOutLinks().get(k).setWeight(Double.parseDouble(s[k + 2]));
                }
            }
        } catch (final Exception ex) {
            throw RankLibError.create("Error in RankNet::load(): ", ex);
        }
    }

    @Override
    public void printParameters() {
        logger.info(() -> "No. of epochs: " + nIteration);
        logger.info(() -> "No. of hidden layers: " + nHiddenLayer);
        logger.info(() -> "No. of hidden nodes per layer: " + nHiddenNodePerLayer);
        logger.info(() -> "Learning rate: " + learningRate);
    }

    @Override
    public String name() {
        return "RankNet";
    }

    /**
     * FOR DEBUGGING PURPOSE ONLY
     */
    protected void printNetworkConfig() {
        if (logger.isLoggable(Level.INFO)) {
            for (int i = 1; i < layers.size(); i++) {
                logger.info("Layer-" + (i + 1));
                for (int j = 0; j < layers.get(i).size(); j++) {
                    final Neuron n = layers.get(i).get(j);
                    logger.info("Neuron-" + (j + 1) + ": " + n.getInLinks().size() + " inputs\t");
                    for (int k = 0; k < n.getInLinks().size(); k++) {
                        logger.info(n.getInLinks().get(k).getWeight() + "\t");
                    }
                }
            }
        }
    }

    protected void printWeightVector() {
        if (logger.isLoggable(Level.INFO)) {
            final StringBuilder buf = new StringBuilder();
            for (int j = 0; j < outputLayer.get(0).getInLinks().size(); j++) {
                buf.append(outputLayer.get(0).getInLinks().get(j).getWeight()).append(' ');
            }
            logger.info(buf.toString());
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy