![JAR search and dependency download from the Maven repository](/logo.png)
org.bytedeco.pytorch.Module Maven / Gradle / Ivy
The newest version!
// Targeted by JavaCPP version 1.5.11: DO NOT EDIT THIS FILE
package org.bytedeco.pytorch;
import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import org.bytedeco.javacpp.annotation.Cast;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;
import org.bytedeco.javacpp.chrono.*;
import static org.bytedeco.javacpp.global.chrono.*;
import static org.bytedeco.pytorch.global.torch.*;
/** The base class for all modules in PyTorch.
*
* \rst
* .. note::
* The design and implementation of this class is largely based on the Python
* API. You may want to consult the python documentation for
* :py:class:{@code pytorch:torch.nn.Module} for further clarification on certain
* methods or behavior.
* \endrst
*
* A {@code Module} is an abstraction over the implementation of some function or
* algorithm, possibly associated with some persistent data. A {@code Module} may
* contain further {@code Module}s ("submodules"), each with their own
* implementation, persistent data and further submodules. {@code Module}s can thus
* be said to form a recursive tree structure. A {@code Module} is registered as a
* submodule to another {@code Module} by calling {@code register_module()}, typically from
* within a parent module's constructor.
*
* A distinction is made between three kinds of persistent data that may be
* associated with a {@code Module}:
*
* 1. *Parameters*: tensors that record gradients, typically weights updated
* during the backward step (e.g. the {@code weight} of a {@code Linear} module),
* 2. *Buffers*: tensors that do not record gradients, typically updated during
* the forward step, such as running statistics (e.g. {@code mean} and {@code variance}
* in the {@code BatchNorm} module),
* 3. Any additional state, not necessarily tensors, required for the
* implementation or configuration of a {@code Module}.
*
* The first two kinds of state are special in that they may be registered
* with the {@code Module} system to allow convenient access and batch configuration.
* For example, registered parameters in any {@code Module} may be iterated over via
* the {@code parameters()} accessor. Further, changing the data type of a {@code Module}'s
* registered parameters can be done conveniently via {@code Module::to()}, e.g.
* {@code module->to(torch::kCUDA)} to move all parameters to GPU memory. Lastly,
* registered parameters and buffers are handled specially during a {@code clone()}
* operation, which performs a deepcopy of a cloneable {@code Module} hierarchy.
*
* Parameters are registered with a {@code Module} via {@code register_parameter}. Buffers
* are registered separately via {@code register_buffer}. These methods are part of
* the public API of {@code Module} and are typically invoked from within a
* concrete {@code Module}s constructor. */
@Namespace("torch::nn") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class Module extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public Module(Pointer p) { super(p); }
/** Tells the base {@code Module} about the name of the submodule. */
public Module(@StdString BytePointer name) { super((Pointer)null); allocate(name); }
@SharedPtr @Name("std::make_shared") private native void allocate(@StdString BytePointer name);
public Module(@StdString String name) { super((Pointer)null); allocate(name); }
@SharedPtr @Name("std::make_shared") private native void allocate(@StdString String name);
/** Constructs the module without immediate knowledge of the submodule's name.
* The name of the submodule is inferred via RTTI (if possible) the first
* time {@code .name()} is invoked. */
public Module() { super((Pointer)null); allocate(); }
@SharedPtr @Name("std::make_shared") private native void allocate();
public Module(@Const @ByRef Module arg0) { super((Pointer)null); allocate(arg0); }
@SharedPtr @Name("std::make_shared") private native void allocate(@Const @ByRef Module arg0);
public native @ByRef @Name("operator =") Module put(@Const @ByRef Module arg0);
/** Returns the name of the {@code Module}.
*
* A {@code Module} has an associated {@code name}, which is a string representation of
* the kind of concrete {@code Module} it represents, such as {@code "Linear"} for the
* {@code Linear} module. Under most circumstances, this name is automatically
* inferred via runtime type information (RTTI). In the unusual circumstance
* that you have this feature disabled, you may want to manually name your
* {@code Module}s by passing the string name to the {@code Module} base class'
* constructor. */
///
///
public native @StdString @NoException(true) BytePointer name();
/** Performs a recursive deep copy of the module and all its registered
* parameters, buffers and submodules.
*
* Optionally, this method sets the current device
* to the one supplied before cloning. If no device is given, each
* parameter and buffer will be moved to the device of its source.
*
* \rst
* .. attention::
* Attempting to call the {@code clone()} method inherited from the base {@code Module}
* class (the one documented here) will fail. To inherit an actual
* implementation of {@code clone()}, you must subclass {@code Cloneable}. {@code Cloneable}
* is templatized on the concrete module type, and can thus properly copy a
* {@code Module}. This method is provided on the base class' API solely for an
* easier-to-use polymorphic interface.
* \endrst */
///
public native @SharedPtr("torch::nn::Module") @ByVal @Virtual(subclasses=false, method="clone") @Cast({"", "std::shared_ptr"}) @Const({false, false, true}) Module clone(
@Const @ByRef(nullValue = "std::optional(std::nullopt)") DeviceOptional device);
/** Applies the {@code function} to the {@code Module} and recursively to every submodule.
* The function must accept a {@code Module&}.
*
* \rst
* .. code-block:: cpp
* MyModule module;
* module->apply([](nn::Module& module) {
* std::cout << module.name() << std::endl;
* });
* \endrst */
///
public native void apply(@Const @ByRef ModuleApplyFunction function);
/** Applies the {@code function} to the {@code Module} and recursively to every submodule.
* The function must accept a {@code const Module&}.
*
* \rst
* .. code-block:: cpp
* MyModule module;
* module->apply([](const nn::Module& module) {
* std::cout << module.name() << std::endl;
* });
* \endrst */
/** Applies the {@code function} to the {@code Module} and recursively to every submodule.
* The function must accept a {@code const std::string&} for the key of the module,
* and a {@code Module&}. The key of the module itself is the empty string. If
* {@code name_prefix} is given, it is prepended to every key as
* {@code .} (and just {@code name_prefix} for the module itself).
*
* \rst
* .. code-block:: cpp
* MyModule module;
* module->apply([](const std::string& key, nn::Module& module) {
* std::cout << key << ": " << module.name() << std::endl;
* });
* \endrst */
///
public native void apply(
@Const @ByRef NamedModuleApplyFunction function,
@StdString BytePointer name_prefix/*=std::string()*/);
public native void apply(
@Const @ByRef NamedModuleApplyFunction function);
public native void apply(
@Const @ByRef NamedModuleApplyFunction function,
@StdString String name_prefix/*=std::string()*/);
/** Applies the {@code function} to the {@code Module} and recursively to every submodule.
* The function must accept a {@code const std::string&} for the key of the module,
* and a {@code const Module&}. The key of the module itself is the empty string.
* If {@code name_prefix} is given, it is prepended to every key as
* {@code .} (and just {@code name_prefix} for the module itself).
*
* \rst
* .. code-block:: cpp
* MyModule module;
* module->apply([](const std::string& key, const nn::Module& module) {
* std::cout << key << ": " << module.name() << std::endl;
* });
* \endrst */
/** Applies the {@code function} to the {@code Module} and recursively to every submodule.
* The function must accept a {@code const std::shared_ptr&}.
*
* \rst
* .. code-block:: cpp
* MyModule module;
* module->apply([](const std::shared_ptr& module) {
* std::cout << module->name() << std::endl;
* });
* \endrst */
///
public native void apply(@Cast("const torch::nn::Module::ModulePointerApplyFunction*") @ByRef SharedModuleApplyFunction function);
/** Applies the {@code function} to the {@code Module} and recursively to every submodule.
* The function must accept a {@code const std::string&} for the key of the module,
* and a {@code const std::shared_ptr&}. The key of the module itself is
* the empty string. If {@code name_prefix} is given, it is prepended to every key
* as
* {@code .} (and just {@code name_prefix} for the module itself).
*
* \rst
* .. code-block:: cpp
* MyModule module;
* module->apply([](const std::string& key,
* const std::shared_ptr& module) {
* std::cout << key << ": " << module->name() << std::endl;
* });
* \endrst */
public native void apply(
@Const @ByRef NamedSharedModuleApplyFunction function,
@StdString BytePointer name_prefix/*=std::string()*/);
public native void apply(
@Const @ByRef NamedSharedModuleApplyFunction function);
public native void apply(
@Const @ByRef NamedSharedModuleApplyFunction function,
@StdString String name_prefix/*=std::string()*/);
/** Returns the parameters of this {@code Module} and if {@code recurse} is true, also
* recursively of every submodule. */
public native @ByVal TensorVector parameters(@Cast("bool") boolean recurse/*=true*/);
public native @ByVal TensorVector parameters();
/** Returns an {@code OrderedDict} with the parameters of this {@code Module} along with
* their keys, and if {@code recurse} is true also recursively of every submodule. */
public native @ByVal StringTensorDict named_parameters(@Cast("bool") boolean recurse/*=true*/);
public native @ByVal StringTensorDict named_parameters();
/** Returns the buffers of this {@code Module} and if {@code recurse} is true, also
* recursively of every submodule. */
public native @ByVal TensorVector buffers(@Cast("bool") boolean recurse/*=true*/);
public native @ByVal TensorVector buffers();
/** Returns an {@code OrderedDict} with the buffers of this {@code Module} along with
* their keys, and if {@code recurse} is true also recursively of every submodule. */
///
public native @ByVal StringTensorDict named_buffers(@Cast("bool") boolean recurse/*=true*/);
public native @ByVal StringTensorDict named_buffers();
/** Returns the submodules of this {@code Module} (the entire submodule hierarchy)
* and if {@code include_self} is true, also inserts a {@code shared_ptr} to this module
* in the first position.
*
* \rst
* .. warning::
* Only pass {@code include_self} as {@code true} if this {@code Module} is stored in a
* {@code shared_ptr}! Otherwise an exception will be thrown. You may still call
* this method with {@code include_self} set to false if your {@code Module} is not
* stored in a {@code shared_ptr}.
* \endrst */
///
public native @ByVal SharedModuleVector modules(@Cast("bool") boolean include_self/*=true*/);
public native @ByVal SharedModuleVector modules();
/** Returns an {@code OrderedDict} of the submodules of this {@code Module} (the entire
* submodule hierarchy) and their keys, and if {@code include_self} is true, also
* inserts a {@code shared_ptr} to this module in the first position. If
* {@code name_prefix} is given, it is prepended to every key as
* {@code .} (and just {@code name_prefix} for the module itself).
*
* \rst
* .. warning::
* Only pass {@code include_self} as {@code true} if this {@code Module} is stored in a
* {@code shared_ptr}! Otherwise an exception will be thrown. You may still call
* this method with {@code include_self} set to false if your {@code Module} is not
* stored in a {@code shared_ptr}.
* \endrst */
public native @ByVal StringSharedModuleDict named_modules(
@StdString BytePointer name_prefix/*=std::string()*/,
@Cast("bool") boolean include_self/*=true*/);
public native @ByVal StringSharedModuleDict named_modules();
public native @ByVal StringSharedModuleDict named_modules(
@StdString String name_prefix/*=std::string()*/,
@Cast("bool") boolean include_self/*=true*/);
/** Returns the direct submodules of this {@code Module}. */
public native @ByVal SharedModuleVector children();
/** Returns an {@code OrderedDict} of the direct submodules of this {@code Module} and
* their keys. */
public native @ByVal StringSharedModuleDict named_children();
/** Enables "training" mode. */
public native @Virtual(subclasses=false, method="train") void train(@Cast("bool") boolean on/*=true*/);
/** Calls train(false) to enable "eval" mode.
* Do not override this method, override {@code train()} instead. */
///
public native void eval();
/** True if the module is in training mode.
*
* Every {@code Module} has a boolean associated with it that determines whether
* the {@code Module} is currently in *training* mode (set via {@code .train()}) or in
* *evaluation* (inference) mode (set via {@code .eval()}). This property is
* exposed via {@code is_training()}, and may be used by the implementation of a
* concrete module to modify its runtime behavior. See the {@code BatchNorm} or
* {@code Dropout} modules for examples of {@code Module}s that use different code paths
* depending on this property. */
///
public native @Cast("bool") @Virtual(subclasses=false, method="is_training") @NoException(true) @Const({false, false, true}) boolean is_training();
/** Recursively casts all parameters to the given {@code dtype} and {@code device}.
*
* If {@code non_blocking} is true and the source is in pinned memory and
* destination is on the GPU or vice versa, the copy is performed
* asynchronously with respect to the host. Otherwise, the argument has no
* effect. */
///
public native @Virtual(subclasses=false, method="to") void to(
@ByVal Device device,
ScalarType dtype,
@Cast("bool") boolean non_blocking/*=false*/);
/** Recursively casts all parameters to the given dtype.
*
* If {@code non_blocking} is true and the source is in pinned memory and
* destination is on the GPU or vice versa, the copy is performed
* asynchronously with respect to the host. Otherwise, the argument has no
* effect. */
///
public native @Virtual(subclasses=false, method="to") void to(ScalarType dtype, @Cast("bool") boolean non_blocking/*=false*/);
/** Recursively moves all parameters to the given device.
*
* If {@code non_blocking} is true and the source is in pinned memory and
* destination is on the GPU or vice versa, the copy is performed
* asynchronously with respect to the host. Otherwise, the argument has no
* effect. */
public native @Virtual(subclasses=false, method="to") void to(@ByVal Device device, @Cast("bool") boolean non_blocking/*=false*/);
/** Recursively zeros out the {@code grad} value of each registered parameter. */
///
///
///
public native @Virtual(subclasses=false, method="zero_grad") void zero_grad(@Cast("bool") boolean set_to_none/*=true*/);
/** Attempts to cast this {@code Module} to the given {@code ModuleType}.
*
* This method is useful when calling {@code apply()}.
* \rst
* .. code-block:: cpp
*
* void initialize_weights(nn::Module& module) {
* torch::NoGradGuard no_grad;
* if (auto* linear = module.as()) {
* linear->weight.normal_(0.0, 0.02);
* }
* }
*
* MyModule module;
* module->apply(initialize_weights);
* \endrst */
/** Attempts to cast this {@code Module} to the given {@code ModuleType}.
*
* This method is useful when calling {@code apply()}.
* \rst
* .. code-block:: cpp
* void initialize_weights(nn::Module& module) {
* torch::NoGradGuard no_grad;
* if (auto* linear = module.as()) {
* linear->weight.normal_(0.0, 0.02);
* }
* }
*
* MyModule module;
* module->apply(initialize_weights);
* \endrst */
/** Attempts to cast this {@code Module} to the given {@code ModuleType}.
*
* This method is useful when calling {@code apply()}.
* \rst
* .. code-block:: cpp
*
* void initialize_weights(nn::Module& module) {
* torch::NoGradGuard no_grad;
* if (auto* linear = module.as()) {
* linear->weight.normal_(0.0, 0.02);
* }
* }
*
* MyModule module;
* module.apply(initialize_weights);
* \endrst */
///
///
///
public native @Name("as") @NoException(true) ModuleDictImpl asModuleDict();
///
///
///
public native @Name("as") @NoException(true) ModuleListImpl asModuleList();
///
///
///
public native @Name("as") @NoException(true) SequentialImpl asSequential();
///
///
///
public native @Name("as") @NoException(true) ParameterDictImpl asParameterDict();
///
///
///
public native @Name("as") @NoException(true) ParameterListImpl asParameterList();
///
///
///
public native @Name("as") @NoException(true) AdaptiveLogSoftmaxWithLossImpl asAdaptiveLogSoftmaxWithLoss();
///
///
///
public native @Name("as") @NoException(true) BatchNorm1dImpl asBatchNorm1d();
///
///
///
public native @Name("as") @NoException(true) InstanceNorm1dImpl asInstanceNorm1d();
///
///
///
public native @Name("as") @NoException(true) Conv1dImpl asConv1d();
///
///
///
public native @Name("as") @NoException(true) ConvTranspose1dImpl asConvTranspose1d();
///
///
///
public native @Name("as") @NoException(true) DropoutImpl asDropout();
///
///
///
public native @Name("as") @NoException(true) BatchNorm2dImpl asBatchNorm2d();
///
///
///
public native @Name("as") @NoException(true) InstanceNorm2dImpl asInstanceNorm2d();
///
///
///
public native @Name("as") @NoException(true) Conv2dImpl asConv2d();
///
///
///
public native @Name("as") @NoException(true) ConvTranspose2dImpl asConvTranspose2d();
///
///
///
public native @Name("as") @NoException(true) Dropout2dImpl asDropout2d();
///
///
///
public native @Name("as") @NoException(true) BatchNorm3dImpl asBatchNorm3d();
///
///
///
public native @Name("as") @NoException(true) InstanceNorm3dImpl asInstanceNorm3d();
///
///
///
public native @Name("as") @NoException(true) Conv3dImpl asConv3d();
///
///
///
public native @Name("as") @NoException(true) ConvTranspose3dImpl asConvTranspose3d();
///
///
///
public native @Name("as") @NoException(true) Dropout3dImpl asDropout3d();
///
///
///
public native @Name("as") @NoException(true) AlphaDropoutImpl asAlphaDropout();
///
///
///
public native @Name("as") @NoException(true) FeatureAlphaDropoutImpl asFeatureAlphaDropout();
///
///
///
public native @Name("as") @NoException(true) CosineSimilarityImpl asCosineSimilarity();
///
///
///
public native @Name("as") @NoException(true) PairwiseDistanceImpl asPairwiseDistance();
///
///
///
public native @Name("as") @NoException(true) EmbeddingImpl asEmbedding();
///
///
///
public native @Name("as") @NoException(true) EmbeddingBagImpl asEmbeddingBag();
///
///
///
public native @Name("as") @NoException(true) FoldImpl asFold();
///
///
///
public native @Name("as") @NoException(true) UnfoldImpl asUnfold();
///
///
///
public native @Name("as") @NoException(true) IdentityImpl asIdentity();
///
///
///
public native @Name("as") @NoException(true) LinearImpl asLinear();
///
///
///
public native @Name("as") @NoException(true) BilinearImpl asBilinear();
///
///
///
public native @Name("as") @NoException(true) FlattenImpl asFlatten();
///
///
///
public native @Name("as") @NoException(true) UnflattenImpl asUnflatten();
///
///
///
public native @Name("as") @NoException(true) L1LossImpl asL1Loss();
///
///
///
public native @Name("as") @NoException(true) KLDivLossImpl asKLDivLoss();
///
///
///
public native @Name("as") @NoException(true) MSELossImpl asMSELoss();
///
///
///
public native @Name("as") @NoException(true) BCELossImpl asBCELoss();
///
///
///
public native @Name("as") @NoException(true) HingeEmbeddingLossImpl asHingeEmbeddingLoss();
///
///
///
public native @Name("as") @NoException(true) MultiMarginLossImpl asMultiMarginLoss();
///
///
///
public native @Name("as") @NoException(true) CosineEmbeddingLossImpl asCosineEmbeddingLoss();
///
///
///
public native @Name("as") @NoException(true) SmoothL1LossImpl asSmoothL1Loss();
///
///
///
public native @Name("as") @NoException(true) HuberLossImpl asHuberLoss();
///
///
///
public native @Name("as") @NoException(true) MultiLabelMarginLossImpl asMultiLabelMarginLoss();
///
///
///
public native @Name("as") @NoException(true) SoftMarginLossImpl asSoftMarginLoss();
///
///
///
public native @Name("as") @NoException(true) MultiLabelSoftMarginLossImpl asMultiLabelSoftMarginLoss();
///
///
///
public native @Name("as") @NoException(true) TripletMarginLossImpl asTripletMarginLoss();
///
///
///
public native @Name("as") @NoException(true) TripletMarginWithDistanceLossImpl asTripletMarginWithDistanceLoss();
///
///
///
public native @Name("as") @NoException(true) CTCLossImpl asCTCLoss();
///
///
///
public native @Name("as") @NoException(true) PoissonNLLLossImpl asPoissonNLLLoss();
///
///
///
public native @Name("as") @NoException(true) MarginRankingLossImpl asMarginRankingLoss();
///
///
///
public native @Name("as") @NoException(true) NLLLossImpl asNLLLoss();
///
///
///
public native @Name("as") @NoException(true) CrossEntropyLossImpl asCrossEntropyLoss();
///
///
///
public native @Name("as") @NoException(true) BCEWithLogitsLossImpl asBCEWithLogitsLoss();
///
///
///
public native @Name("as") @NoException(true) ReflectionPad1dImpl asReflectionPad1d();
///
///
///
public native @Name("as") @NoException(true) ReplicationPad1dImpl asReplicationPad1d();
///
///
///
public native @Name("as") @NoException(true) ConstantPad1dImpl asConstantPad1d();
///
///
///
public native @Name("as") @NoException(true) ZeroPad1dImpl asZeroPad1d();
///
///
///
public native @Name("as") @NoException(true) AvgPool1dImpl asAvgPool1d();
///
///
///
public native @Name("as") @NoException(true) MaxPool1dImpl asMaxPool1d();
///
///
///
public native @Name("as") @NoException(true) AdaptiveAvgPool1dImpl asAdaptiveAvgPool1d();
///
///
///
public native @Name("as") @NoException(true) AdaptiveMaxPool1dImpl asAdaptiveMaxPool1d();
///
///
///
public native @Name("as") @NoException(true) MaxUnpool1dImpl asMaxUnpool1d();
///
///
///
public native @Name("as") @NoException(true) LPPool1dImpl asLPPool1d();
///
///
///
public native @Name("as") @NoException(true) ReflectionPad2dImpl asReflectionPad2d();
///
///
///
public native @Name("as") @NoException(true) ReplicationPad2dImpl asReplicationPad2d();
///
///
///
public native @Name("as") @NoException(true) ConstantPad2dImpl asConstantPad2d();
///
///
///
public native @Name("as") @NoException(true) ZeroPad2dImpl asZeroPad2d();
///
///
///
public native @Name("as") @NoException(true) AvgPool2dImpl asAvgPool2d();
///
///
///
public native @Name("as") @NoException(true) MaxPool2dImpl asMaxPool2d();
///
///
///
public native @Name("as") @NoException(true) AdaptiveAvgPool2dImpl asAdaptiveAvgPool2d();
///
///
///
public native @Name("as") @NoException(true) AdaptiveMaxPool2dImpl asAdaptiveMaxPool2d();
///
///
///
public native @Name("as") @NoException(true) MaxUnpool2dImpl asMaxUnpool2d();
///
///
///
public native @Name("as") @NoException(true) FractionalMaxPool2dImpl asFractionalMaxPool2d();
///
///
///
public native @Name("as") @NoException(true) LPPool2dImpl asLPPool2d();
///
///
///
public native @Name("as") @NoException(true) ReflectionPad3dImpl asReflectionPad3d();
///
///
///
public native @Name("as") @NoException(true) ReplicationPad3dImpl asReplicationPad3d();
///
///
///
public native @Name("as") @NoException(true) ConstantPad3dImpl asConstantPad3d();
///
///
///
public native @Name("as") @NoException(true) ZeroPad3dImpl asZeroPad3d();
///
///
///
public native @Name("as") @NoException(true) AvgPool3dImpl asAvgPool3d();
///
///
///
public native @Name("as") @NoException(true) MaxPool3dImpl asMaxPool3d();
///
///
///
public native @Name("as") @NoException(true) AdaptiveAvgPool3dImpl asAdaptiveAvgPool3d();
///
///
///
public native @Name("as") @NoException(true) AdaptiveMaxPool3dImpl asAdaptiveMaxPool3d();
///
///
///
public native @Name("as") @NoException(true) MaxUnpool3dImpl asMaxUnpool3d();
///
///
///
public native @Name("as") @NoException(true) FractionalMaxPool3dImpl asFractionalMaxPool3d();
///
///
///
public native @Name("as") @NoException(true) LPPool3dImpl asLPPool3d();
///
///
///
public native @Name("as") @NoException(true) RNNImpl asRNN();
///
///
///
public native @Name("as") @NoException(true) LSTMImpl asLSTM();
///
///
///
public native @Name("as") @NoException(true) GRUImpl asGRU();
///
///
///
public native @Name("as") @NoException(true) RNNCellImpl asRNNCell();
///
///
///
public native @Name("as") @NoException(true) LSTMCellImpl asLSTMCell();
///
///
///
public native @Name("as") @NoException(true) GRUCellImpl asGRUCell();
///
///
///
public native @Name("as") @NoException(true) PixelShuffleImpl asPixelShuffle();
///
///
///
public native @Name("as") @NoException(true) PixelUnshuffleImpl asPixelUnshuffle();
///
///
///
public native @Name("as") @NoException(true) UpsampleImpl asUpsample();
///
///
///
public native @Name("as") @NoException(true) ELUImpl asELU();
///
///
///
public native @Name("as") @NoException(true) SELUImpl asSELU();
///
///
///
public native @Name("as") @NoException(true) HardshrinkImpl asHardshrink();
///
///
///
public native @Name("as") @NoException(true) HardtanhImpl asHardtanh();
///
///
///
public native @Name("as") @NoException(true) LeakyReLUImpl asLeakyReLU();
///
///
///
public native @Name("as") @NoException(true) LogSigmoidImpl asLogSigmoid();
///
///
///
public native @Name("as") @NoException(true) SoftmaxImpl asSoftmax();
///
///
///
public native @Name("as") @NoException(true) SoftminImpl asSoftmin();
///
///
///
public native @Name("as") @NoException(true) LogSoftmaxImpl asLogSoftmax();
///
///
///
public native @Name("as") @NoException(true) Softmax2dImpl asSoftmax2d();
///
///
///
public native @Name("as") @NoException(true) PReLUImpl asPReLU();
///
///
///
public native @Name("as") @NoException(true) ReLUImpl asReLU();
///
///
///
public native @Name("as") @NoException(true) ReLU6Impl asReLU6();
///
///
///
public native @Name("as") @NoException(true) RReLUImpl asRReLU();
///
///
///
public native @Name("as") @NoException(true) CELUImpl asCELU();
///
///
///
public native @Name("as") @NoException(true) GLUImpl asGLU();
///
///
///
public native @Name("as") @NoException(true) GELUImpl asGELU();
///
///
///
public native @Name("as") @NoException(true) SiLUImpl asSiLU();
///
///
///
public native @Name("as") @NoException(true) MishImpl asMish();
///
///
///
public native @Name("as") @NoException(true) SigmoidImpl asSigmoid();
///
///
///
public native @Name("as") @NoException(true) SoftplusImpl asSoftplus();
///
///
///
public native @Name("as") @NoException(true) SoftshrinkImpl asSoftshrink();
///
///
///
public native @Name("as") @NoException(true) SoftsignImpl asSoftsign();
///
///
///
public native @Name("as") @NoException(true) TanhImpl asTanh();
///
///
///
public native @Name("as") @NoException(true) TanhshrinkImpl asTanhshrink();
///
///
///
public native @Name("as") @NoException(true) ThresholdImpl asThreshold();
///
///
///
public native @Name("as") @NoException(true) MultiheadAttentionImpl asMultiheadAttention();
///
///
///
public native @Name("as") @NoException(true) LayerNormImpl asLayerNorm();
///
///
///
public native @Name("as") @NoException(true) LocalResponseNormImpl asLocalResponseNorm();
///
///
///
public native @Name("as") @NoException(true) CrossMapLRN2dImpl asCrossMapLRN2d();
///
///
///
public native @Name("as") @NoException(true) GroupNormImpl asGroupNorm();
///
///
///
public native @Name("as") @NoException(true) TransformerEncoderLayerImpl asTransformerEncoderLayer();
///
///
///
public native @Name("as") @NoException(true) TransformerDecoderLayerImpl asTransformerDecoderLayer();
///
///
///
public native @Name("as") @NoException(true) TransformerEncoderImpl asTransformerEncoder();
///
///
///
public native @Name("as") @NoException(true) TransformerDecoderImpl asTransformerDecoder();
///
///
///
public native @Name("as") @NoException(true) TransformerImpl asTransformer();
/** Attempts to cast this {@code Module} to the given {@code ModuleType}.
*
* This method is useful when calling {@code apply()}.
* \rst
* .. code-block:: cpp
*
* void initialize_weights(nn::Module& module) {
* torch::NoGradGuard no_grad;
* if (auto* linear = module.as()) {
* linear->weight.normal_(0.0, 0.02);
* }
* }
*
* MyModule module;
* module.apply(initialize_weights);
* \endrst */
/** Serializes the {@code Module} into the given {@code OutputArchive}.
*
* If the {@code Module} contains unserializable submodules (e.g.
* {@code nn::Functional}), those submodules are skipped when serializing. */
///
public native @Virtual(subclasses=false, method="save") @Const({false, false, true}) void save(@ByRef OutputArchive archive);
/** Deserializes the {@code Module} from the given {@code InputArchive}.
*
* If the {@code Module} contains unserializable submodules (e.g.
* {@code nn::Functional}), we don't check the existence of those submodules in the
* {@code InputArchive} when deserializing. */
///
public native @Virtual(subclasses=false, method="load") void load(@ByRef InputArchive archive);
/** Streams a pretty representation of the {@code Module} into the given {@code stream}.
* By default, this representation will be the name of the module (taken from
* {@code name()}), followed by a recursive pretty print of all of the {@code Module}'s
* submodules.
*
* Override this method to change the pretty print. The input
* {@code stream} should be returned from the method, to allow easy chaining. */
public native @Virtual(subclasses=false, method="pretty_print") @Const({false, false, true}) void pretty_print(@Cast("std::ostream*") @ByRef Pointer stream);
/** Returns whether the {@code Module} is serializable. */
///
///
///
///
public native @Cast("bool") @Virtual(subclasses=false, method="is_serializable") @Const({false, false, true}) boolean is_serializable();
/** Registers a parameter with this {@code Module}.
*
* A parameter should be any gradient-recording tensor used in the
* implementation of your {@code Module}. Registering it makes it available to
* methods such as {@code parameters()}, {@code clone()} or {@code to().}
*
* Note that registering an undefined Tensor (e.g.
* {@code module.register_parameter("param", Tensor())}) is allowed, and is
* equivalent to {@code module.register_parameter("param", None)} in Python API.
*
* \rst
* .. code-block:: cpp
*
* MyModule::MyModule() {
* weight_ = register_parameter("weight", torch::randn({A, B}));
* }
* \endrst */
///
///
///
public native @ByRef Tensor register_parameter(
@StdString BytePointer name,
@ByVal Tensor tensor,
@Cast("bool") boolean requires_grad/*=true*/);
public native @ByRef Tensor register_parameter(
@StdString BytePointer name,
@ByVal Tensor tensor);
public native @ByRef Tensor register_parameter(
@StdString String name,
@ByVal Tensor tensor,
@Cast("bool") boolean requires_grad/*=true*/);
public native @ByRef Tensor register_parameter(
@StdString String name,
@ByVal Tensor tensor);
/** Registers a buffer with this {@code Module}.
*
* A buffer is intended to be state in your module that does not record
* gradients, such as running statistics. Registering it makes it available
* to methods such as {@code buffers()}, {@code clone()} or {@code to().
*
* \rst
* .. code-block:: cpp
*
* MyModule::MyModule() {
* mean_ = register_buffer("mean", torch::empty({num_features_}));
* }
* \endrst */
///
///
///
public native @ByRef Tensor register_buffer(@StdString BytePointer name, @ByVal Tensor tensor);
public native @ByRef Tensor register_buffer(@StdString String name, @ByVal Tensor tensor);
/** Registers a submodule with this {@code Module}.
*
* Registering a module makes it available to methods such as {@code modules()},
* {@code clone()} or {@code to()}.
*
* \rst
* .. code-block:: cpp
*
* MyModule::MyModule() {
* submodule_ = register_module("linear", torch::nn::Linear(3, 4));
* }
* \endrst */
///
///
///
///
private native @Name("register_module") void _register_module(@StdString BytePointer name, @SharedPtr @ByVal Module module);
public M register_module(BytePointer name, M module) { _register_module(name, module); return module; }
private native @Name("register_module") void _register_module(@StdString String name, @SharedPtr @ByVal Module module);
public M register_module(String name, M module) { _register_module(name, module); return module; }
/** Registers a submodule with this {@code Module}.
*
* This method deals with {@code ModuleHolder}s.
*
* Registering a module makes it available to methods such as {@code modules()},
* {@code clone()} or {@code to()}.
*
* \rst
* .. code-block:: cpp
*
* MyModule::MyModule() {
* submodule_ = register_module("linear", torch::nn::Linear(3, 4));
* }
* \endrst */
/** Replaces a registered submodule with this {@code Module}.
*
* This takes care of the registration, if you used submodule members, you
* should */
// assign the submodule as well, i.e. use as
/** module->submodule_ = module->replace_module("linear",
/** torch::nn::Linear(3, 4));
/** It only works when a module of the name is already registered.
/**
/** This is useful for replacing a module after initialization, e.g.
/** for finetuning. */
/** Replaces a registered submodule with this {@code Module}.
* This method deals with {@code ModuleHolder}s.
*
* This takes care of the registration, if you used submodule members, you
* should */
// assign the submodule as well, i.e. use as
/** module->submodule_ = module->replace_module("linear", linear_holder);
/** It only works when a module of the name is already registered.
/**
/** This is useful for replacing a module after initialization, e.g.
/** for finetuning. */
/** Unregisters a submodule from this {@code Module}. If there is no such module
* with {@code name} an exception is thrown. */
public native void unregister_module(@StdString BytePointer name);
public native void unregister_module(@StdString String name);
private static native @Namespace @Cast("std::ostream*") @ByRef @Name("operator <<") Pointer shiftLeft(
@Cast("std::ostream*") @ByRef Pointer stream,
@Const @ByRef Module module);
public Pointer shiftLeft(Pointer stream) { return shiftLeft(stream, this); }
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy