
water.api.ModelsHandler Maven / Gradle / Ivy
package water.api;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import hex.Model;
import water.DKV;
import water.Futures;
import water.Iced;
import water.Key;
import water.KeySnapshot;
import water.Value;
import water.api.FramesHandler.Frames;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OKeyNotFoundArgumentException;
import water.exceptions.H2OKeyWrongTypeArgumentException;
import water.exceptions.H2OKeysNotFoundArgumentException;
import water.fvec.Frame;
import water.serial.ObjectTreeBinarySerializer;
import water.util.FileUtils;
class ModelsHandler> extends Handler {
/** Class which contains the internal representation of the models list and params. */
protected static final class Models extends Iced {
public Key model_id;
public Model[] models;
public boolean find_compatible_frames = false;
public static Model[] fetchAll() {
final Key[] modelKeys = KeySnapshot.globalSnapshot().filter(new KeySnapshot.KVFilter() {
@Override
public boolean filter(KeySnapshot.KeyInfo k) {
return Value.isSubclassOf(k._type, Model.class);
}
}).keys();
Model[] models = new Model[modelKeys.length];
for (int i = 0; i < modelKeys.length; i++) {
Model model = getFromDKV("(none)", modelKeys[i]);
models[i] = model;
}
return models;
}
/**
* Fetch all the Frames so we can see if they are compatible with our Model(s).
*/
protected Map> fetchFrameCols() {
Frame[] all_frames = null;
Map> all_frames_cols = null;
if (this.find_compatible_frames) {
// caches for this request
all_frames = Frames.fetchAll();
all_frames_cols = new HashMap>();
for (Frame f : all_frames) {
all_frames_cols.put(f, new HashSet(Arrays.asList(f._names)));
}
}
return all_frames_cols;
}
/**
* For a given model return an array of the compatible frames.
*
* @param model The model to fetch the compatible frames for.
* @param all_frames An array of all the Frames in the DKV.
* @param all_frames_cols A Map of Frame to a Set of its column names.
* @return
*/
private static Frame[] findCompatibleFrames(Model model, Frame[] all_frames, Map> all_frames_cols) {
List compatible_frames = new ArrayList();
Set model_column_names = new HashSet(Arrays.asList(model._output._names));
for (Map.Entry> entry : all_frames_cols.entrySet()) {
Frame frame = entry.getKey();
Set frame_cols = entry.getValue();
if (frame_cols.containsAll(model_column_names)) {
// See if adapt throws an exception or not.
try {
if( model.adaptTestForTrain(new Frame(frame), false, false).length == 0 )
compatible_frames.add(frame);
} catch( IllegalArgumentException e ) {
// skip
}
}
}
return compatible_frames.toArray(new Frame[0]);
}
}
/** Return all the models. */
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelsV3 list(int version, ModelsV3 s) {
Models m = s.createAndFillImpl();
m.models = Models.fetchAll();
return (ModelsV3) s.fillFromImplWithSynopsis(m);
}
// TODO: almost identical to ModelsHandler; refactor
public static Model getFromDKV(String param_name, String key_str) {
return getFromDKV(param_name, Key.make(key_str));
}
// TODO: almost identical to ModelsHandler; refactor
public static Model getFromDKV(String param_name, Key key) {
if (null == key)
throw new H2OIllegalArgumentException(param_name, "Models.getFromDKV()", key);
Value v = DKV.get(key);
if (null == v)
throw new H2OKeyNotFoundArgumentException(param_name, key.toString());
Iced ice = v.get();
if (! (ice instanceof Model))
throw new H2OKeyWrongTypeArgumentException(param_name, key.toString(), Model.class, ice.getClass());
return (Model)ice;
}
/** Return a single model. */
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelsV3 fetchPreview(int version, ModelsV3 s) {
s.preview = true;
return fetch(version, s);
}
/** Return a single model. */
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelsV3 fetch(int version, ModelsV3 s) {
Model model = getFromDKV("key", s.model_id.key());
s.models = new ModelSchema[1];
s.models[0] = (ModelSchema)Schema.schema(version, model).fillFromImpl(model);
if (s.find_compatible_frames) {
// TODO: refactor fetchFrameCols so we don't need this Models object
Models m = new Models();
m.models = new Model[1];
m.models[0] = model;
m.find_compatible_frames = true;
Frame[] compatible = Models.findCompatibleFrames(model, Frames.fetchAll(), m.fetchFrameCols());
s.compatible_frames = new FrameV3[compatible.length]; // TODO: FrameBase
((ModelSchema)s.models[0]).compatible_frames = new String[compatible.length];
int i = 0;
for (Frame f : compatible) {
s.compatible_frames[i] = new FrameV3(f).fillFromImpl(f); // TODO: FrameBase
((ModelSchema)s.models[0]).compatible_frames[i] = f._key.toString();
i++;
}
}
return s;
}
/** Remove an unlocked model. Fails if model is in-use. */
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelsV3 delete(int version, ModelsV3 s) {
Model model = getFromDKV("key", s.model_id.key());
model.delete(); // lock & remove
return s;
}
/**
* Remove ALL an unlocked models. Throws IAE for all deletes that failed
* (perhaps because the Models were locked & in-use).
*/
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelsV3 deleteAll(int version, ModelsV3 models) {
final Key[] keys = KeySnapshot.globalKeysOfClass(Model.class);
ArrayList missing = new ArrayList<>();
Futures fs = new Futures();
for( int i = 0; i < keys.length; i++ ) {
try {
getFromDKV("(none)", keys[i]).delete(null, fs);
} catch( IllegalArgumentException iae ) {
missing.add(keys[i].toString());
}
}
fs.blockForPending();
if( missing.size() != 0 ) throw new H2OKeysNotFoundArgumentException("(none)", missing.toArray(new String[missing.size()]));
return models;
}
public ModelsV3 importModel(int version, ModelImportV3 mimport) {
ModelsV3 s = (ModelsV3) Schema.newInstance(ModelsV3.class);
try {
List importedKeys = new ObjectTreeBinarySerializer().load(FileUtils.getURI(mimport.dir));
Model model = (Model) importedKeys.get(0).get();
s.models = new ModelSchema[1];
s.models[0] = (ModelSchema) Schema.schema(version, model).fillFromImpl(model);
} catch (IOException e) {
throw new H2OIllegalArgumentException("dir", "importModel", e);
}
return s;
}
public ModelExportV3 exportModel(int version, ModelExportV3 mexport) {
Model model = getFromDKV("model_id", mexport.model_id.key());
List keysToExport = new LinkedList<>();
keysToExport.add(model._key);
keysToExport.addAll(model.getPublishedKeys());
try {
URI targetUri = FileUtils.getURI(mexport.dir);
new ObjectTreeBinarySerializer(mexport.force).save(keysToExport, targetUri);
// Send back
mexport.dir = "file".equals(targetUri.getScheme()) ? new File(targetUri).getCanonicalPath() : targetUri.toString();
} catch (IOException e) {
throw new H2OIllegalArgumentException("dir", "exportModel", e);
}
return mexport;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy