ai.databand.DbndWrapper Maven / Gradle / Ivy
/*
* © Copyright Databand.ai, an IBM Company 2022
*/
package ai.databand;
import ai.databand.config.DbndConfig;
import ai.databand.log.HistogramRequest;
import ai.databand.log.LogDatasetRequest;
import ai.databand.schema.ColumnStats;
import ai.databand.schema.DatabandTaskContext;
import ai.databand.schema.DatasetOperationStatus;
import ai.databand.schema.DatasetOperationType;
import ai.databand.schema.LogDataset;
import ai.databand.schema.TaskRun;
import javassist.ClassPool;
import javassist.Loader;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.PatternLayout;
import org.apache.log4j.spi.LoggingEvent;
import org.apache.spark.scheduler.SparkListenerEvent;
import org.apache.spark.scheduler.SparkListenerStageCompleted;
import org.apache.spark.sql.Dataset;
import org.slf4j.LoggerFactory;
import java.lang.reflect.Method;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
/**
* AspectJ wrapper for @Pipeline and @Task annonations.
*/
public class DbndWrapper {
private static final org.slf4j.Logger LOG = LoggerFactory.getLogger(DbndWrapper.class);
private DbndClient dbnd;
private final DbndConfig config;
// state
private final Set loadedClasses;
private final Map methodsCache;
private DbndRun run;
private boolean pipelineInitialized;
private final Deque stack;
private boolean externalContextSet = false;
/**
* Indicates that Spark has started shutdown sequence.
*/
private boolean isSparkShutdown = false;
private static final DbndWrapper INSTANCE = new DbndWrapper();
public static DbndWrapper instance() {
return INSTANCE;
}
public DbndWrapper() {
config = new DbndConfig();
try {
dbnd = new DbndClient(config);
} catch (Exception e) {
dbnd = null;
LOG.error("Unable to initialize DbndClient, tracking will be disabled. Reason: {}", e.getMessage());
config.setTrackingEnabled(false);
}
methodsCache = new HashMap<>(1);
stack = new ArrayDeque<>(1);
loadedClasses = new HashSet<>(1);
// inject log4j logger appender which will capture all log output and will send it to the tracker.
String pattern = "[%d] {%c{2}} %p - %m%n";
DbndLogAppender dbndAppender = new DbndLogAppender(this);
dbndAppender.setLayout(new PatternLayout(pattern));
dbndAppender.setThreshold(Level.INFO);
dbndAppender.activateOptions();
Logger.getLogger("org.apache.spark").addAppender(dbndAppender);
Logger.getLogger("org.spark_project").addAppender(dbndAppender);
Logger.getLogger("ai.databand").addAppender(dbndAppender);
}
public Optional> loadClass(String className) {
try {
return Optional.of(Class.forName(className));
} catch (ClassNotFoundException e) {
// do nothing, class loader we've got doesn't have pipeline class
}
try {
// try to use Javassist classloader
return Optional.of(new Loader(ClassPool.getDefault()).loadClass(className));
} catch (ClassNotFoundException e) {
// do nothing
}
return Optional.empty();
}
public void beforePipeline(String className, String methodName, Object[] args) {
Method method = findMethodByName(methodName, className);
if (method == null) {
pipelineInitialized = false;
return;
}
// log4j system is not initialized properly at this point, so we're using stdout directly
System.out.println("Running Databand!");
System.out.printf("TRACKER URL: %s%n", config.databandUrl());
System.out.printf("CMD: %s%n", config.cmd());
System.out.println("Parsed Databand properties: " + config);
getOrCreateRun(method, args);
pipelineInitialized = true;
}
protected Method findMethodByName(String methodName, String classname) {
if (classname != null && !loadedClasses.contains(classname)) {
loadMethods(classname);
}
String truncated = removeArgsFromMethodName(methodName);
for (Map.Entry mthd : methodsCache.entrySet()) {
if (mthd.getKey().contains(truncated)) {
return mthd.getValue();
}
}
return null;
}
/**
* Removes arguments part from string representation of method name.
* ai.databand.JavaSparkPipeline.execute(java.lang.String) → ai.databand.JavaSparkPipeline.execute(
* Opening parent should be present in result because latter it will be used in methods cache calculation
*
* @param methodName
* @return
*/
protected String removeArgsFromMethodName(String methodName) {
int parenIndex = methodName.indexOf("(");
return parenIndex > 0 ? methodName.substring(0, parenIndex + 1) : methodName;
}
protected void loadMethods(String classname) {
Optional> pipelineClass = loadClass(classname);
if (!pipelineClass.isPresent()) {
LOG.error("Unable to build method cache for class {} because it can not be loaded", classname);
pipelineInitialized = false;
return;
}
for (Method mthd : pipelineClass.get().getDeclaredMethods()) {
String fullMethodName = mthd.toGenericString();
methodsCache.put(fullMethodName, mthd);
}
loadedClasses.add(classname);
}
public void afterPipeline() {
currentRun().stop();
cleanup();
}
public void errorPipeline(Throwable error) {
currentRun().error(error);
cleanup();
}
protected void cleanup() {
run = null;
methodsCache.clear();
pipelineInitialized = false;
loadedClasses.clear();
}
public void beforeTask(String className, String methodName, Object[] args) {
if (!pipelineInitialized) {
// this is first task, let's initialize pipeline
if (stack.isEmpty()) {
beforePipeline(className, methodName, args);
stack.push(methodName);
} else {
// main method was loaded by different classloader
beforePipeline(className, stack.peek(), args);
}
return;
}
DbndRun run = currentRun();
Method method = findMethodByName(methodName, className);
LOG.info("Running task {}", run.getTaskName(method));
run.startTask(method, args);
stack.push(methodName);
}
public void afterTask(String methodName, Object result) {
stack.pop();
if (stack.isEmpty()) {
// this was the last task in stack, e.g. pipeline
afterPipeline();
return;
}
DbndRun run = currentRun();
Method method = findMethodByName(methodName, null);
run.completeTask(method, result);
LOG.info("Task {} has been completed!", run.getTaskName(method));
}
public void errorTask(String methodName, Throwable error) {
String poll = stack.pop();
LOG.info("Task {} returned error!", poll);
if (stack.isEmpty()) {
// this was the last task in stack, e.g. pipeline
errorPipeline(error);
return;
}
DbndRun run = currentRun();
Method method = findMethodByName(methodName, null);
run.errorTask(method, error);
}
public void logTask(LoggingEvent event, String eventStr) {
DbndRun run = currentRun();
if (run == null) {
return;
}
run.saveLog(event, eventStr);
}
public void logMetric(String key, Object value) {
DbndRun run = currentRun();
if (run == null) {
run = createAgentlessRun();
}
run.logMetric(key, value);
LOG.info("Metric logged: [{}: {}]", key, value);
}
public void logDatasetOperation(String path,
DatasetOperationType type,
DatasetOperationStatus status,
Dataset> data,
Throwable error,
LogDatasetRequest params) {
DbndRun run = currentRun();
if (run == null) {
run = createAgentlessRun();
}
run.logDatasetOperation(path, type, status, data, error, params, LogDataset.OP_SOURCE_JAVA_MANUAL_LOGGING);
LOG.info("Dataset Operation [path: {}], [type: {}], [status: {}] logged", path, type, status);
}
public void logDatasetOperation(String path,
DatasetOperationType type,
DatasetOperationStatus status,
String valuePreview,
List dataDimensions,
String dataSchema,
Boolean withPartition,
List columnStats,
String operationSource) {
DbndRun run = currentRun();
if (run == null) {
run = createAgentlessRun();
}
run.logDatasetOperation(path, type, status, valuePreview, null, dataDimensions, dataSchema, withPartition, columnStats, operationSource);
LOG.info("Dataset Operation [path: {}], [type: {}], [status: {}] logged", path, type, status);
if (isSparkShutdown) {
// If spark is in the shutdown sequence, pyspark tracking is already completed.
// This call ensures Spark Listener will send `stop` signal.
LOG.info("Sending \"SUCCESS\" signal to the task run");
run.stopListener();
}
}
public void logMetrics(Map metrics) {
logMetrics(metrics, null);
}
public void logMetrics(Map metrics, String source) {
DbndRun run = currentRun();
if (run == null) {
run = createAgentlessRun();
}
run.logMetrics(metrics, source);
}
public void logDataframe(String key, Dataset> value, HistogramRequest histogramRequest) {
DbndRun run = currentRun();
if (run == null) {
run = createAgentlessRun();
}
run.logDataframe(key, value, histogramRequest);
}
public void logHistogram(Map histogram) {
DbndRun run = currentRun();
if (run == null) {
run = createAgentlessRun();
}
run.logHistogram(histogram);
}
public void logDataframe(String key, Dataset> value, boolean withHistograms) {
DbndRun run = currentRun();
if (run == null) {
run = createAgentlessRun();
}
run.logDataframe(key, value, new HistogramRequest(withHistograms));
LOG.info("Dataframe {} logged", key);
}
public void logSpark(SparkListenerEvent event) {
if (run == null) {
run = createAgentlessRun();
}
if (event instanceof SparkListenerStageCompleted) {
run.saveSparkMetrics((SparkListenerStageCompleted) event);
LOG.info("Spark metrics received from SparkListener saved");
}
}
public DbndConfig config() {
return config;
}
// TODO: replace synchronized with better approach to avoid performance bottlenecks
private synchronized DbndRun getOrCreateRun(Method method, Object[] args) {
if (currentRun() == null) {
initRun(method, args);
}
return currentRun();
}
private DbndRun createAgentlessRun() {
// add jvm shutdown hook so run will be completed after spark job will stop
// hook should be added before, because listener is called asynchronously and spark can initialize stop sequence
if (!config.isTrackingEnabled()) {
return new NoopDbndRun();
}
Runtime.getRuntime().addShutdownHook(new Thread(this::stop));
// check if we're running inside databand task context
if (config.databandTaskContext().isPresent()) {
// don't init run from the scratch, reuse values
run = config.isTrackingEnabled() ? new DefaultDbndRun(dbnd, config) : new NoopDbndRun();
if (!config.isTrackingEnabled()) {
System.out.println("Tracking is not enabled. Set DBND__TRACKING variable to True if you want to enable it.");
}
System.out.println("Reusing existing task");
DatabandTaskContext dbndCtx = config.databandTaskContext().get();
TaskRun driverTask = new TaskRun();
driverTask.setRunUid(dbndCtx.getRootRunUid());
driverTask.setTaskRunUid(dbndCtx.getTaskRunUid());
driverTask.setTaskRunAttemptUid(dbndCtx.getTaskRunAttemptUid());
config.airflowContext().ifPresent(ctx -> driverTask.setName(ctx.getTaskId()));
run.setDriverTask(driverTask);
} else {
try {
StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
StackTraceElement main = null;
for (StackTraceElement el : stackTrace) {
if (el.getMethodName().equals("main")) {
main = el;
break;
}
}
if (main == null) {
main = stackTrace[stackTrace.length - 1];
}
// workaround to prevent class not found exception when scala is using layered classloader
// source: https://github.com/sbt/sbt/issues/4760
Class> entryPoint = Class.forName(main.getClassName(), true, Thread.currentThread().getContextClassLoader());
for (Method method : entryPoint.getMethods()) {
if (method.getName().contains(main.getMethodName())) {
Object[] args = new Object[method.getParameterCount()];
Arrays.fill(args, null);
beforePipeline(main.getClassName(), method.getName(), args);
break;
}
}
} catch (ClassNotFoundException e) {
System.out.printf("Class not found: %s%n", e.getMessage());
// do nothing
}
}
if (Objects.isNull(run)) {
// in case pipeline is not annotated and class not found exception initializing run with no args
getOrCreateRun(null, null);
}
return run;
}
protected void stop() {
if (run != null) {
run.stop();
}
}
private void setSparkShutdown() {
isSparkShutdown = true;
}
protected DbndRun currentRun() {
return run;
}
private void initRun(Method method, Object[] args) {
run = config.isTrackingEnabled() ? new DefaultDbndRun(dbnd, config) : new NoopDbndRun();
if (!config.isTrackingEnabled()) {
System.out.println("Tracking is not enabled. Set DBND__TRACKING variable to True if you want to enable it.");
return;
}
try {
run.init(method, args);
// log4j isn't initialized at this point
System.out.printf("Running pipeline %s%n", run.getTaskName(method));
} catch (Exception e) {
run = new NoopDbndRun();
System.out.println("Unable to init run:");
e.printStackTrace();
}
}
protected void printStack() {
StringBuilder buffer = new StringBuilder(3);
Iterator iterator = stack.iterator();
buffer.append('[');
while (iterator.hasNext()) {
buffer.append(' ');
buffer.append(iterator.next());
buffer.append(' ');
}
buffer.append(']');
LOG.info(buffer.toString());
}
/**
* Set tracking context from external source.
* This allows us to set context externally (for instance when calling pyspark script) and avoid runs duplication.
* TODO: since context can be controlled externally in this way, it may sense to start/stop JVM tasks from the Python
*
* @param runUid
* @param taskRunUid
* @param taskRunAttemptUid
* @param taskName
*/
public void setExternalTaskContext(String runUid, String taskRunUid, String taskRunAttemptUid, String taskName) {
if (externalContextSet) {
// external context was already set
// listener will report all dataset ops to the root task
// no need to set context again
return;
}
if (!config.isTrackingEnabled()) {
run = new NoopDbndRun();
LOG.info("Attempt to set external task context failed: tracking is not enabled");
return;
}
if (run == null) {
run = new DefaultDbndRun(dbnd, config);
// before spark will be stopped we have to submit all saved metrics from the last external task
Runtime.getRuntime().addShutdownHook(new Thread(run::stopExternal));
// when pyspark is running, py tracking will complete before Spark will start shutdown sequence
// Query Listener will still be working during shutdown. We need to know this because listener
// has to send signal to Databand Tracker to recalculate dataset operations
Runtime.getRuntime().addShutdownHook(new Thread(this::setSparkShutdown));
}
// before setting context we should submit all gathered metrics from a previous context
run.stopExternal();
// and then set new context
TaskRun task = new TaskRun();
task.setRunUid(runUid);
task.setTaskRunUid(taskRunUid);
task.setTaskRunAttemptUid(taskRunAttemptUid);
task.setName(taskName);
run.setDriverTask(task);
externalContextSet = true;
LOG.info("External task context was set. run_uid: {}, task_run_uid: {}, task_run_attempt_uid: {}, task_name: {}", runUid, taskRunUid, taskRunAttemptUid, taskName);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy