All Downloads are FREE. Search and download functionalities are using the official Maven repository.

deepboof.impl.forward.standard.BaseFunction Maven / Gradle / Ivy

There is a newer version: 0.5.3
Show newest version
/*
 * Copyright (c) 2016, Peter Abeles. All Rights Reserved.
 *
 * This file is part of DeepBoof
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package deepboof.impl.forward.standard;

import deepboof.Function;
import deepboof.Tensor;
import deepboof.misc.TensorOps;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Base class which implements common functionality between all {@link Function functions}
 *
 * @author Peter Abeles
 */
@SuppressWarnings("unchecked")
public abstract class BaseFunction implements Function {
	protected int [] shapeInput = new int[0];
	protected List shapeParameters = new ArrayList<>();
	protected int [] shapeOutput = new int[0];

	protected List parameters;

	/**
	 * Number of inputs in the mini-batch
	 */
	protected int miniBatchSize;

	@Override
	public void initialize(int... shapeInput) {
		this.shapeInput = shapeInput.clone();
		shapeParameters.clear();
		Arrays.fill(shapeOutput,-1);

		_initialize();
	}

	public abstract void _initialize();

	@Override
	public void setParameters(List parameters) {
		TensorOps.checkShape("parameters", shapeParameters, (List) parameters, false);

		this.parameters = new ArrayList<>(parameters);
		_setParameters(parameters);
	}

	public abstract void _setParameters(List parameters);

	@Override
	public List getParameters() {
		return parameters;
	}

	@Override
	public void forward(T input, T output) {
		if( shapeInput == null )
			throw new IllegalArgumentException("Must initialize first!");

		TensorOps.checkShape("input",-1,shapeInput,input.getShape(),true);
		TensorOps.checkShape("output", -1,shapeOutput,output.getShape(),true);

		// see if the number of stacked inputs is the same in input and output
		miniBatchSize = input.length(0);
		if( output.length(0) != miniBatchSize) {
			int M = output.length(0);
			throw new IllegalArgumentException("Dimension 0 in the output is "+M+
					" and does not match input dimension 0 of "+ miniBatchSize);
		}

		_forward(input, output);
	}

	public abstract void _forward(T input, T output);

	@Override
	public List getParameterShapes() {
		return shapeParameters;
	}

	@Override
	public int[] getOutputShape() {
		return shapeOutput;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy