All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.common.primitives.CounterMap Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.common.primitives;

import lombok.EqualsAndHashCode;

import java.io.Serializable;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
 *
 * @author [email protected]
 */
@EqualsAndHashCode
public class CounterMap implements Serializable{
    private static final long serialVersionUID = 119L;

    protected Map> maps = new ConcurrentHashMap<>();

    public CounterMap() {

    }

    /**
     * This method checks if this CounterMap has any values stored
     *
     * @return
     */
    public boolean isEmpty() {
        return maps.isEmpty();
    }

    /**
     * This method checks if this CounterMap has any values stored for a given first element
     *
     * @param element
     * @return
     */
    public boolean isEmpty(F element){
        if (isEmpty())
            return true;

        Counter m = maps.get(element);
        if (m == null)
            return true;
        else
            return m.isEmpty();
    }

    /**
     * This method will increment values of this counter, by counts of other counter
     *
     * @param other
     */
    public void incrementAll(CounterMap other) {
        for (Map.Entry> entry : other.maps.entrySet()) {
            F key = entry.getKey();
            Counter innerCounter = entry.getValue();
            for (Map.Entry innerEntry : innerCounter.entrySet()) {
                S value = innerEntry.getKey();
                incrementCount(key, value, innerEntry.getValue().get());
            }
        }
    }

    /**
     * This method will increment counts for a given first/second pair
     *
     * @param first
     * @param second
     * @param inc
     */
    public void incrementCount(F first, S second, double inc) {
        Counter counter = maps.get(first);
        if (counter == null) {
            counter = new Counter();
            maps.put(first, counter);
        }

        counter.incrementCount(second, inc);
    }

    /**
     * This method returns counts for a given first/second pair
     *
     * @param first
     * @param second
     * @return
     */
    public double getCount(F first, S second) {
        Counter counter = maps.get(first);
        if (counter == null)
            return 0.0;

        return counter.getCount(second);
    }

    /**
     * This method allows you to set counter value for a given first/second pair
     *
     * @param first
     * @param second
     * @param value
     * @return
     */
    public double setCount(F first, S second, double value) {
        Counter counter = maps.get(first);
        if (counter == null) {
            counter = new Counter();
            maps.put(first, counter);
        }

        return counter.setCount(second, value);
    }

    /**
     * This method returns pair of elements with a max value
     *
     * @return
     */
    public Pair argMax() {
        Double maxCount = -Double.MAX_VALUE;
        Pair maxKey = null;
        for (Map.Entry> entry : maps.entrySet()) {
            Counter counter = entry.getValue();
            S localMax = counter.argMax();
            if (counter.getCount(localMax) > maxCount || maxKey == null) {
                maxKey = new Pair(entry.getKey(), localMax);
                maxCount = counter.getCount(localMax);
            }
        }
        return maxKey;
    }

    /**
     * This method purges all counters
     */
    public void clear() {
        maps.clear();
    }

    /**
     * This method purges counter for a given first element
     * @param element
     */
    public void clear(F element) {
        Counter s = maps.get(element);
        if (s != null)
            s.clear();
    }

    /**
     * This method returns Set of all first elements
     * @return
     */
    public Set keySet() {
        return maps.keySet();
    }

    /**
     * This method returns counter for a given first element
     *
     * @param first
     * @return
     */
    public Counter getCounter(F first) {
        return maps.get(first);
    }

    /**
     * This method returns Iterator of all first/second pairs stored in this counter
     *
     * @return
     */
    public Iterator> getIterator() {
        return new Iterator>() {

            Iterator outerIt;
            Iterator innerIt;
            F curKey;

            {
                outerIt = keySet().iterator();
            }

            private boolean hasInside() {
                if (innerIt == null || !innerIt.hasNext()) {
                    if (!outerIt.hasNext()) {
                        return false;
                    }
                    curKey = outerIt.next();
                    innerIt = getCounter(curKey).keySet().iterator();
                }
                return true;
            }

            public boolean hasNext() {
                return hasInside();
            }

            public Pair next() {
                hasInside();
                if (curKey == null)
                    throw new RuntimeException("Outer element can't be null");

                return Pair.makePair(curKey, innerIt.next());
            }

            public void remove() {
                //
            }
        };
    }

    /**
     * This method returns number of First elements in this CounterMap
     * @return
     */
    public int size() {
        return maps.size();
    }

    /**
     * This method returns total number of elements in this CounterMap
     * @return
     */
    public int totalSize() {
        int size = 0;
        for (F first: keySet()) {
            size += getCounter(first).size();
        }

        return size;
    }
}