All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.api.scala.codegen.TreeGen.scala Maven / Gradle / Ivy

There is a newer version: 0.10.2-hadoop1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.api.scala.codegen

import scala.language.implicitConversions

import scala.reflect.macros.Context

private[flink] trait TreeGen[C <: Context] { this: MacroContextHolder[C] with TypeDescriptors[C] =>
  import c.universe._

    def mkDefault(tpe: Type): Tree = {
      tpe match {
        case definitions.BooleanTpe => Literal(Constant(false))
        case definitions.ByteTpe    => Literal(Constant(0: Byte))
        case definitions.CharTpe    => Literal(Constant(0: Char))
        case definitions.DoubleTpe  => Literal(Constant(0: Double))
        case definitions.FloatTpe   => Literal(Constant(0: Float))
        case definitions.IntTpe     => Literal(Constant(0: Int))
        case definitions.LongTpe    => Literal(Constant(0: Long))
        case definitions.ShortTpe   => Literal(Constant(0: Short))
        case definitions.UnitTpe    => Literal(Constant(()))
        case _            => Literal(Constant(null: Any))
      }
    }

    def mkUnit = reify( () ).tree
    def mkNull = reify( null ).tree
    def mkZero = reify( 0 ).tree
    def mkOne = reify( 1 ).tree

    def mkAsInstanceOf[T: c.WeakTypeTag](source: Tree): Tree =
      reify(c.Expr(source).splice.asInstanceOf[T]).tree

    def maybeMkAsInstanceOf[S: c.WeakTypeTag, T: c.WeakTypeTag](source: Tree): Tree = {
      if (weakTypeOf[S] <:< weakTypeOf[T]) {
        source
      } else {
        mkAsInstanceOf[T](source)
      }
    }

//    def mkIdent(target: Symbol): Tree = Ident(target) setType target.tpe
    def mkSelect(rootModule: String, path: String*): Tree =
      mkSelect(Ident(newTermName(rootModule)), path: _*)

    def mkSelect(source: Tree, path: String*): Tree =
      path.foldLeft(source) { (ret, item) => Select(ret, newTermName(item)) }

    def mkSelectSyms(source: Tree, path: Symbol*): Tree =
      path.foldLeft(source) { (ret, item) => Select(ret, item) }
    
    def mkCall(root: Tree, path: String*)(args: List[Tree]) = Apply(mkSelect(root, path: _*), args)

    def mkSeq(items: List[Tree]): Tree =
      Apply(mkSelect("_root_", "scala", "collection", "Seq", "apply"), items)

    def mkList(items: List[Tree]): Tree =
      Apply(mkSelect("_root_", "scala", "collection", "immutable", "List", "apply"), items)

    def mkMap(items: List[Tree]): Tree =
      Apply(mkSelect("_root_", "scala", "collection", "immutable", "Map", "apply"), items)

    def mkVal(name: String, flags: FlagSet, transient: Boolean, valTpe: Type, value: Tree): Tree = {
      ValDef(Modifiers(flags), newTermName(name), TypeTree(valTpe), value)
    }

    def mkVar(name: String, flags: FlagSet, transient: Boolean, valTpe: Type, value: Tree): Tree = {
      mkVal(name, flags | Flag.MUTABLE, transient, valTpe, value)
    }

    def mkValAndGetter(name: String, flags: FlagSet, valTpe: Type, value: Tree): List[Tree] = {
      val fieldName = name + " "
      val valDef = mkVal(fieldName, Flag.PRIVATE, false, valTpe, value)
      val defDef = mkMethod(name, flags, Nil, valTpe, Ident(newTermName(fieldName)))
      List(valDef, defDef)
    }

    def mkVarAndLazyGetter(
        name: String,
        flags: FlagSet,
        valTpe: Type,
        value: Tree): (Tree, Tree) = {
      val fieldName = name + " "
      val field = mkVar(fieldName, NoFlags, false, valTpe, mkNull)
      val fieldSel = Ident(newTermName(fieldName))

      val getter = mkMethod(name, flags, Nil, valTpe, {
        val eqeq = Select(fieldSel, newTermName("$eq$eq"))
        val chk = Apply(eqeq, List(mkNull))
        val init = Assign(fieldSel, value) 
        Block(List(If(chk, init, EmptyTree)), fieldSel)
      })

      (field, getter)
    }

    def mkIf(cond: Tree, bodyT: Tree): Tree = mkIf(cond, bodyT, EmptyTree)

    def mkIf(cond: Tree, bodyT: Tree, bodyF: Tree): Tree = cond match {
      case EmptyTree => bodyT
      case _         => If(cond, bodyT, bodyF)
    }

    def mkSingle(stats: Seq[Tree]): Tree = stats match {
      case Seq()     => EmptyTree
      case Seq(stat) => stat
      case _         => Block(stats.init.toList, stats.last)
    }

    def mkAnd(cond1: Tree, cond2: Tree): Tree = cond1 match {
      case EmptyTree => cond2
      case _ => cond2 match {
        case EmptyTree => cond1
        case _         => reify(c.Expr[Boolean](cond1).splice && c.Expr[Boolean](cond2).splice).tree
      }
    }

    def mkMethod(
        name: String,
        flags: FlagSet,
        args: List[(String, Type)],
        ret: Type,
        impl: Tree): Tree = {
      val valParams = args map { case (`name`, tpe) =>
        ValDef(Modifiers(Flag.PARAM), newTermName(name), TypeTree(tpe), EmptyTree)
      }
      DefDef(Modifiers(flags), newTermName(name), Nil, List(valParams), TypeTree(ret), impl)
    }

    def mkClass(
        name: TypeName,
        flags: FlagSet,
        parents: List[Type],
        members: List[Tree]): ClassDef = {
      val parentTypeTrees = parents map { TypeTree(_) }
      val selfType = ValDef(Modifiers(), nme.WILDCARD, TypeTree(NoType), EmptyTree)
      ClassDef(Modifiers(flags), name, Nil, Template(parentTypeTrees, selfType, members))
    }

    def mkThrow(tpe: Type, msg: Tree): Tree =
      Throw(Apply(Select(New(TypeTree(tpe)), nme.CONSTRUCTOR), List(msg)))

//    def mkThrow(tpe: Type, msg: Tree): Tree = Throw(New(TypeTree(tpe)), List(List(msg))))

    def mkThrow(tpe: Type, msg: String): Tree = mkThrow(tpe, c.literal(msg).tree)

    def mkThrow(msg: String): Tree = mkThrow(typeOf[java.lang.RuntimeException], msg)

    implicit def tree2Ops[T <: Tree](tree: T) = new {
      // copy of Tree.copyAttrs, since that method is private
//      def copyAttrs(from: Tree): T = {
//        tree.pos = from.pos
//        tree.tpe = from.tpe
//        if (tree.hasSymbol) tree.symbol = from.symbol
//        tree
//      }

      def getSimpleClassName: String = {
        val name = tree.getClass.getName
        val idx = math.max(name.lastIndexOf('$'), name.lastIndexOf('.')) + 1
        name.substring(idx)
      }
    }
    
    def mkIteratorOf(tpe: Type) = {
      def makeIt[T: c.WeakTypeTag] = weakTypeOf[Iterator[T]]
      makeIt(c.WeakTypeTag(tpe))
    }
    
    def mkSeqOf(tpe: Type) = {
      def makeIt[T: c.WeakTypeTag] = weakTypeOf[Seq[T]]
      makeIt(c.WeakTypeTag(tpe))
    }
    
    def mkListOf(tpe: Type) = {
      def makeIt[T: c.WeakTypeTag] = weakTypeOf[List[T]]
      makeIt(c.WeakTypeTag(tpe))
    }
    
    def mkBuilderOf(elemTpe: Type, listTpe: Type) = {
      def makeIt[ElemTpe: c.WeakTypeTag, ListTpe: c.WeakTypeTag] =
        weakTypeOf[scala.collection.mutable.Builder[ElemTpe, ListTpe]]

      makeIt(c.WeakTypeTag(elemTpe), c.WeakTypeTag(listTpe))
    }
    
    def mkCanBuildFromOf(fromTpe: Type, elemTpe: Type, toTpe: Type) = {
      def makeIt[From: c.WeakTypeTag, Elem: c.WeakTypeTag, To: c.WeakTypeTag] =
        weakTypeOf[scala.collection.generic.CanBuildFrom[From, Elem, To]]

      makeIt(c.WeakTypeTag(fromTpe), c.WeakTypeTag(elemTpe), c.WeakTypeTag(toTpe))
    }
    
    def mkCtorCall(tpe: Type, args: List[Tree]) =
      Apply(Select(New(TypeTree(tpe)), nme.CONSTRUCTOR), args)

    def mkSuperCall(args: List[Tree] = List()) =
      Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), args)

    def mkWhile(cond: Tree)(body: Tree): Tree = {
      val lblName = c.fresh[TermName]("while")
      val jump = Apply(Ident(lblName), Nil)
      val block = body match {
        case Block(stats, expr) => Block(stats :+ expr, jump)
        case _                  => Block(List(body), jump)
      }
      LabelDef(lblName, Nil, If(cond, block, EmptyTree))
    }
    
    def typeCheck(classDef: ClassDef): (ClassDef, Type) = {
      val block = Block(List(classDef), EmptyTree)
      val checkedBlock = c.typeCheck(block)

      // extract the class def from the block again
      val checkedDef = checkedBlock match {
        case Block((cls: ClassDef) :: Nil, _) => cls
      }
      val tpe = checkedDef.symbol.asClass.toType
      (checkedDef, tpe)
    }
    
    def extractOneInputUdf(fun: Tree) = {
      val (paramName, udfBody) = fun match {
        case Function(List(param), body) => (param.name.toString, body)
        case _ =>
          c.abort(c.enclosingPosition, "Could not extract user defined function, got: " + show(fun))
      }
      val uncheckedUdfBody = c.resetLocalAttrs(udfBody)
      (paramName, uncheckedUdfBody)
    }
    
    def extractTwoInputUdf(fun: Tree) = {
      val (param1Name, param2Name, udfBody) = fun match {
        case Function(List(param1, param2), body) =>
          (param1.name.toString, param2.name.toString, body)
        case _ =>
          c.abort(c.enclosingPosition, "Could not extract user defined function, got: " + show(fun))
      }
      val uncheckedUdfBody = c.resetLocalAttrs(udfBody)
      (param1Name, param2Name, uncheckedUdfBody)
    }
    
    def extractClass(block: Tree) = block match {
      case Block((cls: ClassDef) :: Nil, _) => cls
      case e => throw new RuntimeException("No class def at first position in block.")
    }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy