All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tigergraph.spark.write.TigerGraphDataWriter Maven / Gradle / Ivy

There is a newer version: 0.2.1
Show newest version
/**
 * Copyright (c) 2023 TigerGraph Inc.
 *
 * 

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file * except in compliance with the License. You may obtain a copy of the License at * *

http://www.apache.org/licenses/LICENSE-2.0 * *

Unless required by applicable law or agreed to in writing, software distributed under the * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either * express or implied. See the License for the specific language governing permissions and * limitations under the License. */ package com.tigergraph.spark.write; import com.tigergraph.spark.TigerGraphConnection; import com.tigergraph.spark.client.Write; import com.tigergraph.spark.client.Write.LoadingResponse; import com.tigergraph.spark.log.LoggerFactory; import com.tigergraph.spark.util.Options; import com.tigergraph.spark.util.Utils; import java.io.IOException; import java.sql.Date; import java.sql.Timestamp; import java.text.DateFormat; import java.text.SimpleDateFormat; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.TimeZone; import java.util.function.BiFunction; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.write.DataWriter; import org.apache.spark.sql.types.BooleanType; import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DateType; import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.DoubleType; import org.apache.spark.sql.types.FloatType; import org.apache.spark.sql.types.IntegerType; import org.apache.spark.sql.types.LongType; import org.apache.spark.sql.types.ShortType; import org.apache.spark.sql.types.StringType; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.TimestampNTZType; import org.apache.spark.sql.types.TimestampType; import org.slf4j.Logger; /** The data writer of an executor responsible for writing data for an input RDD partition. */ public class TigerGraphDataWriter implements DataWriter { private static final Logger logger = LoggerFactory.getLogger(TigerGraphDataWriter.class); private final StructType schema; private final int partitionId; private final long taskId; private final long epochId; private final Write write; private final String version; private final String jobId; private final String graph; private final String sep; private final String eol; private final int maxBatchSizeInBytes; private final Map queryMap; private final List> converters; private final StringBuilder sb = new StringBuilder(); private int sbOffset = 0; // Metrics private long totalLines = 0; /** For Streaming write */ TigerGraphDataWriter( StructType schema, TigerGraphConnection conn, int partitionId, long taskId, long epochId) { this.schema = schema; this.partitionId = partitionId; this.taskId = taskId; this.epochId = epochId; this.write = conn.getWrite(); this.version = conn.getVersion(); this.jobId = conn.getLoadingJobId(); Options opts = conn.getOpts(); this.graph = opts.getString(Options.GRAPH); this.maxBatchSizeInBytes = opts.getInt(Options.LOADING_BATCH_SIZE_BYTES); this.sep = opts.getString(Options.LOADING_SEPARATOR); this.eol = opts.getString(Options.LOADING_EOL); queryMap = new HashMap<>(); queryMap.put("tag", opts.getString(Options.LOADING_JOB)); queryMap.put("filename", opts.getString(Options.LOADING_FILENAME)); queryMap.put("sep", opts.getString(Options.LOADING_SEPARATOR)); queryMap.put("eol", opts.getString(Options.LOADING_EOL)); queryMap.put("timeout", opts.getInt(Options.LOADING_TIMEOUT_MS)); if (Utils.versionCmp(version, "3.9.4") >= 0) { queryMap.put("jobid", jobId); if (opts.containsOption(Options.LOADING_MAX_NUM_ERROR)) { queryMap.put("max_num_error", opts.getInt(Options.LOADING_MAX_NUM_ERROR)); } if (opts.containsOption(Options.LOADING_MAX_PERCENT_ERROR)) { queryMap.put("max_percent_error", opts.getInt(Options.LOADING_MAX_PERCENT_ERROR)); } } // Set converters for each field according to the data type converters = getConverters(schema); logger.info( "Created data writer for partition {}, task {}, epochId {}", partitionId, taskId, epochId); } /** For Batch write */ TigerGraphDataWriter(StructType schema, TigerGraphConnection conn, int partitionId, long taskId) { this(schema, conn, partitionId, taskId, (long) -1); } @Override public void close() throws IOException { // no-op } @Override public void write(InternalRow record) throws IOException { String line = IntStream.range(0, record.numFields()) .mapToObj(i -> record.isNullAt(i) ? "" : converters.get(i).apply(record, i)) .collect(Collectors.joining(sep)); if (sb.length() + line.length() + eol.length() > maxBatchSizeInBytes) { postToDDL(); } sb.append(line).append(eol); sbOffset++; } private void postToDDL() { LoadingResponse resp = write.ddl(graph, sb.toString(), queryMap); logger.info("Upsert {} rows to TigerGraph graph {}", sbOffset, graph); resp.panicOnFail(); Utils.removeUserData(resp.results); // process stats if (resp.hasInvalidRecord()) { logger.error("Found rejected rows, it won't abort the loading: "); logger.error(resp.results.toPrettyString()); } else { logger.debug(resp.results.toPrettyString()); } totalLines += sbOffset; sbOffset = 0; sb.setLength(0); } @Override public TigerGraphWriterCommitMessage commit() throws IOException { if (sb.length() > 0) { postToDDL(); } logger.info( "Finished writing {} rows to TigerGraph. Partition {}, task {}, epoch {}.", totalLines, partitionId, taskId, epochId); return new TigerGraphWriterCommitMessage(totalLines, partitionId, taskId); } @Override public void abort() throws IOException { logger.error( "Write aborted with {} records loaded. Partition {}, task {}, epoch {}", totalLines, partitionId, taskId, epochId); } /** Pre-generate the converter for each field of the source dataframe's schema */ protected static List> getConverters(StructType schema) { return Stream.of(schema.fields()) .map((f) -> getConverter(f.dataType())) .collect(Collectors.toList()); } /** * Get converter for the specific Spark type for converting it to string, which then can be used * to build delimited data for loading job */ private static BiFunction getConverter(DataType dt) { if (dt instanceof IntegerType) { return (row, idx) -> String.valueOf(row.getInt(idx)); } else if (dt instanceof LongType) { return (row, idx) -> String.valueOf(row.getLong(idx)); } else if (dt instanceof DoubleType) { return (row, idx) -> String.valueOf(row.getDouble(idx)); } else if (dt instanceof FloatType) { return (row, idx) -> String.valueOf(row.getFloat(idx)); } else if (dt instanceof ShortType) { return (row, idx) -> String.valueOf(row.getShort(idx)); } else if (dt instanceof ByteType) { return (row, idx) -> String.valueOf(row.getByte(idx)); } else if (dt instanceof BooleanType) { return (row, idx) -> String.valueOf(row.getBoolean(idx)); } else if (dt instanceof StringType) { return (row, idx) -> row.getString(idx); } else if (dt instanceof TimestampType || dt instanceof TimestampNTZType) { DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS"); // TG DATETIME is essentially a timestamp_ntz, need to format the time in UTC dateFormat.setTimeZone(TimeZone.getTimeZone("UTC")); // Spark timestamp is stored as the number of microseconds return (row, idx) -> dateFormat.format(new Timestamp(row.getLong(idx) / 1000)); } else if (dt instanceof DateType) { DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); // TG DATETIME is essentially a timestamp_ntz, need to format the time in UTC dateFormat.setTimeZone(TimeZone.getTimeZone("UTC")); // Spark date is stored as the number of days counting from 1970-01-01 return (row, idx) -> dateFormat.format(new Date(row.getInt(idx) * 24 * 60 * 60 * 1000L)); } else if (dt instanceof DecimalType) { return (row, idx) -> row.getDecimal(idx, ((DecimalType) dt).precision(), ((DecimalType) dt).scale()) .toString(); } else { throw new UnsupportedOperationException( "Unsupported Spark type: " + dt.typeName() + ", please convert it to a string that matches the LOAD statement defined in the" + " loading job:" + " https://docs.tigergraph.com/gsql-ref/current/ddl-and-loading/creating-a-loading-job#_more_complex_attribute_expressions"); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy