All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.utils.SerdeUtils Maven / Gradle / Ivy

There is a newer version: 3.2.0
Show newest version
/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */


package com.tencent.angel.utils;

import com.tencent.angel.protobuf.generated.WorkerMasterServiceProtos.SplitInfoProto;
import com.tencent.angel.ps.storage.matrix.PSMatrixInit;
import com.tencent.angel.split.SplitClassification;
import com.tencent.angel.split.SplitInfo;
import io.netty.buffer.ByteBuf;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.serializer.Deserializer;
import org.apache.hadoop.io.serializer.SerializationFactory;
import org.apache.hadoop.io.serializer.Serializer;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
 * Serialize/Deserialize tool for training data split.
 */
public class SerdeUtils {
  private final static Log LOG = LogFactory.getLog(SerdeUtils.class);
  private static SerializationFactory factory;

  public static SplitClassification deSerilizeSplitProtos(List splitInfoList,
    Configuration conf) throws ClassNotFoundException, IOException {
    boolean isUseNewAPI = conf.getBoolean("mapred.mapper.new-api", false);
    if (isUseNewAPI) {
      List splitList =
        new ArrayList();
      for (SplitInfoProto splitInfo : splitInfoList) {
        splitList.add(
          deSerilizeNewSplit(splitInfo.getSplitClass(), splitInfo.getSplit().toByteArray(), conf));
      }

      SplitClassification splits = new SplitClassification(null, splitList, true);
      return splits;
    } else {
      List splitList =
        new ArrayList();
      for (SplitInfoProto splitInfo : splitInfoList) {
        splitList.add(
          deSerilizeOldSplit(splitInfo.getSplitClass(), splitInfo.getSplit().toByteArray(), conf));
      }

      SplitClassification splits = new SplitClassification(splitList, null, true);
      return splits;
    }
  }


  @SuppressWarnings({"unchecked", "rawtypes"})
  public static SplitInfo serilizeSplit(org.apache.hadoop.mapreduce.InputSplit split,
    Configuration conf) throws IOException {
    if (factory == null) {
      factory = new SerializationFactory(conf);
    }
    DataOutputBuffer out = new DataOutputBuffer(1024);

    try {
      Serializer serializer = factory.getSerializer(split.getClass());
      serializer.open(out);
      serializer.serialize(split);
      SplitInfo ret = new SplitInfo(split.getClass().getName(), out.getData());
      return ret;
    } finally {
      out.close();
    }
  }

  @SuppressWarnings({"unchecked", "rawtypes"})
  public static SplitInfo serilizeSplit(org.apache.hadoop.mapred.InputSplit split,
    Configuration conf) throws IOException {
    if (factory == null) {
      factory = new SerializationFactory(conf);
    }
    DataOutputBuffer out = new DataOutputBuffer(1024);

    try {
      Serializer serializer = factory.getSerializer(split.getClass());
      serializer.open(out);
      serializer.serialize(split);
      SplitInfo ret = new SplitInfo(split.getClass().getName(), out.getData());
      return ret;
    } finally {
      out.close();
    }
  }

  public static List serilizeSplits(SplitClassification splits, Configuration conf)
    throws IOException {

    List splitInfoList = new ArrayList();
    if (splits.isUseNewAPI()) {
      List splitList = splits.getSplitsNewAPI();
      for (org.apache.hadoop.mapreduce.InputSplit split : splitList) {
        splitInfoList.add(serilizeSplit(split, conf));
      }
    } else {
      List splitList = splits.getSplitsOldAPI();
      for (org.apache.hadoop.mapred.InputSplit split : splitList) {
        splitInfoList.add(serilizeSplit(split, conf));
      }
    }

    return splitInfoList;
  }

  @SuppressWarnings("unchecked")
  public static org.apache.hadoop.mapreduce.InputSplit deSerilizeNewSplit(SplitInfo splitInfo,
    Configuration conf) throws IOException, ClassNotFoundException {
    if (factory == null) {
      factory = new SerializationFactory(conf);
    }

    ByteArrayInputStream in = null;

    try {
      Deserializer deSerializer = factory
        .getDeserializer((Class) Class
          .forName(splitInfo.getSplitClass()));
      in = new ByteArrayInputStream(splitInfo.getSplit());
      deSerializer.open(in);
      return deSerializer.deserialize(null);
    } finally {
      if (in != null) {
        in.close();
      }
    }

  }

  @SuppressWarnings("unchecked")
  public static org.apache.hadoop.mapred.InputSplit deSerilizeOldSplit(SplitInfo splitInfo,
    Configuration conf) throws ClassNotFoundException, IOException {
    if (factory == null) {
      factory = new SerializationFactory(conf);
    }

    ByteArrayInputStream in = null;

    try {
      Deserializer deSerializer = factory
        .getDeserializer((Class) Class
          .forName(splitInfo.getSplitClass()));
      in = new ByteArrayInputStream(splitInfo.getSplit());
      deSerializer.open(in);
      return deSerializer.deserialize(null);
    } finally {
      if (in != null) {
        in.close();
      }
    }
  }

  public static SplitClassification deSerilizeSplits(List splitInfoList,
    Configuration conf) throws ClassNotFoundException, IOException {
    boolean isUseNewAPI = conf.getBoolean("mapred.mapper.new-api", false);
    if (isUseNewAPI) {
      List splitList =
        new ArrayList();
      for (SplitInfo splitInfo : splitInfoList) {
        splitList.add(deSerilizeNewSplit(splitInfo, conf));
      }

      SplitClassification splits = new SplitClassification(null, splitList, true);
      return splits;
    } else {
      List splitList =
        new ArrayList();
      for (SplitInfo splitInfo : splitInfoList) {
        splitList.add(deSerilizeOldSplit(splitInfo, conf));
      }

      SplitClassification splits = new SplitClassification(splitList, null, true);
      return splits;
    }
  }

  @SuppressWarnings("unchecked")
  public static org.apache.hadoop.mapreduce.InputSplit deSerilizeNewSplit(String className,
    byte[] data, Configuration conf) throws IOException, ClassNotFoundException {
    if (factory == null) {
      factory = new SerializationFactory(conf);
    }

    ByteArrayInputStream in = null;

    try {
      Deserializer deSerializer = factory
        .getDeserializer(
          (Class) Class.forName(className));
      in = new ByteArrayInputStream(data);
      deSerializer.open(in);
      return deSerializer.deserialize(null);
    } finally {
      if (in != null) {
        in.close();
      }
    }

  }

  @SuppressWarnings("unchecked")
  public static org.apache.hadoop.mapred.InputSplit deSerilizeOldSplit(String className,
    byte[] data, Configuration conf) throws ClassNotFoundException, IOException {
    if (factory == null) {
      factory = new SerializationFactory(conf);
    }

    ByteArrayInputStream in = null;

    try {
      Deserializer deSerializer = factory
        .getDeserializer(
          (Class) Class.forName(className));
      in = new ByteArrayInputStream(data);
      deSerializer.open(in);
      return deSerializer.deserialize(null);
    } finally {
      if (in != null) {
        in.close();
      }
    }
  }

  public static byte[] serializeInitFunc(PSMatrixInit initFunc) {
    ByteBuf buf = ByteBufUtils.newHeapByteBuf(initFunc.bufferLen());
    String partParamClassName = initFunc.getClass().getName();
    LOG.info("func name=" + partParamClassName);
    byte[] data = partParamClassName.getBytes();
    buf.writeInt(data.length);

    buf.writeBytes(data);
    initFunc.serialize(buf);
    int writeIndex = buf.writerIndex();
    data = new byte[writeIndex];
    buf.readBytes(data);

    return data;
  }

  public static PSMatrixInit deserializeInitFunc(byte[] data) {
    ByteBuf buf = ByteBufUtils.newHeapByteBuf(data.length);
    buf.writeBytes(data);
    int size = buf.readInt();
    byte[] nameData = new byte[size];
    buf.readBytes(nameData);

    String className = new String(nameData);
    LOG.info("func name=" + className);
    PSMatrixInit iniFunc;

    try {
      iniFunc = (PSMatrixInit) Class.forName(className).newInstance();
      iniFunc.deserialize(buf);
    } catch (Throwable e) {
      LOG.error("deserialize init func falied, ", e);
      throw new RuntimeException("deserialize init func falied:" + e.getMessage());
    }

    return iniFunc;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy