All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.pytorch.engine.PtEngine Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.pytorch.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.SymbolBlock;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.pytorch.jni.LibUtils;
import ai.djl.training.GradientCollector;
import ai.djl.util.Utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.FileNotFoundException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

/**
 * The {@code PtEngine} is an implementation of the {@link Engine} based on the PyTorch Deep Learning Framework.
 *
 * 

To get an instance of the {@code PtEngine} when it is not the default Engine, call {@link * Engine#getEngine(String)} with the Engine name "PyTorch". */ public final class PtEngine extends Engine { private static final Logger logger = LoggerFactory.getLogger(PtEngine.class); public static final String ENGINE_NAME = "PyTorch"; static final int RANK = 2; private PtEngine() {} @SuppressWarnings("PMD.AvoidRethrowingException") static Engine newInstance() { try { LibUtils.loadLibrary(); JniUtils.setGradMode(false); if (Integer.getInteger("ai.djl.pytorch.num_interop_threads") != null) { JniUtils.setNumInteropThreads( Integer.getInteger("ai.djl.pytorch.num_interop_threads")); } if (Integer.getInteger("ai.djl.pytorch.num_threads") != null) { JniUtils.setNumThreads(Integer.getInteger("ai.djl.pytorch.num_threads")); } // for ConvNN related model speed up if (Boolean.getBoolean("ai.djl.pytorch.cudnn_benchmark")) { JniUtils.setBenchmarkCuDNN(true); } if ("true".equals(System.getProperty("ai.djl.pytorch.graph_optimizer", "true"))) { logger.info( "PyTorch graph executor optimizer is enabled, this may impact your" + " inference latency and throughput. See:" + " https://docs.djl.ai/master/docs/development/inference_performance_optimization.html#graph-executor-optimization"); } logger.info("Number of inter-op threads is {}", JniUtils.getNumInteropThreads()); logger.info("Number of intra-op threads is {}", JniUtils.getNumThreads()); String paths = Utils.getEnvOrSystemProperty("PYTORCH_EXTRA_LIBRARY_PATH"); if (paths != null) { String[] files = paths.split(","); for (String file : files) { Path path = Paths.get(file); if (Files.notExists(path)) { throw new FileNotFoundException("PyTorch extra Library not found: " + file); } System.load(path.toAbsolutePath().toString()); // NOPMD } } return new PtEngine(); } catch (EngineException e) { throw e; } catch (Throwable t) { throw new EngineException("Failed to load PyTorch native library", t); } } /** {@inheritDoc} */ @Override public Engine getAlternativeEngine() { return null; } /** {@inheritDoc} */ @Override public String getEngineName() { return ENGINE_NAME; } /** {@inheritDoc} */ @Override public int getRank() { return RANK; } /** {@inheritDoc} */ @Override public String getVersion() { return LibUtils.getVersion(); } /** {@inheritDoc} */ @Override public boolean hasCapability(String capability) { return JniUtils.getFeatures().contains(capability); } /** {@inheritDoc} */ @Override public SymbolBlock newSymbolBlock(NDManager manager) { return new PtSymbolBlock((PtNDManager) manager); } /** {@inheritDoc} */ @Override public Model newModel(String name, Device device) { return new PtModel(name, device); } /** {@inheritDoc} */ @Override public NDManager newBaseManager() { return PtNDManager.getSystemManager().newSubManager(); } /** {@inheritDoc} */ @Override public NDManager newBaseManager(Device device) { return PtNDManager.getSystemManager().newSubManager(device); } /** {@inheritDoc} */ @Override public GradientCollector newGradientCollector() { return new PtGradientCollector(); } /** {@inheritDoc} */ @Override public void setRandomSeed(int seed) { super.setRandomSeed(seed); JniUtils.setSeed(seed); } /** {@inheritDoc} */ @Override public String toString() { StringBuilder sb = new StringBuilder(200); sb.append(getEngineName()).append(':').append(getVersion()).append(", capabilities: [\n"); for (String feature : JniUtils.getFeatures()) { sb.append("\t").append(feature).append(",\n"); // NOPMD } sb.append("]\nPyTorch Library: ").append(LibUtils.getLibtorchPath()); return sb.toString(); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy