All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.MatrixUtil Maven / Gradle / Ivy

/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.util.Collections;
import java.util.Comparator;
import java.util.List;

import org.dmg.pmml.Array;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MatCell;
import org.dmg.pmml.Matrix;
import org.jpmml.model.InvalidElementException;
import org.jpmml.model.UnsupportedAttributeException;

public class MatrixUtil {

	private MatrixUtil(){
	}

	/**
	 * @param row The row index. The index of the first row is 1.
	 * @param column The column index. The index of the first column is 1.
	 *
	 * @return The element at the specified location, or null.
	 *
	 * @throws IndexOutOfBoundsException If either the row or column index is out of range.
	 */
	static
	public Number getElementAt(Matrix matrix, int row, int column){
		List arrays = matrix.getArrays();
		List matCells = matrix.getMatCells();

		Matrix.Kind kind = matrix.getKind();
		switch(kind){
			case DIAGONAL:
				{
					// "The content is just one Array of numbers representing the diagonal values"
					if(arrays.size() == 1){
						Array array = arrays.get(0);

						List elements = ArrayUtil.asNumberList(array);

						// Diagonal element
						if(row == column){
							return elements.get(row - 1);
						} else

						// Off-diagonal element
						{
							int min = 1;
							int max = elements.size();

							if((row < min || row > max) || (column < min || column > max)){
								throw new IndexOutOfBoundsException();
							}

							return matrix.getOffDiagDefault();
						}
					}
				}
				break;
			case SYMMETRIC:
				{
					// "The content must be represented by Arrays"
					if(!arrays.isEmpty()){

						// Make sure the specified coordinates target the lower left triangle
						if(column > row){
							int temp = row;

							row = column;
							column = temp;
						}

						return getArrayValue(arrays, row, column);
					}
				}
				break;
			case ANY:
				{
					if(!arrays.isEmpty()){
						return getArrayValue(arrays, row, column);
					} // End if

					if(!matCells.isEmpty()){

						if(row < 1 || column < 1){
							throw new IndexOutOfBoundsException();
						}

						Number value = getMatCellValue(matCells, row, column);
						if(value == null){

							if(row == column){
								return matrix.getDiagDefault();
							}

							return matrix.getOffDiagDefault();
						}

						return value;
					}
				}
				break;
			default:
				throw new UnsupportedAttributeException(matrix, kind);
		}

		throw new InvalidElementException(matrix);
	}

	static
	private Number getArrayValue(List arrays, int row, int column){
		Array array = arrays.get(row - 1);

		List elements = ArrayUtil.asNumberList(array);

		return elements.get(column - 1);
	}

	static
	private Number getMatCellValue(List matCells, int row, int column){

		for(int i = 0, max = matCells.size(); i < max; i++){
			MatCell matCell = matCells.get(i);

			if((matCell.getRow() == row) && (matCell.getCol() == column)){
				return (Number)TypeUtil.parseOrCast(DataType.DOUBLE, matCell.getValue());
			}
		}

		return null;
	}

	/**
	 * @return The number of rows.
	 */
	static
	public int getRows(Matrix matrix){
		Integer nbRows = matrix.getNbRows();
		if(nbRows != null){
			return nbRows;
		}

		List arrays = matrix.getArrays();
		List matCells = matrix.getMatCells();

		Matrix.Kind kind = matrix.getKind();
		switch(kind){
			case DIAGONAL:
				{
					if(arrays.size() == 1){
						Array array = arrays.get(0);

						return ArrayUtil.getSize(array);
					}
				}
				break;
			case SYMMETRIC:
				{
					if(!arrays.isEmpty()){
						return arrays.size();
					}
				}
				break;
			case ANY:
				{
					if(!arrays.isEmpty()){
						return arrays.size();
					} // End if

					if(!matCells.isEmpty()){
						MatCell matCell = Collections.max(matCells, MatrixUtil.rowComparator);

						return matCell.getRow();
					}
				}
				break;
			default:
				throw new UnsupportedAttributeException(matrix, kind);
		}

		throw new InvalidElementException(matrix);
	}

	/**
	 * @return The number of columns.
	 */
	static
	public int getColumns(Matrix matrix){
		Integer nbCols = matrix.getNbCols();
		if(nbCols != null){
			return nbCols;
		}

		List arrays = matrix.getArrays();
		List matCells = matrix.getMatCells();

		Matrix.Kind kind = matrix.getKind();
		switch(kind){
			case DIAGONAL:
				{
					if(arrays.size() == 1){
						Array array = arrays.get(0);

						return ArrayUtil.getSize(array);
					}
				}
				break;
			case SYMMETRIC:
				{
					if(!arrays.isEmpty()){
						return arrays.size();
					}
				}
				break;
			case ANY:
				{
					if(!arrays.isEmpty()){
						Array array = arrays.get(arrays.size() - 1);

						return ArrayUtil.getSize(array);
					} // End if

					if(!matCells.isEmpty()){
						MatCell matCell = Collections.max(matCells, MatrixUtil.columnComparator);

						return matCell.getCol();
					}
				}
				break;
			default:
				throw new UnsupportedAttributeException(matrix, kind);
		}

		throw new InvalidElementException(matrix);
	}

	private static final Comparator rowComparator = new Comparator(){

		@Override
		public int compare(MatCell left, MatCell right){
			return (left.getRow() - right.getRow());
		}
	};

	private static final Comparator columnComparator = new Comparator(){

		@Override
		public int compare(MatCell left, MatCell right){
			return (left.getCol() - right.getCol());
		}
	};
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy