hex.deeplearning.Dropout Maven / Gradle / Ivy
package hex.deeplearning;
import water.util.RandomUtils;
import java.util.Arrays;
import java.util.Random;
/**
* Helper class for dropout training of Neural Nets
*/
public class Dropout {
private transient Random _rand;
private transient byte[] _bits;
private transient double _rate;
public byte[] bits() { return _bits; }
// public Dropout() {
// _rate = 0.5;
// }
@Override
public String toString() {
String s = "Dropout: " + super.toString();
s += "\nRandom: " + _rand.toString();
s += "\nDropout rate: " + _rate;
s += "\nbits: ";
for (int i=0; i< _bits.length*8; ++i) s += unit_active(i) ? "1":"0";
s += "\n";
return s;
}
Dropout(int units) {
_bits = new byte[(units+7)/8];
_rand = RandomUtils.getRNG(0);
_rate = 0.5;
}
Dropout(int units, double rate) {
this(units);
_rate = rate;
}
public void randomlySparsifyActivation(Neurons.Vector a, long seed) {
if (a instanceof Neurons.DenseVector)
randomlySparsifyActivation((Neurons.DenseVector) a, seed);
else if (a instanceof Neurons.SparseVector)
randomlySparsifyActivation((Neurons.SparseVector)a, seed);
else throw new UnsupportedOperationException("randomlySparsifyActivation not implemented for this type: " + a.getClass().getSimpleName());
}
// for input layer
private void randomlySparsifyActivation(Neurons.DenseVector a, long seed) {
if (_rate == 0) return;
setSeed(seed);
for( int i = 0; i < a.size(); i++ )
if (_rand.nextFloat() < _rate) a.set(i, 0);
}
private void randomlySparsifyActivation(Neurons.SparseVector a, long seed) {
if (_rate == 0) return;
setSeed(seed);
for (Neurons.SparseVector.Iterator it=a.begin(); !it.equals(a.end()); it.next())
if (_rand.nextFloat() < _rate) it.setValue(0f);
}
// for hidden layers
public void fillBytes(long seed) {
setSeed(seed);
if (_rate == 0.5) _rand.nextBytes(_bits);
else {
Arrays.fill(_bits, (byte)0);
for (int i=0;i<_bits.length*8;++i)
if (_rand.nextFloat() > _rate) _bits[i / 8] |= 1 << (i % 8);
}
}
public boolean unit_active(int o) {
return (_bits[o / 8] & (1 << (o % 8))) != 0;
}
private void setSeed(long seed) {
if ((seed >>> 32) < 0x0000ffffL) seed |= 0x5b93000000000000L;
if (((seed << 32) >>> 32) < 0x0000ffffL) seed |= 0xdb910000L;
_rand.setSeed(seed);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy