All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.sequencevectors.listeners.SerializingListener Maven / Gradle / Ivy

/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.models.sequencevectors.listeners;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.enums.ListenerEvent;
import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.nd4j.common.util.SerializationUtils;

import java.io.File;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.concurrent.Semaphore;

/**
 *
 * This is example VectorsListener implementation. It can be used to serialize models in the middle of training process
 *
 * @author [email protected]
 */
@Slf4j
public class SerializingListener implements VectorsListener {
    private File targetFolder = new File("./");
    private String modelPrefix = "Model_";
    private boolean useBinarySerialization = true;
    private ListenerEvent targetEvent = ListenerEvent.EPOCH;
    private int targetFrequency = 100000;

    private Semaphore locker = new Semaphore(1);

    protected SerializingListener() {}

    /**
     * This method is called prior each processEvent call, to check if this specific VectorsListener implementation is viable for specific event
     *
     * @param event
     * @param argument
     * @return TRUE, if this event can and should be processed with this listener, FALSE otherwise
     */
    @Override
    public boolean validateEvent(ListenerEvent event, long argument) {
        try {
            /**
             * please note, since sequence vectors are multithreaded we need to stop processed while model is being saved
             */
            locker.acquire();

            if (event == targetEvent && argument % targetFrequency == 0) {
                return true;
            } else
                return false;
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {
            locker.release();
        }
    }

    /**
     * This method is called at each epoch end
     *
     * @param event
     * @param sequenceVectors
     * @param argument
     */
    @Override
    public void processEvent(ListenerEvent event, SequenceVectors sequenceVectors, long argument) {
        try {
            locker.acquire();

            SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS");

            StringBuilder builder = new StringBuilder(targetFolder.getAbsolutePath());
            builder.append("/").append(modelPrefix).append("_").append(sdf.format(new Date())).append(".seqvec");
            File targetFile = new File(builder.toString());

            if (useBinarySerialization) {
                SerializationUtils.saveObject(sequenceVectors, targetFile);
            } else {
                throw new UnsupportedOperationException("Not implemented yet");
            }

        } catch (Exception e) {
            log.error("",e);
        } finally {
            locker.release();
        }
    }

    public static class Builder {
        private File targetFolder = new File("./");
        private String modelPrefix = "Model_";
        private boolean useBinarySerialization = true;
        private ListenerEvent targetEvent = ListenerEvent.EPOCH;
        private int targetFrequency = 100000;

        public Builder(ListenerEvent targetEvent, int frequency) {
            this.targetEvent = targetEvent;
            this.targetFrequency = frequency;
        }

        /**
         * This method allows you to define template for file names that will be created during serialization
         * @param reallyUse
         * @return
         */
        public Builder setFilenamePrefix(boolean reallyUse) {
            this.useBinarySerialization = reallyUse;
            return this;
        }

        /**
         * This method specifies target folder where models should be saved
         *
         * @param folder
         * @return
         */
        public Builder setTargetFolder(@NonNull String folder) {
            this.setTargetFolder(new File(folder));
            return this;
        }

        /**
         * This method specifies target folder where models should be saved
         *
         * @param folder
         * @return
         */
        public Builder setTargetFolder(@NonNull File folder) {
            if (!folder.exists() || !folder.isDirectory())
                throw new IllegalStateException("Target folder must exist!");
            this.targetFolder = folder;
            return this;
        }

        /**
         * This method returns new SerializingListener instance
         *
         * @return
         */
        public SerializingListener build() {
            SerializingListener listener = new SerializingListener<>();
            listener.modelPrefix = this.modelPrefix;
            listener.targetFolder = this.targetFolder;
            listener.useBinarySerialization = this.useBinarySerialization;
            listener.targetEvent = this.targetEvent;
            listener.targetFrequency = this.targetFrequency;

            return listener;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy