org.apache.sysml.api.MLOutput Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.api;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.util.UtilFunctions;
/**
* This is a simple container object that returns the output of execute from MLContext
*
*/
public class MLOutput {
HashMap> _outputs;
private HashMap _outMetadata = null;
public MLOutput(HashMap> outputs, HashMap outMetadata) {
this._outputs = outputs;
this._outMetadata = outMetadata;
}
public JavaPairRDD getBinaryBlockedRDD(String varName) throws DMLRuntimeException {
if(_outputs.containsKey(varName)) {
return _outputs.get(varName);
}
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
}
public MatrixCharacteristics getMatrixCharacteristics(String varName) throws DMLRuntimeException {
if(_outputs.containsKey(varName)) {
return _outMetadata.get(varName);
}
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
}
/**
* Note, the output DataFrame has an additional column ID.
* An easy way to get DataFrame without ID is by df.sort("ID").drop("ID")
* @param sqlContext
* @param varName
* @return
* @throws DMLRuntimeException
*/
public DataFrame getDF(SQLContext sqlContext, String varName) throws DMLRuntimeException {
if(sqlContext == null) {
throw new DMLRuntimeException("SQLContext is not created.");
}
JavaPairRDD rdd = getBinaryBlockedRDD(varName);
if(rdd != null) {
MatrixCharacteristics mc = _outMetadata.get(varName);
return RDDConverterUtilsExt.binaryBlockToDataFrame(rdd, mc, sqlContext);
}
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
}
/**
*
* @param sqlContext
* @param varName
* @param outputVector if true, returns DataFrame with two column: ID and org.apache.spark.mllib.linalg.Vector
* @return
* @throws DMLRuntimeException
*/
public DataFrame getDF(SQLContext sqlContext, String varName, boolean outputVector) throws DMLRuntimeException {
if(sqlContext == null) {
throw new DMLRuntimeException("SQLContext is not created.");
}
if(outputVector) {
JavaPairRDD rdd = getBinaryBlockedRDD(varName);
if(rdd != null) {
MatrixCharacteristics mc = _outMetadata.get(varName);
return RDDConverterUtilsExt.binaryBlockToVectorDataFrame(rdd, mc, sqlContext);
}
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
}
else {
return getDF(sqlContext, varName);
}
}
/**
* This methods improves the performance of MLPipeline wrappers.
* @param sqlContext
* @param varName
* @param range range is inclusive
* @return
* @throws DMLRuntimeException
*/
public DataFrame getDF(SQLContext sqlContext, String varName, HashMap> range) throws DMLRuntimeException {
if(sqlContext == null) {
throw new DMLRuntimeException("SQLContext is not created.");
}
JavaPairRDD binaryBlockRDD = getBinaryBlockedRDD(varName);
if(binaryBlockRDD == null) {
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
}
MatrixCharacteristics mc = _outMetadata.get(varName);
long rlen = mc.getRows(); long clen = mc.getCols();
int brlen = mc.getRowsPerBlock(); int bclen = mc.getColsPerBlock();
ArrayList>> alRange = new ArrayList>>();
for(Entry> e : range.entrySet()) {
alRange.add(new Tuple2>(e.getKey(), e.getValue()));
}
// Very expensive operation here: groupByKey (where number of keys might be too large)
JavaRDD rowsRDD = binaryBlockRDD.flatMapToPair(new ProjectRows(rlen, clen, brlen, bclen))
.groupByKey().map(new ConvertDoubleArrayToRangeRows(clen, bclen, alRange));
int numColumns = (int) clen;
if(numColumns <= 0) {
throw new DMLRuntimeException("Output dimensions unknown after executing the script and hence cannot create the dataframe");
}
List fields = new ArrayList();
// LongTypes throw an error: java.lang.Double incompatible with java.lang.Long
fields.add(DataTypes.createStructField("ID", DataTypes.DoubleType, false));
for(int k = 0; k < alRange.size(); k++) {
String colName = alRange.get(k)._1;
long low = alRange.get(k)._2._1;
long high = alRange.get(k)._2._2;
if(low != high)
fields.add(DataTypes.createStructField(colName, new VectorUDT(), false));
else
fields.add(DataTypes.createStructField(colName, DataTypes.DoubleType, false));
}
// This will cause infinite recursion due to bug in Spark
// https://issues.apache.org/jira/browse/SPARK-6999
// return sqlContext.createDataFrame(rowsRDD, colNames); // where ArrayList colNames
return sqlContext.createDataFrame(rowsRDD.rdd(), DataTypes.createStructType(fields));
}
public JavaRDD getStringRDD(String varName, String format) throws DMLRuntimeException {
if(format.compareTo("text") == 0) {
JavaPairRDD binaryRDD = getBinaryBlockedRDD(varName);
MatrixCharacteristics mcIn = getMatrixCharacteristics(varName);
return RDDConverterUtilsExt.binaryBlockToStringRDD(binaryRDD, mcIn, format);
}
// else if(format.compareTo("csv") == 0) {
//
// }
else {
throw new DMLRuntimeException("The output format:" + format + " is not implemented yet.");
}
}
public MLMatrix getMLMatrix(MLContext ml, SQLContext sqlContext, String varName) throws DMLRuntimeException {
if(sqlContext == null) {
throw new DMLRuntimeException("SQLContext is not created.");
}
else if(ml == null) {
throw new DMLRuntimeException("MLContext is not created.");
}
JavaPairRDD rdd = getBinaryBlockedRDD(varName);
if(rdd != null) {
MatrixCharacteristics mc = getMatrixCharacteristics(varName);
StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
return new MLMatrix(sqlContext.createDataFrame(rdd.map(new GetMLBlock()).rdd(), schema), mc, ml);
}
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
}
// /**
// * Experimental: Please use this with caution as it will fail in many corner cases.
// * @return org.apache.spark.mllib.linalg.distributed.BlockMatrix
// * @throws DMLRuntimeException
// */
// public BlockMatrix getMLLibBlockedMatrix(MLContext ml, SQLContext sqlContext, String varName) throws DMLRuntimeException {
// return getMLMatrix(ml, sqlContext, varName).toBlockedMatrix();
// }
public static class ProjectRows implements PairFlatMapFunction, Long, Tuple2> {
private static final long serialVersionUID = -4792573268900472749L;
long rlen; long clen;
int brlen; int bclen;
public ProjectRows(long rlen, long clen, int brlen, int bclen) {
this.rlen = rlen;
this.clen = clen;
this.brlen = brlen;
this.bclen = bclen;
}
@Override
public Iterable>> call(Tuple2 kv) throws Exception {
// ------------------------------------------------------------------
// Compute local block size:
// Example: For matrix: 1500 X 1100 with block length 1000 X 1000
// We will have four local block sizes (1000X1000, 1000X100, 500X1000 and 500X1000)
long blockRowIndex = kv._1.getRowIndex();
long blockColIndex = kv._1.getColumnIndex();
int lrlen = UtilFunctions.computeBlockSize(rlen, blockRowIndex, brlen);
int lclen = UtilFunctions.computeBlockSize(clen, blockColIndex, bclen);
// ------------------------------------------------------------------
long startRowIndex = (kv._1.getRowIndex()-1) * bclen;
MatrixBlock blk = kv._2;
ArrayList>> retVal = new ArrayList>>();
for(int i = 0; i < lrlen; i++) {
Double[] partialRow = new Double[lclen];
for(int j = 0; j < lclen; j++) {
partialRow[j] = blk.getValue(i, j);
}
retVal.add(new Tuple2>(startRowIndex + i, new Tuple2(kv._1.getColumnIndex(), partialRow)));
}
return (Iterable>>) retVal;
}
}
public static class ConvertDoubleArrayToRows implements Function>>, Row> {
private static final long serialVersionUID = 4441184411670316972L;
int bclen; long clen;
boolean outputVector;
public ConvertDoubleArrayToRows(long clen, int bclen, boolean outputVector) {
this.bclen = bclen;
this.clen = clen;
this.outputVector = outputVector;
}
@Override
public Row call(Tuple2>> arg0)
throws Exception {
HashMap partialRows = new HashMap();
int sizeOfPartialRows = 0;
for(Tuple2 kv : arg0._2) {
partialRows.put(kv._1, kv._2);
sizeOfPartialRows += kv._2.length;
}
// Insert first row as row index
Object[] row = null;
if(outputVector) {
row = new Object[2];
double [] vecVals = new double[sizeOfPartialRows];
for(long columnBlockIndex = 1; columnBlockIndex <= partialRows.size(); columnBlockIndex++) {
if(partialRows.containsKey(columnBlockIndex)) {
Double [] array = partialRows.get(columnBlockIndex);
// ------------------------------------------------------------------
// Compute local block size:
int lclen = UtilFunctions.computeBlockSize(clen, columnBlockIndex, bclen);
// ------------------------------------------------------------------
if(array.length != lclen) {
throw new Exception("Incorrect double array provided by ProjectRows");
}
for(int i = 0; i < lclen; i++) {
vecVals[(int) ((columnBlockIndex-1)*bclen + i)] = array[i];
}
}
else {
throw new Exception("The block for column index " + columnBlockIndex + " is missing. Make sure the last instruction is not returning empty blocks");
}
}
long rowIndex = arg0._1;
row[0] = new Double(rowIndex);
row[1] = new DenseVector(vecVals); // breeze.util.JavaArrayOps.arrayDToDv(vecVals);
}
else {
row = new Double[sizeOfPartialRows + 1];
long rowIndex = arg0._1;
row[0] = new Double(rowIndex);
for(long columnBlockIndex = 1; columnBlockIndex <= partialRows.size(); columnBlockIndex++) {
if(partialRows.containsKey(columnBlockIndex)) {
Double [] array = partialRows.get(columnBlockIndex);
// ------------------------------------------------------------------
// Compute local block size:
int lclen = UtilFunctions.computeBlockSize(clen, columnBlockIndex, bclen);
// ------------------------------------------------------------------
if(array.length != lclen) {
throw new Exception("Incorrect double array provided by ProjectRows");
}
for(int i = 0; i < lclen; i++) {
row[(int) ((columnBlockIndex-1)*bclen + i) + 1] = array[i];
}
}
else {
throw new Exception("The block for column index " + columnBlockIndex + " is missing. Make sure the last instruction is not returning empty blocks");
}
}
}
Object[] row_fields = row;
return RowFactory.create(row_fields);
}
}
public static class ConvertDoubleArrayToRangeRows implements Function>>, Row> {
private static final long serialVersionUID = 4441184411670316972L;
int bclen; long clen;
ArrayList>> range;
public ConvertDoubleArrayToRangeRows(long clen, int bclen, ArrayList>> range) {
this.bclen = bclen;
this.clen = clen;
this.range = range;
}
@Override
public Row call(Tuple2>> arg0)
throws Exception {
HashMap partialRows = new HashMap();
int sizeOfPartialRows = 0;
for(Tuple2 kv : arg0._2) {
partialRows.put(kv._1, kv._2);
sizeOfPartialRows += kv._2.length;
}
// Insert first row as row index
Object[] row = null;
row = new Object[range.size() + 1];
double [] vecVals = new double[sizeOfPartialRows];
for(long columnBlockIndex = 1; columnBlockIndex <= partialRows.size(); columnBlockIndex++) {
if(partialRows.containsKey(columnBlockIndex)) {
Double [] array = partialRows.get(columnBlockIndex);
// ------------------------------------------------------------------
// Compute local block size:
int lclen = UtilFunctions.computeBlockSize(clen, columnBlockIndex, bclen);
// ------------------------------------------------------------------
if(array.length != lclen) {
throw new Exception("Incorrect double array provided by ProjectRows");
}
for(int i = 0; i < lclen; i++) {
vecVals[(int) ((columnBlockIndex-1)*bclen + i)] = array[i];
}
}
else {
throw new Exception("The block for column index " + columnBlockIndex + " is missing. Make sure the last instruction is not returning empty blocks");
}
}
long rowIndex = arg0._1;
row[0] = new Double(rowIndex);
int i = 1;
//for(Entry> e : range.entrySet()) {
for(int k = 0; k < range.size(); k++) {
long low = range.get(k)._2._1;
long high = range.get(k)._2._2;
if(high < low) {
throw new Exception("Incorrect range:" + high + "<" + low);
}
if(low == high) {
row[i] = new Double(vecVals[(int) (low-1)]);
}
else {
int lengthOfVector = (int) (high - low + 1);
double [] tempVector = new double[lengthOfVector];
for(int j = 0; j < lengthOfVector; j++) {
tempVector[j] = vecVals[(int) (low + j - 1)];
}
row[i] = new DenseVector(tempVector);
}
i++;
}
Object[] row_fields = row;
return RowFactory.create(row_fields);
}
}
}