All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.sequencevectors.transformers.impl.GraphTransformer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.models.sequencevectors.transformers.impl;

import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.graph.primitives.IGraph;
import org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.labels.LabelsProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Iterator;
import java.util.concurrent.atomic.AtomicInteger;

/**
 *
 * This class is used to build vocabulary and sequences out of graph, via GraphWalkers
 *
 * @author [email protected]
 */
public class GraphTransformer implements Iterable> {
    protected IGraph sourceGraph;
    protected GraphWalker walker;
    protected LabelsProvider labelsProvider;
    protected AtomicInteger counter = new AtomicInteger(0);
    protected boolean shuffle = true;
    protected VocabCache vocabCache;

    protected static final Logger log = LoggerFactory.getLogger(GraphTransformer.class);

    protected GraphTransformer() {}

    /**
     * This method handles required initialization for GraphTransformer
     */
    protected void initialize() {
        log.info("Building Huffman tree for source graph...");
        int nVertices = sourceGraph.numVertices();
        //int[] degrees = new int[nVertices];
        //for( int i=0; i> iterator() {
        this.counter.set(0);
        this.walker.reset(shuffle);
        return new Iterator>() {
            private GraphWalker walker = GraphTransformer.this.walker;

            @Override
            public void remove() {
                throw new UnsupportedOperationException("This is not supported on read-only iterator");
            }

            @Override
            public boolean hasNext() {
                return walker.hasNext();
            }

            @Override
            public Sequence next() {
                Sequence sequence = walker.next();
                sequence.setSequenceId(counter.getAndIncrement());

                // we might already have labels defined from walker
                if (walker.isLabelEnabled() && sequence.getSequenceLabels() == null)
                    if (labelsProvider != null) {
                        // TODO: sequence labels to be implemented for graph walks
                        sequence.setSequenceLabel(labelsProvider.getLabel(sequence.getSequenceId()));
                    }

                return sequence;
            }
        };
    }

    public static class Builder {
        protected IGraph sourceGraph;
        protected LabelsProvider labelsProvider;
        protected GraphWalker walker;
        protected boolean shuffle = true;
        protected VocabCache vocabCache;

        public Builder() {
            //
        }

        public Builder(@NonNull GraphWalker walker) {
            this.walker = walker;
        }

        public Builder(@NonNull IGraph sourceGraph) {
            this.sourceGraph = sourceGraph;
        }


        public Builder setLabelsProvider(@NonNull LabelsProvider provider) {
            this.labelsProvider = provider;
            return this;
        }

        public Builder setGraphWalker(@NonNull GraphWalker walker) {
            this.walker = walker;
            return this;
        }

        public Builder setVocabCache(@NonNull VocabCache vocabCache) {
            this.vocabCache = vocabCache;
            return this;
        }

        public Builder shuffleOnReset(boolean reallyShuffle) {
            this.shuffle = reallyShuffle;
            return this;
        }

        public GraphTransformer build() {
            if (this.walker == null)
                throw new IllegalStateException("Please provide GraphWalker instance.");

            GraphTransformer transformer = new GraphTransformer<>();
            if (this.sourceGraph == null)
                this.sourceGraph = walker.getSourceGraph();

            transformer.sourceGraph = this.sourceGraph;
            transformer.labelsProvider = this.labelsProvider;
            transformer.shuffle = this.shuffle;
            transformer.vocabCache = this.vocabCache;
            transformer.walker = this.walker;

            // FIXME: get rid of this
            transformer.initialize();

            return transformer;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy