
org.jpmml.evaluator.MatrixUtil Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-evaluator Show documentation
Show all versions of pmml-evaluator Show documentation
JPMML class model evaluator
/*
* Copyright (c) 2013 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import com.google.common.base.Predicate;
import com.google.common.collect.Iterables;
import org.dmg.pmml.Array;
import org.dmg.pmml.MatCell;
import org.dmg.pmml.Matrix;
public class MatrixUtil {
private MatrixUtil(){
}
/**
* @param row The row index. The index of the first row is 1
.
* @param column The column index. The index of the first column is 1
.
*
* @return The element at the specified location, or null
.
*
* @throws IndexOutOfBoundsException If either the row or column index is out of range.
*/
static
public Number getElementAt(Matrix matrix, int row, int column){
List arrays = matrix.getArrays();
List matCells = matrix.getMatCells();
Matrix.Kind kind = matrix.getKind();
switch(kind){
case DIAGONAL:
{
// "The content is just one Array of numbers representing the diagonal values"
if(arrays.size() == 1){
Array array = arrays.get(0);
List extends Number> elements = ArrayUtil.asNumberList(array);
// Diagonal element
if(row == column){
return elements.get(row - 1);
} else
// Off-diagonal element
{
int min = 1;
int max = elements.size();
if((row < min || row > max) || (column < min || column > max)){
throw new IndexOutOfBoundsException();
}
return matrix.getOffDiagDefault();
}
}
}
break;
case SYMMETRIC:
{
// "The content must be represented by Arrays"
if(arrays.size() > 0){
// Make sure the specified coordinates target the lower left triangle
if(column > row){
int temp = row;
row = column;
column = temp;
}
return getArrayValue(arrays, row, column);
}
}
break;
case ANY:
{
if(arrays.size() > 0){
return getArrayValue(arrays, row, column);
} // End if
if(matCells.size() > 0){
if(row < 1 || column < 1){
throw new IndexOutOfBoundsException();
}
Number value = getMatCellValue(matCells, row, column);
if(value == null){
if(row == column){
return matrix.getDiagDefault();
}
return matrix.getOffDiagDefault();
}
return value;
}
}
break;
default:
throw new UnsupportedFeatureException(matrix, kind);
}
throw new InvalidFeatureException(matrix);
}
static
private Number getArrayValue(List arrays, int row, int column){
Array array = arrays.get(row - 1);
List extends Number> elements = ArrayUtil.asNumberList(array);
return elements.get(column - 1);
}
static
private Number getMatCellValue(List matCells, final int row, final int column){
Predicate filter = new Predicate(){
@Override
public boolean apply(MatCell matCell){
return (getRow(matCell) == row) && (getColumn(matCell) == column);
}
};
MatCell matCell = Iterables.getFirst(Iterables.filter(matCells, filter), null);
if(matCell != null){
return Double.parseDouble(matCell.getValue());
}
return null;
}
/**
* @return The number of rows.
*/
static
public int getRows(Matrix matrix){
Integer nbRows = matrix.getNbRows();
if(nbRows != null){
return nbRows.intValue();
}
List arrays = matrix.getArrays();
List matCells = matrix.getMatCells();
Matrix.Kind kind = matrix.getKind();
switch(kind){
case DIAGONAL:
{
if(arrays.size() == 1){
Array array = arrays.get(0);
return ArrayUtil.getSize(array);
}
}
break;
case SYMMETRIC:
{
if(arrays.size() > 0){
return arrays.size();
}
}
break;
case ANY:
{
if(arrays.size() > 0){
return arrays.size();
} // End if
if(matCells.size() > 0){
MatCell matCell = Collections.max(matCells, MatrixUtil.rowComparator);
return getRow(matCell);
}
}
break;
default:
throw new UnsupportedFeatureException(matrix, kind);
}
throw new InvalidFeatureException(matrix);
}
/**
* @return The number of columns.
*/
static
public int getColumns(Matrix matrix){
Integer nbCols = matrix.getNbCols();
if(nbCols != null){
return nbCols.intValue();
}
List arrays = matrix.getArrays();
List matCells = matrix.getMatCells();
Matrix.Kind kind = matrix.getKind();
switch(kind){
case DIAGONAL:
{
if(arrays.size() == 1){
Array array = arrays.get(0);
return ArrayUtil.getSize(array);
}
}
break;
case SYMMETRIC:
{
if(arrays.size() > 0){
return arrays.size();
}
}
break;
case ANY:
{
if(arrays.size() > 0){
Array array = arrays.get(arrays.size() - 1);
return ArrayUtil.getSize(array);
} // End if
if(matCells.size() > 0){
MatCell matCell = Collections.max(matCells, MatrixUtil.columnComparator);
return getColumn(matCell);
}
}
break;
default:
throw new UnsupportedFeatureException(matrix, kind);
}
throw new InvalidFeatureException(matrix);
}
static
private int getRow(MatCell matCell){
Integer row = matCell.getRow();
if(row == null){
throw new InvalidFeatureException(matCell);
}
return row.intValue();
}
static
private int getColumn(MatCell matCell){
Integer column = matCell.getCol();
if(column == null){
throw new InvalidFeatureException(matCell);
}
return column.intValue();
}
private static final Comparator rowComparator = new Comparator(){
@Override
public int compare(MatCell left, MatCell right){
return (getRow(left) - getRow(right));
}
};
private static final Comparator columnComparator = new Comparator(){
@Override
public int compare(MatCell left, MatCell right){
return (getColumn(left) - getColumn(right));
}
};
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy