All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.udf.lib.SGDNesterovUpdate Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.udf.lib;

import java.io.IOException;
import java.util.Iterator;
import java.util.Random;

import org.apache.sysml.runtime.controlprogram.caching.CacheException;
import org.apache.sysml.runtime.matrix.data.IJV;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.udf.FunctionParameter;
import org.apache.sysml.udf.Matrix;
import org.apache.sysml.udf.PackageFunction;
import org.apache.sysml.udf.Scalar;
import org.apache.sysml.udf.Matrix.ValueType;

/**
 * Use this class to perform an SGD update with Nesterov momentum in CP.
 * Assumption: the input batch fits in CP (which is also the assumption of most deep learning systems).
 * 
 * Usage:
 * update_nesterov = externalFunction(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v) return (matrix[double] X, matrix[double] v) implemented in (classname="org.apache.sysml.udf.lib.SGDNesterovUpdate",exectype="mem");
 * [X, v] = update_nesterov(X, dX, lr, mu, v);
 * 
 * 
 * This class eliminates the unnecessary instruction overhead as well as memory pressure. 
 * 
 */
public class SGDNesterovUpdate extends PackageFunction {
	private static final long serialVersionUID = -3905212831582648882L;

	private Matrix updatedX;
	private Matrix updatedV;
	private Random rand = new Random();
	
	@Override
	public int getNumFunctionOutputs() {
		return 2;
	}

	@Override
	public FunctionParameter getFunctionOutput(int pos) {
		if(pos == 0)
			return updatedX;
		else if(pos == 1)
			return updatedV;
		
		throw new RuntimeException("Invalid function output being requested");
	}

	@Override
	public void execute() {
		try {
			MatrixBlock X = ((Matrix) getFunctionInput(0)).getMatrixObject().acquireRead();
			MatrixBlock dX = ((Matrix) getFunctionInput(1)).getMatrixObject().acquireRead();
			double lr = Double.parseDouble(((Scalar)getFunctionInput(2)).getValue());
			double mu = Double.parseDouble(((Scalar)getFunctionInput(3)).getValue());
			MatrixBlock v = ((Matrix) getFunctionInput(4)).getMatrixObject().acquireRead();
			
			// v = mu * v - lr * dX
			updatedV = new Matrix( "tmp_" + rand.nextLong(), v.getNumRows(), v.getNumColumns(), ValueType.Double );
			MatrixBlock updatedVMB = allocateDenseMatrixBlock(updatedV);
			double [] updatedVData = updatedVMB.getDenseBlock();
			multiplyByConstant(v, mu, updatedVData);
			multiplyByConstant(dX, -lr, updatedVData);
			updatedVMB.setNonZeros(-1); // rather than updatedVMB.recomputeNonZeros();
			updatedV.setMatrixDoubleArray(updatedVMB, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
			
			// X = X - mu * v_prev + (1 + mu) * v
			updatedX = new Matrix( "tmp_" + rand.nextLong(), X.getNumRows(), X.getNumColumns(), ValueType.Double );
			MatrixBlock updatedXMB = allocateDenseMatrixBlock(updatedX);
			double [] updatedXData = updatedXMB.getDenseBlock();
			copy(X, updatedXData);
			multiplyByConstant(v, -mu, updatedXData);
			multiplyByConstant(updatedVData, 1+mu, updatedXData);
			updatedXMB.setNonZeros(-1); // rather than updatedXMB.recomputeNonZeros();
			updatedX.setMatrixDoubleArray(updatedXMB, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
			
			((Matrix) getFunctionInput(0)).getMatrixObject().release();
			((Matrix) getFunctionInput(1)).getMatrixObject().release();
			((Matrix) getFunctionInput(4)).getMatrixObject().release();
		} catch (CacheException e) {
			throw new RuntimeException("Exception while executing SGDNesterovUpdate", e);
		} catch (IOException e) {
			throw new RuntimeException("Exception while executing SGDNesterovUpdate", e);
		}
	}
	
	private MatrixBlock allocateDenseMatrixBlock(Matrix mat) {
		int rows = (int) mat.getNumRows();
		int cols = (int) mat.getNumCols();
		MatrixBlock mb = new MatrixBlock(rows, cols, false);
		mb.allocateDenseBlock();
		return mb;
	}
	
	
	// out += constant*in
	private void multiplyByConstant(double [] in, double constant, double [] out) {
		for(int i = 0; i < out.length; i++) {
			out[i] += in[i]*constant;
		}
	}
	
	// out += constant*in
	private void multiplyByConstant(MatrixBlock in, double constant, double [] out) {
		if(in.isInSparseFormat()) {
			Iterator iter = in.getSparseBlockIterator();
			while(iter.hasNext()) {
				IJV ijv = iter.next();
				out[ijv.getI()*ijv.getJ()] += ijv.getV() * constant;
			}
		}
		else {
			double [] denseBlock = in.getDenseBlock();
			if(denseBlock != null) {
				// If not empty block
				for(int i = 0; i < out.length; i++) {
					out[i] += denseBlock[i]*constant;
				}
			}
		}
	}
	
	// Assumption dest is zero-ed out.
	private void copy(MatrixBlock src, double [] dest) {
		if(src.isInSparseFormat()) {
			Iterator iter = src.getSparseBlockIterator();
			while(iter.hasNext()) {
				IJV ijv = iter.next();
				dest[ijv.getI()*ijv.getJ()] = ijv.getV();
			}
		}
		else {
			double [] denseBlock = src.getDenseBlock();
			if(denseBlock != null) {
				// If not empty block
				System.arraycopy(denseBlock, 0, dest, 0, dest.length);
			}
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy