All Downloads are FREE. Search and download functionalities are using the official Maven repository.

skinny.orm.feature.CRUDFeature.scala Maven / Gradle / Ivy

The newest version!
package skinny.orm.feature

import scalikejdbc._, SQLInterpolation._
import skinny._
import skinny.orm._
import skinny.orm.feature.associations._
import scala.collection.mutable
import skinny.orm.feature.associations.HasManyAssociation
import skinny.orm.feature.associations.BelongsToAssociation
import skinny.orm.feature.associations.HasOneAssociation

/**
 * Provides auto-generated CRUD feature.
 *
 * @tparam Entity entity
 */
trait CRUDFeature[Entity]
    extends SkinnyMapperBase[Entity]
    with SkinnyModel[Entity]
    with ConnectionPoolFeature
    with AutoSessionFeature
    with AssociationsFeature[Entity]
    with JoinsFeature[Entity]
    with IncludesFeature[Entity]
    with QueryingFeature[Entity]
    with FinderFeature[Entity]
    with StrongParametersFeature {

  /**
   * The primary key should be an auto-increment value if true.
   */
  def useAutoIncrementPrimaryKey = true

  /**
   * Returns default scope for update/delete operations.
   *
   * @return default scope
   */
  def defaultScopeForUpdateOperations: Option[SQLSyntax] = None

  override def joins(associations: Association[_]*): CRUDFeature[Entity] = {
    val _self = this
    val _associations = associations
    val _belongsTo = associations.filter(_.isInstanceOf[BelongsToAssociation[Entity]]).map(_.asInstanceOf[BelongsToAssociation[Entity]])
    val _hasOne = associations.filter(_.isInstanceOf[HasOneAssociation[Entity]]).map(_.asInstanceOf[HasOneAssociation[Entity]])
    val _hasMany = associations.filter(_.isInstanceOf[HasManyAssociation[Entity]]).map(_.asInstanceOf[HasManyAssociation[Entity]])

    new CRUDFeature[Entity] {
      override protected val underlying = _self
      override private[skinny] val belongsToAssociations = _self.belongsToAssociations ++ _belongsTo
      override private[skinny] val hasOneAssociations = _self.hasOneAssociations ++ _hasOne
      override private[skinny] val hasManyAssociations = _self.hasManyAssociations ++ _hasMany

      override val associations = _self.associations ++ _associations
      override val defaultJoinDefinitions = _self.defaultJoinDefinitions
      override val defaultBelongsToExtractors = _self.defaultBelongsToExtractors
      override val defaultHasOneExtractors = _self.defaultHasOneExtractors
      override val defaultOneToManyExtractors = _self.defaultOneToManyExtractors

      override def autoSession = underlying.autoSession
      override def connectionPoolName = underlying.connectionPoolName
      override def connectionPool = underlying.connectionPool

      def extract(rs: WrappedResultSet, n: SQLInterpolation.ResultName[Entity]) = underlying.extract(rs, n)
    }
  }

  /**
   * Attributes to be inserted when creation.
   */
  private[this] val attributesForCreation = new mutable.LinkedHashSet[(SQLSyntax, Any)]()

  /**
   * Accepted factories for attributesForCreation.
   */
  private[this] val attributesForCreationFactories = new mutable.LinkedHashSet[() => Boolean]()

  /**
   * Adds new attribute to be inserted when creation.
   *
   * @param namedValue named value
   * @return self
   */
  protected def addAttributeForCreation(namedValue: => (SQLSyntax, Any)) = {
    acceptAttributeForCreation(namedValue)
    this
  }

  /**
   * Attributes for creation are ready if true
   */
  private[this] lazy val attributesForCreationReady: Boolean = {
    attributesForCreationFactories.foreach(_.apply())
    true
  }

  /**
   * Accepts an attribute for creation.
   *
   * @param namedValue named value
   */
  private[this] def acceptAttributeForCreation(namedValue: => (SQLSyntax, Any)): Unit = {
    attributesForCreationFactories.add(() => attributesForCreation.add(namedValue))
  }

  /**
   * Merges already registered attributes to be inserted and parameters.
   *
   * @param namedValues named values
   * @return merged attributes
   */
  protected def mergeNamedValuesForCreation(namedValues: Seq[(SQLSyntax, Any)]): Seq[(SQLSyntax, Any)] = {
    if (!attributesForCreationReady) {
      throw new IllegalStateException("Attributes for creation query is not ready!")
    }

    namedValues.foldLeft(attributesForCreation) {
      case (xs, (column, newValue)) =>
        if (xs.exists(_._1 == column)) xs.map { case (c, v) => if (c == column) (column -> newValue) else (c, v) }
        else xs + (column -> newValue)
    }
    val toBeInserted = attributesForCreation.++(namedValues).toSeq
    toBeInserted
  }

  /**
   * Extracts named values from the permitted parameters.
   *
   * @param strongParameters permitted parameters
   * @return named values
   */
  protected def namedValuesForCreation(strongParameters: PermittedStrongParameters): Seq[(SQLSyntax, Any)] = {
    mergeNamedValuesForCreation(strongParameters.params.map {
      case (name, (value, paramType)) =>
        column.field(name) -> getTypedValueFromStrongParameter(name, value, paramType)
    }.toSeq)
  }

  /**
   * Creates a new entity with permitted strong parameters.
   *
   * @param strongParameters permitted parameters
   * @param s db session
   * @return created count
   */
  def createWithPermittedAttributes(strongParameters: PermittedStrongParameters)(implicit s: DBSession = autoSession): Long = {
    createWithNamedValues(namedValuesForCreation(strongParameters): _*)
  }

  /**
   * Creates a new entity with non-permitted parameters.
   *
   * CAUTION: If you use this method in some web apps, you might have mass assignment vulnerability.
   *
   * @param parameters parameters
   * @param s db session
   * @return created count
   */
  def createWithAttributes(parameters: (Symbol, Any)*)(implicit s: DBSession = autoSession): Long = {
    createWithNamedValues(mergeNamedValuesForCreation(parameters.map {
      case (name, value) => column.field(name.name) -> value
    }.toSeq): _*)
  }

  /**
   * #createWithNamedValues pre-execution.
   *
   * @param namedValues named values
   */
  protected def beforeCreate(namedValues: Seq[(SQLSyntax, Any)])(implicit s: DBSession = autoSession): Unit = {}

  /**
   * #createWithNamedValues post-execution.
   *
   * @param namedValues named values
   * @param generatedId generated id
   */
  protected def afterCreate(namedValues: Seq[(SQLSyntax, Any)], generatedId: Option[Long])(
    implicit s: DBSession = autoSession): Unit = {}

  /**
   * Creates a new entity with named values.
   *
   * @param namedValues named values
   * @param s db session
   * @return created count
   */
  def createWithNamedValues(namedValues: (SQLSyntax, Any)*)(implicit s: DBSession = autoSession): Long = {
    val allNamedValues = mergeNamedValuesForCreation(namedValues)
    beforeCreate(allNamedValues)
    if (useAutoIncrementPrimaryKey) {
      val id = withSQL { insert.into(this).namedValues(allNamedValues: _*) }.updateAndReturnGeneratedKey.apply()
      afterCreate(allNamedValues, Some(id))
      id
    } else {
      withSQL { insert.into(this).namedValues(allNamedValues: _*) }.update.apply()
      val idOpt = allNamedValues.find(v => v._1 == column.field(primaryKeyName)).map {
        case (_, value) =>
          try value.toString.toLong
          catch { case e: Exception => 0L }
      }
      afterCreate(allNamedValues, idOpt)
      idOpt.getOrElse(-1L)
    }
  }

  /**
   * Returns update query builder with condition.
   *
   * @param where where condition
   * @return update query builder
   */
  def updateBy(where: SQLSyntax): UpdateOperationBuilder = {
    new UpdateOperationBuilder(this, where, beforeUpdateByHandlers.toSeq, afterUpdateByHandlers.toSeq)
  }

  /**
   * Returns update query builder with primary key.
   *
   * @param id primary key
   * @return update query builder
   */
  def updateById(id: Long): UpdateOperationBuilder = updateBy(byId(id))

  /**
   * Updates entities with parameters.
   *
   * @param id primary key
   * @param parameters parameters
   * @return updated count
   */
  def updateById(id: Long, parameters: PermittedStrongParameters): Int = updateById(id).withPermittedAttributes(parameters)

  /**
   * Returns a query part which represents primary key search condition.
   *
   * @param id primary key
   * @return query part
   */
  protected def byId(id: Long) = sqls.eq(column.field(primaryKeyName), id)

  /**
   * #updateBy pre-execution handler.
   */
  type BeforeUpdateByHandler = (DBSession, SQLSyntax, Seq[(SQLSyntax, Any)]) => Unit

  /**
   * #updateBy post-execution handler.
   */
  type AfterUpdateByHandler = (DBSession, SQLSyntax, Seq[(SQLSyntax, Any)], Int) => Unit

  /**
   * Registered beforeUpdateByHandlers.
   */
  protected val beforeUpdateByHandlers = new scala.collection.mutable.ListBuffer[BeforeUpdateByHandler]

  /**
   * Registered afterUpdateByHandlers.
   */
  protected val afterUpdateByHandlers = new scala.collection.mutable.ListBuffer[AfterUpdateByHandler]

  /**
   * Registers #updateBy pre-execution handler.
   *
   * @param handler event handler
   */
  protected def beforeUpdateBy(handler: BeforeUpdateByHandler): Unit = beforeUpdateByHandlers.append(handler)

  /**
   * Registers #updateBy post-execution handler.
   *
   * @param handler event handler
   */
  protected def afterUpdateBy(handler: AfterUpdateByHandler): Unit = afterUpdateByHandlers.append(handler)

  /**
   * Update query builder/executor.
   *
   * @param mapper mapper
   * @param where condition
   */
  class UpdateOperationBuilder(
      mapper: CRUDFeature[Entity],
      where: SQLSyntax,
      beforeHandlers: Seq[BeforeUpdateByHandler],
      afterHandlers: Seq[AfterUpdateByHandler]) {

    /**
     * Attributes to be updated.
     */
    private[this] val attributesToBeUpdated = new mutable.HashMap[SQLSyntax, Any]()

    /**
     * Additional query parts after `set` keyword.
     */
    private[this] val additionalUpdateSQLs = new mutable.LinkedHashSet[SQLSyntax]()

    /**
     * Adds new attribute to be updated.
     *
     * @param namedValue named value
     * @return self
     */
    protected def addAttributeToBeUpdated(namedValue: (SQLSyntax, Any)): UpdateOperationBuilder = {
      attributesToBeUpdated.update(namedValue._1, namedValue._2)
      this
    }

    /**
     * Adds new query part.
     *
     * @param queryPart query part
     * @return self
     */
    protected def addUpdateSQLPart(queryPart: SQLSyntax): UpdateOperationBuilder = {
      additionalUpdateSQLs.add(queryPart)
      this
    }

    /**
     * Converts permitted strong parameters to named values.
     *
     * @param strongParameters permitted parameters
     * @return named values
     */
    protected def toNamedValuesToBeUpdated(strongParameters: PermittedStrongParameters): Seq[(SQLSyntax, Any)] = {
      strongParameters.params.map {
        case (name, (value, paramType)) =>
          column.field(name) -> getTypedValueFromStrongParameter(name, value, paramType)
      }.toSeq
    }

    /**
     * Merges already registered attributes to be updated and parameters.
     *
     * @param namedValues named values
     * @return merged attributes
     */
    protected def mergeNamedValues(namedValues: Seq[(SQLSyntax, Any)]): Seq[(SQLSyntax, Any)] = {
      namedValues.foldLeft(attributesToBeUpdated) {
        case (xs, (column, newValue)) =>
          if (xs.exists(_._1 == column)) xs.map { case (c, v) => if (c == column) (column -> newValue) else (c, v) }
          else xs.+=(column -> newValue)
      }.toSeq
    }

    /**
     * Merges additional query parts.
     *
     * @param queryBuilder query builder
     * @param othersAreEmpty other attributes to be updated is empty if true
     * @return query builder
     */
    protected def mergeAdditionalUpdateSQLs(queryBuilder: UpdateSQLBuilder, othersAreEmpty: Boolean): UpdateSQLBuilder = {
      if (additionalUpdateSQLs.isEmpty) {
        queryBuilder
      } else {
        val updates = sqls.csv(additionalUpdateSQLs.toSeq: _*)
        if (othersAreEmpty) queryBuilder.append(updates)
        else queryBuilder.append(sqls", ${updates}")
      }
    }

    /**
     * Updates entities with these permitted strong parameters.
     *
     * @param strongParameters permitted strong parameters
     * @param s db session
     * @return updated count
     */
    def withPermittedAttributes(strongParameters: PermittedStrongParameters)(implicit s: DBSession = autoSession): Int = {
      withNamedValues(toNamedValuesToBeUpdated(strongParameters): _*)
    }

    /**
     * Updates entities with these non-permitted parameters.
     *
     * CAUTION: If you use this method in some web apps, you might have mass assignment vulnerability.
     *
     * @param parameters unsafe parameters
     * @param s db session
     * @return updated count
     */
    def withAttributes(parameters: (Symbol, Any)*)(implicit s: DBSession = autoSession): Int = {
      withNamedValues(parameters.map {
        case (name, value) => column.field(name.name) -> value
      }: _*)
    }

    /**
     * Updates entities with named values.
     *
     * @param namedValues named values
     * @param s db session
     * @return updated count
     */
    def withNamedValues(namedValues: (SQLSyntax, Any)*)(implicit s: DBSession = autoSession): Int = {
      val allValues = mergeNamedValues(namedValues)
      beforeHandlers.foreach(_.apply(s, where, allValues))
      val updatedCount = withSQL {
        mergeAdditionalUpdateSQLs(update(mapper).set(allValues: _*), allValues.isEmpty)
          .where.append(where).and(defaultScopeForUpdateOperations)
      }.update.apply()
      afterHandlers.foreach(_.apply(s, where, allValues, updatedCount))
      updatedCount
    }

  }

  /**
   * Deletes entities by condition.
   *
   * @param where condition
   * @param s db session
   * @return deleted count
   */
  def deleteBy(where: SQLSyntax)(implicit s: DBSession = autoSession): Int = {
    beforeDeleteBy(where)
    val count = withSQL {
      delete.from(this).where(where).and(defaultScopeForUpdateOperations)
    }.update.apply()
    afterDeleteBy(where, count)
  }

  /**
   * #deleteBy pre-execution.
   *
   * @param where condition
   */
  protected def beforeDeleteBy(where: SQLSyntax)(implicit s: DBSession = autoSession): Unit = {}

  /**
   * #deleteBy post-execution.
   *
   * @param where condition
   * @param deletedCount deleted count
   * @return count
   */
  protected def afterDeleteBy(where: SQLSyntax, deletedCount: Int)(implicit s: DBSession = autoSession): Int = deletedCount

  /**
   * Deletes a single entity by primary key.
   *
   * @param id primary key
   * @param s db session
   * @return deleted count
   */
  def deleteById(id: Long)(implicit s: DBSession = autoSession): Int = deleteBy(byId(id))

  // for SkinnyModel

  override def createNewModel(parameters: PermittedStrongParameters) = createWithPermittedAttributes(parameters)

  override def findAllModels() = findAll()

  override def findModel(id: Long) = findById(id)

  override def updateModelById(id: Long, parameters: PermittedStrongParameters) = updateById(id, parameters)

  override def deleteModelById(id: Long) = deleteById(id)

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy