All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.api.MakeGLMModelHandler Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.api;

import hex.DataInfo;
import hex.DataInfo.TransformType;
import hex.Model;
import hex.glm.GLMModel;
import hex.glm.GLMModel.GLMOutput;
import hex.gram.Gram;
import hex.schemas.*;
import water.DKV;
import water.Key;
import water.MRTask;
import water.api.Handler;
import water.api.schemas3.KeyV3;
import water.fvec.*;
import water.fvec.Vec.VectorGroup;

import java.util.Arrays;
import java.util.Map;

/**
 * Created by tomasnykodym on 3/25/15.
 */
public class MakeGLMModelHandler extends Handler {
  public GLMModelV3 make_model(int version, MakeGLMModelV3 args){
    GLMModel model = DKV.getGet(args.model.key());
    if(model == null)
      throw new IllegalArgumentException("missing source model " + args.model);
    boolean multiClass = model._output._multinomial || model._output._ordinal;
    String [] names = multiClass?model._output.multiClassCoeffNames():model._output.coefficientNames(); // coefficient names in order and with Intercept
    Map coefs = model.coefficients();
    if (args.beta.length != names.length) {
      throw new IllegalArgumentException("model coefficient length " + names.length + " is different from coefficient" +
              " provided by user " + args.beta.length + ".\n model coefficients needed are:\n" + String.join("\n", names));
    }
    for(int i = 0; i < args.names.length; ++i)
      coefs.put(args.names[i],args.beta[i]);
    double [] beta = model.beta().clone();
    for(int i = 0; i < beta.length; ++i)
      beta[i] = coefs.get(names[i]);
    GLMModel m = new GLMModel(args.dest != null?args.dest.key():Key.make(),model._parms,null, model._ymu,
            Double.NaN, Double.NaN, -1);
    m.setInputParms(model._input_parms);
    DataInfo dinfo = model.dinfo();
    dinfo.setPredictorTransform(TransformType.NONE);
    m._output = new GLMOutput(model.dinfo(), model._output._names, model._output._column_types, model._output._domains,
            model._output.coefficientNames(), beta, model._output._binomial, model._output._multinomial, 
            model._output._ordinal);
    DKV.put(m._key, m);
    GLMModelV3 res = new GLMModelV3();
    res.fillFromImpl(m);
    return res;
  }
  
  public GLMRegularizationPathV3 extractRegularizationPath(int v, GLMRegularizationPathV3 args) {
    GLMModel model = DKV.getGet(args.model.key());
    if(model == null)
      throw new IllegalArgumentException("missing source model " + args.model);
    return new GLMRegularizationPathV3().fillFromImpl(model.getRegularizationPath());
  }
  // instead of adding a new endpoint, just put this stupid test functionality here
 /** Get the expanded (interactions + offsets) dataset. Careful printing! Test only
  */
  public DataInfoFrameV3 getDataInfoFrame(int version, DataInfoFrameV3 args) {
    Frame fr = DKV.getGet(args.frame.key());
    if( null==fr ) throw new IllegalArgumentException("no frame found");
    args.result = new KeyV3.FrameKeyV3(oneHot(fr, Model.InteractionSpec.allPairwise(args.interactions), args.use_all, args.standardize, args.interactions_only, true)._key);
    return args;
  }

  public static Frame oneHot(Frame fr, Model.InteractionSpec interactions, boolean useAll, boolean standardize, final boolean interactionsOnly, final boolean skipMissing) {
    final DataInfo dinfo = new DataInfo(fr,null,1,useAll,standardize?TransformType.STANDARDIZE:TransformType.NONE,TransformType.NONE,skipMissing,false,false,false,false,false, interactions);
    Frame res;
    if( interactionsOnly ) {
      if( null==dinfo._interactionVecs ) throw new IllegalArgumentException("no interactions");
      int noutputs=0;
      final int[] colIds = new int[dinfo._interactionVecs.length];
      final int[] offsetIds = new int[dinfo._interactionVecs.length];
      int idx=0;
      String[] coefNames = dinfo.coefNames();
      for(int i : dinfo._interactionVecs)
        noutputs+= ( offsetIds[idx++] = ((InteractionWrappedVec)dinfo._adaptedFrame.vec(i)).expandedLength());
      String[] names = new String[noutputs];
      int offset=idx=0;
      int namesIdx=0;
      for(int i=0;i dinfo._interactionVecs.length ) break; // no more interaciton vecs left
        } else {
          if( v.isCategorical() ) offset+= v.domain().length - (useAll?0:1);
          else                    offset++;
        }
      }
      res = new MRTask() {
        @Override public void map(Chunk[] cs, NewChunk ncs[]) {
          DataInfo.Row r = dinfo.newDenseRow();
          for(int i=0;i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy