All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.models.nn.LayeredNetwork Maven / Gradle / Ivy

package com.expleague.ml.models.nn;

import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.FuncC1;
import com.expleague.ml.func.generic.Const;
import com.expleague.ml.func.generic.SubVecFuncC1;
import com.expleague.ml.func.generic.WSumSigmoid;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/**
* User: solar
* Date: 26.05.15
* Time: 11:46
*/
public class LayeredNetwork extends NeuralSpider {
  private final Node[] nodes;
  private final Random rng;
  private final double dropout;
  private final int[] config;
  private final int dim;

  public LayeredNetwork(Random rng, double dropout, final int... config) {
    this.rng = rng;
    this.dropout = dropout;
    this.config = config;
    final int nodesCount;
    {
      int dim = config[0];
      int count = 1 + config[0];
      for (int i = 1; i < config.length; i++) {
        count += config[i];
        dim += config[i] * config[i - 1];
      }
      this.dim = dim;
      nodesCount = count;
    }
    int wStart = config[0];
    final List nodes = new ArrayList<>();
    for(int d = 1; d < config.length; d++) {
      final int prevLayerPower = config[d - 1];
      final int layerStart = nodes.size() + config[0] + 1;

      for (int i = 0; i < config[d]; i++) {
        final int fwStart = wStart;
        nodes.add(new Node() {
          @Override
          public FuncC1 transByParameters(Vec betta) {
            return new SubVecFuncC1(new WSumSigmoid(betta.sub(fwStart, prevLayerPower)), layerStart - prevLayerPower, prevLayerPower, nodesCount);
          }

          @Override
          public FuncC1 transByParents(Vec state) {
            return new SubVecFuncC1(new WSumSigmoid(state.sub(layerStart - prevLayerPower, prevLayerPower)), fwStart, prevLayerPower, dim);
          }
        });
        wStart += prevLayerPower;
      }
    }
    this.nodes = nodes.toArray(new Node[nodes.size()]);
  }
  @Override
  public int dim() {
    return dim;
  }

  @Override
  protected Topology topology(final Vec argument, final boolean dropout) {
    if (argument.dim() != config[0])
      throw new IllegalArgumentException();
    final Node[] inputLayer = new Node[config[0]];
    for(int i = 0; i < inputLayer.length; i++) {
      final int nindex = i;
      inputLayer[i] = new Node() {
        @Override
        public FuncC1 transByParameters(Vec betta) {
          return new Const(argument.get(nindex));
        }

        @Override
        public FuncC1 transByParents(Vec state) {
          return new Const(argument.get(nindex));
        }
      };
    }

    return new Topology.Stub() {
      @Override
      public int outputCount() {
        return config[config.length - 1];
      }

      @Override
      public boolean isDroppedOut(int nodeIndex) {
        //noinspection SimplifiableIfStatement
        if (!dropout || nodeIndex < inputLayer.length || nodeIndex > nodes.length - inputLayer.length)
          return false;
        return LayeredNetwork.this.dropout > MathTools.EPSILON && rng.nextDouble() < LayeredNetwork.this.dropout;
      }

      @Override
      public Node at(int i) {
        return i <= inputLayer.length ? inputLayer[i - 1] : nodes[i - inputLayer.length - 1];
      }

      @Override
      public int length() {
        return inputLayer.length + nodes.length + 1;
      }
    };
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy