org.ejml.alg.block.GeneratorBlockInnerMultiplication Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of ejml Show documentation
Show all versions of ejml Show documentation
A fast and easy to use dense matrix linear algebra library written in Java.
/*
* Copyright (c) 2009-2013, Peter Abeles. All Rights Reserved.
*
* This file is part of Efficient Java Matrix Library (EJML).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.ejml.alg.block;
import org.ejml.alg.generic.CodeGeneratorMisc;
import java.io.FileNotFoundException;
import java.io.PrintStream;
/**
* @author Peter Abeles
*/
public class GeneratorBlockInnerMultiplication {
String className;
PrintStream stream;
public GeneratorBlockInnerMultiplication( String className ) throws FileNotFoundException {
this.className = className;
stream = new PrintStream(className+".java");
}
public void createClass() {
printTop();
for( int i = 0; i < 2; i++ ) {
boolean hasAlpha = i==1;
for( Operation o : Operation.values()) {
if( hasAlpha && o == Operation.MINUS )
continue;
print_mult(hasAlpha,o);
print_multTransA(hasAlpha,o);
print_multTransB(hasAlpha,o);
}
}
stream.print("}\n");
}
private void printTop() {
String foo = CodeGeneratorMisc.COPYRIGHT +
"\n" +
"package org.ejml.alg.block;\n" +
"\n" +
"/**\n" +
" * \n" +
" * Matrix multiplication for the inner row major blocks, typically inside of a {@link org.ejml.data.BlockMatrix64F}.\n" +
" *
\n" +
" *\n" +
" * \n" +
" * This code was auto generated by {@link GeneratorBlockInnerMultiplication} and should not be modified directly.\n" +
" *
\n" +
" *\n" +
" * @author Peter Abeles\n" +
" */\n" +
"public class "+className+" {\n";
stream.print(foo);
}
private void print_mult( boolean hasAlpha , Operation opType ) {
createHeader(hasAlpha,opType,false,false);
stream.print(
"// for( int i = 0; i < heightA; i++ ) {\n" +
"// for( int k = 0; k < widthA; k++ ) {\n" +
"// for( int j = 0; j < widthC; j++ ) {\n" +
"// dataC[ i*widthC + j + indexC ] += dataA[i*widthA + k + indexA] * dataB[k*widthC + j + indexB];\n" +
"// }\n" +
"// }\n" +
"// }\n");
stream.println();
String o = ( opType == Operation.MINUS ) ? "-=" : "+=";
String m = hasAlpha ? "alpha*" : "";
stream.print(
" int a = indexA;\n"+
" int rowC = indexC;\n"+
" for( int i = 0; i < heightA; i++ , rowC += widthC ) {\n" +
" int b = indexB;\n" +
"\n" +
" final int endC = rowC + widthC;\n" +
" final int endA = a + widthA;"+
"\n"+
" while( a != endA ) {//for( int k = 0; k < widthA; k++ ) {\n" +
" double valA = "+m+"dataA[a++];\n" +
"\n" +
" int c = rowC;\n" +
"\n");
if( opType == Operation.SET ) {
stream.print(
" if( b == indexB ) {\n" +
" while( c != endC ) {//for( int j = 0; j < widthC; j++ ) {\n" +
" dataC[ c++ ] = valA * dataB[ b++ ];\n" +
" }\n" +
" } else {\n" +
" while( c != endC ) {//for( int j = 0; j < widthC; j++ ) {\n" +
" dataC[ c++ ] "+o+" valA * dataB[ b++ ];\n" +
" }\n" +
" }\n");
} else {
stream.print(
" while( c != endC ) {//for( int j = 0; j < widthC; j++ ) {\n" +
" dataC[ c++ ] "+o+" valA * dataB[ b++ ];\n" +
" }\n");
}
stream.println(
" }\n" +
" }");
stream.println(" }");
}
private String createOpString(boolean hasAlpha, Operation opType) {
String o = opString(opType);
if( hasAlpha ) o += " alpha * ";
return o;
}
private void print_multTransA( boolean hasAlpha , Operation opType ) {
createHeader(hasAlpha,opType,true,false);
String o = ( opType == Operation.MINUS ) ? "-=" : "+=";
String m = hasAlpha ? "alpha*" : "";
stream.print(
"// for( int i = 0; i < widthA; i++ ) {\n" +
"// for( int k = 0; k < heightA; k++ ) {\n" +
"// double valA = dataA[k*widthA + i + indexA];\n" +
"// for( int j = 0; j < widthC; j++ ) {\n" +
"// dataC[ i*widthC + j + indexC ] += valA * dataB[k*widthC + j + indexB];\n" +
"// }\n" +
"// }\n" +
"// }\n");
stream.println();
stream.print(
" int rowC = indexC;\n"+
" for( int i = 0; i < widthA; i++ , rowC += widthC) {\n" +
" int colA = i + indexA;\n" +
" int endA = colA + widthA*heightA;\n" +
" int b = indexB;\n" +
"\n" +
" // for( int k = 0; k < heightA; k++ ) {\n" +
" while(colA != endA ) {\n" +
" double valA = "+m+"dataA[colA];\n" +
"\n" +
" int c = rowC;\n" +
" final int endB = b + widthC;\n" +
"\n" +
" //for( int j = 0; j < widthC; j++ ) {\n");
if( opType == Operation.SET ) {
stream.print(
" if( b == indexB ) {\n" +
" while( b != endB ) {\n" +
" dataC[ c++ ] = valA * dataB[b++];\n" +
" } \n" +
" } else {\n" +
" while( b != endB ) {\n" +
" dataC[ c++ ] "+o+" valA * dataB[b++];\n" +
" }\n" +
" }\n");
} else {
stream.print(
" while( b != endB ) {\n" +
" dataC[ c++ ] "+o+" valA * dataB[b++];\n" +
" }\n");
}
stream.print(
" colA += widthA;\n"+
" }\n" +
" }\n");
stream.println(" }");
}
private void print_multTransB( boolean hasAlpha , Operation opType ) {
createHeader(hasAlpha,opType,false,true);
String o = createOpString(hasAlpha, opType);
stream.println(
" for( int i = 0; i < heightA; i++ ) {\n" +
" for( int j = 0; j < widthC; j++ ) {\n" +
" double val = 0;\n" +
"\n" +
" for( int k = 0; k < widthA; k++ ) {\n" +
" val += dataA[i*widthA + k + indexA] * dataB[j*widthA + k + indexB];\n" +
" }\n" +
"\n" +
" dataC[ i*widthC + j + indexC ] "+o+" val;\n" +
" }\n" +
" }");
stream.println(" }");
}
private void createHeader( boolean hasAlpha , Operation opType , boolean transA , boolean transB )
{
String alphaString = hasAlpha ? " α " : "";
String alphaParam = hasAlpha ? " double alpha ," : "";
String transAString = transA ? "T" : "";
String transBString = transB ? "T" : "";
String opTypeString;
switch( opType ) {
case MINUS: opTypeString = "C - "; break;
case PLUS: opTypeString = "C + "; break;
case SET: opTypeString = ""; break;
default: throw new RuntimeException("Unknown optype");
}
String funcName = "blockMult"+opName(opType);
if( transA && transB ) funcName += "TransAB";
else if( transA ) funcName += "TransA";
else if( transB ) funcName += "TransB";
stream.println();
stream.print(
" /**\n" +
" * \n" +
" * Performs the follow operation on individual inner blocks:
\n" +
" *
\n");
stream.print(
" * C = "+opTypeString+alphaString+"A"+transAString+" * B"+transBString+"\n");
stream.print(
" *
\n" +
" */\n" +
" public static void "+funcName+"("+alphaParam+" final double[] dataA, final double []dataB, final double []dataC,\n" +
" int indexA, int indexB, int indexC,\n" +
" final int heightA, final int widthA, final int widthC) {\n");
}
private String opString( Operation opType ) {
switch( opType ) {
case MINUS:
return "-=";
case PLUS:
return "+=";
case SET:
return "=";
default:
throw new RuntimeException("Unknown opType "+opType);
}
}
private String opName( Operation opType ) {
switch( opType ) {
case MINUS:
return "Minus";
case PLUS:
return "Plus";
case SET:
return "Set";
default:
throw new RuntimeException("Unknown opType "+opType);
}
}
private static enum Operation
{
/** Add results to output matrix */
PLUS,
/** Subtract results from output matrix */
MINUS,
/** set output matrix to results */
SET
}
public static void main( String args[] ) throws FileNotFoundException {
GeneratorBlockInnerMultiplication app = new GeneratorBlockInnerMultiplication("BlockInnerMultiplication");
app.createClass();
System.out.println("Done generating class");
}
}