All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetradapp.editor.DirichletBayesImNodeCountsTable Maven / Gradle / Ivy

There is a newer version: 7.6.6
Show newest version
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard        //
// Scheines, Joseph Ramsey, and Clark Glymour.                               //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////

package edu.cmu.tetradapp.editor;

import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.DirichletBayesIm;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.JOptionUtils;
import edu.cmu.tetrad.util.NumberFormatUtil;
import org.apache.commons.math3.util.FastMath;
import org.jetbrains.annotations.NotNull;

import javax.swing.*;
import javax.swing.table.AbstractTableModel;
import javax.swing.table.TableCellEditor;
import javax.swing.table.TableColumn;
import javax.swing.table.TableModel;
import java.awt.*;
import java.awt.event.FocusAdapter;
import java.awt.event.FocusEvent;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.beans.PropertyChangeListener;
import java.beans.PropertyChangeSupport;

/**
 * This is the JTable which displays the getModel parameter set (an Model).
 *
 * @author josephramsey
 */
class DirichletBayesImNodeCountsTable extends JTable {
    private int focusRow;
    private int focusCol;
    private int lastX;
    private int lastY;

    /**
     * Constructs a new editing table from a given editing table model.
     */
    public DirichletBayesImNodeCountsTable(Node node,
                                           DirichletBayesIm dirichletBayesIm) {
        if (node == null) {
            throw new NullPointerException();
        }

        if (dirichletBayesIm == null) {
            throw new NullPointerException();
        }

        if (dirichletBayesIm.getNodeIndex(node) < 0) {
            throw new IllegalArgumentException("Node " + node +
                    " is not a node" + " in this DirichletBayesIm.");
        }

        resetModel(node, dirichletBayesIm);

        setDefaultEditor(Number.class,
                new NumberCellEditor(NumberFormatUtil.getInstance().getNumberFormat()));
        setDefaultRenderer(Number.class,
                new NumberCellRenderer(NumberFormatUtil.getInstance().getNumberFormat()));
        getTableHeader().setReorderingAllowed(false);
        getTableHeader().setResizingAllowed(true);
        setAutoResizeMode(JTable.AUTO_RESIZE_OFF);
        setCellSelectionEnabled(true);

        ListSelectionModel rowSelectionModel = getSelectionModel();

        rowSelectionModel.addListSelectionListener(e -> {
            ListSelectionModel m = (ListSelectionModel) (e.getSource());
            setFocusRow(m.getAnchorSelectionIndex());
        });

        ListSelectionModel columnSelectionModel = getColumnModel()
                .getSelectionModel();

        columnSelectionModel.addListSelectionListener(
                e -> {
                    ListSelectionModel m =
                            (ListSelectionModel) (e.getSource());
                    setFocusColumn(m.getAnchorSelectionIndex());
                });

        addMouseListener(new MouseAdapter() {
            public void mousePressed(MouseEvent e) {
                if (SwingUtilities.isRightMouseButton(e) || e.isControlDown()) {
                    showPopup(e);
                }
            }
        });

        addFocusListener(new FocusAdapter() {
            public void focusGained(FocusEvent e) {
                ((Model) getModel()).fireTableDataChanged();
            }

            public void focusLost(FocusEvent e) {
                ((Model) getModel()).fireTableDataChanged();
            }
        });

        setFocusRow(0);
        setFocusColumn(0);
    }

    private void resetModel(Node node, DirichletBayesIm dirichletBayesIm) {
        Model model = new Model(node, dirichletBayesIm, this);
        model.addPropertyChangeListener(evt -> {
            if ("editorValueChanged".equals(evt.getPropertyName())) {
                firePropertyChange("editorValueChanged", null, null);
            }
        });
        setModel(model);
    }

    public void createDefaultColumnsFromModel() {
        super.createDefaultColumnsFromModel();

        if (getModel() instanceof Model) {
            FontMetrics fontMetrics = getFontMetrics(getFont());
            Model model = (Model) getModel();

            for (int i = 0; i < model.getColumnCount(); i++) {
                TableColumn column = getColumnModel().getColumn(i);
                String columnName = model.getColumnName(i);
                int currentWidth = column.getPreferredWidth();

                if (columnName != null) {
                    int minimumWidth = fontMetrics.stringWidth(columnName) + 8;

                    if (minimumWidth > currentWidth) {
                        column.setPreferredWidth(minimumWidth);
                    }
                }
            }
        }
    }

    private void showPopup(MouseEvent e) {
        JPopupMenu popup = new JPopupMenu();

        JMenuItem randomizeRow = new JMenuItem("Randomize this row");
        JMenuItem randomizeIncompleteRows =
                new JMenuItem("Randomize incomplete rows in table");
        JMenuItem randomizeEntireTable =
                new JMenuItem("Randomize entire table");
        JMenuItem randomizeAllTables = new JMenuItem("Randomize all tables");

        JMenuItem clearRow = new JMenuItem("Clear this row");
        JMenuItem clearEntireTable = new JMenuItem("Clear entire table");

        randomizeRow.addActionListener(e12 -> {
            int nodeIndex = getEditingTableModel().getNodeIndex();

            DirichletBayesImNodeCountsTable editingTable =
                    DirichletBayesImNodeCountsTable.this;
            TableCellEditor cellEditor = editingTable.getCellEditor();

            if (cellEditor != null) {
                cellEditor.cancelCellEditing();
            }

            Point point = new Point(getLastX(), getLastY());
            int rowIndex = editingTable.rowAtPoint(point);

            DirichletBayesIm dirichletBayesIm = getDirichletBayesIm();
//                requestRowTotal(dirichletBayesIm);
            dirichletBayesIm.randomizeRow(nodeIndex, rowIndex);

            getEditingTableModel().fireTableDataChanged();
        });

        randomizeIncompleteRows.addActionListener(e1 -> {
            int nodeIndex = getEditingTableModel().getNodeIndex();
            DirichletBayesIm dirichletBayesIm = getDirichletBayesIm();

            if (!existsIncompleteRow(dirichletBayesIm, nodeIndex)) {
                JOptionPane.showMessageDialog(JOptionUtils.centeringComp(),
                        "There are no incomplete rows in this table.");
                return;
            }

            DirichletBayesImNodeCountsTable editingTable =
                    DirichletBayesImNodeCountsTable.this;
            TableCellEditor cellEditor = editingTable.getCellEditor();

            if (cellEditor != null) {
                cellEditor.cancelCellEditing();
            }

//                requestRowTotal(dirichletBayesIm);
            dirichletBayesIm.randomizeIncompleteRows(nodeIndex);
            getEditingTableModel().fireTableDataChanged();
        });

        randomizeEntireTable.addActionListener(e13 -> {
            int nodeIndex = getEditingTableModel().getNodeIndex();
            DirichletBayesIm dirichletBayesIm = getDirichletBayesIm();

            if (existsCompleteRow(dirichletBayesIm, nodeIndex)) {
                int ret = JOptionPane.showConfirmDialog(
                        JOptionUtils.centeringComp(),
                        "This will modify all values in the table. " +
                                "Continue?", "Warning",
                        JOptionPane.YES_NO_OPTION);

                if (ret == JOptionPane.NO_OPTION) {
                    return;
                }
            }

            DirichletBayesImNodeCountsTable editingTable =
                    DirichletBayesImNodeCountsTable.this;
            TableCellEditor cellEditor = editingTable.getCellEditor();

            if (cellEditor != null) {
                cellEditor.cancelCellEditing();
            }

//                requestRowTotal(dirichletBayesIm);
            dirichletBayesIm.randomizeTable(nodeIndex);
            getEditingTableModel().fireTableDataChanged();
        });

        randomizeAllTables.addActionListener(e14 -> {
            int ret = JOptionPane.showConfirmDialog(
                    JOptionUtils.centeringComp(),
                    "This will modify all values in the entire Dirichlet model! " +
                            "Continue?", "Warning",
                    JOptionPane.YES_NO_OPTION);

            if (ret == JOptionPane.NO_OPTION) {
                return;
            }

            DirichletBayesIm dirichletBayesIm = getDirichletBayesIm();
//                requestRowTotal(dirichletBayesIm);

            for (int nodeIndex = 0;
                 nodeIndex < dirichletBayesIm.getNumNodes(); nodeIndex++) {

                DirichletBayesImNodeCountsTable editingTable =
                        DirichletBayesImNodeCountsTable.this;
                TableCellEditor cellEditor = editingTable.getCellEditor();

                if (cellEditor != null) {
                    cellEditor.cancelCellEditing();
                }

                dirichletBayesIm.randomizeTable(nodeIndex);
                getEditingTableModel().fireTableDataChanged();
            }
        });

        clearRow.addActionListener(e15 -> {
            int nodeIndex = getEditingTableModel().getNodeIndex();

            DirichletBayesImNodeCountsTable editingTable =
                    DirichletBayesImNodeCountsTable.this;
            TableCellEditor cellEditor = editingTable.getCellEditor();

            if (cellEditor != null) {
                cellEditor.cancelCellEditing();
            }

            Point point = new Point(getLastX(), getLastY());
            int rowIndex = editingTable.rowAtPoint(point);

            DirichletBayesIm dirichletBayesIm = getDirichletBayesIm();
            dirichletBayesIm.clearRow(nodeIndex, rowIndex);

            getEditingTableModel().fireTableRowsUpdated(rowIndex, rowIndex);
        });

        clearEntireTable.addActionListener(e16 -> {
            int nodeIndex = getEditingTableModel().getNodeIndex();
            DirichletBayesIm dirichletBayesIm = getDirichletBayesIm();

            if (existsCompleteRow(dirichletBayesIm, nodeIndex)) {
                int ret = JOptionPane.showConfirmDialog(
                        JOptionUtils.centeringComp(),
                        "This will delete all values in the table. " +
                                "Continue?", "Warning",
                        JOptionPane.YES_NO_OPTION);

                if (ret == JOptionPane.NO_OPTION) {
                    return;
                }
            }

            DirichletBayesImNodeCountsTable editingTable =
                    DirichletBayesImNodeCountsTable.this;
            TableCellEditor cellEditor = editingTable.getCellEditor();

            if (cellEditor != null) {
                cellEditor.cancelCellEditing();
            }

            dirichletBayesIm.clearTable(nodeIndex);

            getEditingTableModel().fireTableDataChanged();
        });

        popup.add(randomizeRow);
        popup.add(randomizeIncompleteRows);
        popup.add(randomizeEntireTable);
        popup.add(randomizeAllTables);
        popup.addSeparator();
        popup.add(clearRow);
        popup.add(clearEntireTable);

        this.lastX = e.getX();
        this.lastY = e.getY();

        popup.show((Component) e.getSource(), e.getX(), e.getY());
    }

    private boolean existsCompleteRow(DirichletBayesIm dirichletBayesIm,
                                      int nodeIndex) {
        boolean existsCompleteRow = false;

        for (int rowIndex = 0;
             rowIndex < dirichletBayesIm.getNumRows(nodeIndex); rowIndex++) {
            if (!dirichletBayesIm.isIncomplete(nodeIndex, rowIndex)) {
                existsCompleteRow = true;
                break;
            }
        }
        return existsCompleteRow;
    }

    private boolean existsIncompleteRow(DirichletBayesIm dirichletBayesIm,
                                        int nodeIndex) {
        boolean existsCompleteRow = false;

        for (int rowIndex = 0;
             rowIndex < dirichletBayesIm.getNumRows(nodeIndex); rowIndex++) {
            if (dirichletBayesIm.isIncomplete(nodeIndex, rowIndex)) {
                existsCompleteRow = true;
                break;
            }
        }
        return existsCompleteRow;
    }

    public void setModel(@NotNull TableModel model) {
        super.setModel(model);
    }

    /**
     * Sets the focus row to the anchor row currently being selected.
     */
    private void setFocusRow(int row) {
        Model editingTableModel = (Model) getModel();
        int failedRow = editingTableModel.getFailedRow();

        if (failedRow != -1) {
            row = failedRow;
            editingTableModel.resetFailedRow();
        }

        this.focusRow = row;

        if (this.focusCol < getRowCount()) {
            setRowSelectionInterval(this.focusRow, this.focusRow);
            editCellAt(this.focusRow, this.focusCol);
        }
    }

    /**
     * Sets the focus column to the anchor column currently being selected.
     */
    private void setFocusColumn(int col) {
        Model editingTableModel = (Model) getModel();
        int failedCol = editingTableModel.getFailedCol();

        if (failedCol != -1) {
            col = failedCol;
            editingTableModel.resetFailedCol();
        }

        if (col < getNumParents()) {
            col = getNumParents();
        }

        this.focusCol = FastMath.max(col, getNumParents());

        if (this.focusCol >= getNumParents() &&
                this.focusCol < getColumnCount()) {
            setColumnSelectionInterval(this.focusCol, this.focusCol);
            editCellAt(this.focusRow, this.focusCol);
        }
    }

    private int getNumParents() {
        Model editingTableModel = (Model) getModel();
        DirichletBayesIm dirichletBayesIm =
                editingTableModel.getDirichletBayesIm();
        int nodeIndex = editingTableModel.getNodeIndex();
        return dirichletBayesIm.getNumParents(nodeIndex);
    }

    private Model getEditingTableModel() {
        return (Model) getModel();
    }

    private DirichletBayesIm getDirichletBayesIm() {
        return getEditingTableModel().getDirichletBayesIm();
    }

    private int getLastX() {
        return this.lastX;
    }

    private int getLastY() {
        return this.lastY;
    }

    /**
     * The abstract table model containing the parameters to be edited for a given node.  Parameters for a given node N
     * with parents P1, P2, ..., are of the form P(N=v0 | P1=v1, P2=v2, ..., Pn = vn).  The first n columns of this
     * table for each row contains a combination of values for (P1, P2, ... Pn), such as (v0, v1, ..., vn).  If there
     * are m values for N, the next m columns contain numbers in the range [0.0, 1.0] representing conditional
     * probabilities that N takes on that corresponding value given this combination of parent values.  These
     * conditional probabilities may be edited.  As they are being edited for a given row, the only condition is that
     * they be greater than or equal to 0.0.
     *
     * @author josephramsey
     */
    static final class Model extends AbstractTableModel {

        /**
         * The BayesIm being edited.
         */
        private final DirichletBayesIm dirichletBayesIm;

        /**
         * This table can only display conditional probabilities for one node at at time. This is the node.
         */
        private final int nodeIndex;

        private int failedRow = -1;
        private int failedCol = -1;
        private PropertyChangeSupport pcs;

        /**
         * Constructs a new editing table model for a given a node in a given dirichletBayesIm.
         */
        public Model(Node node, DirichletBayesIm dirichletBayesIm,
                     JComponent messageAnchor) {
            if (node == null) {
                throw new NullPointerException("Node must not be null.");
            }

            if (dirichletBayesIm == null) {
                throw new NullPointerException("Bayes IM must not be null.");
            }

            if (messageAnchor == null) {
                throw new NullPointerException(
                        "Message anchor must not be null.");
            }

            this.dirichletBayesIm = dirichletBayesIm;
            this.nodeIndex = dirichletBayesIm.getNodeIndex(node);
        }

        /**
         * @return the name of the given column.
         */
        public String getColumnName(int col) {
            Node node = getDirichletBayesIm().getNode(getNodeIndex());
            int numParents =
                    getDirichletBayesIm().getNumParents(getNodeIndex());
            int numColumns =
                    getDirichletBayesIm().getNumColumns(getNodeIndex());
            int totalsColumn = numParents + numColumns;

            if (col < numParents) {
                int parent =
                        getDirichletBayesIm().getParent(getNodeIndex(), col);
                return getDirichletBayesIm().getNode(parent).getName();
            } else if (col < numParents + numColumns) {
                int valIndex = col - numParents;
                String value = getDirichletBayesIm().getBayesPm().getCategory(
                        node, valIndex);
                return node.getName() + "=" + value;
            } else if (col == totalsColumn) {
                return "TOTAL COUNT";
            } else {
                return null;
            }
        }

        /**
         * @return the number of rows in the table.
         */
        public int getRowCount() {
            return getDirichletBayesIm().getNumRows(getNodeIndex());
        }

        /**
         * @return the total number of columns in the table, which is equal to the number of parents for the node plus
         * the number of values for the node.
         */
        public int getColumnCount() {
            int numParents =
                    getDirichletBayesIm().getNumParents(getNodeIndex());
            int numColumns =
                    getDirichletBayesIm().getNumColumns(getNodeIndex());

            // Add an extra column to display row totals.
            return numParents + numColumns + 1;
        }

        /**
         * @return the value of the table at the given row and column. The type of value returned depends on the column.
         * If there are n parent values and m node values, then the first n columns have String values representing the
         * values of the parent nodes for a particular combination (row) and the next m columns have Double values
         * representing conditional probabilities of node values given parent value combinations.
         */
        public Object getValueAt(int tableRow, int tableCol) {
            int numParents =
                    getDirichletBayesIm().getNumParents(getNodeIndex());
            int numColumns =
                    getDirichletBayesIm().getNumColumns(getNodeIndex());
            int totalsColumn = numParents + numColumns;
            int[] parentVals = getDirichletBayesIm().getParentValues(
                    getNodeIndex(), tableRow);

            if (tableCol < numParents) {
                Node columnNode =
                        getDirichletBayesIm().getNode(getDirichletBayesIm()
                                .getParent(getNodeIndex(), tableCol));
                BayesPm bayesPm = getDirichletBayesIm().getBayesPm();
                return bayesPm.getCategory(columnNode, parentVals[tableCol]);
            } else if (tableCol < numParents + numColumns) {
                int colIndex = tableCol - numParents;
                double value = getDirichletBayesIm().getPseudocount(
                        getNodeIndex(), tableRow, colIndex);

                if (value == -1) {
                    return null;
                } else {
                    return value;
                }
            } else if (tableCol == totalsColumn) {
                return getDirichletBayesIm().getRowPseudocount(getNodeIndex(),
                        tableRow);
            } else {
                return null;
            }
        }

        /**
         * Determines whether a cell is in the column range to allow for editing.
         */
        public boolean isCellEditable(int row, int col) {
            int numParents =
                    getDirichletBayesIm().getNumParents(getNodeIndex());
            int numColumns =
                    getDirichletBayesIm().getNumColumns(getNodeIndex());

            return !(col < numParents) && col < numParents + numColumns + 1;
        }

        /**
         * Sets the value of the cell at (row, col) to 'aValue'.
         */
        public void setValueAt(Object aValue, int row, int col) {
            int numParents =
                    getDirichletBayesIm().getNumParents(getNodeIndex());
            int numColumns =
                    getDirichletBayesIm().getNumColumns(getNodeIndex());

            if (col == numParents + numColumns) {
                setTotal(row, aValue);
            } else {
                setPseudocount(row, col, aValue);
            }
        }

        public int findColumn(String columnName) {
            return super.findColumn(columnName);
        }

        private void setTotal(int row, Object aValue) {
            if ("".equals(aValue) || aValue == null) {
                return;
            }

            try {
                double total = Double.parseDouble((String) aValue);

                if (total < 0.0) {
                    JOptionPane.showMessageDialog(JOptionUtils.centeringComp(),
                            "Total must be nonnegative.");
                    this.failedRow = row;
                } else {
                    double currentTotal = getDirichletBayesIm()
                            .getRowPseudocount(getNodeIndex(), row);
                    double ratio = total / currentTotal;

                    for (int i = 0; i < getDirichletBayesIm().getNumColumns(
                            getNodeIndex()); i++) {
                        double count = getDirichletBayesIm().getPseudocount(
                                getNodeIndex(), row, i);
                        getDirichletBayesIm().setPseudocount(getNodeIndex(),
                                row, i, count * ratio);
                    }

                    fireTableRowsUpdated(row, row);
                    getPcs().firePropertyChange("editorValueChanged", null,
                            null);
                }
            } catch (NumberFormatException e) {
                JOptionPane.showMessageDialog(JOptionUtils.centeringComp(),
                        "Could not interpret '" + aValue + "'");
                this.failedRow = row;
            }
        }

        private void setPseudocount(int row, int col, Object aValue) {
            int numParents =
                    getDirichletBayesIm().getNumParents(getNodeIndex());
            int colIndex = col - numParents;

            if ("".equals(aValue) || aValue == null) {
                return;
            }

            try {
                double pseudocount = Double.parseDouble((String) aValue);

                if (countIsNegative(pseudocount)) {
                    JOptionPane.showMessageDialog(JOptionUtils.centeringComp(),
                            "Counts must be nonnegative.");
                    this.failedRow = row;
                    this.failedCol = col;
                } else {
                    getDirichletBayesIm().setPseudocount(getNodeIndex(), row,
                            colIndex, pseudocount);
                    fireTableRowsUpdated(row, row);
                    getPcs().firePropertyChange("editorValueChanged", null,
                            null);
                }
            } catch (NumberFormatException e) {
                JOptionPane.showMessageDialog(JOptionUtils.centeringComp(),
                        "Could not interpret '" + aValue + "'");
                this.failedRow = row;
                this.failedCol = col;
            }
        }

        public void addPropertyChangeListener(PropertyChangeListener l) {
            getPcs().addPropertyChangeListener(l);
        }

        private PropertyChangeSupport getPcs() {
            if (this.pcs == null) {
                this.pcs = new PropertyChangeSupport(this);
            }
            return this.pcs;
        }

        private boolean countIsNegative(double count) {
            return count < 0.0;
        }

        /**
         * @return the class of the column.
         */
        public Class getColumnClass(int col) {
            boolean isParent =
                    col < getDirichletBayesIm().getNumParents(getNodeIndex());
            return isParent ? Object.class : Number.class;
        }

        public DirichletBayesIm getDirichletBayesIm() {
            return this.dirichletBayesIm;
        }

        public int getNodeIndex() {
            return this.nodeIndex;
        }

        public int getFailedRow() {
            return this.failedRow;
        }

        public int getFailedCol() {
            return this.failedCol;
        }

        public void resetFailedRow() {
            this.failedRow = -1;
        }

        public void resetFailedCol() {
            this.failedCol = -1;
        }
    }
}










© 2015 - 2025 Weber Informatics LLC | Privacy Policy