org.deeplearning4j.nn.gradient.DefaultGradient Maven / Gradle / Ivy
/*-
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
/**
* Default gradient implementation. Basically lookup table
* for ndarrays
*
* @author Adam Gibson
*/
public class DefaultGradient implements Gradient {
public static final char DEFAULT_FLATTENING_ORDER = 'f';
private Map gradients = new LinkedHashMap<>();
private Map flatteningOrders;
private INDArray flattenedGradient;
public DefaultGradient() {}
public DefaultGradient(INDArray flattenedGradient) {
this.flattenedGradient = flattenedGradient;
}
@Override
public Map gradientForVariable() {
return gradients;
}
@Override
public INDArray gradient(List order) {
List toFlatten = new ArrayList<>();
if (flatteningOrders == null) {
for (String s : order) {
if (!gradients.containsKey(s))
continue;
toFlatten.add(gradients.get(s));
}
} else {
for (String s : order) {
if (!gradients.containsKey(s))
continue;
if (flatteningOrders.containsKey(s) && flatteningOrders.get(s) != DEFAULT_FLATTENING_ORDER) {
//Arrays with non-default order get flattened to row vector first, then everything is flattened to f order
//TODO revisit this, and make more efficient
toFlatten.add(Nd4j.toFlattened(flatteningOrders.get(s), gradients.get(s)));
} else {
toFlatten.add(gradients.get(s));
}
}
}
return Nd4j.toFlattened(DEFAULT_FLATTENING_ORDER, toFlatten);
}
private void flattenGradient() {
if (flatteningOrders != null) {
//Arrays with non-default order get flattened to row vector first, then everything is flattened to f order
//TODO revisit this, and make more efficient
List toFlatten = new ArrayList<>();
for (Map.Entry entry : gradients.entrySet()) {
if (flatteningOrders.containsKey(entry.getKey())
&& flatteningOrders.get(entry.getKey()) != DEFAULT_FLATTENING_ORDER) {
//Specific flattening order for this array, that isn't the default
toFlatten.add(Nd4j.toFlattened(flatteningOrders.get(entry.getKey()), entry.getValue()));
} else {
//default flattening order for this array
toFlatten.add(entry.getValue());
}
}
flattenedGradient = Nd4j.toFlattened(DEFAULT_FLATTENING_ORDER, toFlatten);
} else {
//Standard case: flatten all to f order
flattenedGradient = Nd4j.toFlattened(DEFAULT_FLATTENING_ORDER, gradients.values());
}
}
@Override
public INDArray gradient() {
if (flattenedGradient != null)
return flattenedGradient;
flattenGradient();
return flattenedGradient;
}
@Override
public void clear() {
gradients.clear();
}
@Override
public INDArray getGradientFor(String variable) {
return gradients.get(variable);
}
@Override
public INDArray setGradientFor(String variable, INDArray newGradient) {
INDArray last = gradients.put(variable, newGradient);
// TODO revisit whether setGradientFor should update the gradient that can be pulled from this object in any form - currently does not update flattened
// use of unitialized var for flattengradient in backprop is generating an error in gradient calc if bellow is used
// flattenGradient();
return last;
}
@Override
public INDArray setGradientFor(String variable, INDArray gradient, Character flatteningOrder) {
INDArray last = setGradientFor(variable, gradient);
if (flatteningOrder != null) {
if (flatteningOrders == null)
flatteningOrders = new LinkedHashMap<>();
flatteningOrders.put(variable, flatteningOrder);
}
return last;
}
@Override
public Character flatteningOrderForVariable(String variable) {
if (flatteningOrders == null)
return null;
return flatteningOrders.get(variable);
}
@Override
public String toString() {
return "DefaultGradient{" + "gradients=" + gradients + (flatteningOrders != null ? flatteningOrders : "") + '}';
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy