
org.flyte.jflyte.utils.ProjectClosure Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jflyte-utils Show documentation
Show all versions of jflyte-utils Show documentation
Primarily used by jflyte, but can also be used to extend or build a jflyte alternative
/*
* Copyright 2021-2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.jflyte.utils;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static org.flyte.jflyte.utils.MoreCollectors.mapValues;
import static org.flyte.jflyte.utils.MoreCollectors.toUnmodifiableList;
import static org.flyte.jflyte.utils.MoreCollectors.toUnmodifiableMap;
import static org.flyte.jflyte.utils.QuantityUtil.asJavaQuantity;
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.google.protobuf.ByteString;
import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.flyte.api.v1.Container;
import org.flyte.api.v1.ContainerTask;
import org.flyte.api.v1.ContainerTaskRegistrar;
import org.flyte.api.v1.DynamicWorkflowTask;
import org.flyte.api.v1.DynamicWorkflowTaskRegistrar;
import org.flyte.api.v1.IfBlock;
import org.flyte.api.v1.IfElseBlock;
import org.flyte.api.v1.KeyValuePair;
import org.flyte.api.v1.LaunchPlan;
import org.flyte.api.v1.LaunchPlanIdentifier;
import org.flyte.api.v1.LaunchPlanRegistrar;
import org.flyte.api.v1.Node;
import org.flyte.api.v1.PartialTaskIdentifier;
import org.flyte.api.v1.PartialWorkflowIdentifier;
import org.flyte.api.v1.Resources;
import org.flyte.api.v1.Resources.ResourceName;
import org.flyte.api.v1.RunnableTask;
import org.flyte.api.v1.RunnableTaskRegistrar;
import org.flyte.api.v1.Struct;
import org.flyte.api.v1.Task;
import org.flyte.api.v1.TaskIdentifier;
import org.flyte.api.v1.TaskTemplate;
import org.flyte.api.v1.WorkflowIdentifier;
import org.flyte.api.v1.WorkflowNode;
import org.flyte.api.v1.WorkflowTemplate;
import org.flyte.api.v1.WorkflowTemplateRegistrar;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@AutoValue
public abstract class ProjectClosure {
private static final Logger LOG = LoggerFactory.getLogger(ProjectClosure.class);
public abstract Map taskSpecs();
public abstract Map workflowSpecs();
public abstract Map launchPlans();
ProjectClosure applyCustom(JFlyteCustom custom) {
Map rewrittenTaskSpecs =
mapValues(taskSpecs(), x -> applyCustom(x, custom));
return ProjectClosure.builder()
.workflowSpecs(workflowSpecs())
.launchPlans(launchPlans())
.taskSpecs(rewrittenTaskSpecs)
.build();
}
public void serialize(BiConsumer output) {
int size = taskSpecs().size() + launchPlans().size() + workflowSpecs().size();
int sizeDigits = (int) (Math.log10(size) + 1);
AtomicInteger counter = new AtomicInteger();
taskSpecs()
.forEach(
(id, spec) -> {
int i = counter.getAndIncrement();
String filename = String.format("%0" + sizeDigits + "d_%s_1.pb", i, id.name());
output.accept(filename, ProtoUtil.serialize(spec).toByteString());
});
workflowSpecs()
.forEach(
(id, spec) -> {
int i = counter.getAndIncrement();
String filename = String.format("%0" + sizeDigits + "d_%s_2.pb", i, id.name());
output.accept(filename, ProtoUtil.serialize(id, spec).toByteString());
});
launchPlans()
.forEach(
(id, spec) -> {
int i = counter.getAndIncrement();
String filename = String.format("%0" + sizeDigits + "d_%s_3.pb", i, id.name());
output.accept(filename, ProtoUtil.serialize(spec).toByteString());
});
}
private static TaskSpec applyCustom(TaskSpec taskSpec, JFlyteCustom custom) {
Struct rewrittenCustom = merge(custom.serializeToStruct(), taskSpec.taskTemplate().custom());
TaskTemplate rewrittenTaskTemplate =
taskSpec.taskTemplate().toBuilder().custom(rewrittenCustom).build();
return TaskSpec.create(rewrittenTaskTemplate);
}
public static ProjectClosure loadAndStage(
String packageDir,
ExecutionConfig config,
Supplier stagerSupplier,
FlyteAdminClient adminClient) {
IdentifierRewrite rewrite =
IdentifierRewrite.builder()
.adminClient(adminClient)
.domain(config.domain())
.project(config.project())
.version(config.version())
.build();
// before we run anything, switch class loader, because we will be touching user classes;
// setting it in thread context will give us access to the right class loader
ClassLoader packageClassLoader = ClassLoaders.forDirectory(new File(packageDir));
ProjectClosure closure = ProjectClosure.load(config, rewrite, packageClassLoader);
List artifacts;
if (isStagingRequired(closure)) {
artifacts = stagePackageFiles(stagerSupplier.get(), packageDir);
} else {
artifacts = emptyList();
LOG.info(
"Skipping artifact staging because there are no runnable tasks or dynamic workflow tasks");
}
JFlyteCustom custom = JFlyteCustom.builder().artifacts(artifacts).build();
return closure.applyCustom(custom);
}
private static boolean isStagingRequired(ProjectClosure closure) {
return closure.taskSpecs().values().stream()
.map(TaskSpec::taskTemplate)
.map(TaskTemplate::type)
.anyMatch(type -> !type.equals("raw-container"));
}
private static List stagePackageFiles(ArtifactStager stager, String packageDir) {
try (Stream fileStream = Files.list(Paths.get(packageDir))) {
List files =
fileStream.map(x -> x.toFile().getAbsolutePath()).collect(toUnmodifiableList());
return stager.stageFiles(files);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
static ProjectClosure load(
ExecutionConfig config, IdentifierRewrite rewrite, ClassLoader packageClassLoader) {
Map env =
ImmutableMap.builder()
// we keep JFLYTE_ only for backwards-compatibility
.put("JFLYTE_DOMAIN", config.domain())
.put("JFLYTE_PROJECT", config.project())
.put("JFLYTE_VERSION", config.version())
.put("FLYTE_INTERNAL_DOMAIN", config.domain())
.put("FLYTE_INTERNAL_PROJECT", config.project())
.put("FLYTE_INTERNAL_VERSION", config.version())
.build();
// 1. load classes, and create templates
Map runnableTasks =
ClassLoaders.withClassLoader(
packageClassLoader, () -> Registrars.loadAll(RunnableTaskRegistrar.class, env));
Map dynamicWorkflowTasks =
ClassLoaders.withClassLoader(
packageClassLoader, () -> Registrars.loadAll(DynamicWorkflowTaskRegistrar.class, env));
Map containerTasks =
ClassLoaders.withClassLoader(
packageClassLoader, () -> Registrars.loadAll(ContainerTaskRegistrar.class, env));
Map workflows =
ClassLoaders.withClassLoader(
packageClassLoader, () -> Registrars.loadAll(WorkflowTemplateRegistrar.class, env));
Map launchPlans =
ClassLoaders.withClassLoader(
packageClassLoader, () -> Registrars.loadAll(LaunchPlanRegistrar.class, env));
return load(
config,
rewrite,
runnableTasks,
dynamicWorkflowTasks,
containerTasks,
workflows,
launchPlans);
}
static ProjectClosure load(
ExecutionConfig config,
IdentifierRewrite rewrite,
Map runnableTasks,
Map dynamicWorkflowTasks,
Map containerTasks,
Map workflowTemplates,
Map launchPlans) {
Map taskTemplates =
createTaskTemplates(config, runnableTasks, dynamicWorkflowTasks, containerTasks);
// 2. rewrite workflows and launch plans
Map rewrittenWorkflowTemplates =
mapValues(
workflowTemplates,
rewrite::apply,
"Couldn't rewrite Workflow identifier: [%s]",
id -> new Object[] {id.name()});
Map rewrittenLaunchPlans =
mapValues(
launchPlans,
rewrite::apply,
"Couldn't rewrite Launch Plan identifier: [%s]",
id -> new Object[] {id.name()});
checkCycles(rewrittenWorkflowTemplates);
// 3. create specs for registration
Map workflowSpecs =
mapValues(
rewrittenWorkflowTemplates,
workflowTemplate -> {
Map subWorkflows =
collectSubWorkflows(workflowTemplate.nodes(), rewrittenWorkflowTemplates);
return WorkflowSpec.builder()
.workflowTemplate(workflowTemplate)
.subWorkflows(subWorkflows)
.build();
});
Map taskSpecs = mapValues(taskTemplates, TaskSpec::create);
return ProjectClosure.builder()
.taskSpecs(taskSpecs)
.workflowSpecs(workflowSpecs)
.launchPlans(rewrittenLaunchPlans)
.build();
}
@VisibleForTesting
static void checkCycles(Map allWorkflows) {
Optional cycle =
allWorkflows.keySet().stream()
.filter(
workflowId ->
checkCycles(
workflowId,
allWorkflows,
/* beingVisited= */ new HashSet<>(),
/* visited= */ new HashSet<>()))
.findFirst();
if (cycle.isPresent()) {
throw new IllegalArgumentException(
String.format(
"Workflow [%s] cannot have itself as a node, directly or indirectly", cycle.get()));
}
}
static boolean checkCycles(
WorkflowIdentifier workflowId,
Map allWorkflows,
Set beingVisited,
Set visited) {
beingVisited.add(workflowId);
WorkflowTemplate workflow = allWorkflows.get(workflowId);
List nodes =
workflow.nodes().stream().flatMap(ProjectClosure::flatBranch).collect(toUnmodifiableList());
for (Node node : nodes) {
if (isSubWorkflowNode(node)) {
PartialWorkflowIdentifier partialSubWorkflowId =
Objects.requireNonNull(node.workflowNode()).reference().subWorkflowRef();
WorkflowIdentifier subWorkflowId =
WorkflowIdentifier.builder()
.project(partialSubWorkflowId.project())
.name(partialSubWorkflowId.name())
.domain(partialSubWorkflowId.domain())
.version(partialSubWorkflowId.version())
.build();
if (beingVisited.contains(subWorkflowId) // backward edge
|| (!visited.contains(subWorkflowId)
&& checkCycles(subWorkflowId, allWorkflows, beingVisited, visited))) {
return true;
}
}
}
beingVisited.remove(workflowId);
visited.add(workflowId);
return false;
}
@VisibleForTesting
static Map collectSubWorkflows(
List nodes, Map allWorkflows) {
return collectSubWorkflows(nodes, allWorkflows, Function.identity());
}
public static Map collectSubWorkflows(
List nodes,
Map allWorkflows,
Function, List> nodesRewriter) {
List rewrittenNodes = nodesRewriter.apply(nodes);
return collectSubWorkflowIds(rewrittenNodes).stream()
// all identifiers should be rewritten at this point
.map(
workflowId ->
WorkflowIdentifier.builder()
.project(workflowId.project())
.name(workflowId.name())
.domain(workflowId.domain())
.version(workflowId.version())
.build())
.distinct()
.flatMap(
workflowId -> {
WorkflowTemplate subWorkflow = allWorkflows.get(workflowId);
if (subWorkflow == null) {
throw new NoSuchElementException(
"Can't find referenced sub-workflow " + workflowId);
}
Map nestedSubWorkflows =
collectSubWorkflows(subWorkflow.nodes(), allWorkflows, nodesRewriter);
return Stream.concat(
Stream.of(Maps.immutableEntry(workflowId, subWorkflow)),
nestedSubWorkflows.entrySet().stream());
})
.collect(toUnmodifiableMap());
}
public static Map collectDynamicWorkflowTasks(
List nodes,
Map allTasks,
Function remoteTaskTemplateFetcher) {
return collectTaskIds(nodes).stream()
// all identifiers should be rewritten at this point
.map(
taskId ->
TaskIdentifier.builder()
.project(taskId.project())
.name(taskId.name())
.domain(taskId.domain())
.version(taskId.version())
.build())
.distinct()
.map(
taskId -> {
TaskTemplate taskTemplate =
Optional.ofNullable(allTasks.get(taskId))
.orElseGet(() -> remoteTaskTemplateFetcher.apply(taskId));
if (taskTemplate == null) {
throw new NoSuchElementException("Can't find referenced task " + taskId);
}
return Maps.immutableEntry(taskId, taskTemplate);
})
.collect(toUnmodifiableMap());
}
private static List collectTaskIds(List rewrittenNodes) {
return rewrittenNodes.stream()
.filter(x -> x.taskNode() != null)
.map(x -> x.taskNode().referenceId())
.collect(toUnmodifiableList());
}
public static Map createTaskTemplates(
ExecutionConfig config,
Map runnableTasks,
Map dynamicWorkflowTasks,
Map containerTasks) {
Map taskTemplates = new HashMap<>();
runnableTasks.forEach(
(id, task) -> {
TaskTemplate taskTemplate = createTaskTemplateForRunnableTask(task, config.image());
taskTemplates.put(id, taskTemplate);
});
dynamicWorkflowTasks.forEach(
(id, task) -> {
TaskTemplate taskTemplate = createTaskTemplateForDynamicWorkflow(task, config.image());
taskTemplates.put(id, taskTemplate);
});
containerTasks.forEach(
(id, task) -> {
TaskTemplate taskTemplate = createTaskTemplateForContainerTask(task);
taskTemplates.put(id, taskTemplate);
});
return taskTemplates;
}
@VisibleForTesting
static TaskTemplate createTaskTemplateForRunnableTask(RunnableTask task, String image) {
Container container =
Container.builder()
.command(ImmutableList.of())
.args(
ImmutableList.of(
"jflyte",
"execute",
"--task",
task.getName(),
"--inputs",
"{{.input}}",
"--outputPrefix",
"{{.outputPrefix}}",
"--taskTemplatePath",
"{{.taskTemplatePath}}"))
.image(image)
.env(javaToolOptionsEnv(task).map(ImmutableList::of).orElse(ImmutableList.of()))
.resources(task.getResources())
.build();
return createTaskTemplate(task, container);
}
@VisibleForTesting
static TaskTemplate createTaskTemplateForContainerTask(ContainerTask task) {
Resources resources = task.getResources();
Container container =
Container.builder()
.command(task.getCommand())
.args(task.getArgs())
.image(task.getImage())
.env(task.getEnv())
.resources(resources)
.build();
return createTaskTemplate(task, container);
}
private static TaskTemplate createTaskTemplate(Task task, Container container) {
TaskTemplate.Builder templateBuilder =
TaskTemplate.builder()
.container(container)
.interface_(task.getInterface())
.retries(task.getRetries())
.type(task.getType())
.custom(task.getCustom())
.discoverable(task.isCached())
.cacheSerializable(task.isCacheSerializable());
if (task.getCacheVersion() != null) {
templateBuilder.discoveryVersion(task.getCacheVersion());
}
return templateBuilder.build();
}
private static Optional javaToolOptionsEnv(RunnableTask task) {
List javaToolOptions = new ArrayList<>();
Resources resources = task.getResources();
Map limits = resources.limits();
if (limits != null && limits.containsKey(ResourceName.MEMORY)) {
String maxMemory = asJavaQuantity(limits.get(ResourceName.MEMORY));
javaToolOptions.add("-Xmx" + maxMemory);
}
javaToolOptions.addAll(task.getCustomJavaToolOptions());
if (javaToolOptions.isEmpty()) {
return Optional.empty();
} else {
return Optional.of(KeyValuePair.of("JAVA_TOOL_OPTIONS", String.join(" ", javaToolOptions)));
}
}
private static TaskTemplate createTaskTemplateForDynamicWorkflow(
DynamicWorkflowTask task, String image) {
Container container =
Container.builder()
.command(ImmutableList.of())
.args(
ImmutableList.of(
"jflyte",
"execute-dynamic-workflow",
"--task",
task.getName(),
"--inputs",
"{{.input}}",
"--outputPrefix",
"{{.outputPrefix}}",
"--taskTemplatePath",
"{{.taskTemplatePath}}"))
.image(image)
.env(emptyList())
.build();
return TaskTemplate.builder()
.container(container)
.interface_(task.getInterface())
.retries(task.getRetries())
.type("container")
.custom(Struct.of(emptyMap()))
// TODO: consider if cache makes sense for a dynamic task then implement
// it or change this comment to explicitly say no cache for dynamic tasks
.discoverable(false)
.cacheSerializable(false)
.build();
}
@VisibleForTesting
public static Struct merge(Struct source, Struct target) {
Map fields = new HashMap<>(target.fields());
fields.putAll(source.fields());
return Struct.of(Collections.unmodifiableMap(fields));
}
private static List collectSubWorkflowIds(List rewrittenNodes) {
return rewrittenNodes.stream()
.flatMap(ProjectClosure::flatBranch)
.filter(ProjectClosure::isSubWorkflowNode)
.map(x -> Objects.requireNonNull(x.workflowNode()).reference().subWorkflowRef())
.collect(toUnmodifiableList());
}
private static Stream flatBranch(Node node) {
if (node.branchNode() == null) {
return Stream.of(node);
}
IfElseBlock ifElseBlock = node.branchNode().ifElse();
return Stream.concat(
ifElseBlock.other().stream().map(IfBlock::thenNode),
Stream.of(ifElseBlock.case_().thenNode(), ifElseBlock.elseNode()))
.filter(Objects::nonNull)
// Nested branch
.flatMap(ProjectClosure::flatBranch);
}
private static boolean isSubWorkflowNode(Node node) {
return node.workflowNode() != null
&& node.workflowNode().reference().kind() == WorkflowNode.Reference.Kind.SUB_WORKFLOW_REF;
}
static Builder builder() {
return new AutoValue_ProjectClosure.Builder();
}
@AutoValue.Builder
abstract static class Builder {
abstract Builder taskSpecs(Map taskSpecs);
abstract Builder launchPlans(Map launchPlans);
abstract Builder workflowSpecs(Map workflowSpecs);
abstract ProjectClosure build();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy