All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chen0040.glm.data.BasicDataFrame Maven / Gradle / Ivy

package com.github.chen0040.glm.data;


import com.github.chen0040.glm.utils.CollectionUtils;
import com.github.chen0040.glm.utils.TupleTwo;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;


/**
 * Created by xschen on 1/5/2017.
 */
public class BasicDataFrame implements DataFrame {

   private final List rows = new ArrayList<>();
   private final List inputDataColumns = new ArrayList<>();
   private final List outputDataColumns = new ArrayList<>();
   private boolean locked = false;

   @Override public int rowCount() {
      return rows.size();
   }


   @Override public DataRow row(int i) {
      return rows.get(i);
   }


   @Override public List getInputColumns() {
      return inputDataColumns;
   }


   @Override public List getOutputColumns() {
      return outputDataColumns;
   }


   @Override public void unlock(){
      locked = false;
   }

   @Override
   public boolean isLocked() {
      return locked;
   }


   @Override public void lock() {
      Map> counts = new HashMap<>();
      Set numericOutputs = new HashSet<>();
      Set categoricalOutputs = new HashSet<>();
      for(DataRow row : rows){
         List keys = row.getColumnNames();
         for(String key: keys) {
            Set set;

            if(counts.containsKey(key)){
               set = counts.get(key);
            } else {
               set = new HashSet<>();
               counts.put(key, set);
            }

            set.add(row.getCell(key));
         }
         numericOutputs.addAll(row.getTargetColumnNames());
         categoricalOutputs.addAll(row.getCategoricalTargetColumnNames());
      }

      inputDataColumns.clear();
      for(Map.Entry> entry : counts.entrySet()){
         Set set = entry.getValue();
         InputDataColumn inputDataColumn = new InputDataColumn();
         inputDataColumn.setColumnName(entry.getKey());
         if(set.size() < rowCount() / 3) {
            inputDataColumn.setLevels(set);
         }
         inputDataColumns.add(inputDataColumn);
      }

      outputDataColumns.clear();
      outputDataColumns.addAll(numericOutputs.stream().map(o -> new OutputDataColumn(o, false)).collect(Collectors.toList()));
      outputDataColumns.addAll(categoricalOutputs.stream().map(o -> new OutputDataColumn(o, true)).collect(Collectors.toList()));

      inputDataColumns.sort((a, b) -> a.getColumnName().compareTo(b.getColumnName()));
      outputDataColumns.sort((a, b) -> a.getColumnName().compareTo(b.getColumnName()));

      List inputColumns = inputDataColumns.stream().map(InputDataColumn::getColumnName).collect(Collectors.toList());
      List numericOutputColumns = outputDataColumns.stream().filter(c -> !c.isCategorical()).map(OutputDataColumn::getColumnName).collect(Collectors.toList());
      List categoricalOutputColumns = outputDataColumns.stream().filter(OutputDataColumn::isCategorical).map(OutputDataColumn::getColumnName).collect(Collectors.toList());

      inputColumns.sort(String::compareTo);
      numericOutputColumns.sort(String::compareTo);
      categoricalOutputColumns.sort(String::compareTo);

      for(int i=0; i < rowCount(); ++i) {
         DataRow row = row(i);
         row.setColumnNames(inputColumns);
         row.setTargetColumnNames(numericOutputColumns);
         row.setCategoricalTargetColumnNames(categoricalOutputColumns);
      }

      locked = true;
   }


   @Override public DataRow newRow() {
      return new BasicDataRow();
   }


   @Override public void addRow(DataRow row) {
      if(locked) {
         throw new RuntimeException("Data frame is currently locked, please unlock first");
      }
      rows.add(row);
   }


   @Override public String head(int limit) {
      StringBuilder sb = new StringBuilder();
      int max = Math.min(limit, rowCount());
      for(int i=0; i < max; ++i) {
         if(i != 0){
            sb.append("\n");
         }
         sb.append(row(i));
      }
      return sb.toString();
   }


   @Override public DataFrame shuffle() {
      Random random = new Random(System.currentTimeMillis());
      for(int i=1; i < rows.size(); ++i) {
         int j = random.nextInt(i+1);
         CollectionUtils.exchange(rows, i, j);
      }
      return this;
   }


   @Override public TupleTwo split(double ratio) {
      assert this.locked;

      BasicDataFrame frame1 = new BasicDataFrame();
      BasicDataFrame frame2 = new BasicDataFrame();

      frame1.inputDataColumns.addAll(inputDataColumns.stream().map(InputDataColumn::makeCopy).collect(Collectors.toList()));
      frame2.inputDataColumns.addAll(inputDataColumns.stream().map(InputDataColumn::makeCopy).collect(Collectors.toList()));

      frame1.outputDataColumns.addAll(outputDataColumns.stream().map(OutputDataColumn::makeCopy).collect(Collectors.toList()));
      frame2.outputDataColumns.addAll(outputDataColumns.stream().map(OutputDataColumn::makeCopy).collect(Collectors.toList()));

      int split = (int)(rows.size() * ratio);
      for(int i=0; i < split; ++i) {
         frame1.addRow(rows.get(i).makeCopy());
      }
      for(int i=split; i < rows.size(); ++i){
         frame2.addRow(rows.get(i).makeCopy());
      }

      return new TupleTwo<>(frame1, frame2);

   }


   @Override public Stream stream() {
      return rows.stream();
   }


   @Override public Iterator iterator() {
      return rows.iterator();
   }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy