All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.collection.IntArrayKeyMap Maven / Gradle / Ivy

/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.linalg.collection;

import com.google.common.primitives.Ints;
import lombok.Getter;
import org.nd4j.base.Preconditions;

import java.util.*;

/**
 * A map for int arrays backed by a {@link java.util.TreeMap}
 * @param  the value for the map.
 *
 * @author Adam Gibson
 */
public class IntArrayKeyMap implements Map {

    private Map map = new LinkedHashMap<>();

    @Override
    public int size() {
        return map.size();
    }

    @Override
    public boolean isEmpty() {
        return map.isEmpty();
    }

    @Override
    public boolean containsKey(Object o) {
        return map.containsKey(new IntArray((int[]) o));
    }

    @Override
    public boolean containsValue(Object o) {
        return map.containsValue(new IntArray((int[]) o));
    }

    @Override
    public V get(Object o) {
        return map.get(new IntArray((int[]) o));
    }

    @Override
    public V put(int[] ints, V v) {
        return map.put(new IntArray(ints),v);
    }

    @Override
    public V remove(Object o) {
        return map.remove(new IntArray((int[]) o));
    }

    @Override
    public void putAll(Map map) {
        for(Entry entry : map.entrySet()) {
            this.map.put(new IntArray(entry.getKey()),entry.getValue());
        }
    }

    @Override
    public void clear() {
        map.clear();
    }

    @Override
    public Set keySet() {
        Set intArrays = map.keySet();
        Set ret = new LinkedHashSet<>();
        for(IntArray intArray : intArrays)
            ret.add(intArray.backingArray);
        return ret;
    }

    @Override
    public Collection values() {
        return map.values();
    }

    @Override
    public Set> entrySet() {
        Set> intArrays = map.entrySet();
        Set> ret = new LinkedHashSet<>();
        for(Map.Entry intArray : intArrays) {
            final Map.Entry intArray2 = intArray;
            ret.add(new Map.Entry() {
                @Override
                public int[] getKey() {
                    return intArray2.getKey().backingArray;
                }

                @Override
                public V getValue() {
                    return intArray2.getValue();
                }

                @Override
                public V setValue(V v) {
                    return intArray2.setValue(v);
                }
            });
        }
        return ret;
    }


    public static class IntArray implements Comparable {
        @Getter
        private int[] backingArray;

        public IntArray(int[] backingArray) {
            Preconditions.checkNotNull(backingArray,"Backing array must not be null!");
            this.backingArray = Ints.toArray(new LinkedHashSet<>(Ints.asList(backingArray)));
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;

            IntArray intArray = (IntArray) o;

            return Arrays.equals(intArray.backingArray,backingArray);
        }

        @Override
        public int hashCode() {
            return Arrays.hashCode(backingArray);
        }

        @Override
        public int compareTo(IntArray intArray) {
            if(this.backingArray.length == 0 || intArray.backingArray.length == 0) {
                return 1;
            }

            else if(Arrays.equals(backingArray,intArray.backingArray))
                return 1;

            return Ints.compare(Ints.max(backingArray),Ints.max(intArray.backingArray));
        }
    }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy