
org.mwg.ml.BaseMLNode Maven / Gradle / Ivy
/**
* Copyright 2017 The MWG Authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.mwg.ml;
import org.mwg.Callback;
import org.mwg.Constants;
import org.mwg.DeferCounter;
import org.mwg.Graph;
import org.mwg.base.BaseNode;
import org.mwg.internal.task.CoreActions;
import org.mwg.plugin.Job;
import org.mwg.task.Task;
import org.mwg.task.TaskResult;
import static org.mwg.internal.task.CoreActions.*;
import static org.mwg.task.Tasks.emptyResult;
import static org.mwg.task.Tasks.newTask;
public abstract class BaseMLNode extends BaseNode {
public static String FROM_SEPARATOR = ";";
public static String FROM = "from";
public BaseMLNode(long p_world, long p_time, long p_id, Graph p_graph) {
super(p_world, p_time, p_id, p_graph);
}
/**
* If {@code obj} is null, throws {@code NullPointerException} with a {@code message}
*
* @param obj
* @param message
*/
protected static void requireNotNull(Object obj, String message) {
if (obj == null) {
throw new RuntimeException(message);
}
}
/**
* Asserts that condition is true. If not - throws {@code IllegalArgumentException} with a specified error message
*
* @param condition Condition to test
* @param errorMessage Error message thrown with {@code IllegalArgumentException} (if thrown)
* @throws IllegalArgumentException if condition is false
*/
protected void illegalArgumentIfFalse(boolean condition, String errorMessage) {
assert errorMessage != null;
if (!condition) {
throw new IllegalArgumentException(errorMessage);
}
}
protected void extractFeatures(final Callback callback) {
String query = (String) super.get(FROM);
if (query != null) {
//TODO CACHE TO AVOID PARSING EVERY TIME
String[] split = query.split(FROM_SEPARATOR);
Task[] tasks = new Task[split.length];
for (int i = 0; i < split.length; i++) {
Task t = newTask().then(travelInWorld("" + world()));
t.then(CoreActions.travelInTime(time() + ""));
t.parse(split[i].trim(),graph());
tasks[i] = t;
}
//END TODO IN CACHE
final double[] result = new double[tasks.length];
final DeferCounter waiter = graph().newCounter(tasks.length);
for (int i = 0; i < split.length; i++) {
final int taskIndex = i;
final TaskResult initial = emptyResult();
initial.add(this);
tasks[i].executeWith(graph(), initial, new Callback() {
@Override
public void on(TaskResult currentResult) {
if (currentResult == null) {
result[taskIndex] = Constants.NULL_LONG;
} else {
result[taskIndex] = Double.parseDouble(currentResult.get(0).toString());
currentResult.free();
}
waiter.count();
}
});
}
waiter.then(new Job() {
@Override
public void run() {
callback.on(result);
}
});
} else {
callback.on(null);
}
}
}