All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bytedeco.pytorch.Module Maven / Gradle / Ivy

The newest version!
// Targeted by JavaCPP version 1.5.11: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;
import org.bytedeco.javacpp.chrono.*;
import static org.bytedeco.javacpp.global.chrono.*;

import static org.bytedeco.pytorch.global.torch.*;


/** The base class for all modules in PyTorch.
 * 
 *  \rst
 *  .. note::
 *    The design and implementation of this class is largely based on the Python
 *    API. You may want to consult the python documentation for
 *    :py:class:{@code pytorch:torch.nn.Module} for further clarification on certain
 *    methods or behavior.
 *  \endrst
 * 
 *  A {@code Module} is an abstraction over the implementation of some function or
 *  algorithm, possibly associated with some persistent data. A {@code Module} may
 *  contain further {@code Module}s ("submodules"), each with their own
 *  implementation, persistent data and further submodules. {@code Module}s can thus
 *  be said to form a recursive tree structure. A {@code Module} is registered as a
 *  submodule to another {@code Module} by calling {@code register_module()}, typically from
 *  within a parent module's constructor.
 * 
 *  A distinction is made between three kinds of persistent data that may be
 *  associated with a {@code Module}:
 * 
 *  1. *Parameters*: tensors that record gradients, typically weights updated
 *     during the backward step (e.g. the {@code weight} of a {@code Linear} module),
 *  2. *Buffers*: tensors that do not record gradients, typically updated during
 *     the forward step, such as running statistics (e.g. {@code mean} and {@code variance}
 *     in the {@code BatchNorm} module),
 *  3. Any additional state, not necessarily tensors, required for the
 *     implementation or configuration of a {@code Module}.
 * 
 *  The first two kinds of state are special in that they may be registered
 *  with the {@code Module} system to allow convenient access and batch configuration.
 *  For example, registered parameters in any {@code Module} may be iterated over via
 *  the {@code parameters()} accessor. Further, changing the data type of a {@code Module}'s
 *  registered parameters can be done conveniently via {@code Module::to()}, e.g.
 *  {@code module->to(torch::kCUDA)} to move all parameters to GPU memory. Lastly,
 *  registered parameters and buffers are handled specially during a {@code clone()}
 *  operation, which performs a deepcopy of a cloneable {@code Module} hierarchy.
 * 
 *  Parameters are registered with a {@code Module} via {@code register_parameter}. Buffers
 *  are registered separately via {@code register_buffer}. These methods are part of
 *  the public API of {@code Module} and are typically invoked from within a
 *  concrete {@code Module}s constructor. */
@Namespace("torch::nn") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class Module extends Pointer {
    static { Loader.load(); }
    /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
    public Module(Pointer p) { super(p); }


  /** Tells the base {@code Module} about the name of the submodule. */
  public Module(@StdString BytePointer name) { super((Pointer)null); allocate(name); }
  @SharedPtr @Name("std::make_shared") private native void allocate(@StdString BytePointer name);
  public Module(@StdString String name) { super((Pointer)null); allocate(name); }
  @SharedPtr @Name("std::make_shared") private native void allocate(@StdString String name);

  /** Constructs the module without immediate knowledge of the submodule's name.
   *  The name of the submodule is inferred via RTTI (if possible) the first
   *  time {@code .name()} is invoked. */
  public Module() { super((Pointer)null); allocate(); }
  @SharedPtr @Name("std::make_shared") private native void allocate();
  public Module(@Const @ByRef Module arg0) { super((Pointer)null); allocate(arg0); }
  @SharedPtr @Name("std::make_shared") private native void allocate(@Const @ByRef Module arg0);
  public native @ByRef @Name("operator =") Module put(@Const @ByRef Module arg0);

  /** Returns the name of the {@code Module}.
   * 
   *  A {@code Module} has an associated {@code name}, which is a string representation of
   *  the kind of concrete {@code Module} it represents, such as {@code "Linear"} for the
   *  {@code Linear} module. Under most circumstances, this name is automatically
   *  inferred via runtime type information (RTTI). In the unusual circumstance
   *  that you have this feature disabled, you may want to manually name your
   *  {@code Module}s by passing the string name to the {@code Module} base class'
   *  constructor. */
  
  ///
  ///
  public native @StdString @NoException(true) BytePointer name();

  /** Performs a recursive deep copy of the module and all its registered
   *  parameters, buffers and submodules.
   * 
   *  Optionally, this method sets the current device
   *  to the one supplied before cloning. If no device is given, each
   *  parameter and buffer will be moved to the device of its source.
   * 
   *  \rst
   *  .. attention::
   *    Attempting to call the {@code clone()} method inherited from the base {@code Module}
   *    class (the one documented here) will fail. To inherit an actual
   *    implementation of {@code clone()}, you must subclass {@code Cloneable}. {@code Cloneable}
   *    is templatized on the concrete module type, and can thus properly copy a
   *    {@code Module}. This method is provided on the base class' API solely for an
   *    easier-to-use polymorphic interface.
   *  \endrst */
  
  ///
  public native @SharedPtr("torch::nn::Module") @ByVal @Virtual(subclasses=false, method="clone") @Cast({"", "std::shared_ptr"}) @Const({false, false, true}) Module clone(
        @Const @ByRef(nullValue = "std::optional(std::nullopt)") DeviceOptional device);

  /** Applies the {@code function} to the {@code Module} and recursively to every submodule.
   *  The function must accept a {@code Module&}.
   * 
   *  \rst
   *  .. code-block:: cpp
   *    MyModule module;
   *    module->apply([](nn::Module& module) {
   *      std::cout << module.name() << std::endl;
   *    });
   *  \endrst */
  
  ///
  public native void apply(@Const @ByRef ModuleApplyFunction function);

  /** Applies the {@code function} to the {@code Module} and recursively to every submodule.
   *  The function must accept a {@code const Module&}.
   * 
   *  \rst
   *  .. code-block:: cpp
   *    MyModule module;
   *    module->apply([](const nn::Module& module) {
   *      std::cout << module.name() << std::endl;
   *    });
   *  \endrst */

  /** Applies the {@code function} to the {@code Module} and recursively to every submodule.
   *  The function must accept a {@code const std::string&} for the key of the module,
   *  and a {@code Module&}. The key of the module itself is the empty string. If
   *  {@code name_prefix} is given, it is prepended to every key as
   *  {@code .} (and just {@code name_prefix} for the module itself).
   * 
   *  \rst
   *  .. code-block:: cpp
   *    MyModule module;
   *    module->apply([](const std::string& key, nn::Module& module) {
   *      std::cout << key << ": " << module.name() << std::endl;
   *    });
   *  \endrst */
  
  ///
  public native void apply(
        @Const @ByRef NamedModuleApplyFunction function,
        @StdString BytePointer name_prefix/*=std::string()*/);
  public native void apply(
        @Const @ByRef NamedModuleApplyFunction function);
  public native void apply(
        @Const @ByRef NamedModuleApplyFunction function,
        @StdString String name_prefix/*=std::string()*/);

  /** Applies the {@code function} to the {@code Module} and recursively to every submodule.
   *  The function must accept a {@code const std::string&} for the key of the module,
   *  and a {@code const Module&}. The key of the module itself is the empty string.
   *  If {@code name_prefix} is given, it is prepended to every key as
   *  {@code .} (and just {@code name_prefix} for the module itself).
   * 
   *  \rst
   *  .. code-block:: cpp
   *    MyModule module;
   *    module->apply([](const std::string& key, const nn::Module& module) {
   *      std::cout << key << ": " << module.name() << std::endl;
   *    });
   *  \endrst */

  /** Applies the {@code function} to the {@code Module} and recursively to every submodule.
   *  The function must accept a {@code const std::shared_ptr&}.
   * 
   *  \rst
   *  .. code-block:: cpp
   *    MyModule module;
   *    module->apply([](const std::shared_ptr& module) {
   *      std::cout << module->name() << std::endl;
   *    });
   *  \endrst */
  
  ///
  public native void apply(@Cast("const torch::nn::Module::ModulePointerApplyFunction*") @ByRef SharedModuleApplyFunction function);

  /** Applies the {@code function} to the {@code Module} and recursively to every submodule.
   *  The function must accept a {@code const std::string&} for the key of the module,
   *  and a {@code const std::shared_ptr&}. The key of the module itself is
   *  the empty string. If {@code name_prefix} is given, it is prepended to every key
   *  as
   *  {@code .} (and just {@code name_prefix} for the module itself).
   * 
   *  \rst
   *  .. code-block:: cpp
   *    MyModule module;
   *    module->apply([](const std::string& key,
   *                     const std::shared_ptr& module) {
   *      std::cout << key << ": " << module->name() << std::endl;
   *    });
   *  \endrst */
  public native void apply(
        @Const @ByRef NamedSharedModuleApplyFunction function,
        @StdString BytePointer name_prefix/*=std::string()*/);
  public native void apply(
        @Const @ByRef NamedSharedModuleApplyFunction function);
  public native void apply(
        @Const @ByRef NamedSharedModuleApplyFunction function,
        @StdString String name_prefix/*=std::string()*/);

  /** Returns the parameters of this {@code Module} and if {@code recurse} is true, also
   *  recursively of every submodule. */
  public native @ByVal TensorVector parameters(@Cast("bool") boolean recurse/*=true*/);
  public native @ByVal TensorVector parameters();

  /** Returns an {@code OrderedDict} with the parameters of this {@code Module} along with
   *  their keys, and if {@code recurse} is true also recursively of every submodule. */
  public native @ByVal StringTensorDict named_parameters(@Cast("bool") boolean recurse/*=true*/);
  public native @ByVal StringTensorDict named_parameters();

  /** Returns the buffers of this {@code Module} and if {@code recurse} is true, also
   *  recursively of every submodule. */
  public native @ByVal TensorVector buffers(@Cast("bool") boolean recurse/*=true*/);
  public native @ByVal TensorVector buffers();

  /** Returns an {@code OrderedDict} with the buffers of this {@code Module} along with
   *  their keys, and if {@code recurse} is true also recursively of every submodule. */
  
  ///
  public native @ByVal StringTensorDict named_buffers(@Cast("bool") boolean recurse/*=true*/);
  public native @ByVal StringTensorDict named_buffers();

  /** Returns the submodules of this {@code Module} (the entire submodule hierarchy)
   *  and if {@code include_self} is true, also inserts a {@code shared_ptr} to this module
   *  in the first position.
   * 
   *  \rst
   *  .. warning::
   *    Only pass {@code include_self} as {@code true} if this {@code Module} is stored in a
   *    {@code shared_ptr}! Otherwise an exception will be thrown. You may still call
   *    this method with {@code include_self} set to false if your {@code Module} is not
   *    stored in a {@code shared_ptr}.
   *  \endrst */
  
  ///
  public native @ByVal SharedModuleVector modules(@Cast("bool") boolean include_self/*=true*/);
  public native @ByVal SharedModuleVector modules();

  /** Returns an {@code OrderedDict} of the submodules of this {@code Module} (the entire
   *  submodule hierarchy) and their keys, and if {@code include_self} is true, also
   *  inserts a {@code shared_ptr} to this module in the first position. If
   *  {@code name_prefix} is given, it is prepended to every key as
   *  {@code .} (and just {@code name_prefix} for the module itself).
   * 
   *  \rst
   *  .. warning::
   *    Only pass {@code include_self} as {@code true} if this {@code Module} is stored in a
   *    {@code shared_ptr}! Otherwise an exception will be thrown. You may still call
   *    this method with {@code include_self} set to false if your {@code Module} is not
   *    stored in a {@code shared_ptr}.
   *  \endrst */
  public native @ByVal StringSharedModuleDict named_modules(
        @StdString BytePointer name_prefix/*=std::string()*/,
        @Cast("bool") boolean include_self/*=true*/);
  public native @ByVal StringSharedModuleDict named_modules();
  public native @ByVal StringSharedModuleDict named_modules(
        @StdString String name_prefix/*=std::string()*/,
        @Cast("bool") boolean include_self/*=true*/);

  /** Returns the direct submodules of this {@code Module}. */
  public native @ByVal SharedModuleVector children();

  /** Returns an {@code OrderedDict} of the direct submodules of this {@code Module} and
   *  their keys. */
  public native @ByVal StringSharedModuleDict named_children();

  /** Enables "training" mode. */
  public native @Virtual(subclasses=false, method="train") void train(@Cast("bool") boolean on/*=true*/);

  /** Calls train(false) to enable "eval" mode.
   *  Do not override this method, override {@code train()} instead. */
  
  ///
  public native void eval();

  /** True if the module is in training mode.
   * 
   *  Every {@code Module} has a boolean associated with it that determines whether
   *  the {@code Module} is currently in *training* mode (set via {@code .train()}) or in
   *  *evaluation* (inference) mode (set via {@code .eval()}). This property is
   *  exposed via {@code is_training()}, and may be used by the implementation of a
   *  concrete module to modify its runtime behavior. See the {@code BatchNorm} or
   *  {@code Dropout} modules for examples of {@code Module}s that use different code paths
   *  depending on this property. */
  
  ///
  public native @Cast("bool") @Virtual(subclasses=false, method="is_training") @NoException(true) @Const({false, false, true}) boolean is_training();

  /** Recursively casts all parameters to the given {@code dtype} and {@code device}.
   * 
   *  If {@code non_blocking} is true and the source is in pinned memory and
   *  destination is on the GPU or vice versa, the copy is performed
   *  asynchronously with respect to the host. Otherwise, the argument has no
   *  effect. */
  
  ///
  public native @Virtual(subclasses=false, method="to") void to(
        @ByVal Device device,
        ScalarType dtype,
        @Cast("bool") boolean non_blocking/*=false*/);

  /** Recursively casts all parameters to the given dtype.
   * 
   *  If {@code non_blocking} is true and the source is in pinned memory and
   *  destination is on the GPU or vice versa, the copy is performed
   *  asynchronously with respect to the host. Otherwise, the argument has no
   *  effect. */
  
  ///
  public native @Virtual(subclasses=false, method="to") void to(ScalarType dtype, @Cast("bool") boolean non_blocking/*=false*/);

  /** Recursively moves all parameters to the given device.
   * 
   *  If {@code non_blocking} is true and the source is in pinned memory and
   *  destination is on the GPU or vice versa, the copy is performed
   *  asynchronously with respect to the host. Otherwise, the argument has no
   *  effect. */
  public native @Virtual(subclasses=false, method="to") void to(@ByVal Device device, @Cast("bool") boolean non_blocking/*=false*/);

  /** Recursively zeros out the {@code grad} value of each registered parameter. */
  
  ///
  ///
  ///
  public native @Virtual(subclasses=false, method="zero_grad") void zero_grad(@Cast("bool") boolean set_to_none/*=true*/);

  /** Attempts to cast this {@code Module} to the given {@code ModuleType}.
   * 
   *  This method is useful when calling {@code apply()}.
   *  \rst
   *  .. code-block:: cpp
   * 
   *    void initialize_weights(nn::Module& module) {
   *      torch::NoGradGuard no_grad;
   *      if (auto* linear = module.as()) {
   *        linear->weight.normal_(0.0, 0.02);
   *      }
   *    }
   * 
   *    MyModule module;
   *    module->apply(initialize_weights);
   *  \endrst */

  /** Attempts to cast this {@code Module} to the given {@code ModuleType}.
   * 
   *  This method is useful when calling {@code apply()}.
   *  \rst
   *  .. code-block:: cpp
   *    void initialize_weights(nn::Module& module) {
   *      torch::NoGradGuard no_grad;
   *      if (auto* linear = module.as()) {
   *        linear->weight.normal_(0.0, 0.02);
   *      }
   *    }
   * 
   *    MyModule module;
   *    module->apply(initialize_weights);
   *  \endrst */

  /** Attempts to cast this {@code Module} to the given {@code ModuleType}.
   * 
   *  This method is useful when calling {@code apply()}.
   *  \rst
   *  .. code-block:: cpp
   * 
   *    void initialize_weights(nn::Module& module) {
   *      torch::NoGradGuard no_grad;
   *      if (auto* linear = module.as()) {
   *        linear->weight.normal_(0.0, 0.02);
   *      }
   *    }
   * 
   *    MyModule module;
   *    module.apply(initialize_weights);
   *  \endrst */
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ModuleDictImpl asModuleDict();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ModuleListImpl asModuleList();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) SequentialImpl asSequential();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ParameterDictImpl asParameterDict();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ParameterListImpl asParameterList();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) AdaptiveLogSoftmaxWithLossImpl asAdaptiveLogSoftmaxWithLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) BatchNorm1dImpl asBatchNorm1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) InstanceNorm1dImpl asInstanceNorm1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) Conv1dImpl asConv1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ConvTranspose1dImpl asConvTranspose1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) DropoutImpl asDropout();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) BatchNorm2dImpl asBatchNorm2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) InstanceNorm2dImpl asInstanceNorm2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) Conv2dImpl asConv2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ConvTranspose2dImpl asConvTranspose2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) Dropout2dImpl asDropout2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) BatchNorm3dImpl asBatchNorm3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) InstanceNorm3dImpl asInstanceNorm3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) Conv3dImpl asConv3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ConvTranspose3dImpl asConvTranspose3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) Dropout3dImpl asDropout3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) AlphaDropoutImpl asAlphaDropout();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) FeatureAlphaDropoutImpl asFeatureAlphaDropout();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) CosineSimilarityImpl asCosineSimilarity();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) PairwiseDistanceImpl asPairwiseDistance();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) EmbeddingImpl asEmbedding();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) EmbeddingBagImpl asEmbeddingBag();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) FoldImpl asFold();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) UnfoldImpl asUnfold();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) IdentityImpl asIdentity();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) LinearImpl asLinear();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) BilinearImpl asBilinear();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) FlattenImpl asFlatten();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) UnflattenImpl asUnflatten();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) L1LossImpl asL1Loss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) KLDivLossImpl asKLDivLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) MSELossImpl asMSELoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) BCELossImpl asBCELoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) HingeEmbeddingLossImpl asHingeEmbeddingLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) MultiMarginLossImpl asMultiMarginLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) CosineEmbeddingLossImpl asCosineEmbeddingLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) SmoothL1LossImpl asSmoothL1Loss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) HuberLossImpl asHuberLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) MultiLabelMarginLossImpl asMultiLabelMarginLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) SoftMarginLossImpl asSoftMarginLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) MultiLabelSoftMarginLossImpl asMultiLabelSoftMarginLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) TripletMarginLossImpl asTripletMarginLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) TripletMarginWithDistanceLossImpl asTripletMarginWithDistanceLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) CTCLossImpl asCTCLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) PoissonNLLLossImpl asPoissonNLLLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) MarginRankingLossImpl asMarginRankingLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) NLLLossImpl asNLLLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) CrossEntropyLossImpl asCrossEntropyLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) BCEWithLogitsLossImpl asBCEWithLogitsLoss();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ReflectionPad1dImpl asReflectionPad1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ReplicationPad1dImpl asReplicationPad1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ConstantPad1dImpl asConstantPad1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ZeroPad1dImpl asZeroPad1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) AvgPool1dImpl asAvgPool1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) MaxPool1dImpl asMaxPool1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) AdaptiveAvgPool1dImpl asAdaptiveAvgPool1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) AdaptiveMaxPool1dImpl asAdaptiveMaxPool1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) MaxUnpool1dImpl asMaxUnpool1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) LPPool1dImpl asLPPool1d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ReflectionPad2dImpl asReflectionPad2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ReplicationPad2dImpl asReplicationPad2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ConstantPad2dImpl asConstantPad2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ZeroPad2dImpl asZeroPad2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) AvgPool2dImpl asAvgPool2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) MaxPool2dImpl asMaxPool2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) AdaptiveAvgPool2dImpl asAdaptiveAvgPool2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) AdaptiveMaxPool2dImpl asAdaptiveMaxPool2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) MaxUnpool2dImpl asMaxUnpool2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) FractionalMaxPool2dImpl asFractionalMaxPool2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) LPPool2dImpl asLPPool2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ReflectionPad3dImpl asReflectionPad3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ReplicationPad3dImpl asReplicationPad3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ConstantPad3dImpl asConstantPad3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ZeroPad3dImpl asZeroPad3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) AvgPool3dImpl asAvgPool3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) MaxPool3dImpl asMaxPool3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) AdaptiveAvgPool3dImpl asAdaptiveAvgPool3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) AdaptiveMaxPool3dImpl asAdaptiveMaxPool3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) MaxUnpool3dImpl asMaxUnpool3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) FractionalMaxPool3dImpl asFractionalMaxPool3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) LPPool3dImpl asLPPool3d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) RNNImpl asRNN();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) LSTMImpl asLSTM();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) GRUImpl asGRU();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) RNNCellImpl asRNNCell();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) LSTMCellImpl asLSTMCell();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) GRUCellImpl asGRUCell();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) PixelShuffleImpl asPixelShuffle();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) PixelUnshuffleImpl asPixelUnshuffle();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) UpsampleImpl asUpsample();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ELUImpl asELU();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) SELUImpl asSELU();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) HardshrinkImpl asHardshrink();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) HardtanhImpl asHardtanh();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) LeakyReLUImpl asLeakyReLU();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) LogSigmoidImpl asLogSigmoid();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) SoftmaxImpl asSoftmax();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) SoftminImpl asSoftmin();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) LogSoftmaxImpl asLogSoftmax();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) Softmax2dImpl asSoftmax2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) PReLUImpl asPReLU();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ReLUImpl asReLU();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ReLU6Impl asReLU6();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) RReLUImpl asRReLU();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) CELUImpl asCELU();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) GLUImpl asGLU();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) GELUImpl asGELU();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) SiLUImpl asSiLU();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) MishImpl asMish();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) SigmoidImpl asSigmoid();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) SoftplusImpl asSoftplus();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) SoftshrinkImpl asSoftshrink();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) SoftsignImpl asSoftsign();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) TanhImpl asTanh();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) TanhshrinkImpl asTanhshrink();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) ThresholdImpl asThreshold();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) MultiheadAttentionImpl asMultiheadAttention();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) LayerNormImpl asLayerNorm();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) LocalResponseNormImpl asLocalResponseNorm();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) CrossMapLRN2dImpl asCrossMapLRN2d();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) GroupNormImpl asGroupNorm();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) TransformerEncoderLayerImpl asTransformerEncoderLayer();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) TransformerDecoderLayerImpl asTransformerDecoderLayer();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) TransformerEncoderImpl asTransformerEncoder();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) TransformerDecoderImpl asTransformerDecoder();
  
  ///
  ///
  ///
  public native @Name("as") @NoException(true) TransformerImpl asTransformer();

  /** Attempts to cast this {@code Module} to the given {@code ModuleType}.
   * 
   *  This method is useful when calling {@code apply()}.
   *  \rst
   *  .. code-block:: cpp
   * 
   *    void initialize_weights(nn::Module& module) {
   *      torch::NoGradGuard no_grad;
   *      if (auto* linear = module.as()) {
   *        linear->weight.normal_(0.0, 0.02);
   *      }
   *    }
   * 
   *    MyModule module;
   *    module.apply(initialize_weights);
   *  \endrst */

  /** Serializes the {@code Module} into the given {@code OutputArchive}.
   * 
   *  If the {@code Module} contains unserializable submodules (e.g.
   *  {@code nn::Functional}), those submodules are skipped when serializing. */
  
  ///
  public native @Virtual(subclasses=false, method="save") @Const({false, false, true}) void save(@ByRef OutputArchive archive);

  /** Deserializes the {@code Module} from the given {@code InputArchive}.
   * 
   *  If the {@code Module} contains unserializable submodules (e.g.
   *  {@code nn::Functional}), we don't check the existence of those submodules in the
   *  {@code InputArchive} when deserializing. */
  
  ///
  public native @Virtual(subclasses=false, method="load") void load(@ByRef InputArchive archive);

  /** Streams a pretty representation of the {@code Module} into the given {@code stream}.
   *  By default, this representation will be the name of the module (taken from
   *  {@code name()}), followed by a recursive pretty print of all of the {@code Module}'s
   *  submodules.
   * 
   *  Override this method to change the pretty print. The input
   *  {@code stream} should be returned from the method, to allow easy chaining. */
  public native @Virtual(subclasses=false, method="pretty_print") @Const({false, false, true}) void pretty_print(@Cast("std::ostream*") @ByRef Pointer stream);

  /** Returns whether the {@code Module} is serializable. */
  
  ///
  ///
  ///
  ///
  public native @Cast("bool") @Virtual(subclasses=false, method="is_serializable") @Const({false, false, true}) boolean is_serializable();

  /** Registers a parameter with this {@code Module}.
   * 
   *  A parameter should be any gradient-recording tensor used in the
   *  implementation of your {@code Module}. Registering it makes it available to
   *  methods such as {@code parameters()}, {@code clone()} or {@code to().}
   * 
   *  Note that registering an undefined Tensor (e.g.
   *  {@code module.register_parameter("param", Tensor())}) is allowed, and is
   *  equivalent to {@code module.register_parameter("param", None)} in Python API.
   * 
   *  \rst
   *  .. code-block:: cpp
   * 
   *    MyModule::MyModule() {
   *      weight_ = register_parameter("weight", torch::randn({A, B}));
   *    }
   *  \endrst */
  
  ///
  ///
  ///
  public native @ByRef Tensor register_parameter(
        @StdString BytePointer name,
        @ByVal Tensor tensor,
        @Cast("bool") boolean requires_grad/*=true*/);
  public native @ByRef Tensor register_parameter(
        @StdString BytePointer name,
        @ByVal Tensor tensor);
  public native @ByRef Tensor register_parameter(
        @StdString String name,
        @ByVal Tensor tensor,
        @Cast("bool") boolean requires_grad/*=true*/);
  public native @ByRef Tensor register_parameter(
        @StdString String name,
        @ByVal Tensor tensor);

  /** Registers a buffer with this {@code Module}.
   * 
   *  A buffer is intended to be state in your module that does not record
   *  gradients, such as running statistics. Registering it makes it available
   *  to methods such as {@code buffers()}, {@code clone()} or {@code to().
   * 
   *  \rst
   *  .. code-block:: cpp
   * 
   *    MyModule::MyModule() {
   *      mean_ = register_buffer("mean", torch::empty({num_features_}));
   *    }
   *  \endrst */
  
  ///
  ///
  ///
  public native @ByRef Tensor register_buffer(@StdString BytePointer name, @ByVal Tensor tensor);
  public native @ByRef Tensor register_buffer(@StdString String name, @ByVal Tensor tensor);

  /** Registers a submodule with this {@code Module}.
   * 
   *  Registering a module makes it available to methods such as {@code modules()},
   *  {@code clone()} or {@code to()}.
   * 
   *  \rst
   *  .. code-block:: cpp
   * 
   *    MyModule::MyModule() {
   *      submodule_ = register_module("linear", torch::nn::Linear(3, 4));
   *    }
   *  \endrst */
  
  ///
  ///
  ///
  ///
  private native @Name("register_module") void _register_module(@StdString BytePointer name, @SharedPtr @ByVal Module module);
  public  M register_module(BytePointer name, M module) { _register_module(name, module); return module; }
  private native @Name("register_module") void _register_module(@StdString String name, @SharedPtr @ByVal Module module);
  public  M register_module(String name, M module) { _register_module(name, module); return module; }

  /** Registers a submodule with this {@code Module}.
   * 
   *  This method deals with {@code ModuleHolder}s.
   * 
   *  Registering a module makes it available to methods such as {@code modules()},
   *  {@code clone()} or {@code to()}.
   * 
   *  \rst
   *  .. code-block:: cpp
   * 
   *    MyModule::MyModule() {
   *      submodule_ = register_module("linear", torch::nn::Linear(3, 4));
   *    }
   *  \endrst */

  /** Replaces a registered submodule with this {@code Module}.
   * 
   *  This takes care of the registration, if you used submodule members, you
   *  should */
  //  assign the submodule as well, i.e. use as
  /**     module->submodule_ = module->replace_module("linear",
  /**     torch::nn::Linear(3, 4));
  /** It only works when a module of the name is already registered.
  /**
  /** This is useful for replacing a module after initialization, e.g.
  /** for finetuning. */

  /** Replaces a registered submodule with this {@code Module}.
   *  This method deals with {@code ModuleHolder}s.
   * 
   *  This takes care of the registration, if you used submodule members, you
   *  should */
  //  assign the submodule as well, i.e. use as
  /**     module->submodule_ = module->replace_module("linear", linear_holder);
  /** It only works when a module of the name is already registered.
  /**
  /** This is useful for replacing a module after initialization, e.g.
  /** for finetuning. */

  /** Unregisters a submodule from this {@code Module}. If there is no such module
   *  with {@code name} an exception is thrown. */
  public native void unregister_module(@StdString BytePointer name);
  public native void unregister_module(@StdString String name);
  private static native @Namespace @Cast("std::ostream*") @ByRef @Name("operator <<") Pointer shiftLeft(
        @Cast("std::ostream*") @ByRef Pointer stream,
        @Const @ByRef Module module);
  public Pointer shiftLeft(Pointer stream) { return shiftLeft(stream, this); }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy