All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chen0040.glm.data.BasicDataRow Maven / Gradle / Ivy

package com.github.chen0040.glm.data;


import com.github.chen0040.glm.utils.StringUtils;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;


/**
 * Created by xschen on 1/5/2017.
 * A data row consists two types of columns:
 *
 * column: a column represents an input column for which values are numeric
 * target column: a target column represents an output column for which values are numeric
 */
public class BasicDataRow implements DataRow {

   private final Map targets = new HashMap<>();
   private final Map categoricalTargets = new HashMap<>();

   private final Map values = new HashMap<>();

   private final List columns = new ArrayList<>();
   private final List targetColumns = new ArrayList<>();
   private final List categoricalTargetColumns = new ArrayList<>();

   @Override public double target() {
      return getTargetCell(targetColumnName());
   }

   @Override public String categoricalTarget() {
      return getCategoricalTargetCell(categoricalTargetColumnName());
   }

   @Override
   public double getTargetCell(String columnName){
      return targets.getOrDefault(columnName, 0.0);
   }

   @Override
   public String getCategoricalTargetCell(String columnName){
      return categoricalTargets.getOrDefault(columnName, "");
   }

   @Override
   public void setTargetCell(String columnName, double value) {
      if(value == 0.0) {
         targets.remove(columnName);
      }
      targets.put(columnName, value);
   }

   @Override
   public void setCategoricalTargetCell(String columnName, String value) {
      if(StringUtils.isEmpty(value)) {
         categoricalTargets.remove(columnName);
      }
      categoricalTargets.put(columnName, value);
   }

   @Override public void setColumnNames(List inputColumns) {
      columns.clear();
      columns.addAll(inputColumns);
   }


   @Override public void setTargetColumnNames(List outputColumns) {
      targetColumns.clear();
      targetColumns.addAll(outputColumns);
   }

   @Override public void setCategoricalTargetColumnNames(List outputColumns) {
      categoricalTargetColumns.clear();
      categoricalTargetColumns.addAll(outputColumns);
   }

   @Override public DataRow makeCopy() {
      DataRow clone = new BasicDataRow();
      clone.copy(this);
      return clone;
   }


   @Override public void copy(DataRow that) {

      targets.clear();
      categoricalTargets.clear();
      values.clear();
      columns.clear();
      targetColumns.clear();
      categoricalTargetColumns.clear();

      for(String c : that.getTargetColumnNames()){
         targets.put(c, that.getTargetCell(c));
      }

      for(String c : that.getColumnNames()) {
         values.put(c, that.getCell(c));
      }

      for(String c : that.getCategoricalTargetColumnNames()) {
         categoricalTargets.put(c, that.getCategoricalTargetCell(c));
      }

      setColumnNames(that.getColumnNames());
      setTargetColumnNames(that.getTargetColumnNames());
      setCategoricalTargetColumnNames(that.getCategoricalTargetColumnNames());
   }


   @Override public String targetColumnName() {
      return getTargetColumnNames().get(0);
   }

   @Override public String categoricalTargetColumnName() {
      return getCategoricalTargetColumnNames().get(0);
   }


   @Override public double[] toArray() {
      List cols = getColumnNames();

      double[] result = new double[cols.size()];
      for(int i=0; i < cols.size(); ++i) {
         result[i] = getCell(cols.get(i));
      }
      return result;
   }

   private void buildColumns(){
      List cols = values.keySet().stream().collect(Collectors.toList());
      cols.sort(String::compareTo);
      columns.addAll(cols);
   }

   private void buildTargetColumns(){
      List cols = targets.keySet().stream().collect(Collectors.toList());
      cols.sort(String::compareTo);
      targetColumns.addAll(cols);
   }

   private void buildCategoricalTargetColumns(){
      List cols = categoricalTargets.keySet().stream().collect(Collectors.toList());
      cols.sort(String::compareTo);
      categoricalTargetColumns.addAll(cols);
   }

   @Override public void setCell(String columnName, double value) {
      if(value == 0.0) {
         values.remove(columnName);
      }

      values.put(columnName, value);
   }


   @Override public List getColumnNames() {
      if(columns.size() < values.size()) {
         buildColumns();
      }
      return columns;
   }

   @Override
   public List getTargetColumnNames() {
      if(targetColumns.size() < targets.size()){
         buildTargetColumns();
      }
      return targetColumns;
   }

   @Override
   public List getCategoricalTargetColumnNames() {
      if(categoricalTargetColumns.size() < categoricalTargets.size()){
         buildCategoricalTargetColumns();
      }
      return categoricalTargetColumns;
   }

   @Override public double getCell(String key) {
      return values.getOrDefault(key, 0.0);
   }

   @Override
   public String toString(){
      StringBuilder sb = new StringBuilder();
      List keys = getColumnNames();
      for(int i=0; i < keys.size(); ++i){
         if(i != 0){
            sb.append(", ");
         }
         sb.append(keys.get(i)).append(":").append(getCell(keys.get(i)));
      }
      sb.append(" => ");

      keys = getTargetColumnNames();
      for(int i=0; i < keys.size(); ++i){
         if(i != 0){
            sb.append(", ");
         }
         sb.append(keys.get(i)).append(":").append(getTargetCell(keys.get(i)));
      }

      keys = getCategoricalTargetColumnNames();
      for(int i=0; i < keys.size(); ++i){
         if(i != 0){
            sb.append(", ");
         }
         sb.append(keys.get(i)).append(":").append(getCategoricalTargetCell(keys.get(i)));
      }

      return sb.toString();
   }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy