All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.datavec.api.transform.analysis.DataAnalysis Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.datavec.api.transform.analysis;

import lombok.AllArgsConstructor;
import lombok.Data;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
import org.datavec.api.transform.metadata.CategoricalMetaData;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.JsonMappers;
import org.datavec.api.transform.serde.JsonSerializer;
import org.datavec.api.transform.serde.YamlSerializer;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import org.nd4j.shade.jackson.databind.node.ArrayNode;

import java.io.IOException;
import java.io.Serializable;
import java.util.*;

@AllArgsConstructor
@Data
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public class DataAnalysis implements Serializable {
    private static final String COL_NAME = "columnName";
    private static final String COL_IDX = "columnIndex";
    private static final String COL_TYPE = "columnType";
    private static final String CATEGORICAL_STATE_NAMES = "stateNames";
    private static final String ANALYSIS = "analysis";
    private static final String DATA_ANALYSIS = "DataAnalysis";

    private Schema schema;
    private List columnAnalysis;

    protected DataAnalysis(){
        //No arg for JSON
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        int nCol = schema.numColumns();

        int maxNameLength = 0;
        for (String s : schema.getColumnNames()) {
            maxNameLength = Math.max(maxNameLength, s.length());
        }

        //Header:
        sb.append(String.format("%-6s", "idx")).append(String.format("%-" + (maxNameLength + 8) + "s", "name"))
                        .append(String.format("%-15s", "type")).append("analysis").append("\n");

        for (int i = 0; i < nCol; i++) {
            String colName = schema.getName(i);
            ColumnType type = schema.getType(i);
            ColumnAnalysis analysis = columnAnalysis.get(i);
            String paddedName = String.format("%-" + (maxNameLength + 8) + "s", "\"" + colName + "\"");
            sb.append(String.format("%-6d", i)).append(paddedName).append(String.format("%-15s", type)).append(analysis)
                            .append("\n");
        }

        return sb.toString();
    }

    public ColumnAnalysis getColumnAnalysis(String column) {
        return columnAnalysis.get(schema.getIndexOfColumn(column));
    }

    /**
     * Convert the DataAnalysis object to JSON format
     */
    public String toJson() {
        try{
            return new JsonSerializer().getObjectMapper().writeValueAsString(this);
        } catch (IOException e){
            throw new RuntimeException(e);
        }
    }

    /**
     * Convert the DataAnalysis object to YAML format
     */
    public String toYaml() {
        try{
            return new YamlSerializer().getObjectMapper().writeValueAsString(this);
        } catch (IOException e){
            throw new RuntimeException(e);
        }
    }

    /**
     * Deserialize a JSON DataAnalysis String that was previously serialized with {@link #toJson()}
     */
    public static DataAnalysis fromJson(String json) {
        try{
            return new JsonSerializer().getObjectMapper().readValue(json, DataAnalysis.class);
        } catch (InvalidTypeIdException e){
            if(e.getMessage().contains("@class")){
                try{
                    //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
                    return JsonMappers.getLegacyMapper().readValue(json, DataAnalysis.class);
                } catch (IOException e2){
                    throw new RuntimeException(e2);
                }
            }
            throw new RuntimeException(e);
        } catch (Exception e){
            //Legacy format
            ObjectMapper om = new JsonSerializer().getObjectMapper();
            return fromMapper(om, json);
        }
    }

    /**
     * Deserialize a YAML DataAnalysis String that was previously serialized with {@link #toYaml()}
     */
    public static DataAnalysis fromYaml(String yaml) {
        try{
            return new YamlSerializer().getObjectMapper().readValue(yaml, DataAnalysis.class);
        } catch (Exception e){
            //Legacy format
            ObjectMapper om = new YamlSerializer().getObjectMapper();
            return fromMapper(om, yaml);
        }
    }

    private static DataAnalysis fromMapper(ObjectMapper om, String json) {

        List meta = new ArrayList<>();
        List analysis = new ArrayList<>();
        try {
            JsonNode node = om.readTree(json);
            Iterator fieldNames = node.fieldNames();
            boolean hasDataAnalysis = false;
            while (fieldNames.hasNext()) {
                if ("DataAnalysis".equals(fieldNames.next())) {
                    hasDataAnalysis = true;
                    break;
                }
            }
            if (!hasDataAnalysis) {
                throw new RuntimeException();
            }

            ArrayNode arrayNode = (ArrayNode) node.get("DataAnalysis");
            for (int i = 0; i < arrayNode.size(); i++) {
                JsonNode analysisNode = arrayNode.get(i);
                String name = analysisNode.get(COL_NAME).asText();
                int idx = analysisNode.get(COL_IDX).asInt();
                ColumnType type = ColumnType.valueOf(analysisNode.get(COL_TYPE).asText());

                JsonNode daNode = analysisNode.get(ANALYSIS);
                ColumnAnalysis dataAnalysis = om.treeToValue(daNode, ColumnAnalysis.class);

                if (type == ColumnType.Categorical) {
                    ArrayNode an = (ArrayNode) analysisNode.get(CATEGORICAL_STATE_NAMES);
                    List stateNames = new ArrayList<>(an.size());
                    Iterator iter = an.elements();
                    while (iter.hasNext()) {
                        stateNames.add(iter.next().asText());
                    }
                    meta.add(new CategoricalMetaData(name, stateNames));
                } else {
                    meta.add(type.newColumnMetaData(name));
                }

                analysis.add(dataAnalysis);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }

        Schema schema = new Schema(meta);
        return new DataAnalysis(schema, analysis);
    }

    @Deprecated //Legacy format, no longer used
    private Map>> getJsonRepresentation() {
        Map>> jsonRepresentation = new LinkedHashMap<>();
        List> list = new ArrayList<>();
        jsonRepresentation.put("DataAnalysis", list);

        for (String colName : schema.getColumnNames()) {
            Map current = new LinkedHashMap<>();
            int idx = schema.getIndexOfColumn(colName);
            current.put(COL_NAME, colName);
            current.put(COL_IDX, idx);
            ColumnType columnType = schema.getMetaData(colName).getColumnType();
            current.put(COL_TYPE, columnType);
            if (columnType == ColumnType.Categorical) {
                current.put(CATEGORICAL_STATE_NAMES,
                                ((CategoricalMetaData) schema.getMetaData(colName)).getStateNames());
            }
            current.put(ANALYSIS, Collections.singletonMap(columnAnalysis.get(idx).getClass().getSimpleName(),
                            columnAnalysis.get(idx)));

            list.add(current);
        }

        return jsonRepresentation;
    }

    private String toJson(Map>> jsonRepresentation) {
        ObjectMapper om = new JsonSerializer().getObjectMapper();
        try {
            return om.writeValueAsString(jsonRepresentation);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private String toYaml(Map>> jsonRepresentation) {
        ObjectMapper om = new YamlSerializer().getObjectMapper();
        try {
            return om.writeValueAsString(jsonRepresentation);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy