All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.functionobjects.CM Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.functionobjects;

import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;


/**
 * GENERAL NOTE:
 * * 05/28/2014: We decided to do handle weights consistently to SPSS in an operation-specific manner, 
 *   i.e., we (1) round instead of casting where required (e.g. count), and (2) consistently use
 *   fractional weight values elsewhere. In case a count-base interpretation of weights is needed, just 
 *   ensure rounding before calling CM/COV/KahanPlus.
 * 
 */
public class CM extends ValueFunction 
{

	private static final long serialVersionUID = 9177194651533064123L;

	private AggregateOperationTypes _type = null;
	
	//helper function objects for specific types
	private KahanPlus _plus = null;
	private KahanObject _buff2 = null;
	private KahanObject _buff3 = null;
	
	
	private CM( AggregateOperationTypes type ) 
	{
		_type = type;
		
		switch( _type ) //helper obj on demand
		{
			case COUNT:
				break;
			case CM4:
			case CM3:
				_buff3 = new KahanObject(0, 0);
			case CM2:
				_buff2 = new KahanObject(0, 0);
			case VARIANCE:
			case MEAN:
				_plus = KahanPlus.getKahanPlusFnObject();
				break;
			default:
				//do nothing
		}
	}
	
	public static CM getCMFnObject( AggregateOperationTypes type ) {
		//return new obj, required for correctness in multi-threaded
		//execution due to state in cm object (buff2, buff3)	
		return new CM( type ); 
	}
	
	public Object clone() throws CloneNotSupportedException {
		// cloning is not supported for singleton classes
		throw new CloneNotSupportedException();
	}

	public AggregateOperationTypes getAggOpType() {
		return _type;
	}

	/**
	 * Special case for weights w2==1
	 */
	@Override
	public Data execute(Data in1, double in2) 
		throws DMLRuntimeException 
	{
		CM_COV_Object cm1=(CM_COV_Object) in1;
		
		if(cm1.isCMAllZeros())
		{
			cm1.w=1;
			cm1.mean.set(in2, 0);
			cm1.m2.set(0,0);
			cm1.m3.set(0,0);
			cm1.m4.set(0,0);
			return cm1;
		}
		
		switch( _type )
		{
			case COUNT:
			{
				cm1.w = cm1.w + 1;
				break;
			}
			case MEAN:
			{
				double w= cm1.w + 1;
				double d=in2-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
				cm1.w=w;			
				break;
			}
			case CM2:
			{
				double w= cm1.w + 1;
				double d=in2-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
				double t1=cm1.w/w*d;
				double lt1=t1*d;
				_buff2.set(cm1.m2);
				_buff2=(KahanObject) _plus.execute(_buff2, lt1);
				cm1.m2.set(_buff2);
				cm1.w=w;				
				break;
			}
			case CM3:
			{
				double w = cm1.w + 1;
				double d=in2-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
				double t1=cm1.w/w*d;
				double t2=-1/cm1.w;
				double lt1=t1*d;
				double lt2=Math.pow(t1, 3)*(1.0-Math.pow(t2, 2));
				double f2=1.0/w;
				_buff2.set(cm1.m2);
				_buff2=(KahanObject) _plus.execute(_buff2, lt1);
				_buff3.set(cm1.m3);
				_buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d);
				cm1.m2.set(_buff2);
				cm1.m3.set(_buff3);
				cm1.w=w;
				break;
			}
			case CM4:
			{
				double w=cm1.w+1;
				double d=in2-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
				double t1=cm1.w/w*d;
				double t2=-1/cm1.w;
				double lt1=t1*d;
				double lt2=Math.pow(t1, 3)*(1.0-Math.pow(t2, 2));
				double lt3=Math.pow(t1, 4)*(1.0-Math.pow(t2, 3));
				double f2=1.0/w;
				_buff2.set(cm1.m2);
				_buff2=(KahanObject) _plus.execute(_buff2, lt1);
				_buff3.set(cm1.m3);
				_buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d);
				cm1.m4=(KahanObject) _plus.execute(cm1.m4, 6*cm1.m2._sum*Math.pow(-f2*d, 2) + lt3-4*cm1.m3._sum*f2*d);
				cm1.m2.set(_buff2);
				cm1.m3.set(_buff3);
				cm1.w=w;
				break;
			}
			case VARIANCE:
			{
				double w=cm1.w+1;
				double d=in2-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
				double t1=cm1.w/w*d;
				double lt1=t1*d;
				cm1.m2=(KahanObject) _plus.execute(cm1.m2, lt1);
				cm1.w=w;
				break;
			}
			
			default:
				throw new DMLRuntimeException("Unsupported operation type: "+_type);
		}
		
		return cm1;
	}
	
	/**
	 * General case for arbitrary weights w2
	 */
	@Override
	public Data execute(Data in1, double in2, double w2) 
		throws DMLRuntimeException 
	{
		CM_COV_Object cm1=(CM_COV_Object) in1;
		
		if(cm1.isCMAllZeros())
		{
			cm1.w=w2;
			cm1.mean.set(in2, 0);
			cm1.m2.set(0,0);
			cm1.m3.set(0,0);
			cm1.m4.set(0,0);
			return cm1;
		}
		
		switch( _type )
		{
			case COUNT:
			{
				cm1.w = Math.round(cm1.w + w2);
				break;
			}
			case MEAN:
			{
				double w = cm1.w + w2;
				double d=in2-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
				cm1.w=w;			
				break;
			}
			case CM2:
			{
				double w = cm1.w + w2;
				double d=in2-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
				double t1=cm1.w*w2/w*d;
				double lt1=t1*d;
				_buff2.set(cm1.m2);
				_buff2=(KahanObject) _plus.execute(_buff2, lt1);
				cm1.m2.set(_buff2);
				cm1.w=w;				
				break;
			}
			case CM3:
			{
				double w = cm1.w + w2;
				double d=in2-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
				double t1=cm1.w*w2/w*d;
				double t2=-1/cm1.w;
				double lt1=t1*d;
				double lt2=Math.pow(t1, 3)*(1/Math.pow(w2, 2)-Math.pow(t2, 2));
				double f2=w2/w;
				_buff2.set(cm1.m2);
				_buff2=(KahanObject) _plus.execute(_buff2, lt1);
				_buff3.set(cm1.m3);
				_buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d);
				cm1.m2.set(_buff2);
				cm1.m3.set(_buff3);
				cm1.w=w;
				break;
			}
			case CM4:
			{
				double w = cm1.w + w2;
				double d=in2-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
				double t1=cm1.w*w2/w*d;
				double t2=-1/cm1.w;
				double lt1=t1*d;
				double lt2=Math.pow(t1, 3)*(1/Math.pow(w2, 2)-Math.pow(t2, 2));
				double lt3=Math.pow(t1, 4)*(1/Math.pow(w2, 3)-Math.pow(t2, 3));
				double f2=w2/w;
				_buff2.set(cm1.m2);
				_buff2=(KahanObject) _plus.execute(_buff2, lt1);
				_buff3.set(cm1.m3);
				_buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d);
				cm1.m4=(KahanObject) _plus.execute(cm1.m4, 6*cm1.m2._sum*Math.pow(-f2*d, 2) + lt3-4*cm1.m3._sum*f2*d);
				cm1.m2.set(_buff2);
				cm1.m3.set(_buff3);
				cm1.w=w;
				break;
			}
			case VARIANCE:
			{
				double w = cm1.w + w2;
				double d=in2-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
				double t1=cm1.w*w2/w*d;
				double lt1=t1*d;
				cm1.m2=(KahanObject) _plus.execute(cm1.m2, lt1);
				cm1.w=w;
				break;
			}
			
			default:
				throw new DMLRuntimeException("Unsupported operation type: "+_type);
		}
		
		return cm1;
	}

	/**
	 * Combining stats from two partitions of the data.
	 */
	@Override
	public Data execute(Data in1, Data in2) throws DMLRuntimeException 
	{
		CM_COV_Object cm1=(CM_COV_Object) in1;
		CM_COV_Object cm2=(CM_COV_Object) in2;
		
		if(cm1.isCMAllZeros())
		{
			cm1.w=cm2.w;
			cm1.mean.set(cm2.mean);
			cm1.m2.set(cm2.m2);
			cm1.m3.set(cm2.m3);
			cm1.m4.set(cm2.m4);
			return cm1;
		}
		if(cm2.isCMAllZeros())
			return cm1;
		
		switch( _type )
		{
			case COUNT:
			{
				cm1.w = Math.round(cm1.w + cm2.w);				
				break;
			}
			case MEAN:
			{
				double w = cm1.w + cm2.w;
				double d=cm2.mean._sum-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
				cm1.w=w;
				break;
			}
			case CM2:
			{
				double w = cm1.w + cm2.w;
				double d=cm2.mean._sum-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
				double t1=cm1.w*cm2.w/w*d;
				double lt1=t1*d;
				_buff2.set(cm1.m2);
				_buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction);
				_buff2=(KahanObject) _plus.execute(_buff2, lt1);
				cm1.m2.set(_buff2);
				cm1.w=w;
				break;
			}
			case CM3:
			{
				double w = cm1.w + cm2.w;
				double d=cm2.mean._sum-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
				double t1=cm1.w*cm2.w/w*d;
				double t2=-1/cm1.w;
				double lt1=t1*d;
				double lt2=Math.pow(t1, 3)*(1/Math.pow(cm2.w, 2)-Math.pow(t2, 2));
				double f1=cm1.w/w;
				double f2=cm2.w/w;
				_buff2.set(cm1.m2);
				_buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction);
				_buff2=(KahanObject) _plus.execute(_buff2, lt1);
				_buff3.set(cm1.m3);
				_buff3=(KahanObject) _plus.execute(_buff3, cm2.m3._sum, cm2.m3._correction);
				_buff3=(KahanObject) _plus.execute(_buff3, 3*(-f2*cm1.m2._sum+f1*cm2.m2._sum)*d + lt2);
				cm1.m2.set(_buff2);
				cm1.m3.set(_buff3);
				cm1.w=w;
				break;
			}
			case CM4:
			{
				double w = cm1.w + cm2.w;
				double d=cm2.mean._sum-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
				double t1=cm1.w*cm2.w/w*d;
				double t2=-1/cm1.w;
				double lt1=t1*d;
				double lt2=Math.pow(t1, 3)*(1/Math.pow(cm2.w, 2)-Math.pow(t2, 2));
				double lt3=Math.pow(t1, 4)*(1/Math.pow(cm2.w, 3)-Math.pow(t2, 3));
				double f1=cm1.w/w;
				double f2=cm2.w/w;
				_buff2.set(cm1.m2);
				_buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction);
				_buff2=(KahanObject) _plus.execute(_buff2, lt1);
				_buff3.set(cm1.m3);
				_buff3=(KahanObject) _plus.execute(_buff3, cm2.m3._sum, cm2.m3._correction);
				_buff3=(KahanObject) _plus.execute(_buff3, 3*(-f2*cm1.m2._sum+f1*cm2.m2._sum)*d + lt2);
				cm1.m4=(KahanObject) _plus.execute(cm1.m4, cm2.m4._sum, cm2.m4._correction);
				cm1.m4=(KahanObject) _plus.execute(cm1.m4, 4*(-f2*cm1.m3._sum+f1*cm2.m3._sum)*d 
						+ 6*(Math.pow(-f2, 2)*cm1.m2._sum+Math.pow(f1, 2)*cm2.m2._sum)*Math.pow(d, 2) + lt3);				
				cm1.m2.set(_buff2);
				cm1.m3.set(_buff3);
				cm1.w=w;
				break;
			}
			case VARIANCE:
			{
				double w = cm1.w + cm2.w;
				double d=cm2.mean._sum-cm1.mean._sum;
				cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
				double t1=cm1.w*cm2.w/w*d;
				double lt1=t1*d;
				cm1.m2=(KahanObject) _plus.execute(cm1.m2, cm2.m2._sum, cm2.m2._correction);
				cm1.m2=(KahanObject) _plus.execute(cm1.m2, lt1);
				cm1.w=w;
				break;
			}
			
			default:
				throw new DMLRuntimeException("Unsupported operation type: "+_type);
		}
		
		return cm1;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy