All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.df.data.Dataset Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.df.data;

import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.codehaus.jackson.map.ObjectMapper;
import org.codehaus.jackson.type.TypeReference;

import java.io.IOException;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;

/**
 * Contains information about the attributes.
 */
public class Dataset {

  /**
   * Attributes type
   */
  public enum Attribute {
    IGNORED,
    NUMERICAL,
    CATEGORICAL,
    LABEL;

    public boolean isNumerical() {
      return this == NUMERICAL;
    }

    public boolean isCategorical() {
      return this == CATEGORICAL;
    }

    public boolean isLabel() {
      return this == LABEL;
    }

    public boolean isIgnored() {
      return this == IGNORED;
    }
    
    private static Attribute fromString(String from) {
      Attribute toReturn = LABEL;
      if (NUMERICAL.toString().equalsIgnoreCase(from)) {
        toReturn = NUMERICAL;
      } else if (CATEGORICAL.toString().equalsIgnoreCase(from)) {
        toReturn = CATEGORICAL;
      } else if (IGNORED.toString().equalsIgnoreCase(from)) {
        toReturn = IGNORED;
      }
      return toReturn;
    }
  }

  private Attribute[] attributes;

  /**
   * list of ignored attributes
   */
  private int[] ignored;

  /**
   * distinct values (CATEGORIAL attributes only)
   */
  private String[][] values;

  /**
   * index of the label attribute in the loaded data (without ignored attributed)
   */
  private int labelId;

  /**
   * number of instances in the dataset
   */
  private int nbInstances;
  
  /** JSON serial/de-serial-izer */
  private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();

  // Some literals for JSON representation
  static final String TYPE = "type";
  static final String VALUES = "values";
  static final String LABEL = "label";

  protected Dataset() {}

  /**
   * Should only be called by a DataLoader
   *
   * @param attrs  attributes description
   * @param values distinct values for all CATEGORICAL attributes
   */
  Dataset(Attribute[] attrs, List[] values, int nbInstances, boolean regression) {
    validateValues(attrs, values);

    int nbattrs = countAttributes(attrs);

    // the label values are set apart
    attributes = new Attribute[nbattrs];
    this.values = new String[nbattrs][];
    ignored = new int[attrs.length - nbattrs]; // nbignored = total - nbattrs

    labelId = -1;
    int ignoredId = 0;
    int ind = 0;
    for (int attr = 0; attr < attrs.length; attr++) {
      if (attrs[attr].isIgnored()) {
        ignored[ignoredId++] = attr;
        continue;
      }

      if (attrs[attr].isLabel()) {
        if (labelId != -1) {
          throw new IllegalStateException("Label found more than once");
        }
        labelId = ind;
        if (regression) {
          attrs[attr] = Attribute.NUMERICAL;
        } else {
          attrs[attr] = Attribute.CATEGORICAL;
        }
      }

      if (attrs[attr].isCategorical() || (!regression && attrs[attr].isLabel())) {
        this.values[ind] = new String[values[attr].size()];
        values[attr].toArray(this.values[ind]);
      }

      attributes[ind++] = attrs[attr];
    }

    if (labelId == -1) {
      throw new IllegalStateException("Label not found");
    }

    this.nbInstances = nbInstances;
  }

  public int nbValues(int attr) {
    return values[attr].length;
  }

  public String[] labels() {
    return Arrays.copyOf(values[labelId], nblabels());
  }

  public int nblabels() {
    return values[labelId].length;
  }

  public int getLabelId() {
    return labelId;
  }

  public double getLabel(Instance instance) {
    return instance.get(getLabelId());
  }
  
  public Attribute getAttribute(int attr) {
    return attributes[attr];
  }

  /**
   * Returns the code used to represent the label value in the data
   *
   * @param label label's value to code
   * @return label's code
   */
  public int labelCode(String label) {
    return ArrayUtils.indexOf(values[labelId], label);
  }

  /**
   * Returns the label value in the data
   * This method can be used when the criterion variable is the categorical attribute.
   *
   * @param code label's code
   * @return label's value
   */
  public String getLabelString(double code) {
    // handle the case (prediction is NaN)
    if (Double.isNaN(code)) {
      return "unknown";
    }
    return values[labelId][(int) code];
  }
  
  @Override
  public String toString() {
    return "attributes=" + Arrays.toString(attributes);
  }

  /**
   * Converts a token to its corresponding integer code for a given attribute
   *
   * @param attr attribute index
   */
  public int valueOf(int attr, String token) {
    Preconditions.checkArgument(!isNumerical(attr), "Only for CATEGORICAL attributes");
    Preconditions.checkArgument(values != null, "Values not found (equals null)");
    return ArrayUtils.indexOf(values[attr], token);
  }

  public int[] getIgnored() {
    return ignored;
  }

  /**
   * @return number of attributes that are not IGNORED
   */
  private static int countAttributes(Attribute[] attrs) {
    int nbattrs = 0;
    for (Attribute attr : attrs) {
      if (!attr.isIgnored()) {
        nbattrs++;
      }
    }
    return nbattrs;
  }

  private static void validateValues(Attribute[] attrs, List[] values) {
    Preconditions.checkArgument(attrs.length == values.length, "attrs.length != values.length");
    for (int attr = 0; attr < attrs.length; attr++) {
      Preconditions.checkArgument(!attrs[attr].isCategorical() || values[attr] != null,
          "values not found for attribute " + attr);
    }
  }

  /**
   * @return number of attributes
   */
  public int nbAttributes() {
    return attributes.length;
  }

  /**
   * Is this a numerical attribute ?
   *
   * @param attr index of the attribute to check
   * @return true if the attribute is numerical
   */
  public boolean isNumerical(int attr) {
    return attributes[attr].isNumerical();
  }

  @Override
  public boolean equals(Object obj) {
    if (this == obj) {
      return true;
    }
    if (!(obj instanceof Dataset)) {
      return false;
    }

    Dataset dataset = (Dataset) obj;

    if (!Arrays.equals(attributes, dataset.attributes)) {
      return false;
    }

    for (int attr = 0; attr < nbAttributes(); attr++) {
      if (!Arrays.equals(values[attr], dataset.values[attr])) {
        return false;
      }
    }

    return labelId == dataset.labelId && nbInstances == dataset.nbInstances;
  }

  @Override
  public int hashCode() {
    int hashCode = labelId + 31 * nbInstances;
    for (Attribute attr : attributes) {
      hashCode = 31 * hashCode + attr.hashCode();
    }
    for (String[] valueRow : values) {
      if (valueRow == null) {
        continue;
      }
      for (String value : valueRow) {
        hashCode = 31 * hashCode + value.hashCode();
      }
    }
    return hashCode;
  }

  /**
   * Loads the dataset from a file
   *
   * @throws java.io.IOException
   */
  public static Dataset load(Configuration conf, Path path) throws IOException {
    FileSystem fs = path.getFileSystem(conf);
    long bytesToRead = fs.getFileStatus(path).getLen();
    byte[] buff = new byte[Long.valueOf(bytesToRead).intValue()];
    FSDataInputStream input = fs.open(path);
    try {
      input.readFully(buff);
    } finally {
      Closeables.close(input, true);
    }
    String json = new String(buff, Charset.defaultCharset());
    return fromJSON(json);
  }
  

  /**
   * Serialize this instance to JSON
   * @return some JSON
   */
  public String toJSON() {
    List> toWrite = new LinkedList<>();
    // attributes does not include ignored columns and it does include the class label
    int ignoredCount = 0;
    for (int i = 0; i < attributes.length + ignored.length; i++) {
      Map attribute;
      int attributesIndex = i - ignoredCount;
      if (ignoredCount < ignored.length && i == ignored[ignoredCount]) {
        // fill in ignored atttribute
        attribute = getMap(Attribute.IGNORED, null, false);
        ignoredCount++;
      } else if (attributesIndex == labelId) {
        // fill in the label
        attribute = getMap(attributes[attributesIndex], values[attributesIndex], true);
      } else  {
        // normal attribute
        attribute = getMap(attributes[attributesIndex], values[attributesIndex], false);
      }
      toWrite.add(attribute);
    }
    try {
      return OBJECT_MAPPER.writeValueAsString(toWrite);
    } catch (Exception ex) {
      throw new RuntimeException(ex);
    }
  }

  /**
   * De-serialize an instance from a string
   * @param json From which an instance is created
   * @return A shiny new Dataset
   */
  public static Dataset fromJSON(String json) {
    List> fromJSON;
    try {
      fromJSON = OBJECT_MAPPER.readValue(json, new TypeReference>>() {});
    } catch (Exception ex) {
      throw new RuntimeException(ex);
    }
    List attributes = new LinkedList<>();
    List ignored = new LinkedList<>();
    String[][] nominalValues = new String[fromJSON.size()][];
    Dataset dataset = new Dataset();
    for (int i = 0; i < fromJSON.size(); i++) {
      Map attribute = fromJSON.get(i);
      if (Attribute.fromString((String) attribute.get(TYPE)) == Attribute.IGNORED) {
        ignored.add(i);
      } else {
        Attribute asAttribute = Attribute.fromString((String) attribute.get(TYPE));
        attributes.add(asAttribute);
        if ((Boolean) attribute.get(LABEL)) {
          dataset.labelId = i - ignored.size();
        }
        if (attribute.get(VALUES) != null) {
          List get = (List) attribute.get(VALUES);
          String[] array = get.toArray(new String[get.size()]);
          nominalValues[i - ignored.size()] = array;
        }
      }
    }
    dataset.attributes = attributes.toArray(new Attribute[attributes.size()]);
    dataset.ignored = new int[ignored.size()];
    dataset.values = nominalValues;
    for (int i = 0; i < dataset.ignored.length; i++) {
      dataset.ignored[i] = ignored.get(i);
    }
    return dataset;
  }
  
  /**
   * Generate a map to describe an attribute
   * @param type The type
   * @param values - values
   * @param isLabel - is a label
   * @return map of (AttributeTypes, Values)
   */
  private Map getMap(Attribute type, String[] values, boolean isLabel) {
    Map attribute = new HashMap<>();
    attribute.put(TYPE, type.toString().toLowerCase(Locale.getDefault()));
    attribute.put(VALUES, values);
    attribute.put(LABEL, isLabel);
    return attribute;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy