All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mysema.scalagen.ControlStatements.scala Maven / Gradle / Ivy

There is a newer version: 1.0.2
Show newest version
/*
 * Copyright (C) 2011, Mysema Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.mysema.scalagen

import japa.parser.ast.visitor._
import java.util.ArrayList
import UnitTransformer._
import com.mysema.scalagen.ast.BeginClosureExpr

object ControlStatements extends ControlStatements

/**
 * ControlStatements transform ForStmt, SwitchEntryStmt and If statements
 */
class ControlStatements extends UnitTransformerBase {
  
  private val KEY = new Name("key")
  
  private val VALUE = new Name("value")
  
  private val toUnderscore = new ModifierVisitor[Set[String]] {    
    override def visitName(n: String, arg: Set[String]): String = {
      if (arg.contains(n)) "_" else n
    }  
  }
  
  private def numMatchingNames(n: Node, variableName: String): Int = {
    var matched = 0
    val visitor = new ModifierVisitor[Null] {
      override def visitName(n: String, dummy: Null): String = {
        if (n == variableName) matched += 1
        n
      }
    }
    n.accept(visitor, null)
    matched
  }
  
  private val toKeyAndValue = new ModifierVisitor[String] {
    override def visit(nn: MethodCall, arg: String): Node = {
      val n = super.visit(nn, arg).asInstanceOf[MethodCall]
      n match {
        case MethodCall(str(`arg`), "getKey", Nil) => KEY
        case MethodCall(str(`arg`), "getValue", Nil) => VALUE
        case _ => n
      }
    }    
  }
     
  def transform(cu: CompilationUnit): CompilationUnit = {
    cu.accept(this, cu).asInstanceOf[CompilationUnit] 
  }  
        
  override def visit(nn: For, arg: CompilationUnit): Node = {
    // transform
    //   for (int i = 0; i < x; i++) block 
    // into
    //   for (i <- 0 until x) block
    val n = super.visit(nn, arg).asInstanceOf[For]    
    n match {
      case For((init: VariableDeclaration) :: Nil, l lt r, incr(_) :: Nil, _) => {
        val until = new MethodCall(init.getVars.get(0).getInit, "until", r :: Nil)
        init.getVars.get(0).setInit(null)
        new Foreach(init, until, n.getBody)
      }
      case _ => n
    }
  }
  
  override def visit(nn: MethodCall, arg: CompilationUnit): Node = {
    // transform
    //   System.out.println
    // into 
    //   println
    val n = super.visit(nn, arg).asInstanceOf[MethodCall]
    n match {
      case MethodCall(str("System.out"), "println", args) => {
        new MethodCall(null, "println", args)
      }
      case _ => n
    }
  }
  
  override def visit(nn: Foreach, arg: CompilationUnit): Node = {
    val n = super.visit(nn, arg).asInstanceOf[Foreach]
    n match {
      case Foreach(
          VariableDeclaration(t, v :: Nil), 
          MethodCall(scope, "entrySet", Nil), body) => {
        val vid = v.getId.toString
        new Foreach(
            VariableDeclaration(0, "(key, value)", Type.Object), 
            scope, n.getBody.accept(toKeyAndValue, vid).asInstanceOf[Statement])            
      }
      case _ => n
    }    
  }
  
  // TODO : maybe move this to own class
  override def visit(nn: Block, arg: CompilationUnit): Node = {
    // simplify
    //   for (format <- values if format.mimetype == contentType) return format
    //   defaultFormat
    // into
    //   values.find(_.mimetype == contenType).getOrElse(defaultFormat)
    val n = super.visit(nn, arg).asInstanceOf[Block]
    n match {
      case Block( 
          Foreach(v, it, If(cond, Return(rv1), null)) ::
          Return(rv2) :: Nil) => createFindCall(it, v, cond, rv1, rv2)
      case _ => n
    }
  }
  
  private def createClosure(vid: String, expr: Expression): List[Expression] = numMatchingNames(expr, vid) match {
    case 0 => List(new BeginClosureExpr("_"), expr)
    case 1 => List(expr.accept(toUnderscore, Set(vid)).asInstanceOf[Expression])
    case _ => List(new BeginClosureExpr(vid), expr)
  }
  
  private def createFindCall(it: Expression, v: VariableDeclaration, 
      cond: Expression, rv1: Expression, rv2: Expression): Statement = {
    val vid = v.getVars.get(0).getId.toString
    val newCond = createClosure(vid, cond)
    val newIt = it match {
      case MethodCall(_, "until", _ :: Nil) => new Enclosed(it)
      case _ => it
    }
    val findCall = new MethodCall(newIt, "find", newCond)
    val expr = if (vid == rv1.toString) findCall
               else new MethodCall(findCall, "map", createClosure(vid, rv1))
    val getOrElse = new MethodCall(expr, "getOrElse", rv2 :: Nil)
    new Block(new ExpressionStmt(getOrElse) :: Nil)
  } 
  
  override def visit(nn: If, arg: CompilationUnit): Node = {
    // transform
    //   if (condition) target = x else target = y
    // into
    //   target = if (condition) e else y    
    val n = super.visit(nn, arg).asInstanceOf[If]    
    n match {
      case If(cond, Stmt(t1 set v1), Stmt(t2 set v2)) if t1 == t2 => {
        new ExpressionStmt(new Assign(t1, new Conditional(n.getCondition, v1, v2), Assign.assign))  
      }
      case _ => n
    }    
  }
  
  override def visit(nn: SwitchEntry, arg: CompilationUnit) = {    
    // remove break
    val n = super.visit(nn, arg).asInstanceOf[SwitchEntry]
    val size = if (n.getStmts == null) 0 else n.getStmts.size
    if (size > 1 && n.getStmts.get(size-1).isInstanceOf[Break]) {
      //n.getStmts.remove(size-1)
      n.setStmts(n.getStmts.dropRight(1))
    }
    n
  }
    
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy