All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.datavec.image.transform.PipelineImageTransform Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.datavec.image.transform;

import lombok.Data;
import lombok.NonNull;

import org.datavec.image.data.ImageWritable;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;

import java.util.*;

import org.bytedeco.opencv.opencv_core.*;

@Data
public class PipelineImageTransform extends BaseImageTransform {

    protected List> imageTransforms;
    protected boolean shuffle;
    protected org.nd4j.linalg.api.rng.Random rng;

    protected List currentTransforms = new ArrayList<>();

    public PipelineImageTransform(ImageTransform... transforms) {
        this(1234, false, transforms);
    }

    public PipelineImageTransform(long seed, boolean shuffle, ImageTransform... transforms) {
        super(null); // for perf reasons we ignore java Random, create our own

        List> pipeline = new LinkedList<>();
        for (int i = 0; i < transforms.length; i++) {
            pipeline.add(new Pair<>(transforms[i], 1.0));
        }

        this.imageTransforms = pipeline;
        this.shuffle = shuffle;
        this.rng = Nd4j.getRandom();
        rng.setSeed(seed);
    }

    public PipelineImageTransform(List> transforms) {
        this(1234, transforms, false);
    }

    public PipelineImageTransform(List> transforms, boolean shuffle) {
        this(1234, transforms, shuffle);
    }

    public PipelineImageTransform(long seed, List> transforms) {
        this(seed, transforms, false);
    }

    public PipelineImageTransform(long seed, List> transforms, boolean shuffle) {
        this(null, seed, transforms, shuffle);
    }

    public PipelineImageTransform(Random random, long seed, List> transforms,
            boolean shuffle) {
        super(random); // used by the transforms in the pipeline
        this.imageTransforms = transforms;
        this.shuffle = shuffle;
        this.rng = Nd4j.getRandom();
        rng.setSeed(seed);
    }

    /**
     * Takes an image and executes a pipeline of combined transforms.
     *
     * @param image to transform, null == end of stream
     * @param random object to use (or null for deterministic)
     * @return transformed image
     */
    @Override
    protected ImageWritable doTransform(ImageWritable image, Random random) {
        if (shuffle) {
            Collections.shuffle(imageTransforms);
        }

        currentTransforms.clear();

        // execute each item in the pipeline
        for (Pair tuple : imageTransforms) {
            if (tuple.getSecond() == 1.0 || rng.nextDouble() < tuple.getSecond()) { // probability of execution
                currentTransforms.add(tuple.getFirst());
                image = random != null ? tuple.getFirst().transform(image, random)
                        : tuple.getFirst().transform(image);
            }
        }

        return image;
    }

    @Override
    public float[] query(float... coordinates) {
        for (ImageTransform transform : currentTransforms) {
            coordinates = transform.query(coordinates);
        }
        return coordinates;
    }

    /**
     * Optional builder helper for PipelineImageTransform
     */
    public static class Builder {

        protected List> imageTransforms = new ArrayList<>();
        protected Long seed = null;

        /**
         * This method sets RNG seet for this pipeline
         *
         * @param seed
         * @return
         */
        public Builder setSeed(long seed) {
            this.seed = seed;
            return this;
        }

        /**
         * This method adds given transform with 100% invocation probability to this pipelien
         *
         * @param transform
         * @return
         */
        public Builder addImageTransform(@NonNull ImageTransform transform) {
            return addImageTransform(transform, 1.0);
        }

        /**
         * This method adds given transform with given invocation probability to this pipelien
         *
         * @param transform
         * @param probability
         * @return
         */
        public Builder addImageTransform(@NonNull ImageTransform transform, Double probability) {
            if (probability < 0.0) {
                probability = 0.0;
            }
            if (probability > 1.0) {
                probability = 1.0;
            }

            imageTransforms.add(Pair.makePair(transform, probability));
            return this;
        }

        /**
         * This method returns new PipelineImageTransform instance
         *
         * @return
         */
        public PipelineImageTransform build() {
            if (seed != null) {
                return new PipelineImageTransform(seed, imageTransforms);
            } else {
                return new PipelineImageTransform(imageTransforms);
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy