All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.deeplearning4j.nn.params.SameDiffParamInitializer Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.params;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.util.ArrayUtil;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
@Slf4j
public class SameDiffParamInitializer implements ParamInitializer {
private static final SameDiffParamInitializer INSTANCE = new SameDiffParamInitializer();
public static SameDiffParamInitializer getInstance() {
return INSTANCE;
}
@Override
public long numParams(NeuralNetConfiguration conf) {
return numParams(conf.getLayer());
}
@Override
public long numParams(Layer layer) {
AbstractSameDiffLayer sd = (AbstractSameDiffLayer)layer;
Map m = sd.getLayerParams().getParamShapes();
int n = 0;
for(val arr : m.values()){
n += ArrayUtil.prod(arr);
}
return n;
}
@Override
public List paramKeys(Layer layer) {
AbstractSameDiffLayer sd = (AbstractSameDiffLayer)layer;
return sd.getLayerParams().getParameterKeys();
}
@Override
public List weightKeys(Layer layer) {
AbstractSameDiffLayer sd = (AbstractSameDiffLayer)layer;
return sd.getLayerParams().getWeightParameterKeys();
}
@Override
public List biasKeys(Layer layer) {
AbstractSameDiffLayer sd = (AbstractSameDiffLayer)layer;
return sd.getLayerParams().getBiasParameterKeys();
}
@Override
public boolean isWeightParam(Layer layer, String key) {
return weightKeys(layer).contains(key);
}
@Override
public boolean isBiasParam(Layer layer, String key) {
return biasKeys(layer).contains(key);
}
@Override
public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
AbstractSameDiffLayer sd = (AbstractSameDiffLayer) conf.getLayer();
Map out = subsetAndReshape(sd.getLayerParams().getParameterKeys(),
sd.getLayerParams().getParamShapes(), paramsView, sd);
if(initializeParams){
sd.initializeParameters(out);
}
for(String s : sd.getLayerParams().getParameterKeys()){
conf.addVariable(s);
}
return out;
}
@Override
public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
AbstractSameDiffLayer sd = (AbstractSameDiffLayer) conf.getLayer();
return subsetAndReshape(sd.getLayerParams().getParameterKeys(), sd.getLayerParams().getParamShapes(),
gradientView, sd);
}
private Map subsetAndReshape(List params, Map paramShapes, INDArray view,
AbstractSameDiffLayer sdl){
return subsetAndReshape(params, paramShapes, view, sdl, null);
}
public Map subsetAndReshape(List params, Map paramShapes, INDArray view,
AbstractSameDiffLayer sdl, SameDiffVertex sdv){
Class clazz = (sdl != null ? sdl.getClass() : sdv.getClass());
String layerName = (sdl != null ? sdl.getLayerName() : ""); //TODO
Map out = new LinkedHashMap<>();
int soFar = 0;
for(String s : params){
val sh = paramShapes.get(s);
val length = ArrayUtil.prodLong(sh);
if(length <= 0){
throw new IllegalStateException("Invalid array state for parameter \"" + s + "\" in layer " + layerName
+ " of type " + clazz.getSimpleName() + ": parameter length (" + length
+ ") must be > 0 - parameter array shape: " + Arrays.toString(sh));
}
INDArray viewReshape = view.reshape(view.length());
INDArray sub = viewReshape.get(interval(soFar, soFar + length));
if(!Arrays.equals(sub.shape(), sh)){
char order = (sdl != null ? sdl.paramReshapeOrder(s) : sdv.paramReshapeOrder(s));
sub = sub.reshape(order, sh);
}
out.put(s, sub);
soFar += length;
}
return out;
}
}