org.nd4j.common.function.FunctionalUtils Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.common.function;
import org.nd4j.common.primitives.Pair;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class FunctionalUtils {
/**
* For each key in left and right, cogroup returns the list of values
* as a pair for each value present in left as well as right.
* @param left the left list of pairs to join
* @param right the right list of pairs to join
* @param the key type
* @param the value type
* @return a map of the list of values by key for each value in the left as well as the right
* with each element in the pair representing the values in left and right respectively.
*/
public static Map,List>> cogroup(List> left,List> right) {
Map,List>> ret = new HashMap<>();
//group by key first to consolidate values
Map> leftMap = groupByKey(left);
Map> rightMap = groupByKey(right);
/**
* Iterate over each key in the list
* adding values to the left items
* as values are found in the list.
*/
for(Map.Entry> entry : leftMap.entrySet()) {
K key = entry.getKey();
if(!ret.containsKey(key)) {
List leftListPair = new ArrayList<>();
List rightListPair = new ArrayList<>();
Pair,List> p = Pair.of(leftListPair,rightListPair);
ret.put(key,p);
}
Pair,List> p = ret.get(key);
p.getFirst().addAll(entry.getValue());
}
/**
* Iterate over each key in the list
* adding values to the right items
* as values are found in the list.
*/
for(Map.Entry> entry : rightMap.entrySet()) {
K key = entry.getKey();
if(!ret.containsKey(key)) {
List leftListPair = new ArrayList<>();
List rightListPair = new ArrayList<>();
Pair,List> p = Pair.of(leftListPair,rightListPair);
ret.put(key,p);
}
Pair,List> p = ret.get(key);
p.getSecond().addAll(entry.getValue());
}
return ret;
}
/**
* Group the input pairs by the key of each pair.
* @param listInput the list of pairs to group
* @param the key type
* @param the value type
* @return a map representing a grouping of the
* keys by the given input key type and list of values
* in the grouping.
*/
public static Map> groupByKey(List> listInput) {
Map> ret = new HashMap<>();
for(Pair pair : listInput) {
List currList = ret.get(pair.getFirst());
if(currList == null) {
currList = new ArrayList<>();
ret.put(pair.getFirst(),currList);
}
currList.add(pair.getSecond());
}
return ret;
}
/**
* Convert a map with a set of entries of type K for key
* and V for value in to a list of {@link Pair}
* @param map the map to collapse
* @param the key type
* @param the value type
* @return the collapsed map as a {@link List}
*/
public static List> mapToPair(Map map) {
List> ret = new ArrayList<>(map.size());
for(Map.Entry entry : map.entrySet()) {
ret.add(Pair.of(entry.getKey(),entry.getValue()));
}
return ret;
}
}