Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
* Copyright 2015 University of Basel, Graphics and Vision Research Group
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
package scalismo.statisticalmodel.asm
import breeze.linalg.{convert, DenseVector}
import scalismo.common.UnstructuredPointsDomain.Create.CreateUnstructuredPointsDomain3D
import scalismo.common.{PointId, UnstructuredPointsDomain}
import scalismo.geometry.{_3D, Point, Point3D}
import scalismo.image.DiscreteImage
import scalismo.mesh.TriangleMesh
import scalismo.numerics.Sampler
import scalismo.registration.LandmarkRegistration
import scalismo.statisticalmodel.{MultivariateNormalDistribution, PointDistributionModel, StatisticalMeshModel}
import scalismo.transformations.{
import scalismo.utils.Random
import scala.collection.immutable
import scala.collection.parallel.immutable.ParVector
import scala.util.{Failure, Try}
object ActiveShapeModel {
type TrainingData = Iterator[(DiscreteImage[_3D, Float], Transformation[_3D])]
* Train an active shape model using an existing PCA model
def trainModel(statisticalModel: PointDistributionModel[_3D, TriangleMesh],
trainingData: TrainingData,
preprocessor: ImagePreprocessor,
featureExtractor: FeatureExtractor,
sampler: TriangleMesh[_3D] => Sampler[_3D]): ActiveShapeModel = {
val sampled = sampler(statisticalModel.reference).sample().map(_._1).toIndexedSeq
val pointIds =
// preprocessed images can be expensive in terms of memory, so we go through them one at a time.
val imageFeatures = trainingData.flatMap {
case (image, transform) =>
val (pimg, mesh) = (preprocessor(image), statisticalModel.reference.transform(transform)) { pointId =>
featureExtractor(pimg, mesh.pointSet.point(pointId), mesh, pointId)
// the structure is "wrongly nested" now, like: {img1:{pt1,pt2}, img2:{pt1,pt2}} (flattened).
// We merge the corresponding points together, then estimate an MVD.
val pointsLength = pointIds.length
val imageRange = (0 until imageFeatures.length / pointsLength).toIndexedSeq
val pointFeatures = (0 until pointsLength) { pointIndex =>
val featuresForPoint = imageRange.flatMap { imageIndex =>
imageFeatures(imageIndex * pointsLength + pointIndex).map(convert(_, Double))
val profiles = new Profiles( { case (i, d) => Profile(i, d) })
ActiveShapeModel(statisticalModel, profiles, preprocessor, featureExtractor)
* Class of instances sampled from an Active Shape Model. A sample is therefore comprised of both a shape sampled from
* the Shape Model, and a set of sample features at the profile points.
case class ASMSample(mesh: TriangleMesh[_3D],
featureField: DiscreteFeatureField[_3D, UnstructuredPointsDomain],
featureExtractor: FeatureExtractor)
case class ActiveShapeModel(statisticalModel: PointDistributionModel[_3D, TriangleMesh],
profiles: Profiles,
preprocessor: ImagePreprocessor,
featureExtractor: FeatureExtractor) {
* Returns the mean mesh of the shape model, along with the mean feature profiles at the profile points
def mean(): ASMSample = {
val smean = statisticalModel.mean
val meanProfilePoints = => smean.pointSet.point(p.pointId))
val meanFeatures =
val featureField = DiscreteFeatureField[_3D, UnstructuredPointsDomain](
ASMSample(smean, featureField, featureExtractor)
* Returns a random sample mesh from the shape model, along with randomly sampled feature profiles at the profile points
def sample()(implicit rand: Random): ASMSample = {
val sampleMesh = statisticalModel.sample()
val randomProfilePoints = => sampleMesh.pointSet.point(p.pointId))
val randomFeatures =
val featureField = DiscreteFeatureField[_3D, UnstructuredPointsDomain](
ASMSample(sampleMesh, featureField, featureExtractor)
* Utility function that allows to randomly sample different feature profiles, while keeping the profile points
* Meant to allow to easily inspect/debug the feature distribution
def sampleFeaturesOnly()(implicit rand: Random): ASMSample = {
val smean = statisticalModel.mean
val meanProfilePoints = => smean.pointSet.point(p.pointId))
val randomFeatures =
val featureField = DiscreteFeatureField[_3D, UnstructuredPointsDomain](
ASMSample(smean, featureField, featureExtractor)
* Returns an Active Shape Model where both the statistical shape Model and the profile points distributions are correctly transformed
* according to the provided rigid transformation
def transform(rigidTransformation: RigidTransformation[_3D]): ActiveShapeModel = {
val transformedModel = statisticalModel.transform(rigidTransformation)
this.copy(statisticalModel = transformedModel)
private def noTransformations =
RotationSpace3D(Point3D(0, 0, 0)).identityTransformation)
* Perform an ASM fitting for the given target image.
* This is logically equivalent to calling fitIterator(...).last
* @param targetImage target image to fit to.
* @param searchPointSampler sampler that defines the strategy where profiles are to be sampled.
* @param iterations maximum number of iterations for the fitting.
* @param config fitting configuration (thresholds). If omitted, uses [[FittingConfiguration.Default]]
* @param startingTransformations initial transformations to apply to the statistical model. If omitted, no transformations are applied (i.e. the fitting starts from the mean shape, with no rigid transformation)
* @return fitting result after the given number of iterations
def fit(targetImage: DiscreteImage[_3D, Float],
searchPointSampler: SearchPointSampler,
iterations: Int,
config: FittingConfiguration = FittingConfiguration.Default,
startingTransformations: ModelTransformations = noTransformations): Try[FittingResult] = {
// we're manually looping the iterator here because we're only interested in the last result -- no need to keep all intermediates.
val it = fitIterator(targetImage, searchPointSampler, iterations, config, startingTransformations)
if (!it.hasNext) {
Failure(new IllegalStateException("iterator was empty"))
} else {
var result =
while (it.hasNext) {
result =
* Perform iterative ASM fitting for the given target image. This is essentially the same as the [[fit]] method, except that it returns the full iterator, so every step can be examined.
* @see [[fit()]] for a description of the parameters.
def fitIterator(targetImage: DiscreteImage[_3D, Float],
searchPointSampler: SearchPointSampler,
iterations: Int,
config: FittingConfiguration = FittingConfiguration.Default,
initialTransform: ModelTransformations = noTransformations): Iterator[Try[FittingResult]] = {
fitIteratorPreprocessed(preprocessor(targetImage), searchPointSampler, iterations, config, initialTransform)
* Perform iterative ASM fitting for the given preprocessed image. This is essentially the same as the [[fitIterator]] method, except that it uses the already preprocessed image.
* @see [[fit()]] for a description of the parameters.
def fitIteratorPreprocessed(
image: PreprocessedImage,
searchPointSampler: SearchPointSampler,
iterations: Int,
config: FittingConfiguration = FittingConfiguration.Default,
initialTransform: ModelTransformations = noTransformations
): Iterator[Try[FittingResult]] = {
require(iterations > 0, "number of iterations must be strictly positive")
new Iterator[Try[FittingResult]] {
var lastResult: Option[Try[FittingResult]] = None
var nextCount = 0
override def hasNext = nextCount < iterations && (lastResult.isEmpty || lastResult.get.isSuccess)
override def next() = {
val mesh = lastResult
lastResult = Some(fitOnce(image, searchPointSampler, config, mesh, initialTransform.rigidTransform))
nextCount += 1
private def fitOnce(image: PreprocessedImage,
sampler: SearchPointSampler,
config: FittingConfiguration,
mesh: TriangleMesh[_3D],
poseTransform: TranslationAfterRotation[_3D]): Try[FittingResult] = {
val refPtIdsWithTargetPt = findBestCorrespondingPoints(image, mesh, sampler, config, poseTransform)
if (refPtIdsWithTargetPt.isEmpty) {
new IllegalStateException("No point correspondences found. You may need to relax the configuration thresholds.")
} else
Try {
val refPtsWithTargetPts = {
case (refPtId, tgtPt) => (statisticalModel.reference.pointSet.point(refPtId), tgtPt)
val bestRigidTransform =
val refPtIdsWithTargetPtAtModelSpace = {
case (refPtId, tgtPt) => (refPtId, bestRigidTransform.inverse(tgtPt))
val bestReconstruction = statisticalModel.posterior(refPtIdsWithTargetPtAtModelSpace, 1e-5).mean
val coeffs = statisticalModel.coefficients(bestReconstruction)
val boundedCoeffs = { c =>
Math.min(config.modelCoefficientBounds, Math.max(-config.modelCoefficientBounds, c))
val resultMesh = statisticalModel.instance(boundedCoeffs).transform(bestRigidTransform)
val transformations = ModelTransformations(boundedCoeffs, bestRigidTransform)
FittingResult(transformations, resultMesh)
private def refPoint(profileId: ProfileId): Point[_3D] =
private def findBestCorrespondingPoints(
img: PreprocessedImage,
mesh: TriangleMesh[_3D],
sampler: SearchPointSampler,
config: FittingConfiguration,
poseTransform: RigidTransformation[_3D]
): IndexedSeq[(PointId, Point[_3D])] = {
val matchingPts = new ParVector(profiles.ids.toVector).map { index =>
findBestMatchingPointAtPoint(img, mesh, index, sampler, config, profiles(index).pointId, poseTransform))
val matchingPtsWithinDist = matchingPts.filter(_._2.isDefined).map(p => (p._1, p._2.get))
private def findBestMatchingPointAtPoint(image: PreprocessedImage,
mesh: TriangleMesh[_3D],
profileId: ProfileId,
searchPointSampler: SearchPointSampler,
config: FittingConfiguration,
pointId: PointId,
poseTransform: RigidTransformation[_3D]): Option[Point[_3D]] = {
val sampledPoints = searchPointSampler(mesh, pointId)
val pointsWithFeatureDistances = (for (point <- sampledPoints) yield {
val featureVectorOpt = featureExtractor(image, point, mesh, pointId) { fv =>
(point, featureDistance(profileId, fv))
if (pointsWithFeatureDistances.isEmpty) {
// none of the sampled points returned a valid feature vector
} else {
val (bestPoint, bestFeatureDistance) = pointsWithFeatureDistances.minBy { case (pt, dist) => dist }
if (bestFeatureDistance <= config.featureDistanceThreshold) {
val refPoint = this.refPoint(profileId)
/** Attention: checking for the deformation vector's pdf needs to be done in the model space !**/
val inversePoseTransform = poseTransform.inverse
val bestPointDistance =
.mahalanobisDistance((inversePoseTransform(bestPoint) - refPoint).toBreezeVector)
if (bestPointDistance <= config.pointDistanceThreshold) {
} else {
// point distance above user-set threshold
} else {
// feature distance above user-set threshold
private def featureDistance(pid: ProfileId, features: DenseVector[Double]): Double = {
val mvdAtPoint = profiles(pid).distribution
* Fitting Configuration, specifying thresholds and bounds.
* @param featureDistanceThreshold threshold for the feature distance. If the mahalanobis distance of a candidate point's features to the corresponding profile's mean is larger than this value, then that candidate point will be ignored during fitting.
* @param pointDistanceThreshold threshold for point distance: If the mahalanobis distance of a candidate point to its corresponding marginal distribution is larger than this value, then that candidate point will be ignored during fitting.
* @param modelCoefficientBounds bounds to apply on the model coefficients. In other words, by setting this to n, all coefficients of the fitting result will be restricted to the interval [-n, n].
case class FittingConfiguration(featureDistanceThreshold: Double,
pointDistanceThreshold: Double,
modelCoefficientBounds: Double)
object FittingConfiguration {
lazy val Default =
FittingConfiguration(featureDistanceThreshold = 5.0, pointDistanceThreshold = 5.0, modelCoefficientBounds = 3.0)
* Transformations to apply to a statistical shape model.
* Sample usage: val mesh = ssm.instance(t.coefficients).transform(t.rigidTransform)
* @param coefficients model coefficients to apply. These determine the shape transformation.
* @param rigidTransform rigid transformation to apply. These determine translation and rotation.
case class ModelTransformations(coefficients: DenseVector[Double], rigidTransform: TranslationAfterRotation[_3D])
* Fitting results.
* Note that the fields are redundant: the mesh is completely determined by the transformations.
* It's essentially provided for user convenience, because it would be very likely to be (re-)constructed anyway from the transformations.
* @param transformations transformations to apply to the model
* @param mesh the mesh resulting from applying these transformations
case class FittingResult(transformations: ModelTransformations, mesh: TriangleMesh[_3D])