All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.ignite.ml.math.distributed.CacheUtils Maven / Gradle / Ivy

Go to download

Apache Ignite® is a Distributed Database For High-Performance Computing With In-Memory Speed.

There is a newer version: 2.15.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.ignite.ml.math.distributed;

import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BinaryOperator;
import java.util.stream.Stream;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.Affinity;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.cluster.ClusterGroup;
import org.apache.ignite.cluster.ClusterNode;
import org.apache.ignite.internal.processors.cache.CacheEntryImpl;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.lang.IgniteCallable;
import org.apache.ignite.lang.IgnitePredicate;
import org.apache.ignite.lang.IgniteRunnable;
import org.apache.ignite.ml.math.KeyMapper;
import org.apache.ignite.ml.math.distributed.keys.DataStructureCacheKey;
import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey;
import org.apache.ignite.ml.math.distributed.keys.impl.MatrixBlockKey;
import org.apache.ignite.ml.math.distributed.keys.impl.VectorBlockKey;
import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
import org.apache.ignite.ml.math.functions.IgniteConsumer;
import org.apache.ignite.ml.math.functions.IgniteDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.functions.IgniteSupplier;
import org.apache.ignite.ml.math.functions.IgniteTriFunction;
import org.apache.ignite.ml.math.impls.matrix.MatrixBlockEntry;
import org.apache.ignite.ml.math.impls.vector.VectorBlockEntry;

/**
 * Distribution-related misc. support.
 *
 * TODO: IGNITE-5102, fix sparse key filters.
 */
public class CacheUtils {
    /**
     * Cache entry support.
     *
     * @param 
     * @param 
     */
    public static class CacheEntry {
        /** */
        private Cache.Entry entry;
        /** */
        private IgniteCache cache;

        /**
         * @param entry Original cache entry.
         * @param cache Cache instance.
         */
        CacheEntry(Cache.Entry entry, IgniteCache cache) {
            this.entry = entry;
            this.cache = cache;
        }

        /**
         *
         */
        public Cache.Entry entry() {
            return entry;
        }

        /**
         *
         */
        public IgniteCache cache() {
            return cache;
        }
    }

    /**
     * Gets local Ignite instance.
     */
    public static Ignite ignite() {
        return Ignition.localIgnite();
    }

    /**
     * @param cacheName Cache name.
     * @param k Key into the cache.
     * @param  Key type.
     * @return Cluster group for given key.
     */
    protected static  ClusterGroup getClusterGroupForGivenKey(String cacheName, K k) {
        return ignite().cluster().forNode(ignite().affinity(cacheName).mapKeyToNode(k));
    }

    /**
     * @param cacheName Cache name.
     * @param keyMapper {@link KeyMapper} to validate cache key.
     * @param valMapper {@link ValueMapper} to obtain double value for given cache key.
     * @param  Cache key object type.
     * @param  Cache value object type.
     * @return Sum of the values obtained for valid keys.
     */
    public static  double sum(String cacheName, KeyMapper keyMapper, ValueMapper valMapper) {
        Collection subSums = fold(cacheName, (CacheEntry ce, Double acc) -> {
            if (keyMapper.isValid(ce.entry().getKey())) {
                double v = valMapper.toDouble(ce.entry().getValue());

                return acc == null ? v : acc + v;
            }
            else
                return acc;
        });

        return sum(subSums);
    }

    /**
     * @param matrixUuid Matrix UUID.
     * @return Sum obtained using sparse logic.
     */
    @SuppressWarnings("unchecked")
    public static  double sparseSum(UUID matrixUuid, String cacheName) {
        A.notNull(matrixUuid, "matrixUuid");
        A.notNull(cacheName, "cacheName");

        Collection subSums = fold(cacheName, (CacheEntry ce, Double acc) -> {
            V v = ce.entry().getValue();

            double sum;

            if (v instanceof Map) {
                Map map = (Map)v;

                sum = sum(map.values());
            }
            else if (v instanceof MatrixBlockEntry) {
                MatrixBlockEntry be = (MatrixBlockEntry)v;

                sum = be.sum();
            }
            else
                throw new UnsupportedOperationException();

            return acc == null ? sum : acc + sum;
        }, sparseKeyFilter(matrixUuid));

        return sum(subSums);
    }

    /**
     * @param c {@link Collection} of double values to sum.
     * @return Sum of the values.
     */
    private static double sum(Collection c) {
        // Fix for IGNITE-6762, some collections could store null values.
        return c.stream().filter(Objects::nonNull).mapToDouble(Double::doubleValue).sum();
    }

    /**
     * @param cacheName Cache name.
     * @param keyMapper {@link KeyMapper} to validate cache key.
     * @param valMapper {@link ValueMapper} to obtain double value for given cache key.
     * @param  Cache key object type.
     * @param  Cache value object type.
     * @return Minimum value for valid keys.
     */
    public static  double min(String cacheName, KeyMapper keyMapper, ValueMapper valMapper) {
        Collection mins = fold(cacheName, (CacheEntry ce, Double acc) -> {
            if (keyMapper.isValid(ce.entry().getKey())) {
                double v = valMapper.toDouble(ce.entry().getValue());

                if (acc == null)
                    return v;
                else
                    return Math.min(acc, v);
            }
            else
                return acc;
        });

        return Collections.min(mins);
    }

    /**
     * @param matrixUuid Matrix UUID.
     * @return Minimum value obtained using sparse logic.
     */
    @SuppressWarnings("unchecked")
    public static  double sparseMin(UUID matrixUuid, String cacheName) {
        A.notNull(matrixUuid, "matrixUuid");
        A.notNull(cacheName, "cacheName");

        Collection mins = fold(cacheName, (CacheEntry ce, Double acc) -> {
            V v = ce.entry().getValue();

            double min;

            if (v instanceof Map) {
                Map map = (Map)v;

                min = Collections.min(map.values());
            }
            else if (v instanceof MatrixBlockEntry) {
                MatrixBlockEntry be = (MatrixBlockEntry)v;

                min = be.minValue();
            }
            else
                throw new UnsupportedOperationException();

            if (acc == null)
                return min;
            else
                return Math.min(acc, min);

        }, sparseKeyFilter(matrixUuid));

        return Collections.min(mins);
    }

    /**
     * @param matrixUuid Matrix UUID.
     * @return Maximum value obtained using sparse logic.
     */
    @SuppressWarnings("unchecked")
    public static  double sparseMax(UUID matrixUuid, String cacheName) {
        A.notNull(matrixUuid, "matrixUuid");
        A.notNull(cacheName, "cacheName");

        Collection maxes = fold(cacheName, (CacheEntry ce, Double acc) -> {
            V v = ce.entry().getValue();

            double max;

            if (v instanceof Map) {
                Map map = (Map)v;

                max = Collections.max(map.values());
            }
            else if (v instanceof MatrixBlockEntry) {
                MatrixBlockEntry be = (MatrixBlockEntry)v;

                max = be.maxValue();
            }
            else
                throw new UnsupportedOperationException();

            if (acc == null)
                return max;
            else
                return Math.max(acc, max);

        }, sparseKeyFilter(matrixUuid));

        return Collections.max(maxes);
    }

    /**
     * @param cacheName Cache name.
     * @param keyMapper {@link KeyMapper} to validate cache key.
     * @param valMapper {@link ValueMapper} to obtain double value for given cache key.
     * @param  Cache key object type.
     * @param  Cache value object type.
     * @return Maximum value for valid keys.
     */
    public static  double max(String cacheName, KeyMapper keyMapper, ValueMapper valMapper) {
        Collection maxes = fold(cacheName, (CacheEntry ce, Double acc) -> {
            if (keyMapper.isValid(ce.entry().getKey())) {
                double v = valMapper.toDouble(ce.entry().getValue());

                if (acc == null)
                    return v;
                else
                    return Math.max(acc, v);
            }
            else
                return acc;
        });

        return Collections.max(maxes);
    }

    /**
     * @param cacheName Cache name.
     * @param keyMapper {@link KeyMapper} to validate cache key.
     * @param valMapper {@link ValueMapper} to obtain double value for given cache key.
     * @param mapper Mapping {@link IgniteFunction}.
     * @param  Cache key object type.
     * @param  Cache value object type.
     */
    public static  void map(String cacheName, KeyMapper keyMapper, ValueMapper valMapper,
        IgniteFunction mapper) {
        foreach(cacheName, (CacheEntry ce) -> {
            K k = ce.entry().getKey();

            if (keyMapper.isValid(k))
                // Actual assignment.
                ce.cache().put(k, valMapper.fromDouble(mapper.apply(valMapper.toDouble(ce.entry().getValue()))));
        });
    }

    /**
     * @param matrixUuid Matrix UUID.
     * @param mapper Mapping {@link IgniteFunction}.
     */
    @SuppressWarnings("unchecked")
    public static  void sparseMap(UUID matrixUuid, IgniteDoubleFunction mapper, String cacheName) {
        A.notNull(matrixUuid, "matrixUuid");
        A.notNull(cacheName, "cacheName");
        A.notNull(mapper, "mapper");

        foreach(cacheName, (CacheEntry ce) -> {
            K k = ce.entry().getKey();

            V v = ce.entry().getValue();

            if (v instanceof Map) {
                Map map = (Map)v;

                for (Map.Entry e : (map.entrySet()))
                    e.setValue(mapper.apply(e.getValue()));

            }
            else if (v instanceof MatrixBlockEntry) {
                MatrixBlockEntry be = (MatrixBlockEntry)v;

                be.map(mapper);
            }
            else
                throw new UnsupportedOperationException();

            ce.cache().put(k, v);
        }, sparseKeyFilter(matrixUuid));
    }

    /**
     * Filter for distributed matrix keys.
     *
     * @param matrixUuid Matrix uuid.
     */
    private static  IgnitePredicate sparseKeyFilter(UUID matrixUuid) {
        return key -> {
            if (key instanceof DataStructureCacheKey)
                return ((DataStructureCacheKey)key).dataStructureId().equals(matrixUuid);
            else if (key instanceof IgniteBiTuple)
                return ((IgniteBiTuple)key).get2().equals(matrixUuid);
            else if (key instanceof MatrixBlockKey)
                return ((MatrixBlockKey)key).dataStructureId().equals(matrixUuid);
            else if (key instanceof RowColMatrixKey)
                return ((RowColMatrixKey)key).dataStructureId().equals(matrixUuid);
            else if (key instanceof VectorBlockKey)
                return ((VectorBlockKey)key).dataStructureId().equals(matrixUuid);
            else
                throw new UnsupportedOperationException();
        };
    }

    /**
     * @param cacheName Cache name.
     * @param fun An operation that accepts a cache entry and processes it.
     * @param  Cache key object type.
     * @param  Cache value object type.
     */
    private static  void foreach(String cacheName, IgniteConsumer> fun) {
        foreach(cacheName, fun, null);
    }

    /**
     * @param cacheName Cache name.
     * @param fun An operation that accepts a cache entry and processes it.
     * @param keyFilter Cache keys filter.
     * @param  Cache key object type.
     * @param  Cache value object type.
     */
    protected static  void foreach(String cacheName, IgniteConsumer> fun,
        IgnitePredicate keyFilter) {
        bcast(cacheName, () -> {
            Ignite ignite = Ignition.localIgnite();
            IgniteCache cache = ignite.getOrCreateCache(cacheName);

            int partsCnt = ignite.affinity(cacheName).partitions();

            // Use affinity in filter for scan query. Otherwise we accept consumer in each node which is wrong.
            Affinity affinity = ignite.affinity(cacheName);
            ClusterNode locNode = ignite.cluster().localNode();

            // Iterate over all partitions. Some of them will be stored on that local node.
            for (int part = 0; part < partsCnt; part++) {
                int p = part;

                // Iterate over given partition.
                // Query returns an empty cursor if this partition is not stored on this node.
                for (Cache.Entry entry : cache.query(new ScanQuery(part,
                    (k, v) -> affinity.mapPartitionToNode(p) == locNode && (keyFilter == null || keyFilter.apply(k)))))
                    fun.accept(new CacheEntry<>(entry, cache));
            }
        });
    }

    /**
     * @param cacheName Cache name.
     * @param fun An operation that accepts a cache entry and processes it.
     * @param ignite Ignite.
     * @param keysGen Keys generator.
     * @param  Cache key object type.
     * @param  Cache value object type.
     */
    public static  void update(String cacheName, Ignite ignite,
        IgniteBiFunction, Stream>> fun, IgniteSupplier> keysGen) {
        bcast(cacheName, ignite, () -> {
            Ignite ig = Ignition.localIgnite();
            IgniteCache cache = ig.getOrCreateCache(cacheName);

            Affinity affinity = ig.affinity(cacheName);
            ClusterNode locNode = ig.cluster().localNode();

            Collection ks = affinity.mapKeysToNodes(keysGen.get()).get(locNode);

            if (ks == null)
                return;

            Map m = new ConcurrentHashMap<>();

            ks.parallelStream().forEach(k -> {
                V v = cache.localPeek(k);
                if (v != null)
                    (fun.apply(ignite, new CacheEntryImpl<>(k, v))).forEach(ent -> m.put(ent.getKey(), ent.getValue()));
            });

            cache.putAll(m);
        });
    }

    /**
     * @param cacheName Cache name.
     * @param fun An operation that accepts a cache entry and processes it.
     * @param ignite Ignite.
     * @param keysGen Keys generator.
     * @param  Cache key object type.
     * @param  Cache value object type.
     */
    public static  void update(String cacheName, Ignite ignite, IgniteConsumer> fun,
        IgniteSupplier> keysGen) {
        bcast(cacheName, ignite, () -> {
            Ignite ig = Ignition.localIgnite();
            IgniteCache cache = ig.getOrCreateCache(cacheName);

            Affinity affinity = ig.affinity(cacheName);
            ClusterNode locNode = ig.cluster().localNode();

            Collection ks = affinity.mapKeysToNodes(keysGen.get()).get(locNode);

            if (ks == null)
                return;

            Map m = new ConcurrentHashMap<>();

            for (K k : ks) {
                V v = cache.localPeek(k);
                fun.accept(new CacheEntryImpl<>(k, v));
                m.put(k, v);
            }

            cache.putAll(m);
        });
    }

    /**
     * Currently fold supports only commutative operations.
     *
     * @param cacheName Cache name.
     * @param folder Fold function operating over cache entries.
     * @param  Cache key object type.
     * @param  Cache value object type.
     * @param  Fold result type.
     * @return Fold operation result.
     */
    public static  Collection fold(String cacheName, IgniteBiFunction, A, A> folder) {
        return fold(cacheName, folder, null);
    }

    /**
     * Currently fold supports only commutative operations.
     *
     * @param cacheName Cache name.
     * @param folder Fold function operating over cache entries.
     * @param  Cache key object type.
     * @param  Cache value object type.
     * @param  Fold result type.
     * @return Fold operation result.
     */
    public static  Collection fold(String cacheName, IgniteBiFunction, A, A> folder,
        IgnitePredicate keyFilter) {
        return bcast(cacheName, () -> {
            Ignite ignite = Ignition.localIgnite();
            IgniteCache cache = ignite.getOrCreateCache(cacheName);

            int partsCnt = ignite.affinity(cacheName).partitions();

            // Use affinity in filter for ScanQuery. Otherwise we accept consumer in each node which is wrong.
            Affinity affinity = ignite.affinity(cacheName);
            ClusterNode locNode = ignite.cluster().localNode();

            A a = null;

            // Iterate over all partitions. Some of them will be stored on that local node.
            for (int part = 0; part < partsCnt; part++) {
                int p = part;

                // Iterate over given partition.
                // Query returns an empty cursor if this partition is not stored on this node.
                for (Cache.Entry entry : cache.query(new ScanQuery(part,
                    (k, v) -> affinity.mapPartitionToNode(p) == locNode && (keyFilter == null || keyFilter.apply(k)))))
                    a = folder.apply(new CacheEntry<>(entry, cache), a);
            }

            return a;
        });
    }

    /**
     * Distributed version of fold operation.
     *
     * @param cacheName Cache name.
     * @param folder Folder.
     * @param keyFilter Key filter.
     * @param accumulator Accumulator.
     * @param zeroValSupp Zero value supplier.
     */
    public static  A distributedFold(String cacheName, IgniteBiFunction, A, A> folder,
        IgnitePredicate keyFilter, BinaryOperator accumulator, IgniteSupplier zeroValSupp) {
        return sparseFold(cacheName, folder, keyFilter, accumulator, zeroValSupp, null, null, 0,
            false);
    }

    /**
     * Sparse version of fold. This method also applicable to sparse zeroes.
     *
     * @param cacheName Cache name.
     * @param folder Folder.
     * @param keyFilter Key filter.
     * @param accumulator Accumulator.
     * @param zeroValSupp Zero value supplier.
     * @param defVal Default value.
     * @param defKey Default key.
     * @param defValCnt Def value count.
     * @param isNilpotent Is nilpotent.
     */
    private static  A sparseFold(String cacheName, IgniteBiFunction, A, A> folder,
        IgnitePredicate keyFilter, BinaryOperator accumulator, IgniteSupplier zeroValSupp, V defVal, K defKey,
        long defValCnt, boolean isNilpotent) {

        A defRes = zeroValSupp.get();

        if (!isNilpotent)
            for (int i = 0; i < defValCnt; i++)
                defRes = folder.apply(new CacheEntryImpl<>(defKey, defVal), defRes);

        Collection totalRes = bcast(cacheName, () -> {
            Ignite ignite = Ignition.localIgnite();
            IgniteCache cache = ignite.getOrCreateCache(cacheName);

            int partsCnt = ignite.affinity(cacheName).partitions();

            // Use affinity in filter for ScanQuery. Otherwise we accept consumer in each node which is wrong.
            Affinity affinity = ignite.affinity(cacheName);
            ClusterNode locNode = ignite.cluster().localNode();

            A a = zeroValSupp.get();

            // Iterate over all partitions. Some of them will be stored on that local node.
            for (int part = 0; part < partsCnt; part++) {
                int p = part;

                // Iterate over given partition.
                // Query returns an empty cursor if this partition is not stored on this node.
                for (Cache.Entry entry : cache.query(new ScanQuery(part,
                    (k, v) -> affinity.mapPartitionToNode(p) == locNode && (keyFilter == null || keyFilter.apply(k)))))
                    a = folder.apply(entry, a);
            }

            return a;
        });
        return totalRes.stream().reduce(defRes, accumulator);
    }

    /**
     * Distributed version of fold operation. This method also applicable to sparse zeroes.
     *
     * @param cacheName Cache name.
     * @param ignite ignite
     * @param acc Accumulator
     * @param supp supplier
     * @param entriesGen entries generator
     * @param comb combiner
     * @param zeroValSupp Zero value supplier.
     * @return aggregated result
     */
    public static  A reduce(String cacheName, Ignite ignite,
        IgniteTriFunction, A, A> acc,
        IgniteSupplier supp,
        IgniteSupplier>> entriesGen, IgniteBinaryOperator comb,
        IgniteSupplier zeroValSupp) {

        A defRes = zeroValSupp.get();

        Collection totalRes = bcast(cacheName, ignite, () -> {
            // Use affinity in filter for ScanQuery. Otherwise we accept consumer in each node which is wrong.
            A a = zeroValSupp.get();

            W w = supp.get();

            for (Cache.Entry kvEntry : entriesGen.get())
                a = acc.apply(w, kvEntry, a);

            return a;
        });

        return totalRes.stream().reduce(defRes, comb);
    }

    /**
     * Distributed version of fold operation.
     *
     * @param cacheName Cache name.
     * @param acc Accumulator
     * @param supp supplier
     * @param entriesGen entries generator
     * @param comb combiner
     * @param zeroValSupp Zero value supplier
     * @return aggregated result
     */
    public static  A reduce(String cacheName,
        IgniteTriFunction, A, A> acc,
        IgniteSupplier supp,
        IgniteSupplier>> entriesGen,
        IgniteBinaryOperator comb,
        IgniteSupplier zeroValSupp) {
        return reduce(cacheName, Ignition.localIgnite(), acc, supp, entriesGen, comb, zeroValSupp);
    }

    /**
     * @param cacheName Cache name.
     * @param run {@link Runnable} to broadcast to cache nodes for given cache name.
     */
    public static void bcast(String cacheName, Ignite ignite, IgniteRunnable run) {
        ignite.compute(ignite.cluster().forDataNodes(cacheName)).broadcast(run);
    }

    /**
     * Broadcast runnable to data nodes of given cache.
     *
     * @param cacheName Cache name.
     * @param run Runnable.
     */
    public static void bcast(String cacheName, IgniteRunnable run) {
        bcast(cacheName, ignite(), run);
    }

    /**
     * @param cacheName Cache name.
     * @param call {@link IgniteCallable} to broadcast to cache nodes for given cache name.
     * @param  Type returned by the callable.
     */
    public static  Collection bcast(String cacheName, IgniteCallable call) {
        return bcast(cacheName, ignite(), call);
    }

    /**
     * Broadcast callable to data nodes of given cache.
     *
     * @param cacheName Cache name.
     * @param ignite Ignite instance.
     * @param call Callable to broadcast.
     * @param  Type of callable result.
     * @return Results of callable from each node.
     */
    public static  Collection bcast(String cacheName, Ignite ignite, IgniteCallable call) {
        return ignite.compute(ignite.cluster().forDataNodes(cacheName)).broadcast(call);
    }

    /**
     * @param vectorUuid Matrix UUID.
     * @param mapper Mapping {@link IgniteFunction}.
     */
    @SuppressWarnings("unchecked")
    public static  void sparseMapForVector(UUID vectorUuid, IgniteDoubleFunction mapper, String cacheName) {
        A.notNull(vectorUuid, "vectorUuid");
        A.notNull(cacheName, "cacheName");
        A.notNull(mapper, "mapper");

        foreach(cacheName, (CacheEntry ce) -> {
            K k = ce.entry().getKey();

            V v = ce.entry().getValue();

            if (v instanceof VectorBlockEntry) {
                VectorBlockEntry entry = (VectorBlockEntry)v;

                for (int i = 0; i < entry.size(); i++)
                    entry.set(i, (Double)mapper.apply(entry.get(i)));

                ce.cache().put(k, (V)entry);
            }
            else {
                V mappingRes = mapper.apply((Double)v);
                ce.cache().put(k, mappingRes);
            }
        }, sparseKeyFilter(vectorUuid));
    }
}