neureka.TensorImpl Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of neureka Show documentation
Show all versions of neureka Show documentation
A platform independent tensor library written in Java.
The newest version!
/*
MIT License
Copyright (c) 2019 Gleethos
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
'Any fool can write code that a computer can understand.
Good programmers write code that humans can understand.'
– Martin Fowler
Use the following as search keys :)
§(1) : CONSTRUCTION
§(2) : FLAGS
§(3) : COMPONENT SYSTEM
§(4) : PROPERTIES
§(5) : OBJECT STATE MODIFICATION
§(6) : ND-ITERATOR LOGIC
§(7) : COMPONENT SPECIFIC
§(8) : (OVERLOADABLE) OPERATORS & OPERATIONS
§(9) : SLICING, INDEXING & INJECTING
§(10) : MAPPING
*/
package neureka;
import neureka.autograd.GraphNode;
import neureka.backend.api.ExecutionCall;
import neureka.backend.api.LazyRef;
import neureka.backend.main.memory.MemUtil;
import neureka.backend.main.operations.other.ReLayout;
import neureka.common.composition.AbstractComponentOwner;
import neureka.common.composition.Component;
import neureka.common.utility.DataConverter;
import neureka.common.utility.LogUtil;
import neureka.devices.Device;
import neureka.devices.host.CPU;
import neureka.dtype.DataType;
import neureka.fluent.slicing.SliceBuilder;
import neureka.fluent.slicing.SmartSlicer;
import neureka.fluent.slicing.states.AxisOrGetTensor;
import neureka.framing.NDFrame;
import neureka.framing.Relation;
import neureka.framing.fluent.AxisFrame;
import neureka.math.Function;
import neureka.math.args.Arg;
import neureka.ndim.Filler;
import neureka.ndim.NDConstructor;
import neureka.ndim.config.NDConfiguration;
import neureka.ndim.iterator.NDIterator;
import neureka.view.NdaAsString;
import org.slf4j.LoggerFactory;
import java.awt.*;
import java.awt.image.*;
import java.util.List;
import java.util.*;
import java.util.stream.Collectors;
/**
* The implementation for the {@link Tensor} API.
*
* @param The type parameter for the individual value items within this tensor.
*/
final class TensorImpl extends AbstractNda, V> implements MutateTensor
{
static {
_LOG = LoggerFactory.getLogger( TensorImpl.class );
}
/**
* This field contains multiple flags.
* The bits of this integer are used to encode various states which a tensor can have.
* These bits are flipped by bitmasks which are defined below.
*/
private byte _flags = 0;
/**
* The following fields are bit masks used to store true / false values
* in a targeted bit inside the {@link #_flags} variable.
*/
private static final byte RQS_GRADIENT_MASK = 1;
private static final byte IS_VIRTUAL_MASK = 2;
private static final byte GRADIENT_APPLY_RQD_MASK = 4;
private static final byte IS_DELETED_MASK = 8;
private static final byte IS_INTERMEDIATE_MASK = 16;
/*==================================================================================================================
|
| §(1) : CONSTRUCTION
| ---------------------------
*/
static Tensor _of( Object... args )
{
if ( args == null || args.length == 0 ) return new TensorImpl<>();
if ( args.length == 1 ) {
TensorImpl t = new TensorImpl<>(constructFor(CPU.get(), NDConstructor.of(1)).newPopulatedFromOne( args[ 0 ], args[ 0 ].getClass() ));
if ( args[ 0 ] == null ) {
String message = "Cannot create tensor from argument of type '" + args[ 0 ].getClass().getName() + "'!";
_LOG.error( message );
throw new IllegalArgumentException( message );
}
return t;
}
Class> commonType = _extractCommonType(args);
if ( commonType != null ) {
return new TensorImpl<>(constructFor(CPU.get(), NDConstructor.of( args.length ))
.tryConstructing(
DataType.of(commonType),
args
));
}
/* EXPRESSION BASED CONSTRUCTION:
The following allows the creation of tensors based on passing an expression
alongside input tensors to the constructor.
An example would be:
Tensor> t = Tensor.of( "tanh(", x, ") * 7 **", y );
*/
boolean containsString = false;
int numberOfTensors = 0;
for ( Object o : args ) {
containsString = ( o instanceof String ) || containsString;
if ( o instanceof TensorImpl)
numberOfTensors++;
}
TensorImpl[] tensors = new TensorImpl[ numberOfTensors ];
StringBuilder f = new StringBuilder();
int ti = 0;
for ( Object o : args ) {
if ( o instanceof TensorImpl) {
tensors[ ti ] = ( (TensorImpl) o );
f.append( "I[" ).append( ti ).append( "]" );
ti++;
}
else if ( o instanceof String ) f.append( (String) o );
else
_LOG.debug(
"Unexpected tensor construction argument of type '"+o.getClass().getSimpleName()+"'"
);
}
if ( tensors.length == 0 || tensors[0] == null) return new TensorImpl<>();
return Function.of( f.toString(), true ).call( tensors );
}
static Tensor _of( Iterable iterable )
{
List list = new ArrayList<>();
iterable.forEach( list::add );
return _of( list );
}
static Tensor _of( List list )
{
return new TensorImpl<>(
constructFor(CPU.get(), NDConstructor.of( list.size() ))
.tryConstructing(
DataType.of(_extractCommonType( list.toArray() )),
list.toArray()
));
}
/**
* @param args The objects which should be checked.
* @return A common type or null if they are not all of the same type.
*/
private static Class> _extractCommonType( Object... args ) {
Class> commonType = null;
for ( Object o : args )
if ( o != null ) {
if ( commonType == null ) commonType = o.getClass();
else if ( !commonType.equals(o.getClass()) ) return null;
}
return commonType;
}
// Constructors:
/**
* This constructor creates a completely empty tensor which is void of any contents and meaning.
* The use case for this would be to use the produced {@link Tensor}
* instance as a target for an inline operation which fills this instance with an actual value.
* An example of this approach would be to call the {@link #putAt(List, Nda)} method with an empty list as key.
* This will be interpreted as an inline copy of the contents of the
* second parameter into this {@link Tensor} instance.
* This constructor will be called by the {@link Tensor#newInstance()} factory method.
*/
TensorImpl() {
_setData(new Data() {
@Override public Device owner() { return (Device) CPU.get(); }
@Override public Object getOrNull() { return null;}
@Override public DataType dataType() {
return (DataType) Neureka.get().settings().dtype().getDefaultDataType();
}
@Override public int usages() { return 1; }
});
}
TensorImpl( TensorConstructor.Args args ) {
NDConfiguration ndc = args.getConf();
Boolean isVirtual = args.isVirtual();
Data data = (Data) args.getData();
if ( isVirtual != null )
_setIsVirtual( isVirtual );
if ( ndc != null )
_setNDConf( ndc );
if ( data != null )
_setData( data );
}
public static TensorImpl _of( NDConstructor ndConstructor, Device device, DataType dataType, Object value ) {
Object data = value;
if ( List.class.isAssignableFrom( dataType.getItemTypeClass() ) )
data = new Object[]{ value }; // Make a nd-array of lists possible
if ( Object[].class.isAssignableFrom( dataType.getItemTypeClass() ) )
data = new Object[]{ value }; // Make a nd-array of arrays possible
if ( Object.class == dataType.getItemTypeClass() ) {
if ( value.getClass() != Object[].class )
data = new Object[]{ value };
}
if ( data instanceof List> ) {
List> range = (List>) data;
data = range.toArray();// TODO: This is probably wrong!
}
return new TensorImpl<>(constructFor(device, ndConstructor).tryConstructing( dataType, data ));
}
static TensorImpl _of( NDConstructor ndConstructor, DataType dataType, Data data ) {
// We check if the type of the data is compatible with the type of the tensor:
if ( !dataType.getItemTypeClass().isAssignableFrom( data.dataType().getItemTypeClass() ) )
throw new IllegalArgumentException(
"The data type of the data is not compatible with the data type of the tensor!"
);
return new TensorImpl<>(constructFor(data.owner(), ndConstructor).constructTrusted( data ));
}
/**
* see {@link Tensor#of(DataType, Shape, Filler)}
*/
static TensorImpl _of( NDConstructor ndConstructor, DataType type, Filler filler ) {
LogUtil.nullArgCheck(ndConstructor, "ndcProducer", NDConstructor.class );
LogUtil.nullArgCheck( type, "type", DataType.class );
LogUtil.nullArgCheck( type, "filler", Filler.class );
TensorImpl t = new TensorImpl<>(constructFor(CPU.get(), ndConstructor).unpopulated( false, true, type ));
t._initDataArrayFrom( filler );
return t;
}
/**
* See {@link Tensor#of(Class, Shape, neureka.math.args.Arg.Seed)} and {@link #of(List, String)}
*/
static TensorImpl _of( Class valueType, NDConstructor ndConstructor, Arg.Seed seed ) {
LogUtil.nullArgCheck( valueType, "valueType", Class.class );
LogUtil.nullArgCheck(ndConstructor, "ndcProducer", NDConstructor.class );
LogUtil.nullArgCheck( seed, "seed", Arg.Seed.class );
return new TensorImpl<>(constructFor(CPU.get(), ndConstructor).newSeeded( valueType, seed ));
}
static TensorImpl _of( NDConstructor ndConstructor, DataType> type ) {
LogUtil.nullArgCheck(ndConstructor, "ndcProducer", NDConstructor.class );
LogUtil.nullArgCheck( type, "type", DataType.class );
return new TensorImpl<>(constructFor(CPU.get(), ndConstructor).unpopulated( true, true, type ));
}
/*==================================================================================================================
|
| §(2) : FLAGS
| ----------------------
*/
/** {@inheritDoc} */
@Override
public Tensor setRqsGradient( boolean rqsGradient ) {
if ( rqsGradient() != rqsGradient ) {
if ( !rqsGradient ) this.remove( TensorImpl.class );
else if ( has(GraphNode.class) ) {
if ( getGraphNode().map( n -> n.getMode() == 0 ).orElse(false) )
remove(GraphNode.class);
else
throw new IllegalArgumentException(
"This tensor is already part of a gradient dependent graph as " +
"branch node and therefore cannot be removed from it."
);
}
}
_setRqsGradient( rqsGradient );
return this;
}
/** {@inheritDoc} */
@Override public boolean rqsGradient() { return ( _flags & RQS_GRADIENT_MASK ) == RQS_GRADIENT_MASK; }
private void _setRqsGradient( boolean rqsGradient ) {
if ( rqsGradient() != rqsGradient ) {
if ( rqsGradient ) _flags += RQS_GRADIENT_MASK;
else _flags -= RQS_GRADIENT_MASK;
}
}
/** {@inheritDoc} */
@Override public boolean isIntermediate() { return ( _flags & IS_INTERMEDIATE_MASK ) == IS_INTERMEDIATE_MASK; }
/**
* Intermediate tensors are internal non-user tensors which may be eligible
* for deletion when further consumed by a {@link Function}.
* For the casual user of Neureka, this flag should always be false!
*
* @param isIntermediate The truth value determining if this tensor is not a user tensor but an internal
* tensor which may be eligible for deletion by {@link Function}s consuming it.
*/
private Tensor _setIsIntermediate( boolean isIntermediate ) {
if ( isIntermediate() != isIntermediate ) {
if ( isIntermediate ) _flags += IS_INTERMEDIATE_MASK;
else _flags -= IS_INTERMEDIATE_MASK;
}
return this;
}
/** {@inheritDoc} */
@Override
public boolean isVirtual() { return ( _flags & IS_VIRTUAL_MASK ) == IS_VIRTUAL_MASK; }
/** {@inheritDoc} */
@Override
public Tensor setIsVirtual(boolean isVirtual )
{
if ( getNDConf() == null )
throw new IllegalStateException(
"Cannot set the virtual flag of a tensor which has not been constructed yet!"
);
if ( isVirtual() != isVirtual )
{
if ( isVirtual )
_virtualize();
else
_actualize();
// Virtual and actual tensors require a different mapping from a given index to the underlying data..
// Therefore, we need to re-initialize the NDConfiguration object:
TensorConstructor.Args args = constructFor(getDevice(),NDConstructor.of(getNDConf().shape())).unpopulated( isVirtual, false, getDataType() );
_setState( args );
if ( isVirtual )
this.find( Relation.class )
.ifPresent( r ->
r.getChildren().forEach(c -> {
((TensorImpl)c)._setData( _getData() );
((TensorImpl)c).setIsVirtual( true );
})
);
else
this.find(Relation.class)
.map( relation -> ((Relation)relation).getParent().orElse(null) )
.map( parent -> parent.get(Relation.class) )
.ifPresent( parentRelation -> parentRelation.removeChild( this ) );
}
else if ( isVirtual ) _allocateVirtual(); //> Only a single value representing the rest.
return this;
}
private void _setState(TensorConstructor.Args args) {
Boolean isVirtual = args.isVirtual();
NDConfiguration ndc = args.getConf();
Data data = (Data) args.getData();
if ( isVirtual != null )
_setIsVirtual( isVirtual );
if ( ndc != null )
_setNDConf( ndc );
if ( data != null )
_setData( data );
}
/**
* This method is the inner counterpart to the public "{@link MutateTensor#setIsVirtual}" method.
* It actually performs the bit flipping by applying the corresponding bit mask.
*
* @param isVirtual The truth value which ought to be applied.
*/
@Override
protected void _setIsVirtual( boolean isVirtual ) {
if ( isVirtual() != isVirtual ) {
if ( isVirtual ) _flags += IS_VIRTUAL_MASK;
else _flags -= IS_VIRTUAL_MASK;
}
}
/** {@inheritDoc} */
@Override
public boolean isDeleted() { return ( _flags & IS_DELETED_MASK ) == IS_DELETED_MASK; }
/** {@inheritDoc} */
@Override
public boolean gradientApplyRequested() { return ( _flags & GRADIENT_APPLY_RQD_MASK ) == GRADIENT_APPLY_RQD_MASK; }
/** {@inheritDoc} */
@Override
public Tensor setGradientApplyRequested(boolean applyRequested ) {
if ( gradientApplyRequested() != applyRequested ) {
if ( applyRequested ) {
if (
Neureka.get().settings().autograd().isApplyingGradientWhenRequested() &&
!Neureka.get().settings().autograd().isApplyingGradientWhenTensorIsUsed()
)
this.applyGradient();
else
_flags += GRADIENT_APPLY_RQD_MASK;
}
else _flags -= GRADIENT_APPLY_RQD_MASK;
}
return this;
}
/**
* Although tensors will be garbage collected when they are not strongly referenced,
* there is also the option to manually free up the tensor and its associated data.
* This is especially useful when tensors are stored on a device like the OpenCLDevice.
* In that case calling the "{@link MutateTensor#delete()}" method will free the memory reserved for this tensor.
* This manual memory freeing through this method can be faster than waiting for
* the garbage collector to kick in...
*
*
* @return This very tensor instance to allow for method chaining.
*/
private Tensor _delete()
{
if ( isDeleted() ) return this;
getGraphNode().ifPresent( n -> {
if ( !n.canBeDeleted() ) {
String message = "Cannot delete a tensor which is used as derivative by the AD computation graph!";
_LOG.error( message );
throw new IllegalStateException( message );
}
});
this.find( Device.class ).ifPresent( device -> device.free( this ) );
_setData( null );
_setNDConf( null );
_flags = 0;
this.find( TensorImpl.class ).ifPresent(t -> t.mut().delete() );
_deleteComponents();
_flags += IS_DELETED_MASK;
return this;
}
/*==================================================================================================================
|
| §(3) : COMPONENT SYSTEM
| --------------------------------
*/
/** {@inheritDoc} */
@Override public > T get( Class componentClass )
{
LogUtil.nullArgCheck( componentClass, "componentClass", Class.class );
if ( GraphNode.class.isAssignableFrom(componentClass) )
_guardGet(componentClass.getSimpleName());
else if ( NDFrame.class.isAssignableFrom(componentClass) )
_guardGet(componentClass.getSimpleName());
return super.get(componentClass);
}
/**
* This method is executed when a new Component is added to the tensor.
* The public add method is implemented in the super class
* '{@link AbstractComponentOwner}' from which this class inherits.
* In this super class the component logic is implemented.
*
* @param newComponent A component used to access features. ({@link GraphNode}, {@link NDFrame}, {@link Relation}, int[], ...)
* @return The unchanged object or maybe in future versions: null (component rejected)
*/
@Override
protected < T extends Component> > T _setOrReject(T newComponent ) { return newComponent; }
/**
* This method is executed when a component is being removed from the tensor.
* The public remove method is implemented in the super class
* '{@link AbstractComponentOwner}' from which this class inherits.
* In this super class the component logic is implemented.
*
* @param newComponent A component used to access features. ({@link GraphNode}, {@link NDFrame}, {@link Relation}, int[], ...)
* @return The unchanged object or when rejected: null (component rejected)
*/
@Override
protected >> T _removeOrReject(T newComponent )
{
if ( newComponent instanceof Device ) {
Device device = (Device) newComponent;
/*
The following seems like a redundant check, however often times a tensor
will be removed from a Device implementation inside the "restore" method
when the tensor has already been removed from the device...
Without the condition below a stack overflow would occur!
*/
if ( device.has( this ) ) {
try {
device.restore( this );
} catch ( Exception exception ) {
_LOG.error(
"Removing device from tensor / tensor from device failed.\n" +
"Restoring tensor from device threw exception.\n",
exception
);
throw exception;
}
}
}
return newComponent;
}
/*==================================================================================================================
|
| §(4) : PROPERTIES :
| ---------------------------------------
*/
/**
* {@inheritDoc}
*/
@Override
public int getVersion() { return _version; }
/*==================================================================================================================
|
| §(5) : OBJECT STATE MODIFICATION :
| ------------------------------------------
*/
/**
* This method is responsible for incrementing
* the "_version" field variable which represents the version of the data of this tensor.
* Meaning :
* Every time the underlying data (_value) changes this version ought to increment alongside.
* The method is called during the execution procedure.
*
* @param call The context object containing all relevant information that defines a call for tensor execution.
*/
private void _incrementVersionBecauseOf( ExecutionCall> call ) {
if ( Neureka.get().settings().autograd().isPreventingInlineOperations() ) {
_version++; // Autograd must be warned!
GraphNode> node = get( GraphNode.class );
if ( node != null && node.getPayloadReferenceVersion() != _version ) {
if ( node.usesAD() || node.isUsedAsDerivative() ) {
String error = "Inline operation occurred on tensor which is part of a computation graph node with autograd support!\n" +
"The following OperationType caused an internal version mismatch: '"+call.getOperation().getIdentifier()+"'";
_LOG.error( error );
throw new IllegalStateException( error );
}
}
}
}
/**
* {@inheritDoc}
*/
@Override
public MutateTensor getMut() {
_guardGet("unsafe API");
return this;
}
/** {@inheritDoc} */
@Override public MutateNda.Item at(int... indices ) {
return new MutateNda.Item() {
@Override public V orElseNull() { return item( indices ); }
@Override public void set( V value ) { getMut().putAt( indices, value ); }
@Override public boolean equals( Object o ) {
if ( o == null ) return false;
if ( o == this ) return true;
if ( o.getClass() != this.getClass() ) return false;
Nda.Item other = (Nda.Item) o;
return this.get().equals( other.get() );
}
@Override public int hashCode() { V item = get(); return ( item == null ? 0 : item.hashCode() ); }
@Override public String toString() { return String.valueOf( get() ); }
};
}
/**
* {@inheritDoc}
*/
@Override
public Tensor setNDConf(NDConfiguration configuration ) { TensorImpl.this._setNDConf( configuration ); return TensorImpl.this; }
/**
* {@inheritDoc}
*/
@Override
public Tensor toType(Class typeClass ) {
LogUtil.nullArgCheck( typeClass, "typeClass", Class.class, "Cannot convert tensor to 'null' data type." );
return TensorImpl.this._toType( typeClass );
}
/**
* {@inheritDoc}
*/
@Override
public Tensor upcast(Class superType ) {
LogUtil.nullArgCheck( superType, "superType", Class.class );
if ( superType.isAssignableFrom(TensorImpl.this.itemType()) )
return (Tensor) TensorImpl.this;
else
throw new IllegalArgumentException("Provided type '"+superType+"' is not a super type of '"+ TensorImpl.this.itemType()+"'.");
}
/**
* {@inheritDoc}
*/
@Override
public Tensor toLayout(NDConfiguration.Layout layout ) {
ReLayout.toLayout( this, layout );
return TensorImpl.this;
}
/**
* {@inheritDoc}
*/
@Override
public Tensor incrementVersion(ExecutionCall> call ) {
LogUtil.nullArgCheck( call, "call", ExecutionCall.class );
_incrementVersionBecauseOf( call );
return TensorImpl.this;
}
/**
* {@inheritDoc}
*/
@Override
public Tensor setIsIntermediate(boolean isIntermediate ) { return _setIsIntermediate( isIntermediate ); }
/**
* {@inheritDoc}
*/
@Override public Tensor delete() { return TensorImpl.this._delete(); }
/**
* {@inheritDoc}
*/
@Override public Data getData() { return _getData(); }
/**
* {@inheritDoc}
*/
@Override
public A getDataAs( Class arrayTypeClass ) {
return DataConverter.get().convert( _getData(false), arrayTypeClass );
}
/**
* {@inheritDoc}
*/
@Override
public Tensor setDataAt(int i, V o ) {
_guardMod("data object");
_setDataAt( i, o );
return TensorImpl.this;
}
/**
* {@inheritDoc}
*/
@Override
public Tensor setData(Data data ) {
TensorImpl.this._setData( data );
return TensorImpl.this;
}
/**
* {@inheritDoc}
*/
@Override public Tensor detach() { TensorImpl.this.remove( GraphNode.class ); return TensorImpl.this; }
/** {@inheritDoc} */
@Override public Tensor timesAssign(Tensor other ) {
LogUtil.nullArgCheck(other, "other", Tensor.class, "Cannot multiply-assign 'null' to a tensor!");
return Neureka.get().backend().getFunction().mulAssign().call( TensorImpl.this, other );
}
/** {@inheritDoc} */
@Override public Tensor timesAssign(V other ) {
LogUtil.nullArgCheck(other, "other", TensorImpl.this.getItemType(), "Cannot multiply-assign 'null' to a tensor!");
return this.timesAssign( Tensor.of( getItemType(), this.shape(), other ) );
}
/** {@inheritDoc} */
@Override public Tensor divAssign(Tensor other ) {
LogUtil.nullArgCheck(other, "other", Tensor.class, "Cannot divide-assign a tensor by 'null' (In any sense of the word)!");
return Neureka.get().backend().getFunction().divAssign().call( TensorImpl.this, other );
}
/** {@inheritDoc} */
@Override public Tensor modAssign(Tensor other ) {
LogUtil.nullArgCheck(other, "other", Tensor.class, "Cannot perform tensor modulo 'null'!");
return Neureka.get().backend().getFunction().modAssign().call( TensorImpl.this, other );
}
/** {@inheritDoc} */
@Override public Tensor plusAssign(Tensor other ) {
LogUtil.nullArgCheck(other, "other", Tensor.class, "Cannot add-assign 'null' to a tensor!");
return Neureka.get().backend().getFunction().plusAssign().call( TensorImpl.this, other );
}
/** {@inheritDoc} */
@Override public Tensor minusAssign(Tensor other ) {
LogUtil.nullArgCheck(other, "other", Tensor.class, "Cannot subtract-assign 'null' from a tensor!");
return Neureka.get().backend().getFunction().minusAssign().call( TensorImpl.this, other );
}
/** {@inheritDoc} */
@Override public Tensor minusAssign(V other ) {
LogUtil.nullArgCheck(other, "other", TensorImpl.this.getItemType(), "Cannot subtract-assign 'null' from a tensor!");
return minusAssign(
Tensor.of( TensorImpl.this.getDataType().getItemTypeClass() )
.withShape(TensorImpl.this.getNDConf().shape())
.all(other)
);
}
@Override
public Tensor assign(V other ) {
LogUtil.nullArgCheck(other, "other", TensorImpl.this.getItemType(), "Cannot subtract-assign 'null' from a tensor!");
return assign(
Tensor.of( TensorImpl.this.getDataType().getItemTypeClass() )
.withShape(TensorImpl.this.getNDConf().shape())
.all(other)
);
}
@Override
public Tensor assign(Nda other ) {
LogUtil.nullArgCheck(other, "other", Tensor.class, "Cannot assign 'null' to a tensor!");
return Neureka.get().backend().getFunction().idy().call( TensorImpl.this, (Tensor) other );
}
@Override
public Tensor labelAxes(String[]... labels ) {
LogUtil.nullArgCheck(labels, "labels", String[][].class, "Tensors cannot be labeled 'null'!");
if ( labels.length > this.rank() )
throw new IllegalArgumentException(
"Number of the provided axes labels is larger than the total number of axes (rank) of the nd-array."
);
NDFrame frame = get( NDFrame.class );
if ( frame == null ) {
frame = new NDFrame<>( this, null);
this.set(frame);
}
for ( int i = 0; i < labels.length; i++ ) {
if ( labels[ i ] != null ) {
AxisFrame atAxis = frame.atAxis( i );
for ( int ii = 0; ii < labels[ i ].length; ii++ ) {
if ( labels[ i ][ ii ] != null )
atAxis.atIndexAlias( labels[ i ][ ii ] ).setIndex( ii );
}
}
}
return this;
}
/** {@inheritDoc} */
@Override
public Tensor labelAxes(List> labels ) {
LogUtil.nullArgCheck(labels, "labels", List.class, "Tensors cannot be labeled 'null'!");
NDFrame frame = get( NDFrame.class );
if ( frame == null ) set( new NDFrame<>( labels, this, null ) );
else set( frame.withAxesLabels( labels ) );
return TensorImpl.this;
}
/** {@inheritDoc} */
@Override
public Tensor label(String label ) {
LogUtil.nullArgCheck( label, "label", List.class, "Tensors cannot be labeled 'null'!" );
NDFrame frame = get( NDFrame.class );
if ( frame == null ) set( new NDFrame<>( Collections.emptyList(), this, label ) );
else set( frame.withLabel(label) );
return TensorImpl.this;
}
/** {@inheritDoc} */
@Override
public Tensor labelAxes(Map
© 2015 - 2025 Weber Informatics LLC | Privacy Policy