All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.function.categorization.LinearMultiCategorizer Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                LinearMultiCategorizer.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry Learning Core
 * 
 * Copyright January 28, 2011, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 */

package gov.sandia.cognition.learning.function.categorization;

import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;

/**
 * A multi-category version of the LinearBinaryCategorizer that keeps a separate
 * LinearBinaryCategorizer for each category. Each of these linear categorizers
 * are called prototypes. Categorization is done by choosing the prototype
 * (w_c, b_c) such that w_c * x + b_c is highest.
 * 
 * @param   
 *      The type of categories that the categorizer can output.
 * @author  Justin Basilico
 * @since   3.2.0
 */
public class LinearMultiCategorizer
    extends AbstractCloneableSerializable
    implements DiscriminantCategorizer,
        VectorInputEvaluator
{

    /** A map of each category to its prototype categorizer. */
    protected Map prototypes;

    /**
     * Creates a new, empty {@code LinearMultiCategorizer}.
     */
    public LinearMultiCategorizer()
    {
        this(new LinkedHashMap());
    }

    /**
     * Creates a new {@code LinearMultiCategorizer} with the given prototypes.
     *
     * @param   prototypes
     *      The mapping of categories to prototypes.
     */
    public LinearMultiCategorizer(
        final Map prototypes)
    {
        super();

        this.setPrototypes(prototypes);
    }

    @Override
    public LinearMultiCategorizer clone()
    {
        @SuppressWarnings("unchecked")
        final LinearMultiCategorizer clone = 
            (LinearMultiCategorizer) super.clone();

        // Clone the prototypes.
        if (this.prototypes != null)
        {
            clone.prototypes =
                new LinkedHashMap(
                    this.prototypes.size());

            for (CategoryType category : this.prototypes.keySet())
            {
                clone.prototypes.put(category,
                    this.prototypes.get(category).clone());
            }
        }

        return clone;
    }

    @Override
    public CategoryType evaluate(
        final Vectorizable input)
    {
        return this.evaluateWithDiscriminant(input).getValue();
    }

    @Override
    public DefaultWeightedValueDiscriminant evaluateWithDiscriminant(
        final Vectorizable input)
    {
        // Convert the input to a vector.
        final Vector inputVector = input.convertToVector();

        // Find the category that has the highest match.
        double bestScore = 0.0;
        CategoryType bestCategory = null;
        for (CategoryType category : this.getCategories())
        {
            final double score = this.evaluateAsDouble(inputVector, category);

            if (bestCategory == null || score > bestScore)
            {
                bestScore = score;
                bestCategory = category;
            }
        }

        // Return the discriminant for the category.
        return new DefaultWeightedValueDiscriminant(
            bestCategory, bestScore);
    }

    /**
     * Evaluates how much the given input matches the prototype for the given
     * category.
     *
     * @param   input
     *      The input.
     * @param   category
     *      The category to match the input against.
     * @return
     *      A real value indicating how much the input matches the category.
     *      A larger value indicates a better match.
     */
    public double evaluateAsDouble(
        final Vectorizable input,
        final CategoryType category)
    {
        return this.evaluateAsDouble(input.convertToVector(), category);
    }
    
    /**
     * Evaluates how much the given input matches the prototype for the given
     * category.
     *
     * @param   input
     *      The input.
     * @param   category
     *      The category to match the input against.
     * @return
     *      A real value indicating how much the input matches the category.
     *      A larger value indicates a better match.
     */
    public double evaluateAsDouble(
        final Vector input,
        final CategoryType category)
    {
        final LinearBinaryCategorizer prototype = this.prototypes.get(category);

        if (prototype == null)
        {
            // Bad prototype.
            return 0.0;
        }
        else
        {
            // Evaluate the input with the prototype.
            return prototype.evaluateAsDouble(input);
        }
    }

    @Override
    public Set getCategories()
    {
        return this.prototypes.keySet();
    }

    @Override
    public int getInputDimensionality()
    {
        final LinearBinaryCategorizer firstPrototype =
            CollectionUtil.getFirst(this.prototypes.values());
        if (firstPrototype == null)
        {
            return -1;
        }
        else
        {
            return firstPrototype.getInputDimensionality();
        }
    }

    /**
     * Gets the mapping of categories to prototypes.
     *
     * @return
     *      The mapping of categories to prototypes.
     */
    public Map getPrototypes()
    {
        return prototypes;
    }

    /**
     * Sets the mapping of categories to prototypes.
     *
     * @param   prototypes
     *      The mapping of categories to prototypes.
     */
    public void setPrototypes(
        final Map prototypes)
    {
        this.prototypes = prototypes;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy