All Downloads are FREE. Search and download functionalities are using the official Maven repository.

pl.touk.nussknacker.engine.compile.ProcessCompiler.scala Maven / Gradle / Ivy

The newest version!
package pl.touk.nussknacker.engine.compile

import cats.data.Validated._
import cats.data.{NonEmptyList, Validated, ValidatedNel}
import cats.instances.list._
import com.typesafe.scalalogging.LazyLogging
import pl.touk.nussknacker.engine._
import pl.touk.nussknacker.engine.api.context.ProcessCompilationError._
import pl.touk.nussknacker.engine.api.context._
import pl.touk.nussknacker.engine.api.dict.DictRegistry
import pl.touk.nussknacker.engine.api.process.ComponentUseCase
import pl.touk.nussknacker.engine.api.{JobData, MetaData, NodeId}
import pl.touk.nussknacker.engine.canonicalgraph.CanonicalProcess
import pl.touk.nussknacker.engine.canonize.ProcessCanonizer
import pl.touk.nussknacker.engine.compile.FragmentValidator.validateUniqueFragmentOutputNames
import pl.touk.nussknacker.engine.compile.nodecompilation.{LazyParameterCreationStrategy, NodeCompiler}
import pl.touk.nussknacker.engine.compile.nodecompilation.NodeCompiler.NodeCompilationResult
import pl.touk.nussknacker.engine.compiledgraph.part.{PotentiallyStartPart, TypedEnd}
import pl.touk.nussknacker.engine.compiledgraph.{CompiledProcessParts, part}
import pl.touk.nussknacker.engine.definition.fragment.FragmentParametersDefinitionExtractor
import pl.touk.nussknacker.engine.definition.model.ModelDefinitionWithClasses
import pl.touk.nussknacker.engine.expression.ExpressionEvaluator
import pl.touk.nussknacker.engine.graph.node.{Source => _, _}
import pl.touk.nussknacker.engine.resultcollector.PreventInvocationCollector
import pl.touk.nussknacker.engine.split._
import pl.touk.nussknacker.engine.splittedgraph._
import pl.touk.nussknacker.engine.splittedgraph.end.NormalEnd
import pl.touk.nussknacker.engine.splittedgraph.part._
import pl.touk.nussknacker.engine.splittedgraph.splittednode.EndingNode
import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap
import pl.touk.nussknacker.engine.util.ThreadUtils
import pl.touk.nussknacker.engine.variables.GlobalVariablesPreparer

import scala.util.control.NonFatal

class ProcessCompiler(
    protected val classLoader: ClassLoader,
    protected val sub: PartSubGraphCompiler,
    protected val globalVariablesPreparer: GlobalVariablesPreparer,
    protected val nodeCompiler: NodeCompiler,
    protected val customProcessValidator: CustomProcessValidator
) extends ProcessCompilerBase
    with ProcessValidator {

  override def withLabelsDictTyper: ProcessCompiler =
    new ProcessCompiler(
      classLoader,
      sub.withLabelsDictTyper,
      globalVariablesPreparer,
      nodeCompiler.withLabelsDictTyper,
      customProcessValidator
    )

  override def compile(
      process: CanonicalProcess
  )(implicit jobData: JobData): CompilationResult[CompiledProcessParts] = {
    super.compile(process)
  }

}

trait ProcessValidator extends LazyLogging {

  def validate(process: CanonicalProcess, isFragment: Boolean)(implicit jobData: JobData): CompilationResult[Unit] = {

    try {
      CompilationResult.map4(
        CompilationResult(IdValidator.validate(process, isFragment)),
        CompilationResult(validateWithCustomProcessValidators(process)),
        CompilationResult(validateUniqueFragmentOutputNames(process, isFragment)),
        compile(process).map(_ => ()): CompilationResult[Unit]
      )((_, _, _, _) => { () })
    } catch {
      case NonFatal(e) =>
        logger.warn(s"Unexpected error during compilation of ${process.name}", e)
        CompilationResult(Invalid(NonEmptyList.of(FatalUnknownError(e.getMessage))))
    }
  }

  private def validateWithCustomProcessValidators(
      process: CanonicalProcess
  ): ValidatedNel[ProcessCompilationError, Unit] = {
    customProcessValidator.validate(process)
  }

  def withLabelsDictTyper: ProcessValidator

  protected def compile(process: CanonicalProcess)(implicit jobData: JobData): CompilationResult[_]

  protected def customProcessValidator: CustomProcessValidator

}

protected trait ProcessCompilerBase {

  protected def sub: PartSubGraphCompiler

  protected def classLoader: ClassLoader

  protected def nodeCompiler: NodeCompiler

  protected def globalVariablesPreparer: GlobalVariablesPreparer

  protected def compile(
      process: CanonicalProcess
  )(implicit jobData: JobData): CompilationResult[CompiledProcessParts] = {
    ThreadUtils.withThisAsContextClassLoader(classLoader) {
      val compilationResultWithArtificial =
        ProcessCanonizer.uncanonizeArtificial(process).map(ProcessSplitter.split).map(compile)
      compilationResultWithArtificial.extract
    }
  }

  private def contextWithOnlyGlobalVariables(implicit jobData: JobData): ValidationContext =
    globalVariablesPreparer.prepareValidationContextWithGlobalVariablesOnly(jobData)

  private def compile(
      splittedProcess: SplittedProcess
  )(implicit jobData: JobData): CompilationResult[CompiledProcessParts] =
    CompilationResult.map2(
      CompilationResult(findDuplicates(splittedProcess.sources).toValidatedNel),
      compileSources(splittedProcess.sources)
    ) { (_, sources) =>
      CompiledProcessParts(sources)
    }

  /*
    We need to sort SourceParts to know types of variables in branches for joins. See comment in PartSort
    In the future we'll probably move to direct representation of process as graph and this will no longer be needed
   */
  private def compileSources(
      sources: NonEmptyList[SourcePart]
  )(implicit jobData: JobData): CompilationResult[NonEmptyList[PotentiallyStartPart]] = {
    val zeroAcc = (CompilationResult(Valid(List[PotentiallyStartPart]())), new BranchEndContexts(Nil))
    // we use fold here (and not map/sequence), because we can compile part which starts from Join only when we
    // know compilation results (stored in BranchEndContexts) of all branches that end in this join
    val (result, _) =
      PartSort.sort(sources.toList).foldLeft(zeroAcc) { case ((resultSoFar, branchContexts), nextSourcePart) =>
        val compiledPart = compile(nextSourcePart, branchContexts)
        // we don't use andThen on CompilationResult, since we don't want to stop if there are errors in part
        val nextResult = CompilationResult.map2(resultSoFar, compiledPart)(_ :+ _)
        (nextResult, branchContexts.addPart(nextSourcePart, compiledPart))
      }
    result.map(NonEmptyList.fromListUnsafe)
  }

  private def findDuplicates(parts: NonEmptyList[SourcePart]): Validated[ProcessCompilationError, Unit] = {
    val allNodes = NodesCollector.collectNodesInAllParts(parts)
    val duplicatedIds =
      allNodes.map(_.id).groupBy(identity).collect {
        case (id, grouped) if grouped.size > 1 =>
          id
      }
    if (duplicatedIds.isEmpty)
      valid(())
    else
      invalid(DuplicatedNodeIds(duplicatedIds.toSet))
  }

  private def compile(source: SourcePart, branchEndContexts: BranchEndContexts)(
      implicit jobData: JobData
  ): CompilationResult[compiledgraph.part.PotentiallyStartPart] = {
    implicit val nodeId: NodeId = new NodeId(source.id)

    source match {
      case SourcePart(splittednode.SourceNode(sourceData: SourceNodeData, _), _, _) =>
        compileSourcePart(source, sourceData)
      case SourcePart(srcNode @ splittednode.SourceNode(data: Join, _), _, _) =>
        val node = srcNode.asInstanceOf[splittednode.SourceNode[Join]]
        compileCustomNodePart(source, node, data, Right(branchEndContexts))
    }

  }

  private def compileParts(parts: List[SubsequentPart], ctx: Map[String, ValidationContext])(
      implicit jobData: JobData
  ): CompilationResult[List[compiledgraph.part.SubsequentPart]] = {
    import CompilationResult._
    parts
      .map(p =>
        ctx
          .get(p.id)
          .map(compileSubsequentPart(p, _))
          .getOrElse(CompilationResult(Invalid(NonEmptyList.of[ProcessCompilationError](MissingPart(p.id)))))
      )
      .sequence
  }

  private def compileSubsequentPart(part: SubsequentPart, ctx: ValidationContext)(
      implicit jobData: JobData
  ): CompilationResult[compiledgraph.part.SubsequentPart] = {
    implicit val nodeId: NodeId = NodeId(part.id)
    part match {
      case SinkPart(node) =>
        compileSinkPart(node, ctx)
      case CustomNodePart(node @ splittednode.EndingNode(data), _, _) =>
        compileEndingCustomNodePart(node, data, ctx)
      case CustomNodePart(node @ splittednode.OneOutputSubsequentNode(data, _), _, _) =>
        compileCustomNodePart(part, node, data, Left(ctx))
    }
  }

  def compileSourcePart(
      part: SourcePart,
      sourceData: SourceNodeData
  )(implicit nodeId: NodeId, jobData: JobData): CompilationResult[compiledgraph.part.SourcePart] = {
    val NodeCompilationResult(typingInfo, parameters, initialCtx, compiledSource, _) =
      nodeCompiler.compileSource(sourceData)

    val validatedSource = sub.validate(part.node, initialCtx.valueOr(_ => contextWithOnlyGlobalVariables))
    val typesForParts   = validatedSource.typing.mapValuesNow(_.inputValidationContext)
    val nodeTypingInfo  = Map(part.id -> NodeTypingInfo(contextWithOnlyGlobalVariables, typingInfo, parameters))

    CompilationResult.map4(
      validatedSource,
      compileParts(part.nextParts, typesForParts),
      CompilationResult(initialCtx),
      CompilationResult(nodeTypingInfo, compiledSource)
    ) { (_, nextParts, ctx, obj) =>
      compiledgraph.part.SourcePart(
        obj,
        splittednode.SourceNode(sourceData, part.node.next),
        ctx,
        nextParts,
        part.ends.map(e => TypedEnd(e, typesForParts(e.nodeId)))
      )
    }
  }

  def compileSinkPart(
      node: EndingNode[Sink],
      ctx: ValidationContext
  )(implicit jobData: JobData, nodeId: NodeId): CompilationResult[part.SinkPart] = {
    val NodeCompilationResult(typingInfo, parameters, _, compiledSink, _) = nodeCompiler.compileSink(node.data, ctx)
    val nodeTypingInfo = Map(node.id -> NodeTypingInfo(ctx, typingInfo, parameters))
    CompilationResult.map2(sub.validate(node, ctx), CompilationResult(nodeTypingInfo, compiledSink))((_, obj) =>
      compiledgraph.part.SinkPart(obj, node, ctx, ctx)
    )
  }

  def compileEndingCustomNodePart(
      node: splittednode.EndingNode[CustomNode],
      data: CustomNodeData,
      ctx: ValidationContext
  )(implicit jobData: JobData, nodeId: NodeId): CompilationResult[compiledgraph.part.CustomNodePart] = {
    val NodeCompilationResult(typingInfo, parameters, validatedNextCtx, compiledNode, _) =
      nodeCompiler.compileCustomNodeObject(data, Left(ctx), ending = true)
    val nodeTypingInfo = Map(node.id -> NodeTypingInfo(ctx, typingInfo, parameters))

    CompilationResult
      .map2(
        CompilationResult(nodeTypingInfo, compiledNode),
        CompilationResult(validatedNextCtx)
      ) { (nodeInvoker, nextCtx) =>
        compiledgraph.part.CustomNodePart(
          nodeInvoker,
          node,
          ctx,
          nextCtx,
          List.empty,
          List(TypedEnd(NormalEnd(node.id), ctx))
        )
      }
      .distinctErrors
  }

  def compileCustomNodePart(
      part: ProcessPart,
      node: splittednode.OneOutputNode[CustomNodeData],
      data: CustomNodeData,
      ctx: Either[ValidationContext, BranchEndContexts]
  )(implicit jobData: JobData, nodeId: NodeId): CompilationResult[compiledgraph.part.CustomNodePart] = {
    val NodeCompilationResult(typingInfo, parameters, validatedNextCtx, compiledNode, _) =
      nodeCompiler.compileCustomNodeObject(data, ctx.map(_.contextsForJoin(data.id)), ending = false)

    val nextPartsValidation =
      sub.validate(node, validatedNextCtx.valueOr(_ => ctx.left.getOrElse(contextWithOnlyGlobalVariables)))
    val typesForParts = nextPartsValidation.typing.mapValuesNow(_.inputValidationContext)
    val nodeTypingInfo = Map(
      node.id -> NodeTypingInfo(ctx.left.getOrElse(contextWithOnlyGlobalVariables), typingInfo, parameters)
    )

    CompilationResult
      .map4(
        CompilationResult(nodeTypingInfo, compiledNode),
        nextPartsValidation,
        compileParts(part.nextParts, typesForParts),
        CompilationResult(validatedNextCtx)
      ) { (nodeInvoker, _, nextPartsCompiled, nextCtx) =>
        // TODO: what should be passed for joins here?
        compiledgraph.part.CustomNodePart(
          nodeInvoker,
          node,
          ctx.left.getOrElse(ValidationContext.empty),
          nextCtx,
          nextPartsCompiled,
          part.ends.map(e => TypedEnd(e, typesForParts(e.nodeId)))
        )
      }
      .distinctErrors
  }

  private class BranchEndContexts(joinIdBranchIdContexts: List[(String, (String, ValidationContext))]) {

    def addPart(part: ProcessPart, result: CompilationResult[_]): BranchEndContexts = {
      val newData =
        NodesCollector.collectNodesInAllParts(part).collect { case splittednode.EndingNode(BranchEndData(definition)) =>
          definition.joinId -> (definition.id -> result.typing
            .apply(definition.artificialNodeId)
            .inputValidationContext)
        }
      new BranchEndContexts(joinIdBranchIdContexts ++ newData)
    }

    def contextsForJoin(joinId: String): Map[String, ValidationContext] = joinIdBranchIdContexts.collect {
      case (`joinId`, data) => data
    }.toMap

  }

}

object ProcessValidator {

  def default(modelData: ModelData): ProcessValidator = {
    default(
      modelData.modelDefinitionWithClasses,
      modelData.designerDictServices.dictRegistry,
      modelData.customProcessValidator,
      modelData.modelClassLoader.classLoader
    )
  }

  def default(
      definitionWithTypes: ModelDefinitionWithClasses,
      dictRegistry: DictRegistry,
      customProcessValidator: CustomProcessValidator,
      classLoader: ClassLoader = getClass.getClassLoader
  ): ProcessValidator = {
    import definitionWithTypes.modelDefinition

    val globalVariablesPreparer = GlobalVariablesPreparer(modelDefinition.expressionConfig)
    val expressionEvaluator     = ExpressionEvaluator.unOptimizedEvaluator(globalVariablesPreparer)

    val expressionCompiler = ExpressionCompiler.withoutOptimization(
      classLoader,
      dictRegistry,
      modelDefinition.expressionConfig,
      definitionWithTypes.classDefinitions,
      expressionEvaluator
    )

    val nodeCompiler = new NodeCompiler(
      modelDefinition,
      new FragmentParametersDefinitionExtractor(classLoader),
      expressionCompiler,
      classLoader,
      Seq.empty,
      PreventInvocationCollector,
      ComponentUseCase.Validation,
      nonServicesLazyParamStrategy = LazyParameterCreationStrategy.default
    )
    val sub = new PartSubGraphCompiler(nodeCompiler)
    new ProcessCompiler(
      classLoader,
      sub,
      globalVariablesPreparer,
      nodeCompiler,
      customProcessValidator
    )
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy