ai.databand.DefaultDbndRun Maven / Gradle / Ivy
package ai.databand;
import ai.databand.config.DbndConfig;
import ai.databand.id.Sha1Long;
import ai.databand.id.Sha1Short;
import ai.databand.id.Uuid5;
import ai.databand.log.HistogramRequest;
import ai.databand.parameters.DatasetOperationPreview;
import ai.databand.parameters.Histogram;
import ai.databand.parameters.NullPreview;
import ai.databand.parameters.ParametersPreview;
import ai.databand.parameters.TaskParameterPreview;
import ai.databand.schema.AirflowTaskContext;
import ai.databand.schema.AzkabanTaskContext;
import ai.databand.schema.DatasetOperationStatus;
import ai.databand.schema.DatasetOperationType;
import ai.databand.schema.ErrorInfo;
import ai.databand.schema.LogDataset;
import ai.databand.schema.LogTarget;
import ai.databand.schema.Pair;
import ai.databand.schema.RootRun;
import ai.databand.schema.RunAndDefinition;
import ai.databand.schema.TaskDefinition;
import ai.databand.schema.TaskParamDefinition;
import ai.databand.schema.TaskRun;
import ai.databand.schema.TaskRunParam;
import ai.databand.schema.TaskRunsInfo;
import ai.databand.schema.TrackingSource;
import org.apache.log4j.spi.LoggingEvent;
import org.apache.spark.scheduler.SparkListenerStageCompleted;
import org.apache.spark.scheduler.StageInfo;
import org.apache.spark.sql.Dataset;
import org.apache.spark.util.AccumulatorV2;
import org.apache.spark.util.LongAccumulator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.collection.Iterator;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
@SuppressWarnings("unchecked")
public class DefaultDbndRun implements DbndRun {
private static final Logger LOG = LoggerFactory.getLogger(DefaultDbndRun.class);
private final DbndClient dbnd;
private final List taskRuns;
private final List taskDefinitions;
private final List> parentChildMap;
private final List> upstreamsMap;
private final Deque stack;
// todo: methods cache should be extracted to app-level cache
private final Map> methodsCache;
private final Map methodsRunsCache;
private final Map methodExecutionCounts;
private final ParametersPreview parameters;
private final Map taskRunOutputs;
private String rootRunUid;
private String runId;
private String jobName;
private String driverTaskUid;
private TaskRun driverTask;
private AirflowTaskContext airflowContext;
private AzkabanTaskContext azkabanTaskContext;
private final DbndConfig config;
public DefaultDbndRun(DbndClient dbndClient, DbndConfig config) {
this.dbnd = dbndClient;
this.taskRuns = new ArrayList<>(1);
this.taskDefinitions = new ArrayList<>(1);
this.parentChildMap = new ArrayList<>(1);
this.upstreamsMap = new ArrayList<>(1);
this.stack = new ArrayDeque<>(1);
this.methodsCache = new HashMap<>(1);
this.methodsRunsCache = new HashMap<>(1);
this.methodExecutionCounts = new HashMap<>(1);
this.parameters = new ParametersPreview(config.isPreviewEnabled());
this.taskRunOutputs = new HashMap<>(1);
this.airflowContext = config.airflowContext().orElse(null);
this.azkabanTaskContext = config.azkabanContext().orElse(null);
this.config = config;
}
@Override
public void init(Method method, Object[] args) {
String annotationValue = getTaskName(method);
this.runId = UUID.randomUUID().toString();
String user = System.getProperty("user.name");
String source = null;
TrackingSource trackingSource = null;
if (airflowContext != null) {
this.jobName = airflowContext.jobName();
source = "airflow_tracking";
trackingSource = new TrackingSource(airflowContext);
} else if (azkabanTaskContext != null) {
this.jobName = azkabanTaskContext.databandJobName();
trackingSource = azkabanTaskContext.trackingSource();
if (trackingSource != null) {
source = "azkaban_tracking";
}
} else {
this.jobName = annotationValue == null || annotationValue.isEmpty() ? method.getName() : annotationValue;
}
config.jobName().ifPresent(name -> this.jobName = name);
TaskRunsInfo rootRun = buildRootRun(method, args);
RootRun root = config.azkabanContext().isPresent() ? config.azkabanContext().get().root() : null;
this.rootRunUid = dbnd.initRun(jobName, runId, user, config.runName(), rootRun, airflowContext, root, source, trackingSource, null);
dbnd.setRunState(this.rootRunUid, "RUNNING");
}
/**
* Builds root run.
*
* @return
*/
protected TaskRunsInfo buildRootRun(Method method, Object[] args) {
ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
String runUid = new Uuid5("RUN_UID", runId).toString();
driverTaskUid = new Uuid5("DRIVER_TASK", runId).toString();
String taskRunEnvUid = new Uuid5("TASK_RUN_ENV_UID", runId).toString();
String taskRunAttemptUid = new Uuid5("TASK_RUN_ATTEMPT", runId).toString();
String cmd = config.cmd();
String version = "";
Sha1Short taskSignature = new Sha1Short("TASK_SIGNATURE", runId);
String taskDefinitionUid = new Uuid5("TASK_DEFINITION", runId).toString();
String taskAfId = getTaskName(method);
List taskParamDefinitions = buildTaskParamDefinitions(method);
String methodName = method == null ? "pipeline" : method.getName();
Pair, List> paramsAndTargets = buildTaskRunParamsAndTargets(
method,
args,
runUid,
methodName,
taskRunAttemptUid,
taskDefinitionUid
);
this.driverTask = new TaskRun(
runUid,
true,
false,
null,
version,
driverTaskUid,
taskSignature.toString(),
jobName,
paramsAndTargets.left(),
taskSignature.toString(),
false,
now.toLocalDate(),
now,
"",
"RUNNING",
taskDefinitionUid,
cmd,
false,
false,
taskRunAttemptUid,
taskAfId,
airflowContext != null || azkabanTaskContext != null,
true,
cmd,
taskAfId,
"jvm",
Collections.emptyMap()
);
this.driverTask.setStartDate(now);
String sourceCode = extractSourceCode(method);
TrackingSource trackingSource = null;
if (airflowContext != null) {
this.parentChildMap.add(Arrays.asList(airflowContext.getAfOperatorUid(), driverTaskUid));
this.upstreamsMap.add(Arrays.asList(airflowContext.getAfOperatorUid(), driverTaskUid));
trackingSource = new TrackingSource(airflowContext);
}
if (azkabanTaskContext != null) {
this.parentChildMap.add(Arrays.asList(azkabanTaskContext.taskRunUid(), driverTaskUid));
this.upstreamsMap.add(Arrays.asList(azkabanTaskContext.taskRunUid(), driverTaskUid));
trackingSource = new TrackingSource(azkabanTaskContext);
}
return new TaskRunsInfo(
taskRunEnvUid,
parentChildMap,
runUid,
Collections.singletonList(driverTask),
Collections.emptyList(),
runUid,
upstreamsMap,
false,
Collections.singletonList(
new TaskDefinition(
methodName,
sourceCode,
new Sha1Long("SOURCE", runId).toString(),
"",
taskDefinitionUid,
new Sha1Long("MODULE_SOURCE", runId).toString(),
taskParamDefinitions,
"jvm_task",
"java",
""
)
),
trackingSource
);
}
// TODO: actual source code
protected String extractSourceCode(Method method) {
return "";
}
@Override
public void startTask(Method method, Object[] args) {
RunAndDefinition runAndDefinition = buildRunAndDefinition(method, args, !stack.isEmpty());
TaskRun taskRun = runAndDefinition.taskRun();
taskRuns.add(taskRun);
TaskDefinition taskDefinition = runAndDefinition.taskDefinition();
taskDefinitions.add(taskDefinition);
TaskRun parent = stack.isEmpty() ? driverTask : stack.peek();
// detect nested tasks
if (!stack.isEmpty()) {
upstreamsMap.add(Arrays.asList(parent.getTaskRunUid(), taskRun.getTaskRunUid()));
}
taskRun.addUpstream(parent);
// detect upstream-downstream relations
for (Object arg : args) {
if (arg == null) {
continue;
}
TaskRun parentTask = taskRunOutputs.get(arg.hashCode());
if (parentTask != null) {
upstreamsMap.add(Arrays.asList(taskRun.getTaskRunUid(), parentTask.getTaskRunUid()));
}
}
stack.push(taskRun);
parentChildMap.add(Arrays.asList(parent.getTaskRunUid(), taskRun.getTaskRunUid()));
dbnd.addTaskRuns(rootRunUid, runId, taskRuns, taskDefinitions, parentChildMap, upstreamsMap);
dbnd.logTargets(taskRun.getTaskRunUid(), runAndDefinition.targets());
dbnd.updateTaskRunAttempt(taskRun.getTaskRunUid(), taskRun.getTaskRunAttemptUid(), "RUNNING", null, taskRun.getStartDate());
LOG.info("TASK: task_id={}", taskRun.getTaskId());
LOG.info("TIME: start={}", taskRun.getExecutionDate());
LOG.info("TRACKER: {}/app/jobs/{}/{}/{}", config.databandUrl(), this.driverTask.getTaskAfId(), this.driverTask.getRunUid(), taskRun.getTaskRunUid());
}
protected List buildTaskParamDefinitions(Method method) {
if (method == null) {
return Collections.emptyList();
}
return methodsCache.computeIfAbsent(method, method1 -> {
List result = new ArrayList<>(method.getParameterCount());
for (int i = 0; i < method.getParameterCount(); i++) {
Parameter parameter = method.getParameters()[i];
result.add(
new TaskParamDefinition(
parameter.getName(),
"task_input",
"user",
true,
false,
parameter.getParameterizedType().getTypeName(),
"",
""
)
);
}
result.add(
new TaskParamDefinition(
"result",
"task_output",
"user",
true,
false,
method.getReturnType().getTypeName(),
"",
""
)
);
return result;
});
}
protected Pair, List> buildTaskRunParamsAndTargets(Method method,
Object[] args,
String taskRunUid,
String methodName,
String taskRunAttemptUid,
String taskDefinitionUid) {
if (method == null || args == null || args.length == 0) {
return new Pair<>(Collections.emptyList(), Collections.emptyList());
}
List targets = new ArrayList<>(1);
List params = new ArrayList<>(method.getParameterCount());
for (int i = 0; i < method.getParameterCount(); i++) {
Parameter parameter = method.getParameters()[i];
Object parameterValue = args[i];
TaskParameterPreview preview = parameters.get(parameter.getType());
String compactPreview = preview.compact(parameterValue);
params.add(
new TaskRunParam(
compactPreview,
"",
parameter.getName()
)
);
String targetPath = String.format("%s.%s", method.getName(), parameter.getName());
targets.add(
new LogTarget(
rootRunUid,
taskRunUid,
methodName,
taskRunAttemptUid,
targetPath,
parameter.getName(),
taskDefinitionUid,
"read",
"OK",
preview.full(parameterValue),
preview.dimensions(parameterValue),
preview.schema(parameterValue),
new Sha1Long("", compactPreview).toString()
)
);
}
TaskParameterPreview resultPreview = parameters.get(method.getReturnType());
params.add(
new TaskRunParam(
resultPreview.typeName(method.getReturnType()),
"",
"result"
)
);
return new Pair<>(params, targets);
}
public String getTaskName(Method method) {
if (method == null || method.getName().contains("$anon")) {
// we're running from spark-submit
return config.sparkAppName();
}
Optional taskAnnotation = Arrays.stream(method.getAnnotations())
.filter(at -> at.toString().contains("ai.databand.annotations.Task(value="))
.findAny();
if (!taskAnnotation.isPresent()) {
return method.getName();
}
String annotationStr = taskAnnotation.get().toString();
String annotationValue = annotationStr.substring(annotationStr.indexOf('=') + 1, annotationStr.indexOf(')'));
return annotationValue.isEmpty() ? method.getName() : annotationValue;
}
protected RunAndDefinition buildRunAndDefinition(Method method, Object[] args, boolean hasUpstreams) {
int executionCount = methodExecutionCounts.computeIfAbsent(method, m -> 0);
executionCount++;
String taskName = getTaskName(method);
String methodName = executionCount == 1 ? taskName : String.format("%s_%s", taskName, executionCount);
methodExecutionCounts.put(method, executionCount);
List paramDefinitions = buildTaskParamDefinitions(method);
String taskRunId = UUID.randomUUID().toString();
String taskRunUid = new Uuid5("TASK_RUN_UID", taskRunId).toString();
String taskSignature = new Sha1Short("TASK_SIGNATURE" + methodName, runId).toString();
String taskDefinitionUid = new Uuid5("TASK_DEFINITION" + methodName, runId).toString();
String taskRunAttemptUid = new Uuid5("TASK_RUN_ATTEMPT" + methodName, runId).toString();
String taskAfId = methodName;
ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
Pair, List> paramsAndTargets = buildTaskRunParamsAndTargets(
method,
args,
taskRunUid,
methodName,
taskRunAttemptUid,
taskDefinitionUid
);
List params = paramsAndTargets.left();
List targets = paramsAndTargets.right();
TaskRun taskRun = new TaskRun(
rootRunUid,
false,
false,
null,
"",
taskRunUid,
taskSignature,
taskAfId,
params,
taskSignature,
false,
now.toLocalDate(),
now,
"",
"QUEUED",
taskDefinitionUid,
methodName,
false,
hasUpstreams,
taskRunAttemptUid,
taskAfId,
airflowContext != null,
false,
methodName,
taskAfId,
"jvm",
Collections.emptyMap()
);
TaskDefinition taskDefinition = new TaskDefinition(
methodName,
"",
new Sha1Long("SOURCE", runId).toString(),
"",
taskDefinitionUid,
new Sha1Long("MODULE_SOURCE", runId).toString(),
paramDefinitions,
"jvm_task",
"java",
""
);
methodsRunsCache.put(method, taskRun);
return new RunAndDefinition(taskRun, taskDefinition, targets);
}
@Override
public void errorTask(Method method, Throwable error) {
TaskRun task = stack.pop();
if (task == null) {
return;
}
String stackTrace = extractStackTrace(error);
task.appendLog(stackTrace);
dbnd.saveTaskLog(task.getTaskRunAttemptUid(), task.getTaskLog());
dbnd.logMetrics(task.getTaskRunAttemptUid(), task.getMetrics(), "spark");
ErrorInfo errorInfo = new ErrorInfo(
error.getLocalizedMessage(),
"",
false,
stackTrace,
"",
"",
false,
error.getClass().getCanonicalName()
);
dbnd.updateTaskRunAttempt(task.getTaskRunUid(), task.getTaskRunAttemptUid(), "FAILED", errorInfo, task.getStartDate());
}
protected String extractStackTrace(Throwable error) {
try (StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw)) {
error.printStackTrace(pw);
return sw.toString();
} catch (IOException e) {
LOG.error("Unable to extract stack trace from error", e);
return "";
}
}
@Override
public void completeTask(Method method, Object result) {
TaskRun task = stack.pop();
if (task == null) {
return;
}
if (result != null) {
TaskParameterPreview taskParameter = parameters.get(result.getClass());
String preview = taskParameter.full(result);
taskRunOutputs.put(result.hashCode(), task);
dbnd.logTargets(
task.getTaskRunUid(),
Collections.singletonList(
new LogTarget(
rootRunUid,
task.getTaskRunUid(),
task.getTaskAfId(),
task.getTaskRunAttemptUid(),
new Sha1Long("TARGET_PATH", preview).toString(),
"result",
task.getTaskDefinitionUid(),
"write",
"OK",
preview,
taskParameter.dimensions(result),
taskParameter.schema(result),
new Sha1Long("", preview).toString()
)
));
}
dbnd.saveTaskLog(task.getTaskRunAttemptUid(), task.getTaskLog());
dbnd.logMetrics(task.getTaskRunAttemptUid(), task.getMetrics(), "spark");
dbnd.updateTaskRunAttempt(task.getTaskRunUid(), task.getTaskRunAttemptUid(), "SUCCESS", null, task.getStartDate());
}
@Override
public void stop() {
dbnd.saveTaskLog(driverTask.getTaskRunAttemptUid(), driverTask.getTaskLog());
dbnd.logMetrics(driverTask.getTaskRunAttemptUid(), driverTask.getMetrics(), "spark");
dbnd.updateTaskRunAttempt(driverTask.getTaskRunUid(), driverTask.getTaskRunAttemptUid(), "SUCCESS", null, driverTask.getStartDate());
if (rootRunUid == null) {
// for agentless runs created inside Databand Context (when root run is outside of JVM) we shouldn't complete run
return;
}
dbnd.setRunState(rootRunUid, "SUCCESS");
}
@Override
public void stopExternal() {
if (driverTask == null) {
return;
}
dbnd.saveTaskLog(driverTask.getTaskRunAttemptUid(), driverTask.getTaskLog());
dbnd.logMetrics(driverTask.getTaskRunAttemptUid(), driverTask.getMetrics(), "spark");
}
public void error(Throwable error) {
String stackTrace = extractStackTrace(error);
ErrorInfo errorInfo = new ErrorInfo(
error.getLocalizedMessage(),
"",
false,
stackTrace,
"",
"",
false,
error.getClass().getCanonicalName()
);
driverTask.appendLog(stackTrace);
dbnd.saveTaskLog(driverTask.getTaskRunAttemptUid(), driverTask.getTaskLog());
dbnd.logMetrics(driverTask.getTaskRunAttemptUid(), driverTask.getMetrics(), "spark");
dbnd.updateTaskRunAttempt(driverTask.getTaskRunUid(), driverTask.getTaskRunAttemptUid(), "FAILED", errorInfo, driverTask.getStartDate());
dbnd.setRunState(rootRunUid, "FAILED");
}
@Override
public void logMetric(String key, Object value) {
TaskRun currentTask = stack.peek();
if (currentTask == null) {
currentTask = driverTask;
}
this.logMetric(currentTask, key, value, null);
}
@Override
public void logMetrics(Map metrics) {
this.logMetrics(metrics, null);
}
@Override
public void logMetrics(Map metrics, String source) {
TaskRun currentTask = stack.peek();
if (currentTask == null) {
currentTask = driverTask;
}
this.logMetrics(currentTask, metrics, source);
}
@Override
public void logDataframe(String key, Dataset> value, HistogramRequest histogramRequest) {
try {
TaskRun currentTask = stack.peek();
if (currentTask == null) {
currentTask = driverTask;
}
logMetric(currentTask, key, value, "user", false);
dbnd.logMetrics(currentTask.getTaskRunAttemptUid(), new Histogram(key, value, histogramRequest).metricValues(), "histograms");
} catch (Exception e) {
LOG.error("Unable to log dataframe", e);
}
}
@Override
public void logHistogram(Map histogram) {
try {
TaskRun currentTask = stack.peek();
if (currentTask == null) {
currentTask = driverTask;
}
dbnd.logMetrics(currentTask.getTaskRunAttemptUid(), histogram, "histograms");
} catch (Exception e) {
LOG.error("Unable to log histogram", e);
}
}
@Override
public void logDatasetOperation(String path,
DatasetOperationType type,
DatasetOperationStatus status,
String error,
String valuePreview,
List dataDimensions,
Object dataSchema,
Boolean withPartition) {
try {
TaskRun currentTask = stack.peek();
if (currentTask == null) {
currentTask = driverTask;
}
dbnd.logDatasetOperations(currentTask.getTaskRunUid(), Collections.singletonList(
new LogDataset(
currentTask,
path,
type,
status,
error,
valuePreview,
dataDimensions,
dataSchema,
withPartition
)
));
} catch (Exception e) {
LOG.error("Unable to log dataset operation", e);
}
}
@Override
public void logDatasetOperation(String path,
DatasetOperationType type,
DatasetOperationStatus status,
Dataset> data,
Throwable error,
boolean withPreview,
boolean withSchema,
Boolean withPartition) {
TaskParameterPreview preview = withSchema ? new DatasetOperationPreview() : new NullPreview();
String errorStr = null;
if (error != null) {
StringWriter sw = new StringWriter();
try (PrintWriter pw = new PrintWriter(sw)) {
error.printStackTrace(pw);
errorStr = sw.toString();
}
}
logDatasetOperation(path, type, status, errorStr, preview.full(data), preview.dimensions(data), preview.schema(data), withPartition);
}
public void logMetric(TaskRun taskRun, String key, Object value, String source) {
logMetric(taskRun, key, value, source, true);
}
public void logMetric(TaskRun taskRun, String key, Object value, String source, boolean compact) {
try {
if (taskRun == null) {
return;
}
TaskParameterPreview taskParameter = parameters.get(value.getClass());
dbnd.logMetric(
taskRun.getTaskRunAttemptUid(),
key,
compact ? taskParameter.compact(value) : taskParameter.full(value),
source
);
} catch (Exception e) {
LOG.error("Unable to log metric", e);
}
}
public void logMetrics(TaskRun taskRun, Map metrics, String source) {
try {
if (taskRun == null) {
return;
}
Map result = new HashMap<>(metrics.size());
for (Map.Entry entry : metrics.entrySet()) {
TaskParameterPreview taskParameter = parameters.get(entry.getValue().getClass());
result.put(entry.getKey(), taskParameter.compact(entry.getValue()));
}
dbnd.logMetrics(taskRun.getTaskRunAttemptUid(), result, source);
} catch (Exception e) {
LOG.error("Unable to log metrics");
}
}
@Override
public void saveLog(LoggingEvent event, String formattedEvent) {
try {
if (driverTask == null) {
return;
}
TaskRun currentTask = stack.peek();
// TODO: filter out unrelated messages
if (DbndClient.class.getName().equals(event.getLoggerName())) {
return;
}
if (currentTask == null) {
driverTask.appendLog(formattedEvent);
} else {
currentTask.appendLog(formattedEvent);
}
} catch (Exception e) {
LOG.error("Unable to save task log", e);
}
}
@Override
public void saveSparkMetrics(SparkListenerStageCompleted event) {
try {
StageInfo stageInfo = event.stageInfo();
TaskRun currentTask = stack.peek();
if (currentTask == null) {
currentTask = driverTask;
}
String transformationName = stageInfo.name().substring(0, stageInfo.name().indexOf(' '));
String metricPrefix = String.format("stage-%s.%s.", stageInfo.stageId(), transformationName);
Iterator> it = stageInfo.taskMetrics().accumulators().iterator();
Map values = new HashMap<>(1);
Map prefixedValues = new HashMap<>(1);
while (it.hasNext()) {
AccumulatorV2, ?> next = it.next();
// we're capturing only numeric values
if (!(next instanceof LongAccumulator)) {
continue;
}
String metricName = next.name().get();
String value = String.valueOf(next.value());
prefixedValues.put(metricPrefix + metricName, value);
values.put(metricName, value);
}
currentTask.appendMetrics(values);
currentTask.appendPrefixedMetrics(prefixedValues);
} catch (Exception e) {
LOG.error("Unable to save spark metrics", e);
}
}
public void setDriverTask(TaskRun driverTask) {
this.driverTask = driverTask;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy