All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.campagnelab.dl.somatic.mappers.AbstractFeatureMapper Maven / Gradle / Ivy

package org.campagnelab.dl.somatic.mappers;

import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import org.campagnelab.dl.framework.mappers.FeatureNameMapper;
import org.campagnelab.dl.framework.mappers.MappedDimensions;
import org.campagnelab.dl.somatic.genotypes.GenotypeCountFactory;
import org.campagnelab.dl.varanalysis.protobuf.BaseInformationRecords;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.util.Arrays;
import java.util.Collections;

/**
 * AbstractFeatureMapper encapsulates behavior common to many feature mappers.
 * Created by fac2003 on 6/3/16.
 *
 * @author Fabien Campagne
 */
public abstract class AbstractFeatureMapper implements FeatureNameMapper {
    public static final int MAX_GENOTYPES = 5;
    public static final int N_GENOTYPE_INDEX = 6;
    private float[] buffer;

    protected float[] getBuffer() {

        if (buffer == null) {
            buffer = new float[numberOfFeatures()];
        } else {
            Arrays.fill(buffer, 0f);
        }
        return buffer;
    }

    @Override
    public String getFeatureName(int featureIndex) {
        return null;
    }

    public boolean oneSampleHasTumor(java.util.List samples) {
        for (BaseInformationRecords.SampleInfo sample : samples) {
            if (sample.getIsTumor()) return true;
        }
        return false;

    }

    protected ObjectArrayList getAllCounts(BaseInformationRecords.BaseInformationOrBuilder record, GenotypeCountFactory factory, boolean isTumor) {
        return getAllCounts(record, factory, isTumor, true);
    }

    protected ObjectArrayList getAllCounts(BaseInformationRecords.BaseInformationOrBuilder record, GenotypeCountFactory factory, boolean isTumor, boolean sort) {
        int sampleIndex = isTumor ? 1 : 0;
        ObjectArrayList list = new ObjectArrayList<>();
        int genotypeIndex = 0;
        for (int i = 0; i < record.getSamples(0).getCountsCount(); i++) {
            int germCount = record.getSamples(0).getCounts(i).getGenotypeCountForwardStrand() + record.getSamples(0).getCounts(i).getGenotypeCountReverseStrand();
            BaseInformationRecords.CountInfo genoInfo = record.getSamples(sampleIndex).getCounts(i);
            int forwCount = genoInfo.getGenotypeCountForwardStrand();
            int revCount = genoInfo.getGenotypeCountReverseStrand();
            GenotypeCount count = factory.create();
            count.set(forwCount, revCount, genoInfo.getToSequence(), i, germCount);
            initializeCount(record.getSamples(sampleIndex).getCounts(i), count);
            list.add(count);
        }
        // DO not increment genotypeIndex. It must remain constant for all N bases
        int genotypeIndexFor_Ns = N_GENOTYPE_INDEX;
        // pad with zero until we have 10 elements:
        while (list.size() < MAX_GENOTYPES) {
            final GenotypeCount genotypeCount = getGenotypeCountFactory().create();
            genotypeCount.set(0, 0, "N", genotypeIndexFor_Ns, 0);
            list.add(genotypeCount);

        }
        //sort in decreasing order of counts:
        if (sort) {
            Collections.sort(list);
        }
        // trim the list at 5 elements because we will consider only the 5 genotypes with largest total counts:
        list.trim(MAX_GENOTYPES);

        return list;
    }


    protected abstract void initializeCount(BaseInformationRecords.CountInfo sampleCounts, GenotypeCount count);

    private BaseInformationRecords.BaseInformationOrBuilder recordCached[][] = new BaseInformationRecords.BaseInformationOrBuilder[2][2];
    private ObjectArrayList cachedResult[][] = new ObjectArrayList[2][2];

    protected ObjectArrayList getAllCounts(BaseInformationRecords.BaseInformationOrBuilder record,
                                                                    boolean isTumor, boolean sort) {
        ObjectArrayList cached = getCachedResult(isTumor, sort);
        if (cached != null && record.equals(recordCached[isTumor ? 1 : 0][sort ? 1 : 0])) {
            return cached;
        } else {

            assert oneSampleHasTumor(record.getSamplesList()) : "at least one sample must have hasTumor=true.";

            for (int i = 0; i < record.getSamplesCount(); i++) {
                if (isTumor != record.getSamples(i).getIsTumor()) continue;
                // a subclass is expected to override getGenotypeCountFactory to provide its own type for Genotype counts:
                cached = getAllCounts(record, getGenotypeCountFactory(), isTumor, sort);
                putInCache(record, cached, isTumor, sort);
                return cached;
            }
            throw new InternalError("At least one sample matching isTumor, and one matching not isTumor must be found.");
        }
    }

    private void putInCache(BaseInformationRecords.BaseInformationOrBuilder record, ObjectArrayList cached, boolean isTumor, boolean sort) {
        int index1 = isTumor ? 1 : 0;
        int index2 = sort ? 1 : 0;
        recordCached[index1][index2] = record;
        cachedResult[index1][index2] = cached;
    }

    private ObjectArrayList getCachedResult(boolean isTumor, boolean sort) {
        int index1 = isTumor ? 1 : 0;
        int index2 = sort ? 1 : 0;
        return cachedResult[index1][index2];
    }

    protected ObjectArrayList getAllCounts(BaseInformationRecords.BaseInformationOrBuilder record,
                                                                    boolean isTumor) {
        return getAllCounts(record, isTumor, true);
    }

    protected abstract GenotypeCountFactory getGenotypeCountFactory();

    /**
     * Default implementation assumes a 1-d tensor.
     *
     * @return
     */
    @Override
    public MappedDimensions dimensions() {
        return new MappedDimensions(numberOfFeatures());
    }

    @Override
    public boolean hasMask() {
        return false;
    }

    public void maskLabels(T record, INDArray mask, int indexOfRecord) {
        // do nothing implementation.
    }

    @Override
    public boolean isMasked(T record, int featureIndex) {
        return false;
    }

    @Override
    public void maskFeatures(BaseInformationRecords.BaseInformationOrBuilder record, INDArray mask, int indexOfRecord) {
         // do nothing implementation.
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy