Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.api.ops.impl.reduce.custom;
import lombok.EqualsAndHashCode;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
@EqualsAndHashCode
public class BatchMmul extends DynamicCustomOp {
protected int transposeA;
protected int transposeB;
protected int batchSize;
protected int M;
protected int N;
protected int K;
public BatchMmul(SameDiff sameDiff, SDVariable[] matricesA, SDVariable[] matricesB, boolean transposeA, boolean transposeB) {
this(sameDiff, ArrayUtils.addAll(matricesA, matricesB), transposeA, transposeB);
}
public BatchMmul(SameDiff sameDiff,
SDVariable[] matrices,
boolean transposeA,
boolean transposeB) {
super(null, sameDiff, ArrayUtils.addAll(
new SDVariable[]{
sameDiff.var(Nd4j.ones(matrices[0].dataType(), matrices.length / 2)), // alphas
sameDiff.var(Nd4j.zeros(matrices[1].dataType(), matrices.length / 2))}, // betas
matrices));
Preconditions.checkState(matrices.length % 2 == 0, "The number of provided matrices needs" +
"to be divisible by two.");
this.batchSize = matrices.length / 2;
SDVariable firstMatrix = matrices[0];
long[] firstShape = firstMatrix.getShape();
for (int i = 0; i < batchSize; i++) {
Preconditions.checkState(Arrays.equals(firstShape, matrices[i].getShape()));
}
SDVariable lastMatrix = matrices[2 * batchSize - 1];
long[] lastShape = lastMatrix.getShape();
for (int i = batchSize; i < 2 * batchSize; i++) {
Preconditions.checkState(Arrays.equals(lastShape, matrices[i].getShape()));
}
this.transposeA = transposeA ? 1 : 0;
this.transposeB = transposeB ? 1 : 0;
this.M = transposeA ? (int) firstShape[1]: (int) firstShape[0];
this.N = transposeA ? (int) firstShape[0]: (int) firstShape[1];
this.K = transposeB ? (int) lastShape[0]: (int) lastShape[1];
addArgs();
}
public BatchMmul(INDArray[] matricesA, INDArray[] matricesB, boolean transposeA, boolean transposeB){
super(ArrayUtils.addAll(matricesA, matricesB), null);
this.batchSize = matricesA.length;
this.transposeA = transposeA ? 1 : 0;
this.transposeB = transposeB ? 1 : 0;
long[] firstShape = matricesA[0].shape();
long[] lastShape = matricesB[0].shape();
this.M = transposeA ? (int) firstShape[1]: (int) firstShape[0];
this.N = transposeA ? (int) firstShape[0]: (int) firstShape[1];
this.K = transposeB ? (int) lastShape[0]: (int) lastShape[1];
addArgs();
}
@Override
public int getNumOutputs(){
return batchSize;
}
public void addArgs() {
addIArgument(transposeA, transposeB,
M, K, N, // K and N are swapped in libnd4j
M, K, N, // these three are LDA, LDB and LDC (leading dims / strides) from blas. set to matrix dims here
batchSize);
}
public BatchMmul() {
}
@Override
public String opName() {
return "batched_gemm";
}
@Override
public List doDiff(List grads) {
SDVariable[] dLdOut = grads.toArray(new SDVariable[grads.size()]);
SDVariable[] allArgs = args();
SDVariable[] matricesA = Arrays.copyOfRange(allArgs,0, batchSize);
SDVariable[] matricesB = Arrays.copyOfRange(allArgs, batchSize, 2 * batchSize);
SDVariable[] dLdx = sameDiff.batchMmul(dLdOut, matricesB, false, transposeB == 1);
SDVariable[] dLdy = sameDiff.batchMmul(matricesA, dLdOut, transposeA == 1, false);
List ret = new ArrayList<>();
Collections.addAll(ret, dLdx);
Collections.addAll(ret, dLdy);
return ret;
}
@Override
public List calculateOutputDataTypes(List dataTypes){
List out = new ArrayList<>();
for(int i=0; i