All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.diffkt.Reshape.kt Maven / Gradle / Ivy

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

@file:Suppress("UNREDUCED_STYPE_ERROR") // shapeTyping #72: affects extension functions referencing `this` and shaped properties
package org.diffkt

fun DTensor.reshape(vararg newShape: Int): DTensor = reshape(Shape(newShape))

fun DTensor.reshape(newShape: Shape): DTensor {
    // Check that new shape is valid.
    val oldShape = this.shape
    if (newShape == oldShape) return this
    require(oldShape.product() == newShape.product())
    return if (newShape.isScalar)
        this.operations.reshapeToScalar(this)
    else
        this.operations.reshape(this, newShape)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy