All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.api.MLOutput Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.api;

import java.util.Map;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties;
import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;

/**
 * This is a simple container object that returns the output of execute from MLContext 
 *
 */
public class MLOutput {
	
	Map> _outputs;
	private Map _outMetadata = null;
	
	public MLOutput(Map> outputs, Map outMetadata) {
		this._outputs = outputs;
		this._outMetadata = outMetadata;
	}
	
	public MatrixBlock getMatrixBlock(String varName) throws DMLRuntimeException {
		MatrixCharacteristics mc = getMatrixCharacteristics(varName);
		// The matrix block is always pushed to an RDD and then we do collect
		// We can later avoid this by returning symbol table rather than "Map> _outputs"
		return SparkExecutionContext.toMatrixBlock(getBinaryBlockedRDD(varName), (int) mc.getRows(), (int) mc.getCols(), 
				mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
	}
	
	@SuppressWarnings("unchecked")
	public JavaPairRDD getBinaryBlockedRDD(String varName) throws DMLRuntimeException {
		if(_outputs.containsKey(varName)) {
			return (JavaPairRDD) _outputs.get(varName);
		}
		throw new DMLRuntimeException("Variable " + varName + " not found in the outputs.");
	}
	
	@SuppressWarnings("unchecked")
	public JavaPairRDD getFrameBinaryBlockedRDD(String varName) throws DMLRuntimeException {
		if(_outputs.containsKey(varName)) {
			return (JavaPairRDD)_outputs.get(varName);
		}
		throw new DMLRuntimeException("Variable " + varName + " not found in the outputs.");
	}
	
	public MatrixCharacteristics getMatrixCharacteristics(String varName) throws DMLRuntimeException {
		if(_outputs.containsKey(varName)) {
			return _outMetadata.get(varName);
		}
		throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
	}
	
	/**
	 * Note, the output DataFrame has an additional column ID.
	 * An easy way to get DataFrame without ID is by df.drop("__INDEX")
	 * @param sqlContext the SQLContext
	 * @param varName the variable name
	 * @return the DataFrame
	 * @throws DMLRuntimeException if DMLRuntimeException occurs
	 */
	public DataFrame getDF(SQLContext sqlContext, String varName) throws DMLRuntimeException {
		if(sqlContext == null) {
			throw new DMLRuntimeException("SQLContext is not created.");
		}
		JavaPairRDD rdd = getBinaryBlockedRDD(varName);
		if(rdd != null) {
			MatrixCharacteristics mc = _outMetadata.get(varName);
			return RDDConverterUtils.binaryBlockToDataFrame(sqlContext, rdd, mc, false);
		}
		throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
	}
	
	/**
	 * Obtain the DataFrame
	 * @param sqlContext the SQLContext
	 * @param varName the variable name
	 * @param outputVector if true, returns DataFrame with two column: ID and org.apache.spark.mllib.linalg.Vector
	 * @return the DataFrame
	 * @throws DMLRuntimeException if DMLRuntimeException occurs
	 */
	public DataFrame getDF(SQLContext sqlContext, String varName, boolean outputVector) throws DMLRuntimeException {
		if(sqlContext == null) {
			throw new DMLRuntimeException("SQLContext is not created.");
		}
		if(outputVector) {
			JavaPairRDD rdd = getBinaryBlockedRDD(varName);
			if(rdd != null) {
				MatrixCharacteristics mc = _outMetadata.get(varName);
				return RDDConverterUtils.binaryBlockToDataFrame(sqlContext, rdd, mc, true);
			}
			throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
		}
		else {
			return getDF(sqlContext, varName);
		}
		
	}
	
	/**
	 * This methods improves the performance of MLPipeline wrappers.
	 * @param sqlContext the SQLContext
	 * @param varName the variable name
	 * @param mc the matrix characteristics
	 * @return the DataFrame
	 * @throws DMLRuntimeException if DMLRuntimeException occurs
	 */
	public DataFrame getDF(SQLContext sqlContext, String varName, MatrixCharacteristics mc) 
		throws DMLRuntimeException 
	{
		if(sqlContext == null)
			throw new DMLRuntimeException("SQLContext is not created.");
			
		JavaPairRDD binaryBlockRDD = getBinaryBlockedRDD(varName);
		return RDDConverterUtils.binaryBlockToDataFrame(sqlContext, binaryBlockRDD, mc, true);
	}
	
	public JavaRDD getStringRDD(String varName, String format) throws DMLRuntimeException {
		if(format.equals("text")) {
			JavaPairRDD binaryRDD = getBinaryBlockedRDD(varName);
			MatrixCharacteristics mcIn = getMatrixCharacteristics(varName); 
			return RDDConverterUtils.binaryBlockToTextCell(binaryRDD, mcIn);
		}
		else {
			throw new DMLRuntimeException("The output format:" + format + " is not implemented yet.");
		}
	}
	
	public JavaRDD getStringFrameRDD(String varName, String format, CSVFileFormatProperties fprop ) throws DMLRuntimeException {
		JavaPairRDD binaryRDD = getFrameBinaryBlockedRDD(varName);
		MatrixCharacteristics mcIn = getMatrixCharacteristics(varName); 
		if(format.equals("csv")) {
			return FrameRDDConverterUtils.binaryBlockToCsv(binaryRDD, mcIn, fprop, false);
		}
		else if(format.equals("text")) {
			return FrameRDDConverterUtils.binaryBlockToTextCell(binaryRDD, mcIn);
		}
		else {
			throw new DMLRuntimeException("The output format:" + format + " is not implemented yet.");
		}
		
	}
	
	public DataFrame getDataFrameRDD(String varName, JavaSparkContext jsc) throws DMLRuntimeException {
		JavaPairRDD binaryRDD = getFrameBinaryBlockedRDD(varName);
		MatrixCharacteristics mcIn = getMatrixCharacteristics(varName);
		return FrameRDDConverterUtils.binaryBlockToDataFrame(new SQLContext(jsc), binaryRDD, mcIn, null);
	}
	
	public MLMatrix getMLMatrix(MLContext ml, SQLContext sqlContext, String varName) throws DMLRuntimeException {
		if(sqlContext == null) {
			throw new DMLRuntimeException("SQLContext is not created.");
		}
		else if(ml == null) {
			throw new DMLRuntimeException("MLContext is not created.");
		}
		JavaPairRDD rdd = getBinaryBlockedRDD(varName);
		if(rdd != null) {
			MatrixCharacteristics mc = getMatrixCharacteristics(varName);
			StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
			return new MLMatrix(sqlContext.createDataFrame(rdd.map(new GetMLBlock()).rdd(), schema), mc, ml);
		}
		throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy