
org.deeplearning4j.parallelism.parameterserver.ParameterServerTrainer Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.parallelism.parameterserver;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.parallelism.trainer.DefaultTrainer;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.parameterserver.client.ParameterServerClient;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
@Builder
@Slf4j
@AllArgsConstructor
@NoArgsConstructor
public class ParameterServerTrainer extends DefaultTrainer {
private ParameterServerClient parameterServerClient;
@Override
public void feedMultiDataSet(@NonNull MultiDataSet dataSet, long time) {
// FIXME: this is wrong, and should be fixed
if (getModel() instanceof ComputationGraph) {
ComputationGraph computationGraph = (ComputationGraph) getModel();
computationGraph.fit(dataSet);
} else {
throw new IllegalArgumentException("MultiLayerNetworks can't fit multi datasets");
}
log.info("Sending parameters");
//send the updated params
parameterServerClient.pushNDArray(getModel().params());
}
@Override
public void feedDataSet(@NonNull DataSet dataSet, long time) {
// FIXME: this is wrong, and should be fixed. Training should happen within run() loop
if (getModel() instanceof ComputationGraph) {
ComputationGraph computationGraph = (ComputationGraph) getModel();
computationGraph.fit(dataSet);
} else {
MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) getModel();
log.info("Calling fit on multi layer network");
multiLayerNetwork.fit(dataSet);
}
log.info("About to send params in");
//send the updated params
parameterServerClient.pushNDArray(getModel().params());
log.info("Sent params");
}
@Override
public Model getModel() {
return super.getModel();
}
@Override
public void updateModel(@NonNull Model model) {
super.updateModel(model);
}
public static class ParameterServerTrainerBuilder extends DefaultTrainerBuilder {
@Override
public ParameterServerTrainerBuilder originalModel(Model originalModel) {
return (ParameterServerTrainerBuilder) super.originalModel(originalModel);
}
@Override
public ParameterServerTrainerBuilder replicatedModel(Model replicatedModel) {
return (ParameterServerTrainerBuilder) super.replicatedModel(replicatedModel);
}
@Override
public ParameterServerTrainerBuilder queue(LinkedBlockingQueue queue) {
return (ParameterServerTrainerBuilder) super.queue(queue);
}
@Override
public ParameterServerTrainerBuilder queueMDS(LinkedBlockingQueue queueMDS) {
return (ParameterServerTrainerBuilder) super.queueMDS(queueMDS);
}
@Override
public ParameterServerTrainerBuilder running(AtomicInteger running) {
return (ParameterServerTrainerBuilder) super.running(running);
}
@Override
public ParameterServerTrainerBuilder threadId(int threadId) {
return (ParameterServerTrainerBuilder) super.threadId(threadId);
}
@Override
public ParameterServerTrainerBuilder shouldUpdate(AtomicBoolean shouldUpdate) {
return (ParameterServerTrainerBuilder) super.shouldUpdate(shouldUpdate);
}
@Override
public ParameterServerTrainerBuilder shouldStop(AtomicBoolean shouldStop) {
return (ParameterServerTrainerBuilder) super.shouldStop(shouldStop);
}
@Override
public ParameterServerTrainerBuilder thrownException(Exception thrownException) {
return (ParameterServerTrainerBuilder) super.thrownException(thrownException);
}
@Override
public ParameterServerTrainerBuilder useMDS(boolean useMDS) {
return (ParameterServerTrainerBuilder) super.useMDS(useMDS);
}
@Override
public ParameterServerTrainerBuilder onRootModel(boolean onRootModel) {
return (ParameterServerTrainerBuilder) super.onRootModel(onRootModel);
}
@Override
public ParameterServerTrainerBuilder parallelWrapper(ParallelWrapper parallelWrapper) {
return (ParameterServerTrainerBuilder) super.parallelWrapper(parallelWrapper);
}
@Override
public ParameterServerTrainerBuilder averagingFrequency(int frequency) {
return (ParameterServerTrainerBuilder) super.averagingFrequency(frequency);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy