All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.translate.Pipeline Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.translate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.util.Pair;
import ai.djl.util.PairList;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/** {@code Pipeline} allows applying multiple transforms on an input {@link NDList}. */
public class Pipeline {

    private PairList transforms;

    /** Creates a new instance of {@code Pipeline} that has no {@link Transform} defined yet. */
    public Pipeline() {
        transforms = new PairList<>();
    }

    /**
     * Creates a new instance of {@code Pipeline} that can apply the given transforms on its input.
     *
     * 

Since no keys are provided for these transforms, they will be applied to the first element * in the input {@link NDList} when the {@link #transform(NDList) transform} method is called on * this object. * * @param transforms the transforms to be applied when the {@link #transform(NDList) transform} * method is called on this object */ public Pipeline(Transform... transforms) { this.transforms = new PairList<>(); for (Transform transform : transforms) { this.transforms.add(new IndexKey(0), transform); } } /** * Adds the given {@link Transform} to the list of transforms to be applied on the input when * the {@link #transform(NDList) transform} method is called on this object. * *

Since no keys are provided for this {@link Transform}, it will be applied to the first * element in the input {@link NDList}. * * @param transform the {@link Transform} to be added * @return this {@code Pipeline} */ public Pipeline add(Transform transform) { transforms.add(new IndexKey(0), transform); return this; } /** * Adds the given {@link Transform} to the list of transforms to be applied on the {@link * NDArray} at the given index in the input {@link NDList}. * * @param index the index corresponding to the {@link NDArray} in the input {@link NDList} on * which the given transform must be applied to * @param transform the {@link Transform} to be added * @return this {@code Pipeline} */ public Pipeline add(int index, Transform transform) { transforms.add(new IndexKey(index), transform); return this; } /** * Adds the given {@link Transform} to the list of transforms to be applied on the {@link * NDArray} with the given key as name in the input {@link NDList}. * * @param name the key corresponding to the {@link NDArray} in the input {@link NDList} on which * the given transform must be applied to * @param transform the {@code Transform} to be applied when the {@link #transform(NDList) * transform} method is called on this object * @return this {@code Pipeline} */ public Pipeline add(String name, Transform transform) { transforms.add(new IndexKey(name), transform); return this; } /** * Inserts the given {@link Transform} to the list of transforms at the given position. * *

Since no keys or indices are provided for this {@link Transform}, it will be applied to * the first element in the input {@link NDList} when the {@link #transform(NDList) transform} * method is called on this object. * * @param position the position at which the {@link Transform} must be inserted * @param transform the {@code Transform} to be inserted * @return this {@code Pipeline} */ public Pipeline insert(int position, Transform transform) { transforms.add(position, new IndexKey(0), transform); return this; } /** * Inserts the given {@link Transform} to the list of transforms at the given position to be * applied on the {@link NDArray} at the given index in the input {@link NDList}. * * @param position the position at which the {@link Transform} must be inserted * @param index the index corresponding to the {@link NDArray} in the input {@link NDList} on * which the given transform must be applied to * @param transform the {@code Transform} to be inserted * @return this {@code Pipeline} */ public Pipeline insert(int position, int index, Transform transform) { transforms.add(position, new IndexKey(index), transform); return this; } /** * Inserts the given {@link Transform} to the list of transforms at the given position to be * applied on the {@link NDArray} with the given name in the input {@link NDList}. * * @param position the position at which the {@link Transform} must be inserted * @param name the key corresponding to the {@link NDArray} in the input {@link NDList} on which * the given transform must be applied to * @param transform the {@code Transform} to be inserted * @return this {@code Pipeline} */ public Pipeline insert(int position, String name, Transform transform) { transforms.add(position, new IndexKey(name), transform); return this; } /** * Applies the transforms configured in this object on the input {@link NDList}. * *

If a key is specified with the transform, those transforms will only be applied to the * {@link NDArray} in the input {@link NDList}. If a key is not specified, it will be applied to * the first element in the input {@link NDList}. * * @param input the input {@link NDList} on which the tranforms are to be applied * @return the output {@link NDList} after applying the tranforms */ public NDList transform(NDList input) { if (transforms.isEmpty() || input.isEmpty()) { return input; } NDArray[] arrays = input.toArray(new NDArray[0]); Map map = new ConcurrentHashMap<>(); // create mapping for (int i = 0; i < input.size(); i++) { String key = input.get(i).getName(); if (key != null) { map.put(new IndexKey(key), i); } map.put(new IndexKey(i), i); } // apply transform for (Pair transform : transforms) { IndexKey key = transform.getKey(); int index = map.get(key); NDArray array = arrays[index]; arrays[index] = transform.getValue().transform(array); arrays[index].setName(array.getName()); } return new NDList(arrays); } private static final class IndexKey { private String key; private int index; private IndexKey(String key) { this.key = key; } private IndexKey(int index) { this.index = index; } /** {@inheritDoc} */ @Override public int hashCode() { if (key == null) { return index; } return key.hashCode(); } /** {@inheritDoc} */ @Override public boolean equals(Object obj) { if (this == obj) { return true; } if (!(obj instanceof IndexKey)) { return false; } IndexKey other = (IndexKey) obj; if (key == null) { return index == other.index; } return key.equals(other.key); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy