All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.projectglow.functions.scala Maven / Gradle / Ivy

/*
 * Copyright 2019 The Glow Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.projectglow

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Expression, LambdaFunction, Literal, UnresolvedNamedLambdaVariable}

import io.projectglow.sql.expressions.ExpressionHelper

// scalastyle:off
// format: off

/**
 * Functions provided by Glow. These functions can be used with Spark's DataFrame API.
 * @group complex_type_manipulation 
 * @group etl 
 * @group quality_control 
 * @group gwas_functions 
 */
object functions {
  private def withExpr(expr: Expression): Column = {
    new Column(ExpressionHelper.wrapAggregate(ExpressionHelper.rewrite(expr)))
  }

  private def createLambda(f: Column => Column) = {
    val x = UnresolvedNamedLambdaVariable(Seq("x"))
    val function = f(new Column(x)).expr
    LambdaFunction(function, Seq(x))
  }

  private def createLambda(f: (Column, Column) => Column) = {
    val x = UnresolvedNamedLambdaVariable(Seq("x"))
    val y = UnresolvedNamedLambdaVariable(Seq("y"))
    val function = f(new Column(x), new Column(y)).expr
    LambdaFunction(function, Seq(x, y))
  }

  /**
   * Adds fields to a struct.
   * @group complex_type_manipulation
   * @since 0.3.0
   *
   * @param struct The struct to which fields will be added
   * @param fields The new fields to add. The arguments must alternate between string-typed literal field names and field values.
   * @return A struct consisting of the input struct and the added fields
   */
  def add_struct_fields(struct: Column, fields: Column*): Column = withExpr {
    new io.projectglow.sql.expressions.AddStructFields(struct.expr, fields.map(_.expr))
  }

  /**
   * Computes the minimum, maximum, mean, standard deviation for an array of numerics.
   * @group complex_type_manipulation
   * @since 0.3.0
   *
   * @param arr An array of any numeric type
   * @return A struct containing double ``mean``, ``stdDev``, ``min``, and ``max`` fields
   */
  def array_summary_stats(arr: Column): Column = withExpr {
    new io.projectglow.sql.expressions.ArrayStatsSummary(arr.expr)
  }

  /**
   * Converts an array of numerics into a ``spark.ml`` ``DenseVector``.
   * @group complex_type_manipulation
   * @since 0.3.0
   *
   * @param arr The array of numerics
   * @return A ``spark.ml`` ``DenseVector``
   */
  def array_to_dense_vector(arr: Column): Column = withExpr {
    new io.projectglow.sql.expressions.ArrayToDenseVector(arr.expr)
  }

  /**
   * Converts an array of numerics into a ``spark.ml`` ``SparseVector``.
   * @group complex_type_manipulation
   * @since 0.3.0
   *
   * @param arr The array of numerics
   * @return A ``spark.ml`` ``SparseVector``
   */
  def array_to_sparse_vector(arr: Column): Column = withExpr {
    new io.projectglow.sql.expressions.ArrayToSparseVector(arr.expr)
  }

  /**
   * Promotes fields of a nested struct to top-level columns similar to using ``struct.*`` from SQL, but can be used in more contexts.
   * @group complex_type_manipulation
   * @since 0.3.0
   *
   * @param struct The struct to expand
   * @return Columns corresponding to fields of the input struct
   */
  def expand_struct(struct: Column): Column = withExpr {
    new io.projectglow.sql.expressions.ExpandStruct(struct.expr)
  }

  /**
   * Explodes a ``spark.ml`` ``Matrix`` (sparse or dense) into multiple arrays, one per row of the matrix.
   * @group complex_type_manipulation
   * @since 0.3.0
   *
   * @param matrix The ``sparl.ml`` ``Matrix`` to explode
   * @return An array column in which each row is a row of the input matrix
   */
  def explode_matrix(matrix: Column): Column = withExpr {
    new io.projectglow.sql.expressions.ExplodeMatrix(matrix.expr)
  }

  /**
   * Selects fields from a struct.
   * @group complex_type_manipulation
   * @since 0.3.0
   *
   * @param struct Struct from which to select fields
   * @param fields Fields to select
   * @return A struct containing only the indicated fields
   */
  def subset_struct(struct: Column, fields: String*): Column = withExpr {
    new io.projectglow.sql.expressions.SubsetStruct(struct.expr, fields.map(Literal(_)))
  }

  /**
   * Converts a ``spark.ml`` ``Vector`` (sparse or dense) to an array of doubles.
   * @group complex_type_manipulation
   * @since 0.3.0
   *
   * @param vector Vector to convert
   * @return An array of doubles
   */
  def vector_to_array(vector: Column): Column = withExpr {
    new io.projectglow.sql.expressions.VectorToArray(vector.expr)
  }

  /**
   * Converts an array of probabilities to hard calls. The probabilities are assumed to be diploid. See :ref:`variant-data-transformations` for more details.
   * @group etl
   * @since 0.3.0
   *
   * @param probabilities The array of probabilities to convert
   * @param numAlts The number of alternate alleles
   * @param phased Whether the probabilities are phased. If phased, we expect one ``2 * numAlts`` values in the probabilities array. If unphased, we expect one probability per possible genotype.
   * @param threshold The minimum probability to make a call. If no probability falls into the range of ``[0, 1 - threshold]`` or ``[threshold, 1]``, a no-call (represented by ``-1`` s) will be emitted. If not provided, this parameter defaults to ``0.9``.
   * @return An array of hard calls
   */
  def hard_calls(probabilities: Column, numAlts: Column, phased: Column, threshold: Double): Column = withExpr {
    new io.projectglow.sql.expressions.HardCalls(probabilities.expr, numAlts.expr, phased.expr, Literal(threshold))
  }

  def hard_calls(probabilities: Column, numAlts: Column, phased: Column): Column = withExpr {
    new io.projectglow.sql.expressions.HardCalls(probabilities.expr, numAlts.expr, phased.expr)
  }

  /**
   * Performs liftover for the coordinates of a variant. To perform liftover of alleles and add additional metadata, see :ref:`liftover`.
   * @group etl
   * @since 0.3.0
   *
   * @param contigName The current contig name
   * @param start The current start
   * @param end The current end
   * @param chainFile Location of the chain file on each node in the cluster
   * @param minMatchRatio Minimum fraction of bases that must remap to do liftover successfully. If not provided, defaults to ``0.95``.
   * @return A struct containing ``contigName``, ``start``, and ``end`` fields after liftover
   */
  def lift_over_coordinates(contigName: Column, start: Column, end: Column, chainFile: String, minMatchRatio: Double): Column = withExpr {
    new io.projectglow.sql.expressions.LiftOverCoordinatesExpr(contigName.expr, start.expr, end.expr, Literal(chainFile), Literal(minMatchRatio))
  }

  def lift_over_coordinates(contigName: Column, start: Column, end: Column, chainFile: String): Column = withExpr {
    new io.projectglow.sql.expressions.LiftOverCoordinatesExpr(contigName.expr, start.expr, end.expr, Literal(chainFile))
  }

  /**
   * Normalizes the variant with a behavior similar to vt normalize or bcftools norm.
   * Creates a StructType column including the normalized ``start``, ``end``, ``referenceAllele`` and
   * ``alternateAlleles`` fields (whether they are changed or unchanged as the result of
   * normalization) as well as a StructType field called ``normalizationStatus`` that
   * contains the following fields:
   * 
   *    ``changed``: A boolean field indicating whether the variant data was changed as a result of normalization
   * 
   *    ``errorMessage``: An error message in case the attempt at normalizing the row hit an error. In this case, the ``changed`` field will be set to ``false``. If no errors occur, this field will be ``null``.
   * 
   * In case of an error, the ``start``, ``end``, ``referenceAllele`` and ``alternateAlleles`` fields in the generated struct will be ``null``.
   * 
   * @group etl
   * @since 0.3.0
   *
   * @param contigName The current contig name
   * @param start The current start
   * @param end The current end
   * @param refAllele The current reference allele
   * @param altAlleles The current array of alternate alleles
   * @param refGenomePathString A path to the reference genome ``.fasta`` file. The ``.fasta`` file must be accompanied with a ``.fai`` index file in the same folder.
   * @return A struct as explained above
   */
  def normalize_variant(contigName: Column, start: Column, end: Column, refAllele: Column, altAlleles: Column, refGenomePathString: String): Column = withExpr {
    new io.projectglow.sql.expressions.NormalizeVariantExpr(contigName.expr, start.expr, end.expr, refAllele.expr, altAlleles.expr, Literal(refGenomePathString))
  }

  /**
   * Substitutes the missing values of a numeric array using the mean of the non-missing values. Any values that are NaN, null or equal to the missing value parameter are considered missing. See :ref:`variant-data-transformations` for more details.
   * @group etl
   * @since 0.4.0
   *
   * @param array A numeric array that may contain missing values
   * @param missingValue A value that should be considered missing. If not provided, this parameter defaults to ``-1``.
   * @return A numeric array with substituted missing values
   */
  def mean_substitute(array: Column, missingValue: Column): Column = withExpr {
    new io.projectglow.sql.expressions.MeanSubstitute(array.expr, missingValue.expr)
  }

  def mean_substitute(array: Column): Column = withExpr {
    new io.projectglow.sql.expressions.MeanSubstitute(array.expr)
  }

  /**
   * Computes custom per-sample aggregates.
   * @group quality_control
   * @since 0.3.0
   *
   * @param arr array of values.
   * @param initialValue the initial value
   * @param update update function
   * @param merge merge function
   * @param evaluate evaluate function
   * @return An array of aggregated values. The number of elements in the array is equal to the number of samples.
   */
  def aggregate_by_index(arr: Column, initialValue: Column, update: (Column, Column) => Column, merge: (Column, Column) => Column, evaluate: Column => Column): Column = withExpr {
    new io.projectglow.sql.expressions.UnwrappedAggregateByIndex(arr.expr, initialValue.expr, createLambda(update), createLambda(merge), createLambda(evaluate))
  }

  def aggregate_by_index(arr: Column, initialValue: Column, update: (Column, Column) => Column, merge: (Column, Column) => Column): Column = withExpr {
    new io.projectglow.sql.expressions.UnwrappedAggregateByIndex(arr.expr, initialValue.expr, createLambda(update), createLambda(merge))
  }

  /**
   * Computes call summary statistics for an array of genotype structs. See :ref:`variant-qc` for more details.
   * @group quality_control
   * @since 0.3.0
   *
   * @param genotypes The array of genotype structs with ``calls`` field
   * @return A struct containing ``callRate``, ``nCalled``, ``nUncalled``, ``nHet``, ``nHomozygous``, ``nNonRef``, ``nAllelesCalled``, ``alleleCounts``, ``alleleFrequencies`` fields. See :ref:`variant-qc`.
   */
  def call_summary_stats(genotypes: Column): Column = withExpr {
    new io.projectglow.sql.expressions.CallStats(genotypes.expr)
  }

  /**
   * Computes summary statistics for the depth field from an array of genotype structs. See :ref:`variant-qc`.
   * @group quality_control
   * @since 0.3.0
   *
   * @param genotypes An array of genotype structs with ``depth`` field
   * @return A struct containing ``mean``, ``stdDev``, ``min``, and ``max`` of genotype depths
   */
  def dp_summary_stats(genotypes: Column): Column = withExpr {
    new io.projectglow.sql.expressions.DpSummaryStats(genotypes.expr)
  }

  /**
   * Computes statistics relating to the Hardy Weinberg equilibrium. See :ref:`variant-qc` for more details.
   * @group quality_control
   * @since 0.3.0
   *
   * @param genotypes The array of genotype structs with ``calls`` field
   * @return A struct containing two fields, ``hetFreqHwe`` (the expected heterozygous frequency according to Hardy-Weinberg equilibrium) and ``pValueHwe`` (the associated p-value)
   */
  def hardy_weinberg(genotypes: Column): Column = withExpr {
    new io.projectglow.sql.expressions.HardyWeinberg(genotypes.expr)
  }

  /**
   * Computes summary statistics about the genotype quality field for an array of genotype structs. See :ref:`variant-qc`.
   * @group quality_control
   * @since 0.3.0
   *
   * @param genotypes The array of genotype structs with ``conditionalQuality`` field
   * @return A struct containing ``mean``, ``stdDev``, ``min``, and ``max`` of genotype qualities
   */
  def gq_summary_stats(genotypes: Column): Column = withExpr {
    new io.projectglow.sql.expressions.GqSummaryStats(genotypes.expr)
  }

  /**
   * Computes per-sample call summary statistics. See :ref:`sample-qc` for more details.
   * @group quality_control
   * @since 0.3.0
   *
   * @param genotypes An array of genotype structs with ``calls`` field
   * @param refAllele The reference allele
   * @param alternateAlleles An array of alternate alleles
   * @return A struct containing ``sampleId``, ``callRate``, ``nCalled``, ``nUncalled``, ``nHomRef``, ``nHet``, ``nHomVar``, ``nSnp``, ``nInsertion``, ``nDeletion``, ``nTransition``, ``nTransversion``, ``nSpanningDeletion``, ``rTiTv``, ``rInsertionDeletion``, ``rHetHomVar`` fields. See :ref:`sample-qc`.
   */
  def sample_call_summary_stats(genotypes: Column, refAllele: Column, alternateAlleles: Column): Column = withExpr {
    new io.projectglow.sql.expressions.CallSummaryStats(genotypes.expr, refAllele.expr, alternateAlleles.expr)
  }

  /**
   * Computes per-sample summary statistics about the depth field in an array of genotype structs.
   * @group quality_control
   * @since 0.3.0
   *
   * @param genotypes An array of genotype structs with ``depth`` field
   * @return An array of structs where each struct contains ``mean``, ``stDev``, ``min``, and ``max`` of the genotype depths for a sample. If ``sampleId`` is present in a genotype, it will be propagated to the resulting struct as an extra field.
   */
  def sample_dp_summary_stats(genotypes: Column): Column = withExpr {
    new io.projectglow.sql.expressions.SampleDpSummaryStatistics(genotypes.expr)
  }

  /**
   * Computes per-sample summary statistics about the genotype quality field in an array of genotype structs.
   * @group quality_control
   * @since 0.3.0
   *
   * @param genotypes An array of genotype structs with ``conditionalQuality`` field
   * @return An array of structs where each struct contains ``mean``, ``stDev``, ``min``, and ``max`` of the genotype qualities for a sample. If ``sampleId`` is present in a genotype, it will be propagated to the resulting struct as an extra field.
   */
  def sample_gq_summary_stats(genotypes: Column): Column = withExpr {
    new io.projectglow.sql.expressions.SampleGqSummaryStatistics(genotypes.expr)
  }

  /**
   * Performs a linear regression association test optimized for performance in a GWAS setting. See :ref:`linear-regression` for details.
   * @group gwas_functions
   * @since 0.3.0
   *
   * @param genotypes A numeric array of genotypes
   * @param phenotypes A numeric array of phenotypes
   * @param covariates A ``spark.ml`` ``Matrix`` of covariates
   * @return A struct containing ``beta``, ``standardError``, and ``pValue`` fields. See :ref:`linear-regression`.
   */
  def linear_regression_gwas(genotypes: Column, phenotypes: Column, covariates: Column): Column = withExpr {
    new io.projectglow.sql.expressions.LinearRegressionExpr(genotypes.expr, phenotypes.expr, covariates.expr)
  }

  /**
   * Performs a logistic regression association test optimized for performance in a GWAS setting. See :ref:`logistic-regression` for more details.
   * @group gwas_functions
   * @since 0.3.0
   *
   * @param genotypes An numeric array of genotypes
   * @param phenotypes A double array of phenotype values
   * @param covariates A ``spark.ml`` ``Matrix`` of covariates
   * @param test Which logistic regression test to use. Can be ``LRT`` or ``Firth``
   * @param offset An optional double array of offset values. The offset vector is added with coefficient 1 to the linear predictor term X*b.
   * @return A struct containing ``beta``, ``oddsRatio``, ``waldConfidenceInterval``, and ``pValue`` fields. See :ref:`logistic-regression`.
   */
  def logistic_regression_gwas(genotypes: Column, phenotypes: Column, covariates: Column, test: String, offset: Column): Column = withExpr {
    new io.projectglow.sql.expressions.LogisticRegressionExpr(genotypes.expr, phenotypes.expr, covariates.expr, Literal(test), offset.expr)
  }

  def logistic_regression_gwas(genotypes: Column, phenotypes: Column, covariates: Column, test: String): Column = withExpr {
    new io.projectglow.sql.expressions.LogisticRegressionExpr(genotypes.expr, phenotypes.expr, covariates.expr, Literal(test))
  }

  /**
   * Gets the number of alternate alleles for an array of genotype structs. Returns ``-1`` if there are any ``-1`` s (no-calls) in the calls array.
   * @group gwas_functions
   * @since 0.3.0
   *
   * @param genotypes An array of genotype structs with ``calls`` field
   * @return An array of integers containing the number of alternate alleles in each call array
   */
  def genotype_states(genotypes: Column): Column = withExpr {
    new io.projectglow.sql.expressions.GenotypeStates(genotypes.expr)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy