org.deeplearning4j.spark.api.stats.SparkTrainingStats Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.spark.api.stats;
import org.apache.spark.SparkContext;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingWorkerStats;
import org.deeplearning4j.spark.stats.EventStats;
import java.io.IOException;
import java.io.Serializable;
import java.util.List;
import java.util.Set;
public interface SparkTrainingStats extends Serializable {
/**
* Default indentation for {@link #statsAsString()}
*/
int PRINT_INDENT = 55;
/**
* Default formatter used for {@link #statsAsString()}
*/
String DEFAULT_PRINT_FORMAT = "%-" + PRINT_INDENT + "s";
/**
* @return Set of keys that can be used with {@link #getValue(String)}
*/
Set getKeySet();
/**
* Get the statistic value for this key
*
* @param key Key to get the value for
* @return Statistic for this key, or an exception if key is invalid
*/
List getValue(String key);
/**
* Return a short (display) name for the given key.
*
* @param key Key
* @return Short/display name for key
*/
String getShortNameForKey(String key);
/**
* When plotting statistics, we don't necessarily want to plot everything.
* For example, some statistics/measurements are made up multiple smaller components; it does not always make sense
* to plot both the larger stat, and the components that make it up
*
* @param key Key to check for default plotting behaviour
* @return Whether the specified key should be included in plots by default or not
*/
boolean defaultIncludeInPlots(String key);
/**
* Combine the two training stats instances. Usually, the two objects must be of the same type
*
* @param other Other training stats to return
*/
void addOtherTrainingStats(SparkTrainingStats other);
/**
* Return the nested training stats - if any.
*
* @return The nested stats, if present/applicable, or null otherwise
*/
SparkTrainingStats getNestedTrainingStats();
/**
* Get a String representation of the stats. This functionality is implemented as a separate method (as opposed to toString())
* as the resulting String can be very large.
*
* NOTE: The String representation typically includes only duration information. To get full statistics (including
* machine IDs, etc) use {@link #getValue(String)} or export full data via {@link #exportStatFiles(String, SparkContext)}
*
* @return A String representation of the training statistics
*/
String statsAsString();
/**
* Export the stats as a collection of files. Stats are comma-delimited (CSV) with 1 header line
*
* @param outputPath Base directory to write files to
*/
void exportStatFiles(String outputPath, SparkContext sc) throws IOException;
}