hex.glrm.GlrmMojoWriter Maven / Gradle / Ivy
package hex.glrm;
import hex.ModelMojoWriter;
import hex.genmodel.algos.glrm.GlrmLoss;
import java.io.IOException;
import java.nio.ByteBuffer;
/**
* MOJO serializer for GLRM model.
*/
public class GlrmMojoWriter extends ModelMojoWriter {
@SuppressWarnings("unused") // Called through reflection in ModelBuildersHandler
public GlrmMojoWriter() {}
public GlrmMojoWriter(GLRMModel model) {
super(model);
}
@Override public String mojoVersion() {
return "1.00";
}
@Override
protected void writeModelData() throws IOException {
writekv("initialization", model._parms._init);
writekv("regularizationX", model._parms._regularization_x);
writekv("regularizationY", model._parms._regularization_y);
writekv("gammaX", model._parms._gamma_x);
writekv("gammaY", model._parms._gamma_y);
writekv("ncolX", model._parms._k);
// DataInfo mapping
writekv("cols_permutation", model._output._permutation);
writekv("num_categories", model._output._ncats);
writekv("num_numeric", model._output._nnums);
writekv("norm_sub", model._output._normSub);
writekv("norm_mul", model._output._normMul);
// Loss functions
writekv("ncolA", model._output._lossFunc.length);
startWritingTextFile("losses");
for (GlrmLoss loss : model._output._lossFunc) {
writeln(loss.toString());
}
finishWritingTextFile();
// Archetypes
GLRM.Archetypes arch = model._output._archetypes_raw;
writekv("ncolY", arch.nfeatures());
writekv("nrowY", arch.rank());
writekv("num_levels_per_category", arch._numLevels);
int n = arch.rank() * arch.nfeatures();
ByteBuffer bb = ByteBuffer.wrap(new byte[n * 8]);
for (double[] row : arch.getY(false))
for (double val : row)
bb.putDouble(val);
writeblob("archetypes", bb.array());
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy