All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.ecfront.common.BeanHelper.scala Maven / Gradle / Ivy

There is a newer version: 1.2.3
Show newest version
package com.ecfront.common

import org.apache.commons.beanutils.BeanUtilsBean

import scala.annotation.{StaticAnnotation, tailrec}
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe._

/**
 * Bean操作辅助类
 */
object BeanHelper {

  val rm = runtimeMirror(getClass.getClassLoader)

  private val copyPropertiesAdapter = new NullAwareBeanUtilsBean

  /**
   * Bean 复制,忽略Null值属性
   * @param dest 目录Bean
   * @param orig 源Bean
   */
  def copyProperties(dest: AnyRef, orig: AnyRef) = copyPropertiesAdapter.copyProperties(dest, orig)

  /**
   * 获取Bean的字段名称及类型
   * @param beanClazz 目标Bean类型
   * @param filterNames 要过滤的名称
   * @param filterAnnotations 要过滤的注解
   */
  def findFields(beanClazz: Class[_], filterNames: Seq[String] = Seq(), filterAnnotations: Seq[Class[_ <: StaticAnnotation]] = Seq(classOf[Ignore])): Map[String, String] = {
    val fields = collection.mutable.Map[String, String]()
    val filter = if (filterAnnotations.nonEmpty) findFieldAnnotations(beanClazz, filterAnnotations) else ArrayBuffer[FieldAnnotationInfo]()
    scala.reflect.runtime.currentMirror.classSymbol(beanClazz).toType.members.collect {
      case method: MethodSymbol if method.isGetter && method.isPublic
        && (filterNames == null || filterNames.isEmpty || !filterNames.contains(method.name.toString.trim)) =>
        if (!filter.exists(_.fieldName == method.name.toString.trim)) {
          fields += (method.name.toString.trim -> method.returnType.toString.trim)
        }
    }
    fields.toMap
  }

  /**
   * 获取Bean中字段的值
   * @param bean 目标Bean
   * @param filterNames 要过滤的名称
   */
  def findValues(bean: AnyRef, filterNames: Seq[String] = Seq()): Map[String, Any] = {
    val fields = collection.mutable.Map[String, Any]()
    val m = rm.reflect(bean)
    scala.reflect.runtime.currentMirror.classSymbol(bean.getClass).toType.members.collect {
      case method: MethodSymbol if method.isGetter && method.isPublic
        && (filterNames == null || filterNames.isEmpty || !filterNames.contains(method.name.toString.trim)) =>
        fields += (method.name.toString.trim -> m.reflectMethod(method).apply())
    }
    fields.toMap
  }

  /**
   * 获取Bean中指定字段的值
   * @param bean 目标Bean
   * @param fieldName 指定字段
   */
  def getValue(bean: AnyRef, fieldName: String): Option[Any] = {
    val m = rm.reflect(bean)
    var value: Any = null
    scala.reflect.runtime.currentMirror.classSymbol(bean.getClass).toType.members.collect {
      case term: TermSymbol if term.name.toString.trim == fieldName =>
        value = m.reflectField(term).get
    }
    Some(value)

  }

  def setValue(bean: AnyRef, fieldName: String, value: Any): Unit = {
    val m = rm.reflect(bean)
    scala.reflect.runtime.currentMirror.classSymbol(bean.getClass).toType.members.collect {
      case term: TermSymbol if term.name.toString.trim == fieldName =>
        m.reflectField(term).set(value)
    }
  }

  /**
   * 递归获取带指定注解的字段
   * @param beanClazz 目标Bean
   * @param annotations 指定的注解,为空时获取所有注解
   * @return 注解信息(注解名称及对应的字段)
   **/
  def findFieldAnnotations(beanClazz: Class[_], annotations: Seq[Class[_ <: StaticAnnotation]] = Seq()): ArrayBuffer[FieldAnnotationInfo] = {
    val result = ArrayBuffer[FieldAnnotationInfo]()
    findFieldAnnotations(result, beanClazz, annotations)
    result
  }

  @tailrec
  private def findFieldAnnotations(container: ArrayBuffer[FieldAnnotationInfo], beanClazz: Class[_], annotations: Seq[Class[_ <: StaticAnnotation]]) {
    scala.reflect.runtime.currentMirror.classSymbol(beanClazz).toType.members.collect {
      case m if !m.isMethod =>
        m.annotations.map {
          annotation =>
            val tmp = annotation.toString
            val annotationName = if (tmp.indexOf("(") == -1) tmp else tmp.substring(0, tmp.lastIndexOf("("))
            if (annotations.isEmpty || annotations.exists(ann => ann.getName == annotationName)) {
              val value = annotation.tree.children.tail.map(_.productElement(0).asInstanceOf[Constant].value)
              val typeAnnotation = annotation.tree.tpe
              val res = rm.reflectClass(typeAnnotation.typeSymbol.asClass).
                reflectConstructor(typeAnnotation.decl(termNames.CONSTRUCTOR).asMethod)(value: _*)
              container += FieldAnnotationInfo(res, m.name.toString.trim)
            }
        }
    }
    beanClazz.getGenericSuperclass match {
      case c: Class[_] =>
        if (c != classOf[Object]) {
          findFieldAnnotations(container, c, annotations)
        }
      case _ =>
    }
  }

  /**
   * 递归获取带指定注解的方法,当beanClazz 为object 时务必使用 getClass 获取
   * @param beanClazz 目标Bean
   * @param annotations 指定的注解,为空时获取所有注解
   * @return 注解信息(注解名称及对应的方法)
   **/
  def findMethodAnnotations(beanClazz: Class[_], annotations: Seq[Class[_ <: StaticAnnotation]] = Seq()): ArrayBuffer[methodAnnotationInfo] = {
    val result = ArrayBuffer[methodAnnotationInfo]()
    findMethodAnnotations(result, beanClazz, annotations)
    result
  }

  @tailrec
  private def findMethodAnnotations(container: ArrayBuffer[methodAnnotationInfo], beanClazz: Class[_], annotations: Seq[Class[_ <: StaticAnnotation]]) {
    val tf = scala.reflect.runtime.currentMirror.classSymbol(beanClazz).toType
    tf.members.collect {
      case m if m.isMethod =>
        m.annotations.map {
          annotation =>
            val tmp = annotation.toString
            if(!tmp.startsWith("throws[java.")){
            val annotationName = if (tmp.indexOf("(") == -1) tmp else tmp.substring(0, tmp.lastIndexOf("("))
            if (annotations.isEmpty || annotations.exists(ann => ann.getName == annotationName)) {
              val value = annotation.tree.children.tail.map(_.productElement(0).asInstanceOf[Constant].value)
              val typeAnnotation = annotation.tree.tpe
              val ann = rm.reflectClass(typeAnnotation.typeSymbol.asClass).
                reflectConstructor(typeAnnotation.decl(termNames.CONSTRUCTOR).asMethod)(value: _*)
              container += methodAnnotationInfo(ann, tf.member(TermName(m.name.toString.trim)).asMethod)
            }
            }
        }
    }
    beanClazz.getGenericSuperclass match {
      case c: Class[_] =>
        if (c != classOf[Object]) {
          findMethodAnnotations(container, c, annotations)
        }
      case _ =>
    }
  }

  def invoke(obj: Any, method: MethodSymbol): MethodMirror = {
    rm.reflect(obj).reflectMethod(method)
  }

  /**
   * 获取类注解
   * @tparam A 注解类型
   * @param beanClazz 目标类的类型
   * @return 注解对象
   */
  def getClassAnnotation[A: TypeTag](beanClazz: Class[_]): Option[A] = {
    val res = getClassAnnotation(typeOf[A], beanClazz)
    if (res.isDefined) {
      Some(res.get.asInstanceOf[A])
    } else {
      None
    }
  }

  private def getClassAnnotation(typeAnnotation: Type, beanClazz: Class[_]): Option[Any] = {
    var res = scala.reflect.runtime.currentMirror.classSymbol(beanClazz).toType.typeSymbol.asClass.annotations.find(a => a.tree.tpe == typeAnnotation).map {
      annotation =>
        val value = annotation.tree.children.tail.map(_.productElement(0).asInstanceOf[Constant].value)
        rm.reflectClass(typeAnnotation.typeSymbol.asClass).
          reflectConstructor(typeAnnotation.decl(termNames.CONSTRUCTOR).asMethod)(value: _*)
    }
    if (res.isEmpty) {
      beanClazz.getGenericSuperclass match {
        case c: Class[_] =>
          if (c != classOf[Object]) {
            res = getClassAnnotation(typeAnnotation, c)
          }
        case _ =>
      }
    }
    res
  }

  def getClassByStr(clazzStr: String): Class[_] = {
    clazzStr match {
      case "Int" => classOf[Int]
      case "String" => classOf[String]
      case "Long" => classOf[Long]
      case "Float" => classOf[Float]
      case "Double" => classOf[Double]
      case "Boolean" => classOf[Boolean]
      case "Short" => classOf[Short]
      case "Byte" => classOf[Byte]
      case s if s.startsWith("Map") => Class.forName("scala.collection.immutable.Map")
      case s if s.startsWith("List") || s.startsWith("scala.List") => Class.forName("scala.collection.immutable.List")
      case s if s.startsWith("Set") => Class.forName("scala.collection.immutable.Set")
      case s if s.startsWith("Seq") || s.startsWith("scala.Seq") => Class.forName("scala.collection.immutable.Seq")
      case s if s.startsWith("Vector") || s.startsWith("scala.Vector") => Class.forName("scala.collection.immutable.Vector")
      case s if s.startsWith("Array") => Class.forName("scala.Array")
      //去泛型
      case s if s.endsWith("]") => Class.forName(s.substring(0, s.indexOf("[")))
      case s => Class.forName(s)
    }
  }

}

private class NullAwareBeanUtilsBean extends BeanUtilsBean {
  override def copyProperty(bean: scala.Any, name: String, value: scala.Any): Unit = {
    if (null != value) {
      super.copyProperty(bean, name, value)
    }
  }
}

case class FieldAnnotationInfo(annotation: Any, fieldName: String)

case class methodAnnotationInfo(annotation: Any, method: MethodSymbol)

@scala.annotation.meta.field
class Ignore extends scala.annotation.StaticAnnotation




© 2015 - 2025 Weber Informatics LLC | Privacy Policy