All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.maizegenetics.pangenome.api.CreateGraphUtils Maven / Gradle / Ivy

There is a newer version: 1.10
Show newest version
package net.maizegenetics.pangenome.api;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedSet;
import net.maizegenetics.dna.map.Chromosome;
import net.maizegenetics.pangenome.db_loading.DBLoadingUtils;
import net.maizegenetics.taxa.TaxaList;
import net.maizegenetics.taxa.TaxaListBuilder;
import net.maizegenetics.taxa.Taxon;
import net.maizegenetics.util.Tuple;
import org.apache.log4j.Logger;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.*;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;

/**
 * @author Terry Casstevens Created August 21, 2017
 */
public class CreateGraphUtils {

    private static final Logger myLogger = Logger.getLogger(CreateGraphUtils.class);

    public static final String NO_CONSENSUS_METHOD = "Haplotype_caller";

    private CreateGraphUtils() {
        // utility
    }

    /**
     * Creates a database connection given a properties file
     *
     * @param propertiesFile properties file
     *
     * @return database connection
     */
    public static Connection connection(String propertiesFile) {

        //False indicates don't create db - if db doesn't exist, returns null
        return DBLoadingUtils.connection(propertiesFile, false);

    }

    /**
     * Creates a sqlite database connection.
     *
     * @param host hostname
     * @param user user id
     * @param password password
     * @param dbName database name
     *
     * @return SQLite database connection
     */
    public static Connection connection(String host, String user, String password, String dbName) {

        Connection connection = null;
        String url = "jdbc:sqlite:" + dbName;
        myLogger.info("Database URL: " + url);
        try {
            Class.forName("org.sqlite.JDBC");
            connection = DriverManager.getConnection(url, user, password);
        } catch (ClassNotFoundException e) {
            myLogger.error(e.getMessage(), e);
            throw new IllegalStateException("CreateGraph: connection: org.sqlite.JDBC can't be found");
        } catch (SQLException e) {
            myLogger.error(e.getMessage(), e);
            throw new IllegalStateException("CreateGraph: connection: problem connecting to database: " + e.getMessage());
        }
        myLogger.info("Connected to database:  " + url + "\n");
        return connection;

    }

    /**
     * Retrieves all ReferenceRange instances
     *
     * @param database database connection
     *
     * @return map of ReferenceRanges, key is references_ranges.ref_range_id
     */
    public static Map referenceRangeMap(Connection database) {

        if (database == null) {
            throw new IllegalArgumentException("CreateGraphUtils: referenceRangesAsMap: Must specify database connection.");
        }

        long time = System.nanoTime();

        // Create method name for querying initial ref region and inter-region ref_range_group method ids
        String refLine = getRefLineName(database);

        StringBuilder querySB = new StringBuilder();
        querySB.append("select reference_ranges.ref_range_id, chrom, range_start, range_end, methods.name from reference_ranges ");
        querySB.append(" INNER JOIN ref_range_ref_range_method on ref_range_ref_range_method.ref_range_id=reference_ranges.ref_range_id ");
        querySB.append(" INNER JOIN methods on ref_range_ref_range_method.method_id = methods.method_id ");
        querySB.append(" AND methods.method_type = ");
        querySB.append(DBLoadingUtils.MethodType.REF_RANGE_GROUP.getValue());
        querySB.append(" ORDER BY reference_ranges.ref_range_id");

        String query = querySB.toString();
        myLogger.info("referenceRangesAsMap: query statement: " + query);

        ImmutableMap.Builder builder = ImmutableMap.builder();
        try (ResultSet rs = database.createStatement().executeQuery(query)) {

            String currentChromosome = null;
            int currentStart = -1;
            int currentEnd = -1;
            int currentRefRangeId = -1;
            ImmutableSet.Builder methodNameSet = ImmutableSet.builder();
            while (rs.next()) {
                int id = rs.getInt("ref_range_id");
                String chromosome = rs.getString("chrom");
                int start = rs.getInt("range_start");
                int end = rs.getInt("range_end");
                String methodName = rs.getString("name");
                if (currentRefRangeId == -1) {
                    currentRefRangeId = id;
                    currentChromosome = chromosome;
                    currentStart = start;
                    currentEnd = end;
                    methodNameSet.add(methodName);
                } else if (currentRefRangeId == id) {
                    methodNameSet.add(methodName);
                } else {
                    builder.put(currentRefRangeId, new ReferenceRange(refLine, Chromosome.instance(currentChromosome), currentStart, currentEnd, currentRefRangeId, methodNameSet.build()));
                    methodNameSet = ImmutableSet.builder();
                    currentRefRangeId = id;
                    currentChromosome = chromosome;
                    currentStart = start;
                    currentEnd = end;
                    methodNameSet.add(methodName);
                }
            }

            ImmutableSet methods = methodNameSet.build();
            System.out.println("methods size: " + methods.size());
            if (!methods.isEmpty()) {
                builder.put(currentRefRangeId, new ReferenceRange(refLine, Chromosome.instance(currentChromosome), currentStart, currentEnd, currentRefRangeId, methods));
            }

        } catch (Exception se) {
            myLogger.debug(se.getMessage(), se);
            throw new IllegalStateException("CreateGraphUtils: referenceRanges: Problem querying the database: " + se.getMessage());
        }

        Map result = builder.build();

        myLogger.info("referenceRangesAsMap: number of reference ranges: " + result.size());

        myLogger.info("referenceRangesAsMap: time: " + ((double) (System.nanoTime() - time) / 1_000_000_000.0) + " secs.");

        return result;

    }

    /**
     * Retrieves all ReferenceRange instances with specified genome interval version name.
     *
     * @param database database connection
     *
     * @return ReferenceRanges
     */
    public static SortedSet referenceRanges(Connection database) {

        if (database == null) {
            throw new IllegalArgumentException("CreateGraphUtils: referenceRanges: Must specify database connection.");
        }

        long time = System.nanoTime();

        // Create method name for querying initial ref region and inter-region ref_range_group method ids
        String refLine = getRefLineName(database);

        StringBuilder querySB = new StringBuilder();
        querySB.append("select reference_ranges.ref_range_id, chrom, range_start, range_end, methods.name from reference_ranges ");
        querySB.append(" INNER JOIN ref_range_ref_range_method on ref_range_ref_range_method.ref_range_id=reference_ranges.ref_range_id ");
        querySB.append(" INNER JOIN methods on ref_range_ref_range_method.method_id = methods.method_id ");
        querySB.append(" AND methods.method_type = ");
        querySB.append(DBLoadingUtils.MethodType.REF_RANGE_GROUP.getValue());
        querySB.append(" ORDER BY reference_ranges.ref_range_id");

        String query = querySB.toString();
        myLogger.info("referenceRanges: query statement: " + query);

        ImmutableSortedSet.Builder builder = ImmutableSortedSet.naturalOrder();
        try (ResultSet rs = database.createStatement().executeQuery(query)) {

            String currentChromosome = null;
            int currentStart = -1;
            int currentEnd = -1;
            int currentRefRangeId = -1;
            ImmutableSet.Builder methodNameSet = ImmutableSet.builder();
            while (rs.next()) {
                int id = rs.getInt("ref_range_id");
                String chromosome = rs.getString("chrom");
                int start = rs.getInt("range_start");
                int end = rs.getInt("range_end");
                String methodName = rs.getString("name");
                if (currentRefRangeId == -1) {
                    currentRefRangeId = id;
                    currentChromosome = chromosome;
                    currentStart = start;
                    currentEnd = end;
                    methodNameSet.add(methodName);
                } else if (currentRefRangeId == id) {
                    methodNameSet.add(methodName);
                } else {
                    builder.add(new ReferenceRange(refLine, Chromosome.instance(currentChromosome), currentStart, currentEnd, currentRefRangeId, methodNameSet.build()));
                    methodNameSet = ImmutableSet.builder();
                    currentRefRangeId = id;
                    currentChromosome = chromosome;
                    currentStart = start;
                    currentEnd = end;
                    methodNameSet.add(methodName);
                }
            }

            ImmutableSet methods = methodNameSet.build();
            System.out.println("methods size: " + methods.size());
            if (!methods.isEmpty()) {
                builder.add(new ReferenceRange(refLine, Chromosome.instance(currentChromosome), currentStart, currentEnd, currentRefRangeId, methods));
            }

        } catch (Exception se) {
            myLogger.debug(se.getMessage(), se);
            throw new IllegalStateException("CreateGraphUtils: referenceRanges: Problem querying the database: " + se.getMessage());
        }

        SortedSet result = builder.build();

        myLogger.info("referenceRanges: number of reference ranges: " + result.size());

        myLogger.info("referenceRanges: time: " + ((double) (System.nanoTime() - time) / 1_000_000_000.0) + " secs.");

        return result;

    }

    /**
     * Retrieves all groups of taxa.
     *
     * @param database database connection
     *
     * @return map of TaxaList, key is gamete_groups.gamete_grp_id
     */
    public static Map taxaListMap(Connection database) {

        if (database == null) {
            throw new IllegalArgumentException("CreateGraphUtils: taxaListMap: Must specify database connection.");
        }

        long time = System.nanoTime();

        //
        // select gamete_haplotypes.gamete_grp_id, genotypes.line_name
        // from gamete_haplotypes
        // inner join gametes on gamete_haplotypes.gameteid = gametes.gameteid
        // inner join genotypes on gametes.genoid = genotypes.genoid
        // order by gamete_haplotypes.gamete_grp_id;
        //

        StringBuilder builder = new StringBuilder();
        builder.append("SELECT gamete_haplotypes.gamete_grp_id, genotypes.line_name ");
        builder.append("FROM gamete_haplotypes ");
        builder.append("INNER JOIN gametes ON gamete_haplotypes.gameteid = gametes.gameteid ");
        builder.append("INNER JOIN genotypes on gametes.genoid = genotypes.genoid ");
        builder.append("ORDER BY gamete_haplotypes.gamete_grp_id;");

        String query = builder.toString();
        myLogger.info("taxaListMap: query statement: " + query);


        try (ResultSet rs = database.createStatement().executeQuery(query)) {

            Map taxaCache = new HashMap<>();
            ImmutableMap.Builder resultBuilder = ImmutableMap.builder();

            int currentGroupId = -1;
            List currentTaxaList = new ArrayList<>();
            if (rs.next()) {
                currentGroupId = rs.getInt("gamete_grp_id");
                currentTaxaList.add(taxon(rs.getString("line_name"), taxaCache));
            }

            while (rs.next()) {
                int groupId = rs.getInt("gamete_grp_id");
                if (groupId == currentGroupId) {
                    currentTaxaList.add(taxon(rs.getString("line_name"), taxaCache));
                } else {
                    TaxaListBuilder taxaBuilder = new TaxaListBuilder();
                    taxaBuilder.addAll(currentTaxaList);
                    resultBuilder.put(currentGroupId, taxaBuilder.build());
                    currentGroupId = groupId;
                    currentTaxaList = new ArrayList<>();
                    currentTaxaList.add(taxon(rs.getString("line_name"), taxaCache));
                }
            }

            TaxaListBuilder taxaBuilder = new TaxaListBuilder();
            taxaBuilder.addAll(currentTaxaList);
            resultBuilder.put(currentGroupId, taxaBuilder.build());

            Map result = resultBuilder.build();

            myLogger.info("taxaListMap: number of taxa lists: " + result.size());

            myLogger.info("taxaListMap: time: " + ((double) (System.nanoTime() - time) / 1_000_000_000.0) + " secs.");

            return result;

        } catch (Exception e) {
            myLogger.debug(e.getMessage(), e);
            throw new IllegalStateException("CreateGraphUtils: taxaListMap: Problem querying the database: " + e.getMessage());
        }

    }

    /**
     * Gets Taxon from name.  Reuses Taxon instance if already created.
     *
     * @param name taxon name
     * @param taxaCache taxa cache
     *
     * @return Taxon instance
     */
    private static Taxon taxon(String name, Map taxaCache) {

        Taxon result = taxaCache.get(name);
        if (result == null) {
            result = new Taxon(name);
            taxaCache.put(name, result);
        }
        return result;

    }


    /**
     * Creates lists of HaplotypeNodes organized by reference Range based on the given method.
     *
     * @param database database connection
     * @param referenceRangeMap ReferenceRange map ({@link #referenceRangeMap(Connection)}
     * @param taxaListMap TaxaList map {@link #taxaListMap(Connection)}
     * @param methods
     * @param includeVariantContext whether to include variant contexts in haplotype nodes
     * @param includeHapids includes specified hapids. include everything if null
     * @return Map of HaplotypeNode Lists (keys are ReferenceRange)
     */

    private static final String ALL_CHROMOSOMES = "ALL_CHROMOSOMES";

    public static TreeMap> createHaplotypeNodes(Connection database, Map referenceRangeMap,
                                                                                    Map taxaListMap, List> methods,
                                                                                    boolean includeSequences, boolean includeVariantContext,
                                                                                    SortedSet includeHapids,
                                                                                    List chromosomes, TaxaList taxaToKeep) {

        if (database == null) {
            throw new IllegalArgumentException("CreateGraphUtils: createHaplotypeNodes: Must specify database connection.");
        }

        if (includeHapids != null && !includeHapids.isEmpty() && taxaToKeep != null && !taxaToKeep.isEmpty()) {
            throw new IllegalStateException("CreateGraphUtils: createHaplotypeNodes: can't specify both hapids and taxa");
        }

        if (methods == null && includeHapids != null && !includeHapids.isEmpty()) {
            return createHaplotypeNodes(database, referenceRangeMap, taxaListMap, includeSequences, includeVariantContext,
                    includeHapids, chromosomes);
        } else if (methods == null) {
            throw new IllegalArgumentException("CreateGraphUtils: createHaplotypeNodes: either methods or haplotypeIds must be specified.");
        }

        long time = System.nanoTime();

        TreeMap> result = new TreeMap<>();
        for (Tuple methodPair : methods) {

            String haplotypeMethod = methodPair.x;
            if (haplotypeMethod == null || haplotypeMethod.isEmpty()) {
                throw new IllegalArgumentException("CreateGraphUtils: createHaplotypeNodes: haplotype method must be specified.");
            }

            String rangeGroupMethod = methodPair.y;
            if (rangeGroupMethod == null || rangeGroupMethod.isEmpty()) {
                rangeGroupMethod = null;
            }

            myLogger.info("createHaplotypeNodes: haplotype method: " + haplotypeMethod + " range group method: " + rangeGroupMethod);

            int methodId = methodId(database, haplotypeMethod);

            //
            // select gamete_grp_id, ref_range_id, sequence, seq_hash, variant_list
            // from haplotype where method_id = method_id
            // AND haplotypes_id in (11945906, 11945907, 11945909)
            //

            StringBuilder builder = new StringBuilder();
            builder.append("SELECT haplotypes_id, gamete_grp_id, haplotypes.ref_range_id, asm_contig, asm_start_coordinate," +
                    " asm_end_coordinate, asm_strand, genome_file_id");
            if (includeSequences) {
                builder.append(", sequence");
            }
            builder.append(", seq_hash, seq_len");
            if (includeVariantContext) {
                builder.append(", gvcf_file_id");
            }
            builder.append(" FROM haplotypes ");

            // If getting subset by chromosomes, join with reference_ranges table
            // because that's where chromosome is defined.
            if (chromosomes != null && !chromosomes.isEmpty()) {
                builder.append("inner join reference_ranges on haplotypes.ref_range_id = reference_ranges.ref_range_id ");
            }
            builder.append("WHERE method_id = ");
            builder.append(methodId);

            // Add clause to query only chromosomes specified
            if (chromosomes != null && !chromosomes.isEmpty()) {
                StringJoiner joiner = new StringJoiner(",");
                chromosomes.stream().map(s -> "'" + s + "'").forEach(s -> joiner.add(s));
                builder.append(" AND chrom in (");
                builder.append(joiner.toString());
                builder.append(")");
            }

            if (includeHapids != null && !includeHapids.isEmpty()) {
                builder.append(" AND haplotypes_id in (");
                boolean notFirst = false;
                for (int id : includeHapids) {
                    if (notFirst) {
                        builder.append(", ");
                    } else {
                        notFirst = true;
                    }
                    builder.append(id);
                }
                builder.append(")");
            }

            builder.append(";");

            String query = builder.toString();
            String msg = "createHaplotypeNodes: query statement: " + query;
            if (msg.length() > 200) msg = msg.substring(0, 200) + "...";
            myLogger.info(msg);

            addNodes(result, database, query, referenceRangeMap, taxaListMap, includeSequences, includeVariantContext,
                    rangeGroupMethod, taxaToKeep);

        }

        myLogger.info("createHaplotypeNodes: time: " + ((double) (System.nanoTime() - time) / 1_000_000_000.0) + " secs.");

        if (includeHapids != null && !includeHapids.isEmpty()) {
            warnIfMissingHapids(includeHapids, result);
        }

        return result;

    }

    private static TreeMap> createHaplotypeNodes(Connection database, Map referenceRangeMap, Map taxaListMap, boolean includeSequences,
                                                                                     boolean includeVariantContext,
                                                                                     SortedSet includeHapids,
                                                                                     List chromosomes) {

        if (includeHapids == null || includeHapids.isEmpty()) {
            throw new IllegalArgumentException("CreateGraphUtils: createHaplotypeNodes: haplotypeIds must be specified.");
        }

        long time = System.nanoTime();

        TreeMap> result = new TreeMap<>();

        //
        // select gamete_grp_id, ref_range_id, sequence, seq_hash, variant_list
        // from haplotype where haplotypes_id in (11945906, 11945907, 11945909)
        //

        StringBuilder builder = new StringBuilder();
        builder.append("SELECT haplotypes_id, gamete_grp_id, haplotypes.ref_range_id, asm_contig, asm_start_coordinate, " +
                "asm_end_coordinate, asm_strand, genome_file_id");
        if (includeSequences) {
            builder.append(", sequence");
        }
        builder.append(", seq_hash, seq_len");
        if (includeVariantContext) {
            builder.append(", gvcf_file_id");
        }
        builder.append(" FROM haplotypes ");

        // If getting subset by chromosomes, join with reference_ranges table
        // because that's where chromosome is defined.
        if (chromosomes != null && !chromosomes.isEmpty()) {
            builder.append("inner join reference_ranges on haplotypes.ref_range_id = reference_ranges.ref_range_id ");
        }

        builder.append("WHERE haplotypes_id in (");
        boolean notFirst = false;
        for (int id : includeHapids) {
            if (notFirst) {
                builder.append(", ");
            } else {
                notFirst = true;
            }
            builder.append(id);
        }
        builder.append(")");

        // Add clause to query only chromosomes specified
        if (chromosomes != null && !chromosomes.isEmpty()) {
            StringJoiner joiner = new StringJoiner(",");
            chromosomes.stream().map(s -> "'" + s + "'").forEach(s -> joiner.add(s));
            builder.append(" AND chrom in (");
            builder.append(joiner.toString());
            builder.append(")");
        }

        builder.append(";");

        String query = builder.toString();
        String msg = "createHaplotypeNodes: query statement: " + query;
        if (msg.length() > 200) msg = msg.substring(0, 200) + "...";
        myLogger.info(msg);

        addNodes(result, database, query, referenceRangeMap, taxaListMap, includeSequences, includeVariantContext, null, null);

        myLogger.info("createHaplotypeNodes: time: " + ((double) (System.nanoTime() - time) / 1_000_000_000.0) + " secs.");

        warnIfMissingHapids(includeHapids, result);

        return result;

    }

    private static void warnIfMissingHapids(SortedSet includeHapids, TreeMap> result) {

        HashSet resultIDs = result.entrySet().stream()
                .map(entry -> entry.getValue())
                .flatMap(List::stream)
                .map(node -> node.id())
                .collect(Collectors.toCollection(HashSet::new));

        long numMissing = includeHapids.stream()
                .filter(id -> !resultIDs.contains(id))
                .count();

        if (numMissing != 0) {
            myLogger.warn("warnIfMissingHapids: the graph is missing this number of specified hapids : " + numMissing);
        }

    }

    /**
     * Creates lists of HaplotypeNodes with variant contexts corresponding to the specified nodes organized by reference
     * Range.
     * LCJ June 16, 2022- this function is only called from MergeGVCFPlugin:extractNodesWithVariants()
     * MergeGVCFPlugin has been deprecated, so I am not making changes here to pull gvcf
     * files for the variants.
     *
     * If this function is called at some point from a non-deprecated method, we
     * can re-look at what changes are needed.
     *
     * @param database database connection
     * @param includeHapNodes includes specified hapids
     *
     * @return Map of HaplotypeNode Lists (keys are ReferenceRange)
     */
    public static TreeMap> createHaplotypeNodesWithVariants(Connection database,
                                                                                                Set includeHapNodes) {

        if (database == null) {
            throw new IllegalArgumentException("CreateGraphUtils: createHaplotypeNodesWithVariants: Must specify database connection.");
        }

        if (includeHapNodes == null || includeHapNodes.isEmpty()) {
            throw new IllegalArgumentException("CreateGraphUtils: createHaplotypeNodesWithVariants: Must specify at least one haplotype node to include.");
        }

        long time = System.nanoTime();

        Map nodeMap = new HashMap<>();
        for (HaplotypeNode node : includeHapNodes) {
            nodeMap.put(node.id(), node);
        }

        TreeMap> result = getNodesWithVariants(database, nodeMap);

        myLogger.info("createHaplotypeNodesWithVariants: number of reference ranges: " + result.size());

        myLogger.info("createHaplotypeNodesWithVariants: time: " + ((double) (System.nanoTime() - time) / 1_000_000_000.0) + " secs.");

        return result;

    }

    /**
     * Creates HaplotypeGraph with variant contexts corresponding to the given HaplotypeGraph.
     *
     * @param database database connection
     * @param graph graph without variant contexts
     *
     * @return graph with variant contexts
     */
    public static HaplotypeGraph createHaplotypeNodesWithVariants(Connection database, HaplotypeGraph graph) {

        if (database == null) {
            throw new IllegalArgumentException("CreateGraphUtils: createHaplotypeNodesWithVariants: Must specify database connection.");
        }

        if (graph == null) {
            throw new IllegalArgumentException("CreateGraphUtils: createHaplotypeNodesWithVariants: Must specify haplotype graph.");
        }

        long time = System.nanoTime();

        Map nodeMap = graph.nodeStream().parallel()
                .collect(() -> new HashMap<>(),
                        (nodeMap1, node) -> nodeMap1.put(node.id(), node),
                        (BiConsumer, Map>) (nodeMap01, nodeMap02) -> nodeMap01.putAll(nodeMap02));

        TreeMap> nodes = getNodesWithVariants(database, nodeMap);

        HaplotypeGraph result = new HaplotypeGraph(createEdges(nodes));

        myLogger.info("createHaplotypeNodesWithVariants: time: " + ((double) (System.nanoTime() - time) / 1_000_000_000.0) + " secs.");

        return result;

    }

    /**
     * Get nodes with variants corresponding to nodes in map.
     *
     * @param database database connection
     * @param nodeMap node map
     *
     * @return tree map of reference range to list of nodes.
     */
    private static TreeMap> getNodesWithVariants(Connection database, Map nodeMap) {

        //
        // select haplotypes_id, variant_list
        // from haplotype where haplotypes_id in (11945906, 11945907, 11945909)
        //

        // Set timeout in case we are writing while we are running select
        try {
            database.createStatement().executeQuery("pragma busy_timeout=300000;");
        } catch (Exception e) {
            myLogger.warn("CreateGraphUtils: getNodesWithVariants: Unable to set the timeout.");
        }

        StringBuilder builder = new StringBuilder();
        builder.append("SELECT haplotypes_id, gvcf_file_id, asm_contig, asm_start_coordinate, asm_end_coordinate,asm_strand, genome_file_id");
        builder.append(" FROM haplotypes WHERE ");

        builder.append(" haplotypes_id in (");
        boolean notFirst = false;
        for (Integer id : nodeMap.keySet()) {
            if (notFirst) {
                builder.append(", ");
            } else {
                notFirst = true;
            }
            builder.append(id);
        }
        builder.append(")");
        builder.append(";");

        String query = builder.toString();
        myLogger.info("getNodesWithVariants: query statement: " + query);

        TreeMap> result = new TreeMap<>();

        try (ResultSet rs = database.createStatement().executeQuery(query)) {

            while (rs.next()) {
                int hapId = rs.getInt("haplotypes_id");
                String asmContig = rs.getString("asm_contig");
                int asmStart = rs.getInt("asm_start_coordinate");
                int asmEnd = rs.getInt("asm_end_coordinate");
                String asmStrand = rs.getString("asm_strand");
                int genomeFileID = rs.getInt("genome_file_id");
                int gvcfFileID = rs.getInt("gvcf_file_id");

                HaplotypeNode existingNode = nodeMap.get(hapId);
                if (existingNode == null) {
                    throw new IllegalStateException("CreateGraphUtils: getNodesWithVariants: includeHapNodes doesn't have id: " + hapId);
                }

                ReferenceRange refRange = existingNode.referenceRange();
                TaxaList taxa = existingNode.taxaList();
                HaplotypeSequence hapSeq = existingNode.haplotypeSequence();

                List nodes = result.get(refRange);
                if (nodes == null) {
                    nodes = new ArrayList<>();
                    result.put(refRange, nodes);
                }

                nodes.add(new HaplotypeNode(hapSeq, taxa, hapId, asmContig, asmStart, asmEnd, asmStrand, genomeFileID, gvcfFileID));
            }

            return result;

        } catch (Exception e) {
            myLogger.debug(e.getMessage(), e);
            throw new IllegalStateException("CreateGraphUtils: getNodesWithVariants: Problem querying the database: " + e.getMessage());
        }

    }

    private static int addNodes(TreeMap> result, Connection database, String query,
                                Map referenceRangeMap, Map taxaListMap,
                                boolean includeSequences, boolean includeVariantContext, String rangeGroupMethod,
                                TaxaList taxaToKeep) {

        myLogger.info("CreateGraphUtils:addNodes - query=" + query);
        try (ResultSet rs = database.createStatement().executeQuery(query)) {

            int numNodes = 0;

            Set rangesAdded = new HashSet<>();
            while (rs.next()) {

                int gameteGrp = rs.getInt("gamete_grp_id");

                TaxaList taxa = taxaListMap.get(gameteGrp);
                if (taxa == null) {
                    throw new IllegalStateException("CreateGraphUtils: addNodes: no taxa list for gamete_grp_id: " + gameteGrp);
                }

                if (taxaToKeep != null && !taxaToKeep.isEmpty()) {
                    taxa = taxa.stream()
                            .filter(taxaToKeep::contains)
                            .collect(TaxaList.collect());
                    if (taxa.isEmpty()) continue;
                }

                int hapId = rs.getInt("haplotypes_id");

                String asmContig = rs.getString("asm_contig");
                int asmStart = rs.getInt("asm_start_coordinate");
                int asmEnd = rs.getInt("asm_end_coordinate");
                String asmStrand = rs.getString("asm_strand");
                int genomeFileID = rs.getInt("genome_file_id");
                int refRangeId = rs.getInt("ref_range_id");
                byte[] sequence = null;
                if (includeSequences) {
                    sequence = rs.getBytes("sequence");
                }
                String seqHash = rs.getString("seq_hash");
                int seqLen = rs.getInt("seq_len");
                int gvcfFileId = -1;
                if (includeVariantContext) {
                    gvcfFileId = rs.getInt("gvcf_file_id");
                }

                ReferenceRange refRange = referenceRangeMap.get(refRangeId);
                if (refRange == null) {
                    throw new IllegalStateException("CreateGraphUtils: addNodes: no reference range in map for ref_range_id: " + refRangeId);
                }

                if (rangeGroupMethod != null && !refRange.isPartOf(rangeGroupMethod)) {
                    continue;
                }

                rangesAdded.add(refRangeId);

                HaplotypeSequence hapSeq = HaplotypeSequence.getInstance(sequence, refRange, 0.0, seqHash, seqLen);

                List nodes = result.get(refRange);
                if (nodes == null) {
                    nodes = new ArrayList<>();
                    result.put(refRange, nodes);
                }
                if (includeVariantContext) {
                    nodes.add(new HaplotypeNode(hapSeq, taxa, hapId, asmContig, asmStart, asmEnd, asmStrand, genomeFileID, gvcfFileId));
                } else {
                    nodes.add(new HaplotypeNode(hapSeq, taxa, hapId, asmContig, asmStart, asmEnd, asmStrand, genomeFileID, gvcfFileId));
                }
                numNodes++;
            } // end processing db query results

            myLogger.info("addNodes: number of nodes: " + numNodes);
            myLogger.info("addNodes: number of reference ranges: " + rangesAdded.size());

            return numNodes;

        } catch (Exception e) {
            myLogger.debug(e.getMessage(), e);
            throw new IllegalStateException("CreateGraphUtils: addNodes: Problem querying the database: " + e.getMessage());
        }

    }

    public static HaplotypeGraph addMissingSequenceNodes(HaplotypeGraph graph) {
        TreeMap> tree = tree(graph);
        addMissingSequenceNodes(tree);
        return new HaplotypeGraph(createEdges(tree));
    }

    public static int addMissingSequenceNodes(TreeMap> result) {

        TaxaList allTaxa = taxaInNodes(result);
        int numberMissingNodesAdded = 0;

        for (ReferenceRange range : result.keySet()) {

            Set taxa = new TreeSet<>();
            taxa.addAll(allTaxa);
            for (HaplotypeNode node : result.get(range)) {
                for (Taxon taxon : node.taxaList()) {
                    taxa.remove(taxon);
                }
            }

            if (!taxa.isEmpty()) {
                TaxaListBuilder builder = new TaxaListBuilder();
                builder.addAll(taxa);
                HaplotypeSequence seq = HaplotypeSequence.getInstance("NNNNNNNNN", range, 0.0, "NNNNNNNNN");
                HaplotypeNode missingNode = new HaplotypeNode(seq, builder.build());
                result.get(range).add(missingNode);
                numberMissingNodesAdded++;
            }

        }

        return numberMissingNodesAdded;

    }

    public static TaxaList taxaInNodes(TreeMap> nodes) {

        Set taxa = new TreeSet<>();

        for (List nodeList : nodes.values()) {
            for (HaplotypeNode node : nodeList) {
                for (Taxon taxon : node.taxaList()) {
                    taxa.add(taxon);
                }
            }
        }

        TaxaListBuilder builder = new TaxaListBuilder();
        builder.addAll(taxa);
        return builder.build();

    }

    public static TreeMap> tree(HaplotypeGraph graph) {

        TreeMap> result = new TreeMap<>();
        graph.referenceRangeStream()
                .forEach(range -> {
                    List nodes = new ArrayList<>();
                    nodes.addAll(graph.nodes(range));
                    result.put(range, nodes);
                });
        return result;

    }

    public static HaplotypeGraph nodesSplitByIndividualTaxa(HaplotypeGraph graph, double sameTaxonPercent) {

        if (sameTaxonPercent < 0.0 || sameTaxonPercent > 1.0) {
            throw new IllegalArgumentException("CreateGraphUtils: nodesSplitByIndividualTaxa: sameTaxonPercent should be between 0.0 and 1.0: " + sameTaxonPercent);
        }

        TreeMap> rangeToNode = new TreeMap<>();
        graph.nodeStream()
                .forEach(node -> {
                    ReferenceRange range = node.referenceRange();
                    for (Taxon taxon : node.taxaList()) {
                        TaxaListBuilder singleTaxon = new TaxaListBuilder();
                        singleTaxon.add(taxon);
                        HaplotypeNode newNode = new HaplotypeNode(node.haplotypeSequence(), singleTaxon.build(), node.id(), node.asmContig(), node.asmStart(), node.asmEnd(), node.asmStrand(), node.genomeFileID(), node.gvcfFileID());
                        List nodeList = rangeToNode.get(range);
                        if (nodeList == null) {
                            nodeList = new ArrayList<>();
                            rangeToNode.put(range, nodeList);
                        }
                        nodeList.add(newNode);
                    }
                });

        return new HaplotypeGraph(createEdgesFullyConnectedSingleTaxonNodes(rangeToNode, sameTaxonPercent));

    }

    private static List createEdgesFullyConnectedSingleTaxonNodes(TreeMap> rangeToNode, double sameTaxonPercent) {

        myLogger.info("createEdgesFullyConnected: creating edges from nodes.");

        long time = System.nanoTime();

        List result = new ArrayList<>();

        List leftNodes = null;
        int numLeftNodes = 0;
        Chromosome leftChr = null;
        for (Map.Entry> entry : rangeToNode.entrySet()) {

            // for first reference range
            if (leftNodes == null) {
                leftNodes = entry.getValue();
                numLeftNodes = leftNodes.size();
                leftChr = entry.getKey().chromosome();
            } else {

                List rightNodes = entry.getValue();
                int numRightNodes = rightNodes.size();

                // If transitioning to different chromosome, then make no edges
                if (!leftChr.equals(entry.getKey().chromosome())) {
                    leftNodes = rightNodes;
                    numLeftNodes = leftNodes.size();
                    leftChr = entry.getKey().chromosome();
                    continue;
                }

                for (int l = 0; l < numLeftNodes; l++) {

                    HaplotypeNode left = leftNodes.get(l);

                    if (left.numTaxa() != 1) {
                        throw new IllegalStateException("CreateGraphUtils: createEdgesFullyConnectedSingleTaxonNodes: all nodes must have one taxon: " + left.numTaxa());
                    }

                    Taxon currentTaxon = left.taxaList().get(0);

                    double percentForNonIdentityNodes = (1.0 - sameTaxonPercent) / (double) (numRightNodes - 1);

                    boolean currentTaxonFound = false;
                    for (int r = 0; r < numRightNodes; r++) {

                        HaplotypeNode right = rightNodes.get(r);

                        if (right.taxaList().contains(currentTaxon)) {
                            if (currentTaxonFound) {
                                throw new IllegalStateException("CreateGraphUtils: createEdgesFullyConnectedSingleTaxonNodes: Taxon already found.");
                            }
                            result.add(new HaplotypeEdge(left, right, sameTaxonPercent));
                            currentTaxonFound = true;
                        } else {
                            result.add(new HaplotypeEdge(left, right, percentForNonIdentityNodes));
                        }

                    }

                    if (!currentTaxonFound) {
                        throw new IllegalStateException("CreateGraphUtils: createEdgesFullyConnectedSingleTaxonNodes: Taxon not found.");
                    }

                }

                // right nodes become left nodes to progress to next edges
                leftNodes = rightNodes;
                numLeftNodes = leftNodes.size();

            }

        }

        myLogger.info("createEdgesFullyConnected: time: " + ((double) (System.nanoTime() - time) / 1_000_000_000.0) + " secs.");

        return result;

    }

    public static TreeMap> createHaplotypeNodes(Connection database, List> methods,
                                                                                    boolean includeSequences, boolean includeVariantContext,
                                                                                    SortedSet includeHapids,
                                                                                    List chromosomes, TaxaList taxaToKeep) {

        Map referenceRangeMap = referenceRangeMap(database);
        Map taxaListMap = taxaListMap(database);
        TreeMap> result = createHaplotypeNodes(database, referenceRangeMap, taxaListMap,
                methods, includeSequences, includeVariantContext, includeHapids, chromosomes, taxaToKeep);
        return result;

    }

    /**
     * Generates Edges based on HaplotypeNodes.  Database information not used.
     *
     * @param haplotypeNodes HaplotypeNodes
     *
     * @return Generated HaplotypeEdges
     */
    public static List createEdges(Collection haplotypeNodes) {

        TreeMap> rangeToNode = new TreeMap<>();
        for (HaplotypeNode node : haplotypeNodes) {

            ReferenceRange range = node.referenceRange();
            List nodeList = rangeToNode.get(range);
            if (nodeList == null) {
                nodeList = new ArrayList<>();
                rangeToNode.put(range, nodeList);
                nodeList.add(node);
            } else if (!nodeList.contains(node)) {
                nodeList.add(node);
            }

        }

        return createEdges(rangeToNode);
    }

    public static List createEdges(NavigableMap> rangeToNode) {

        myLogger.info("createEdges: creating edges from nodes.");

        long time = System.nanoTime();

        List result = new ArrayList<>();

        List leftNodes = null;
        for (Map.Entry> entry : rangeToNode.entrySet()) {

            // for first reference range
            if (leftNodes == null) {
                leftNodes = entry.getValue();
            } else {
                List rightNodes = entry.getValue();
                result.addAll(createEdges(leftNodes, rightNodes));
                leftNodes = rightNodes;
            }

        }

        myLogger.info("createEdges: time: " + ((double) (System.nanoTime() - time) / 1_000_000_000.0) + " secs.");

        return result;

    }

    public static List createEdges(List leftNodes, List rightNodes) {

        if (leftNodes == null || leftNodes.size() == 0 || rightNodes == null || rightNodes.size() == 0) {
            return Collections.EMPTY_LIST;
        }

        Chromosome leftChr = leftNodes.get(0).referenceRange().chromosome();
        Chromosome rightChr = rightNodes.get(0).referenceRange().chromosome();
        if (!leftChr.equals(rightChr)) {
            return Collections.EMPTY_LIST;
        }

        int numRightNodes = rightNodes.size();
        int numLeftNodes = leftNodes.size();

        // calculate total number of taxa in right range
        int totalRightTaxa = 0;
        for (HaplotypeNode right : rightNodes) {
            totalRightTaxa += right.numTaxa();
        }

        double[][] possibleEdges = new double[numLeftNodes][numRightNodes];
        Set rightNodesWithOutEdge = new HashSet<>();
        rightNodesWithOutEdge.addAll(rightNodes);


        // Create map of right side nodes' taxa to help
        // find common taxa with left side nodes
        Map taxonToNode = new HashMap<>();
        for (int r = 0; r < numRightNodes; r++) {
            for (Taxon taxon : rightNodes.get(r).taxaList()) {
                taxonToNode.put(taxon.getName(), r);
            }
        }

        // Add one to each left node / right node
        // combination with matching taxon
        for (int l = 0; l < numLeftNodes; l++) {
            HaplotypeNode left = leftNodes.get(l);
            for (Taxon taxon : left.taxaList()) {
                Integer rightIndex = taxonToNode.get(taxon);
                if (rightIndex != null) {
                    possibleEdges[l][rightIndex]++;
                }
            }
        }

        // Dividing each count by number of left node taxa
        // to get probability.
        for (int l = 0; l < numLeftNodes; l++) {
            double numLeftTaxa = (double) leftNodes.get(l).numTaxa();
            double totalProbabilityRemaining = 1.0;
            for (int r = 0; r < numRightNodes; r++) {
                if (possibleEdges[l][r] != 0.0) {
                    possibleEdges[l][r] /= numLeftTaxa;
                    totalProbabilityRemaining -= possibleEdges[l][r];
                    rightNodesWithOutEdge.remove(rightNodes.get(r));
                }
            }

            // if total probability of any left node's out going
            // edges doesn't equal 100%, then add remaining to each edge
            if (totalProbabilityRemaining > 0.001) {

                for (int r = 0; r < numRightNodes; r++) {
                    HaplotypeNode right = rightNodes.get(r);
                    double probability = (double) right.numTaxa() / (double) totalRightTaxa * totalProbabilityRemaining;
                    rightNodesWithOutEdge.remove(right);
                    possibleEdges[l][r] += probability;

                }

            }
        }

        // Doing this to make sure all right nodes have an incoming edge.
        if (!rightNodesWithOutEdge.isEmpty()) {

            // small percent to remove from existing edges to distribute
            // to all edges (new and existing) based on number of taxa in
            // right nodes.  this is to make sure at least one edge goes
            // to all right nodes.
            double smallPercent = 0.1;

            for (int l = 0; l < leftNodes.size(); l++) {

                for (int r = 0; r < numRightNodes; r++) {
                    HaplotypeNode right = rightNodes.get(r);
                    double probability = (double) right.numTaxa() / (double) totalRightTaxa * smallPercent;
                    possibleEdges[l][r] = possibleEdges[l][r] * (1.0 - smallPercent) + probability;
                }

            }

        }

        List result = new ArrayList<>();
        // make edges
        for (int l = 0; l < numLeftNodes; l++) {
            for (int r = 0; r < numRightNodes; r++) {
                if (possibleEdges[l][r] != 0.0) {
                    result.add(new HaplotypeEdge(leftNodes.get(l), rightNodes.get(r), possibleEdges[l][r]));
                }
            }
        }

        return result;

    }

    public static void compareEdges(List edges1, List edges2) {

        if (edges1.size() != edges2.size()) {
            System.out.println("edges1 size: " + edges1.size() + "  edges2 size: " + edges2.size());
        }

        Map, HaplotypeEdge> edges1Map = new HashMap<>();
        for (HaplotypeEdge edge : edges1) {
            edges1Map.put(new Tuple<>(edge.leftHapNode(), edge.rightHapNode()), edge);
        }

        for (HaplotypeEdge edge : edges2) {
            HaplotypeEdge edge2 = edges1Map.get(new Tuple<>(edge.leftHapNode(), edge.rightHapNode()));
            if (edge2 == null) {
                System.out.println("no edge1 for edge2: " + edge);
            } else if (Math.abs(edge.edgeProbability() - edge2.edgeProbability()) > 0.00001) {
                System.out.println("edge probability differ by: " + (edge.edgeProbability() - edge2.edgeProbability()));
            }

        }

    }

    /**
     * Returns the line name of the reference genotype
     *
     * @param database
     *
     * @return reference line name
     */
    public static String getRefLineName(Connection database) {
        // Create method name for querying initial ref region and inter-region ref_range_group method ids
        String methodQuery = "select line_name from genotypes where is_reference=1";
        String refLine = null;
        try (ResultSet rs = database.createStatement().executeQuery(methodQuery)) {
            if (rs.next()) {
                refLine = rs.getString("line_name");
            } else {
                throw new IllegalArgumentException("CreateGraphUtils: getRefLineName: genotypes table has no line marked as reference");
            }
            if (rs.next()) {
                throw new IllegalArgumentException("CreateGraphUtils: getRefLineName: genotypes table has multiple lines marked as reference ");
            }
        } catch (Exception exc) {
            throw new IllegalStateException("CreateGraphUtils: getRelLineName: db failure getting reference line name " + exc.getMessage());
        }
        return refLine;
    }

    /**
     * Returns method id (methods.method_id) for given method name.
     *
     * @param database database connection
     * @param method_name method name
     *
     * @return method id
     */
    public static int methodId(Connection database, String method_name) {

        String query = "SELECT method_id from methods where name='" + method_name + "'";
        try (ResultSet rs = database.createStatement().executeQuery(query)) {
            if (!rs.next()) {
                throw new IllegalArgumentException("CreateGraphUtils: methodId: no method name " + method_name);
            }
            int methodid = rs.getInt("method_id");
            if (rs.next()) {
                throw new IllegalArgumentException("CreateGraphUtils: methodId: method table has multiple  matchs for: " + method_name);
            }
            return methodid;
        } catch (Exception exc) {
            myLogger.debug(exc.getMessage(), exc);
            throw new IllegalArgumentException("CreateGraphUtils: methodId: Problem getting id for method: " + method_name + "\n" + exc.getMessage());
        }

    }

    /**
     * Create graph that's a subset of the given graph which contains only nodes from the taxa list.
     *
     * @param taxa taxa list
     * @param graph original graph
     *
     * @return subset graph
     */
    public static HaplotypeGraph subsetGraph(HaplotypeGraph graph, TaxaList taxa) {

        TreeMap> resultNodes = new TreeMap<>();

        graph.referenceRangeStream().forEach(range -> {

            TaxaListBuilder builder = new TaxaListBuilder();
            for (HaplotypeNode node : graph.nodes(range)) {

                for (Taxon taxon : taxa) {
                    if (node.taxaList().contains(taxon)) {
                        builder.add(taxon);
                    }
                }

                if (builder.numberOfTaxa() != 0) {
                    HaplotypeNode newNode = new HaplotypeNode(node.haplotypeSequence(), builder.build(), node.id(), node.asmContig(), node.asmStart(), node.asmEnd(), node.asmStrand(), node.genomeFileID(), node.gvcfFileID());
                    List temp = resultNodes.get(newNode.referenceRange());
                    if (temp == null) {
                        temp = new ArrayList<>();
                        resultNodes.put(newNode.referenceRange(), temp);
                    }
                    temp.add(newNode);
                    builder = new TaxaListBuilder();
                }

            }

        });

        List edges = CreateGraphUtils.createEdges(resultNodes);
        return new HaplotypeGraph(edges);

    }

    /**
     * Removes reference ranges from given graph that represent less than given minimum percent of total taxa.
     *
     * @param graph graph
     * @param minPercentTaxa minimum percent taxa
     *
     * @return new graph
     */
    public static HaplotypeGraph removeRefRanges(HaplotypeGraph graph, double minPercentTaxa) {

        int totalTaxa = graph.totalNumberTaxa();

        TreeMap> resultNodes = graph.referenceRangeStream()
                .filter(range -> {
                    int numTaxa = graph.numberTaxa(range);
                    double percentRepresented = (double) numTaxa / (double) totalTaxa;
                    return percentRepresented >= minPercentTaxa;
                }).collect(Collectors.toMap(range -> range, range -> graph.nodes(range), (haplotypeNodes, haplotypeNodes2) -> {
                    throw new IllegalStateException("should be no merging");
                }, () -> new TreeMap<>()));

        List edges = CreateGraphUtils.createEdges(resultNodes);
        return new HaplotypeGraph(edges);

    }

    /**
     * Removes reference ranges from given graph that represent less than given minimum number of taxa.
     *
     * @param graph graph
     * @param minCountTaxa minimum number of taxa
     *
     * @return new graph
     */
    public static HaplotypeGraph removeRefRanges(HaplotypeGraph graph, int minCountTaxa) {

        TreeMap> resultNodes = new TreeMap<>();

        graph.referenceRanges().forEach(range -> {
            int numTaxa = graph.numberTaxa(range);
            if (numTaxa >= minCountTaxa) {
                resultNodes.put(range, graph.nodes(range));
            }
        });

        List edges = CreateGraphUtils.createEdges(resultNodes);
        return new HaplotypeGraph(edges);

    }

    /**
     * Removes specified reference ranges from graph.
     *
     * @param graph graph
     * @param ranges reference ranges to remove
     *
     * @return new graph
     */
    public static HaplotypeGraph removeRefRanges(HaplotypeGraph graph, List ranges) {

        TreeMap> resultNodes = new TreeMap<>();

        graph.referenceRangeStream().forEach(range -> {
            if (!ranges.contains(range)) {
                resultNodes.put(range, graph.nodes(range));
            }
        });

        List edges = CreateGraphUtils.createEdges(resultNodes);
        return new HaplotypeGraph(edges);

    }

    /**
     * Creates graph that includes specified reference ranges.
     *
     * @param graph graph
     * @param ranges ranges to keep
     *
     * @return new graph
     */
    public static HaplotypeGraph keepRefRanges(HaplotypeGraph graph, List ranges) {

        TreeMap> resultNodes = new TreeMap<>();

        for (ReferenceRange current : ranges) {
            List nodes = graph.nodes(current);
            if (nodes != null) {
                resultNodes.put(current, nodes);
            }
        }

        List edges = CreateGraphUtils.createEdges(resultNodes);
        return new HaplotypeGraph(edges);

    }

    public static HaplotypeGraph keepRefRangeIDs(HaplotypeGraph graph, List rangeIDs) {

        HashMap rangeMap = new HashMap<>();
        graph.referenceRangeStream().forEach(range -> rangeMap.put(range.id(), range));

        List ranges = new ArrayList<>();
        rangeIDs.stream().forEach(id -> {
            ReferenceRange rangeObj = rangeMap.get(id);
            if (rangeObj != null) ranges.add(rangeObj);
        });

        return keepRefRanges(graph, ranges);

    }

    /**
     * Filters the given graph to keep only the specified haplotype ids.
     *
     * @param graph input graph
     * @param hapids list of haplotype ids to keep
     *
     * @return resulting graph
     */
    public static HaplotypeGraph keepHapIDs(HaplotypeGraph graph, SortedSet hapids) {

        TreeMap> result = new TreeMap<>();

        graph.referenceRangeStream().forEach(range -> {
            List currentNodes = new ArrayList<>();
            graph.nodes(range).forEach(node -> {
                if (hapids.contains(node.id())) {
                    currentNodes.add(node);
                }
            });
            if (!currentNodes.isEmpty()) result.put(range, currentNodes);
        });

        return new HaplotypeGraph(result);

    }

}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy