com.simiacryptus.mindseye.layers.tensorflow.TFLayerBase Maven / Gradle / Ivy
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.layers.tensorflow;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.protobuf.InvalidProtocolBufferException;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.lang.tensorflow.TFIO;
import com.simiacryptus.mindseye.lang.tensorflow.TFUtil;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.*;
import com.simiacryptus.tensorflow.TensorboardEventWriter;
import com.simiacryptus.tensorflow.TensorflowUtil;
import com.simiacryptus.util.Util;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.*;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.Summary;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
public abstract class TFLayerBase extends LayerBase {
private static final Logger log = LoggerFactory.getLogger(TFLayer.class);
@Nullable
public static TensorboardEventWriter eventWriter = null;
private final RefMap weights = new RefHashMap<>();
public TFLayerBase(@Nonnull JsonObject json, Map rs) {
super(json);
Set dataKeys = getDataKeys(json);
for (String key : dataKeys) {
RefMap weights = this.getWeights();
assert weights != null;
RefUtil.freeRef(weights.put(key, Tensor.fromJson(json.get(key), rs)));
weights.freeRef();
}
}
public TFLayerBase(@Nullable RefMap states) {
RefMap weights = this.getWeights();
assert weights != null;
weights.putAll(states);
weights.freeRef();
}
public abstract GraphDef getGraphDef();
@Nullable
public abstract List getInputNodes();
public abstract String getOutputNode();
@Nullable
public abstract String getSummaryOut();
@Nullable
public RefMap getWeights() {
return weights == null ? null : weights.addRef();
}
@Nonnull
public TFLayer asConstLayer() {
return new TFLayer(constGraph().toByteArray(), new RefHashMap<>(), getOutputNode(),
getInputNodes().toArray(new String[]{}));
}
public @Nonnull
GraphDef constGraph() {
return TFUtil.implantConstants(getGraphDef(), getWeights());
}
@Override
public JsonObject getJson(Map resources, @Nonnull DataSerializer dataSerializer) {
JsonObject json = getJsonStub();
RefMap weights = getWeights();
assert weights != null;
weights.forEach((key, tensor) -> {
JsonElement tensorJson = tensor.getJson(resources, dataSerializer);
tensor.freeRef();
json.add(key, tensorJson);
});
weights.freeRef();
return json;
}
@Nullable
@Override
public RefList state() {
RefCollection values = weights.values();
RefList dataList = values.stream().map(x -> {
try {
return x.getData();
} finally {
x.freeRef();
}
}).collect(RefCollectors.toList());
values.freeRef();
return dataList;
}
@Nullable
@Override
public Result eval(@Nullable Result... inputs) {
return eval(new TFSession(addRef()), inputs);
}
public void close() {
}
public boolean invertWeights() {
return true;
}
public @Nonnull
GraphDef getConstGraph(GraphDef graphDef) {
return TFUtil.implantConstants(graphDef, getWeights());
}
public void _free() {
if (null != weights)
weights.freeRef();
close();
super._free();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
TFLayerBase addRef() {
return (TFLayerBase) super.addRef();
}
@Nonnull
Result eval(@Nonnull TFSession tfsession, @Nonnull Result... inputs) {
RefMap weights = getWeights();
assert weights != null;
RefSet keySet = weights.keySet();
RefList stateNames = keySet.stream().collect(RefCollectors.toList());
keySet.freeRef();
Session.Runner runner = tfsession.session.runner();
RefArrayList> tensors = setTensors(runner, weights, RefUtil.addRef(inputs));
boolean summaryOut = run(runner);
TensorArray resultData = getOutput(runner, tensors, summaryOut);
Accumulator accumulator = new Accumulator(runner, summaryOut ? 2 : 1, stateNames, this.getId(), this.getWeights(),
this.getOutputNode(), this.invertWeights(), this.getInputNodes(),
this.floatInputs(), tfsession.getGradients(), tfsession, inputs);
return new Result(resultData, accumulator);
}
@Nonnull
protected abstract Set getDataKeys(JsonObject json);
protected boolean floatInputs() {
return false;
}
private boolean run(Session.Runner runner) {
runner.fetch(getOutputNode());
boolean summaryOut = null != eventWriter && null != getSummaryOut() && !getSummaryOut().isEmpty();
if (summaryOut) {
runner.fetch(getSummaryOut());
}
return summaryOut;
}
@NotNull
private RefArrayList> setTensors(Session.Runner runner, RefMap weights,
@Nonnull Result[] inputs) {
RefArrayList> tensors = new RefArrayList<>();
weights.forEach((nodeName, data) -> {
@Nonnull
org.tensorflow.Tensor extends Number> tensor;
boolean invertRanks = invertWeights();
if (floatInputs()) {
tensor = TFIO.getFloatTensor(data, invertRanks);
} else {
tensor = TFIO.getDoubleTensor(data, invertRanks);
}
runner.feed(nodeName, tensor);
tensors.add(tensor);
});
weights.freeRef();
final List inputNodes = getInputNodes();
assert inputNodes != null;
for (int i = 0; i < inputNodes.size(); i++) {
String inputNode = inputNodes.get(i);
TensorList data = inputs[i].getData();
@Nonnull
org.tensorflow.Tensor extends Number> tensor;
if (floatInputs()) {
tensor = TFIO.getFloatTensor(data, true);
} else {
tensor = TFIO.getDoubleTensor(data, true);
}
runner.feed(inputNode, tensor);
tensors.add(tensor);
}
RefUtil.freeRef(inputs);
return tensors;
}
@NotNull
private TensorArray getOutput(Session.Runner runner, RefArrayList> tensors,
boolean summaryOut) {
Session.Run fwd;
try {
fwd = runner.runAndFetchMetadata();
} catch (IllegalArgumentException e) {
throw e;
}
org.tensorflow.Tensor> tensor = fwd.outputs.get(0);
TensorArray resultData = TFIO.getTensorList(tensor);
tensors.add(tensor);
tensors.freeRef();
if (summaryOut) {
final Summary summary;
try {
summary = Summary.parseFrom(fwd.outputs.get(1).expect(String.class).bytesValue());
} catch (InvalidProtocolBufferException e) {
throw Util.throwException(e);
}
try {
if (null != eventWriter)
eventWriter.write(summary);
} catch (IOException e) {
throw Util.throwException(e);
}
}
return resultData;
}
static class TFSession extends ReferenceCountingBase {
@Nonnull
public final Graph graph;
public final Singleton © 2015 - 2025 Weber Informatics LLC | Privacy Policy