All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.codegen.SpoofMultiAggregate Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.codegen;

import java.io.Serializable;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.codegen.SpoofCellwise.AggOp;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysml.runtime.functionobjects.KahanFunction;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.functionobjects.KahanPlusSq;
import org.apache.sysml.runtime.functionobjects.ValueFunction;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;
import org.apache.sysml.runtime.util.UtilFunctions;

public abstract class SpoofMultiAggregate extends SpoofOperator implements Serializable
{
	private static final long serialVersionUID = -6164871955591089349L;
	private static final long PAR_NUMCELL_THRESHOLD = 1024*1024;   //Min 1M elements
	
	private final AggOp[] _aggOps;
	
	public SpoofMultiAggregate(AggOp... aggOps) {
		_aggOps = aggOps;
	}
	
	public AggOp[] getAggOps() {
		return _aggOps;
	}
	
	@Override
	public String getSpoofType() {
		return "MA" +  getClass().getName().split("\\.")[1];
	}
	
	@Override
	public void execute(ArrayList inputs, ArrayList scalarObjects, MatrixBlock out) 
		throws DMLRuntimeException
	{
		execute(inputs, scalarObjects, out, 1);
	}
	
	@Override
	public void execute(ArrayList inputs, ArrayList scalarObjects, MatrixBlock out, int k)	
		throws DMLRuntimeException
	{
		//sanity check
		if( inputs==null || inputs.size() < 1  )
			throw new RuntimeException("Invalid input arguments.");
		
		if( inputs.get(0).getNumRows()*inputs.get(0).getNumColumns() tasks = new ArrayList();
				int nk = UtilFunctions.roundToNext(Math.min(8*k,m/32), k);
				int blklen = (int)(Math.ceil((double)m/nk));
				for( int i=0; i> taskret = pool.invokeAll(tasks);	
				pool.shutdown();
			
				//aggregate partial results
				ArrayList pret = new ArrayList();
				for( Future task : taskret )
					pret.add(task.get());
				aggregatePartialResults(c, pret);
			}
			catch(Exception ex) {
				throw new DMLRuntimeException(ex);
			}
		}
	
		//post-processing
		out.recomputeNonZeros();
		out.examSparsity();	
	}
	
	private void executeDense(double[] a, double[][] b, double[] scalars, double[] c, int m, int n, int rl, int ru) throws DMLRuntimeException 
	{
		//core dense aggregation operation
		for( int i=rl, ix=rl*n; i pret) 
		throws DMLRuntimeException 
	{
		ValueFunction[] vfun = getAggFunctions(_aggOps); 
		for( int k=0; k<_aggOps.length; k++ ) {
			if( vfun[k] instanceof KahanFunction ) {
				KahanObject kbuff = new KahanObject(0, 0);
				KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
				for(double[] tmp : pret)
					kplus.execute2(kbuff, tmp[k]);
				c[k] = kbuff._sum;
			}
			else {
				for(double[] tmp : pret)
					c[k] = vfun[k].execute(c[k], tmp[k]);
			}
		}
	}
		
	public static void aggregatePartialResults(AggOp[] aggOps, MatrixBlock c, MatrixBlock b) 
		throws DMLRuntimeException 
	{
		ValueFunction[] vfun = getAggFunctions(aggOps);
		
		for( int k=0; k< aggOps.length; k++ ) {
			if( vfun[k] instanceof KahanFunction ) {
				KahanObject kbuff = new KahanObject(c.quickGetValue(0, k), 0);
				KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
				kplus.execute2(kbuff, b.quickGetValue(0, k));
				c.quickSetValue(0, k, kbuff._sum);
			}
			else {
				double cval = c.quickGetValue(0, k);
				double bval = b.quickGetValue(0, k);
				c.quickSetValue(0, k, vfun[k].execute(cval, bval));
			}
		}
	}

	public static ValueFunction[] getAggFunctions(AggOp[] aggOps) {
		ValueFunction[] fun = new ValueFunction[aggOps.length];
		for( int i=0; i 
	{
		private final MatrixBlock _a;
		private final double[][] _b;
		private final double[] _scalars;
		private final int _rlen;
		private final int _clen;
		private final int _rl;
		private final int _ru;

		protected ParAggTask( MatrixBlock a, double[][] b, double[] scalars, 
				int rlen, int clen, int rl, int ru ) {
			_a = a;
			_b = b;
			_scalars = scalars;
			_rlen = rlen;
			_clen = clen;
			_rl = rl;
			_ru = ru;
		}
		
		@Override
		public double[] call() throws DMLRuntimeException {
			double[] c = new double[_aggOps.length];
			setInitialOutputValues(c);
			if( !_a.isInSparseFormat() )
				executeDense(_a.getDenseBlock(), _b, _scalars, c, _rlen, _clen, _rl, _ru);
			else	
				executeSparse(_a.getSparseBlock(), _b, _scalars, c, _rlen, _clen, _rl, _ru);
			return c;
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy