All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.diffkt.model.AvgPool2d.kt Maven / Gradle / Ivy

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

package org.diffkt.model

import org.diffkt.DTensor
import org.diffkt.combineHash

/**
 * Average pool 2d
 *
 * Downsamples the input, returning the average for each pool/window.
 */
class AvgPool2d(
    val poolHeight: Int,
    val poolWidth: Int
) : LayerSingleInput {
    override fun invoke(input: DTensor): DTensor {
        return avgPool(input, poolHeight, poolWidth)
    }
    override fun hashCode(): Int = combineHash("AvgPool2d", poolHeight, poolWidth)
    override fun equals(other: Any?): Boolean = other is AvgPool2d &&
            other.poolHeight == poolHeight &&
            other.poolWidth == poolWidth
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy