org.jpmml.evaluator.MatrixUtil Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-evaluator Show documentation
Show all versions of pmml-evaluator Show documentation
JPMML class model evaluator
/*
* Copyright (c) 2013 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.dmg.pmml.Array;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MatCell;
import org.dmg.pmml.Matrix;
public class MatrixUtil {
private MatrixUtil(){
}
/**
* @param row The row index. The index of the first row is 1
.
* @param column The column index. The index of the first column is 1
.
*
* @return The element at the specified location, or null
.
*
* @throws IndexOutOfBoundsException If either the row or column index is out of range.
*/
static
public Number getElementAt(Matrix matrix, int row, int column){
List arrays = matrix.getArrays();
List matCells = matrix.getMatCells();
Matrix.Kind kind = matrix.getKind();
switch(kind){
case DIAGONAL:
{
// "The content is just one Array of numbers representing the diagonal values"
if(arrays.size() == 1){
Array array = arrays.get(0);
List extends Number> elements = ArrayUtil.asNumberList(array);
// Diagonal element
if(row == column){
return elements.get(row - 1);
} else
// Off-diagonal element
{
int min = 1;
int max = elements.size();
if((row < min || row > max) || (column < min || column > max)){
throw new IndexOutOfBoundsException();
}
return matrix.getOffDiagDefault();
}
}
}
break;
case SYMMETRIC:
{
// "The content must be represented by Arrays"
if(!arrays.isEmpty()){
// Make sure the specified coordinates target the lower left triangle
if(column > row){
int temp = row;
row = column;
column = temp;
}
return getArrayValue(arrays, row, column);
}
}
break;
case ANY:
{
if(!arrays.isEmpty()){
return getArrayValue(arrays, row, column);
} // End if
if(!matCells.isEmpty()){
if(row < 1 || column < 1){
throw new IndexOutOfBoundsException();
}
Number value = getMatCellValue(matCells, row, column);
if(value == null){
if(row == column){
return matrix.getDiagDefault();
}
return matrix.getOffDiagDefault();
}
return value;
}
}
break;
default:
throw new UnsupportedAttributeException(matrix, kind);
}
throw new InvalidElementException(matrix);
}
static
private Number getArrayValue(List arrays, int row, int column){
Array array = arrays.get(row - 1);
List extends Number> elements = ArrayUtil.asNumberList(array);
return elements.get(column - 1);
}
static
private Number getMatCellValue(List matCells, int row, int column){
for(MatCell matCell : matCells){
if((matCell.getRow() == row) && (matCell.getCol() == column)){
return (Number)TypeUtil.parseOrCast(DataType.DOUBLE, matCell.getValue());
}
}
return null;
}
/**
* @return The number of rows.
*/
static
public int getRows(Matrix matrix){
Integer nbRows = matrix.getNbRows();
if(nbRows != null){
return nbRows;
}
List arrays = matrix.getArrays();
List matCells = matrix.getMatCells();
Matrix.Kind kind = matrix.getKind();
switch(kind){
case DIAGONAL:
{
if(arrays.size() == 1){
Array array = arrays.get(0);
return ArrayUtil.getSize(array);
}
}
break;
case SYMMETRIC:
{
if(!arrays.isEmpty()){
return arrays.size();
}
}
break;
case ANY:
{
if(!arrays.isEmpty()){
return arrays.size();
} // End if
if(!matCells.isEmpty()){
MatCell matCell = Collections.max(matCells, MatrixUtil.rowComparator);
return matCell.getRow();
}
}
break;
default:
throw new UnsupportedAttributeException(matrix, kind);
}
throw new InvalidElementException(matrix);
}
/**
* @return The number of columns.
*/
static
public int getColumns(Matrix matrix){
Integer nbCols = matrix.getNbCols();
if(nbCols != null){
return nbCols;
}
List arrays = matrix.getArrays();
List matCells = matrix.getMatCells();
Matrix.Kind kind = matrix.getKind();
switch(kind){
case DIAGONAL:
{
if(arrays.size() == 1){
Array array = arrays.get(0);
return ArrayUtil.getSize(array);
}
}
break;
case SYMMETRIC:
{
if(!arrays.isEmpty()){
return arrays.size();
}
}
break;
case ANY:
{
if(!arrays.isEmpty()){
Array array = arrays.get(arrays.size() - 1);
return ArrayUtil.getSize(array);
} // End if
if(!matCells.isEmpty()){
MatCell matCell = Collections.max(matCells, MatrixUtil.columnComparator);
return matCell.getCol();
}
}
break;
default:
throw new UnsupportedAttributeException(matrix, kind);
}
throw new InvalidElementException(matrix);
}
private static final Comparator rowComparator = new Comparator(){
@Override
public int compare(MatCell left, MatCell right){
return (left.getRow() - right.getRow());
}
};
private static final Comparator columnComparator = new Comparator(){
@Override
public int compare(MatCell left, MatCell right){
return (left.getCol() - right.getCol());
}
};
}