All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.ytsaurus.client.operations.SingleUploadFromClassPathJarsProcessor Maven / Gradle / Ivy

The newest version!
package tech.ytsaurus.client.operations;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.net.URI;
import java.nio.file.Files;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.jar.Attributes;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.jar.JarOutputStream;
import java.util.jar.Manifest;
import java.util.stream.Collectors;
import java.util.zip.ZipEntry;

import javax.annotation.Nullable;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import tech.ytsaurus.client.FileWriter;
import tech.ytsaurus.client.TransactionalClient;
import tech.ytsaurus.client.request.CreateNode;
import tech.ytsaurus.client.request.GetFileFromCache;
import tech.ytsaurus.client.request.GetFileFromCacheResult;
import tech.ytsaurus.client.request.ListNode;
import tech.ytsaurus.client.request.MoveNode;
import tech.ytsaurus.client.request.PutFileToCache;
import tech.ytsaurus.client.request.RemoveNode;
import tech.ytsaurus.client.request.WriteFile;
import tech.ytsaurus.core.GUID;
import tech.ytsaurus.core.cypress.CypressNodeType;
import tech.ytsaurus.core.cypress.YPath;
import tech.ytsaurus.lang.NonNullApi;
import tech.ytsaurus.lang.NonNullFields;
import tech.ytsaurus.ysontree.YTree;
import tech.ytsaurus.ysontree.YTreeNode;


/**
 * Default implementation of {@link JarsProcessor}.
 * Upload jars and files to YT if it is necessary.
 */
@NonNullApi
@NonNullFields
public class SingleUploadFromClassPathJarsProcessor implements JarsProcessor {

    private static final Logger LOGGER = LoggerFactory.getLogger(SingleUploadFromClassPathJarsProcessor.class);

    private static final String NATIVE_FILE_EXTENSION = "so";
    protected static final int DEFAULT_JARS_REPLICATION_FACTOR = 10;

    private final YPath jarsDir;
    @Nullable
    protected final YPath cacheDir;
    private final int fileCacheReplicationFactor;

    private final Duration uploadTimeout;
    private final boolean uploadNativeLibraries;
    private final Map uploadedJars = new HashMap<>();
    private final Map> uploadMap = new HashMap<>();
    @Nullable
    private volatile Instant lastUploadTime;

    private static final char[] DIGITS = {
            '0', '1', '2', '3', '4', '5', '6', '7',
            '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'
    };

    public SingleUploadFromClassPathJarsProcessor(YPath jarsDir, @Nullable YPath cacheDir) {
        this(jarsDir, cacheDir, false, Duration.ofMinutes(10), DEFAULT_JARS_REPLICATION_FACTOR);
    }

    public SingleUploadFromClassPathJarsProcessor(YPath jarsDir, @Nullable YPath cacheDir,
                                                  boolean uploadNativeLibraries) {
        this(jarsDir, cacheDir, uploadNativeLibraries, Duration.ofMinutes(10), DEFAULT_JARS_REPLICATION_FACTOR);
    }

    public SingleUploadFromClassPathJarsProcessor(
            YPath jarsDir,
            @Nullable YPath cacheDir,
            boolean uploadNativeLibraries,
            Duration uploadTimeout) {
        this(jarsDir, cacheDir, uploadNativeLibraries, uploadTimeout, DEFAULT_JARS_REPLICATION_FACTOR);
    }

    public SingleUploadFromClassPathJarsProcessor(
            YPath jarsDir,
            @Nullable YPath cacheDir,
            boolean uploadNativeLibraries,
            Duration uploadTimeout,
            @Nullable Integer fileCacheReplicationFactor) {
        this.jarsDir = jarsDir;
        this.cacheDir = cacheDir;
        this.uploadTimeout = uploadTimeout;
        this.uploadNativeLibraries = uploadNativeLibraries;
        this.fileCacheReplicationFactor = fileCacheReplicationFactor != null
                ? fileCacheReplicationFactor
                : DEFAULT_JARS_REPLICATION_FACTOR;
    }

    @Override
    public Set uploadJars(TransactionalClient yt, MapperOrReducer mapperOrReducer, boolean isLocalMode) {
        synchronized (this) {
            try {
                uploadIfNeeded(yt.getRootClient(), isLocalMode);
            } catch (Exception e) {
                throw new RuntimeException(e);
            }

            return Set.copyOf(new HashSet<>(uploadedJars.values()));
        }
    }

    protected void withJar(File jarFile, Consumer consumer) {
        consumer.accept(jarFile);
    }

    protected void withClassPathDir(File classPathItem, byte[] jarBytes, BiConsumer consumer) {
        consumer.accept(classPathItem, jarBytes);
    }

    private boolean isUsingFileCache() {
        return cacheDir != null;
    }

    private void uploadIfNeeded(TransactionalClient yt, boolean isLocalMode) {
        uploadMap.clear();

        yt.createNode(CreateNode.builder()
                .setPath(jarsDir)
                .setType(CypressNodeType.MAP)
                .setRecursive(true)
                .setIgnoreExisting(true)
                .build()).join();

        if (!isUsingFileCache() && lastUploadTime != null && Instant.now()
                .isBefore(Objects.requireNonNull(lastUploadTime).plus(uploadTimeout))) {
            return;
        }

        uploadedJars.clear();

        collectJars(yt);
        if (uploadNativeLibraries) {
            collectNativeLibs();
        }
        doUpload(yt, isLocalMode);
    }

    protected void writeFile(TransactionalClient yt, YPath path, InputStream data) {
        yt.createNode(new CreateNode(path, CypressNodeType.FILE)).join();
        FileWriter writer = yt.writeFile(WriteFile.builder()
                .setPath(path.toString())
                .setComputeMd5(true)
                .build()).join();
        try {
            byte[] bytes = new byte[0x10000];
            for (; ; ) {
                int count = data.read(bytes);
                if (count < 0) {
                    break;
                }

                writer.write(bytes, 0, count);
                writer.readyEvent().join();
            }
            writer.close().join();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private YPath onFileChecked(TransactionalClient yt, @Nullable YPath path, String originalName, String md5,
                                Supplier fileContent) {
        YPath res = path;
        Objects.requireNonNull(cacheDir);

        if (res == null) {
            YPath tmpPath = jarsDir.child(GUID.create().toString());

            LOGGER.info("Uploading {} to cache", originalName);

            writeFile(yt, tmpPath, fileContent.get());

            res = yt.putFileToCache(new PutFileToCache(tmpPath, cacheDir, md5)).join().getPath();
            yt.removeNode(RemoveNode.builder().setPath(tmpPath).setRecursive(false).setForce(true).build()).join();
        }

        res = res
                .plusAdditionalAttribute("file_name", originalName)
                .plusAdditionalAttribute("md5", md5)
                .plusAdditionalAttribute("cache", cacheDir.toTree());

        return res;
    }

    @NonNullFields
    @NonNullApi
    protected static class CacheUploadTask {
        final CompletableFuture> cacheCheckResult;
        final String md5;
        final Map.Entry> entry;
        @Nullable
        Future result;

        public CacheUploadTask(
                CompletableFuture> cacheCheckResult,
                String md5,
                Map.Entry> entry
        ) {
            this.cacheCheckResult = cacheCheckResult;
            this.md5 = md5;
            this.entry = entry;
        }
    }

    protected List checkInCache(TransactionalClient yt, Map> uploadMap) {
        Objects.requireNonNull(cacheDir);
        List tasks = new ArrayList<>();
        for (Map.Entry> entry : uploadMap.entrySet()) {
            String md5 = calculateMd5(entry.getValue().get());
            CompletableFuture> future =
                    yt.getFileFromCache(new GetFileFromCache(cacheDir, md5))
                            .thenApply(GetFileFromCacheResult::getPath);
            tasks.add(new CacheUploadTask(future, md5, entry));
        }
        return tasks;
    }

    private void checkInCacheAndUpload(TransactionalClient yt, Map> uploadMap) {
        List tasks = checkInCache(yt, uploadMap);

        int threadsCount = Math.min(uploadMap.size(), 5);
        ExecutorService executor = Executors.newFixedThreadPool(threadsCount);

        try {
            for (CacheUploadTask task : tasks) {
                task.result = executor.submit(() -> {
                    try {
                        Optional path = task.cacheCheckResult.get();
                        return onFileChecked(yt, path.orElse(null), task.entry.getKey(), task.md5,
                                task.entry.getValue());
                    } catch (Exception ex) {
                        throw new RuntimeException(ex);
                    }
                });
            }

            for (CacheUploadTask task : tasks) {
                try {
                    // N.B. we filled `result` in the loop above.
                    Future result = Objects.requireNonNull(task.result);
                    uploadedJars.put(task.entry.getKey(), result.get());
                } catch (Exception ex) {
                    throw new RuntimeException(ex);
                }
            }
        } finally {
            executor.shutdown();
        }
    }

    static class UploadTask {
        Future result;
        String fileName;

        UploadTask(String fileName) {
            this.result = new CompletableFuture<>();
            this.fileName = fileName;
        }
    }

    private void uploadToTemp(
            TransactionalClient yt,
            Map> uploadMap,
            boolean isLocalMode
    ) {
        int threadsCount = Math.min(uploadMap.size(), 5);
        ExecutorService executor = Executors.newFixedThreadPool(threadsCount);
        try {
            ArrayList uploadTasks = new ArrayList<>();
            for (Map.Entry> entry : uploadMap.entrySet()) {
                String fileName = entry.getKey();
                UploadTask task = new UploadTask(fileName);
                task.result = executor.submit(() -> maybeUpload(yt, entry.getValue(), fileName, isLocalMode));
                uploadTasks.add(task);
            }

            for (UploadTask uploadTask : uploadTasks) {
                try {
                    YPath path = uploadTask.result.get();
                    uploadedJars.put(uploadTask.fileName, path);
                } catch (Exception ex) {
                    throw new RuntimeException(ex);
                }
            }
        } finally {
            executor.shutdown();
        }
    }

    private void doUpload(TransactionalClient yt, boolean isLocalMode) {
        if (uploadMap.isEmpty()) {
            return;
        }

        if (isUsingFileCache()) {
            checkInCacheAndUpload(yt, uploadMap);
        } else {
            uploadToTemp(yt, uploadMap, isLocalMode);
        }

        lastUploadTime = Instant.now();
    }

    private static void walk(File dir, Consumer consumer) {
        consumer.accept(dir);
        File[] files = dir.listFiles();
        if (files == null) {
            return;
        }
        for (File file : files) {
            walk(file, consumer);
        }
    }

    private File getParentFile(File file) {
        File parent = file.getParentFile();
        if (parent != null) {
            return parent;
        } else {
            String path = file.getPath();
            if (!path.contains("/") && !path.contains(".")) {
                return new File(".");
            }
            throw new RuntimeException(this + " has no parent");
        }
    }

    private static String toHex(byte[] data) {
        StringBuilder result = new StringBuilder();
        for (byte b : data) {
            result.append(DIGITS[(0xF0 & b) >>> 4]);
            result.append(DIGITS[0x0F & b]);
        }
        return result.toString();
    }

    protected static String calculateMd5(InputStream stream) {
        try {
            MessageDigest md = MessageDigest.getInstance("MD5");
            byte[] bytes = new byte[0x1000];
            for (; ; ) {
                int len = stream.read(bytes);
                if (len < 0) {
                    break;
                }
                md.update(bytes, 0, len);
            }

            return toHex(md.digest());
        } catch (NoSuchAlgorithmException | IOException ex) {
            throw new RuntimeException(ex);
        }
    }

    private void collectJars(TransactionalClient yt) {
        yt.createNode(CreateNode.builder()
                .setPath(jarsDir)
                .setType(CypressNodeType.MAP)
                .setRecursive(true)
                .setIgnoreExisting(true)
                .build()).join();
        if (isUsingFileCache()) {
            yt.createNode(CreateNode.builder()
                    .setPath(cacheDir)
                    .setType(CypressNodeType.MAP)
                    .setRecursive(true)
                    .setIgnoreExisting(true)
                    .build()
            ).join();
        }

        Set existsJars = yt.listNode(new ListNode(jarsDir)).join().asList().stream()
                .map(YTreeNode::stringValue)
                .collect(Collectors.toSet());

        if (!isUsingFileCache() && !uploadedJars.isEmpty()) {
            if (uploadedJars.values().stream().allMatch(p -> existsJars.contains(p.name()))) {
                return;
            }
        }

        Set classPathParts = getClassPathParts();
        for (String classPathPart : classPathParts) {
            File classPathItem = new File(classPathPart);
            if (fileHasExtension(classPathItem, "jar")) {
                if (!classPathItem.exists()) {
                    throw new IllegalStateException("Can't find " + classPathItem);
                }
                if (classPathItem.isFile()) {
                    withJar(
                            classPathItem,
                            jar -> collectFile(() -> {
                                try {
                                    return new FileInputStream(jar);
                                } catch (FileNotFoundException ex) {
                                    throw new RuntimeException(ex);
                                }
                            }, classPathItem.getName(), existsJars));
                }
            } else if (classPathItem.isDirectory()) {
                byte[] jarBytes = getClassPathDirJarBytes(classPathItem);
                withClassPathDir(
                        classPathItem,
                        jarBytes,
                        (dir, bytes) -> collectFile(() ->
                                new ByteArrayInputStream(bytes), dir.getName() + ".jar", existsJars
                        )
                );
            }
        }
    }

    private static boolean fileHasExtension(File file, String extension) {
        String lowerExtension = "." + extension;
        return file.getName().toLowerCase().endsWith(lowerExtension);
    }

    private void collectNativeLibs() {
        String libPath = System.getProperty("java.library.path");
        if (libPath == null) {
            throw new IllegalStateException("System property 'java.library.path' is null");
        }
        LOGGER.info("Searching native libs in " + libPath);

        String[] classPathParts = libPath.split(File.pathSeparator);
        for (String classPathPart : classPathParts) {
            File classPathItem = new File(classPathPart);
            if (classPathItem.isDirectory()) {
                walk(classPathItem, elm -> {
                    if (elm.isFile() &&
                            !Files.isSymbolicLink(elm.toPath()) &&
                            fileHasExtension(elm, NATIVE_FILE_EXTENSION)
                    ) {
                        withJar(elm, dll -> collectFile(
                                () -> {
                                    try {
                                        return new FileInputStream(dll);
                                    } catch (FileNotFoundException ex) {
                                        throw new RuntimeException(ex);
                                    }
                                },
                                dll.getName(),
                                Collections.emptySet()));
                    }
                });
            }
        }
    }

    /**
     * @return set of classpath files (usually *.jar)
     */
    private Set getClassPathParts() {
        Set classPathParts = new HashSet<>();
        String classPath = System.getProperty("java.class.path");
        if (classPath == null) {
            throw new IllegalStateException("System property 'java.class.path' is null");
        }
        LOGGER.info("Searching libs in " + classPath);

        String[] classPathPartsRaw = classPath.split(File.pathSeparator);

        Attributes.Name classPathKey = new Attributes.Name("Class-Path");
        for (String classPathPart : classPathPartsRaw) {
            classPathParts.add(classPathPart);

            try {
                File jarFile = new File(classPathPart);
                Manifest m = new JarFile(classPathPart).getManifest();
                if (m != null) {
                    Attributes a = m.getMainAttributes();
                    if (a.containsKey(classPathKey)) {
                        String[] fileList = a.getValue(classPathKey).split(" ");

                        for (String entity : fileList) {
                            try {
                                File jarFileChild;
                                if (entity.startsWith("file:")) {
                                    jarFileChild = new File(new URI(entity));
                                } else {
                                    jarFileChild = new File(entity);
                                }
                                if (!jarFileChild.isAbsolute()) {
                                    jarFileChild = new File(getParentFile(jarFile), entity);
                                }

                                if (jarFileChild.exists()) {
                                    classPathParts.add(jarFileChild.getPath());
                                }
                            } catch (Throwable e) {
                                LOGGER.warn("Cannot open : {}", entity, e);
                            }
                        }
                    }
                }
            } catch (IOException ignored) {
            }
        }
        return classPathParts;
    }

    private static String calculateYPath(Supplier fileContent, String originalName) {
        String md5 = calculateMd5(fileContent.get());
        String[] parts = originalName.split("\\.");
        String ext = parts.length < 2 ? "" : parts[parts.length - 1];

        return md5 + "." + ext;
    }

    private void collectFile(Supplier fileContent, String originalName, Set existsFiles) {
        String fileName = calculateYPath(fileContent, originalName);
        boolean exists = existsFiles.contains(fileName);
        if (isUsingFileCache() || !exists) {
            if (!uploadMap.containsKey(originalName)) {
                uploadMap.put(originalName, fileContent);
            } else if (originalName.endsWith(".jar")) {
                String baseName = originalName.split("\\.")[0];
                uploadMap.put(baseName + "-" + calculateMd5(fileContent.get()) + ".jar", fileContent);
            }
        }
        if (!isUsingFileCache() && exists) {
            uploadedJars.put(originalName, jarsDir.child(fileName));
        }
    }

    private YPath maybeUpload(
            TransactionalClient yt,
            Supplier fileContent,
            String originalName,
            boolean isLocalMode
    ) {
        String md5 = calculateMd5(fileContent.get());
        YPath jarPath;
        if (originalName.endsWith(NATIVE_FILE_EXTENSION)) {
            // TODO: do we really need this?
            YPath dllDir = jarsDir.child(md5);
            yt.createNode(CreateNode.builder()
                    .setPath(dllDir)
                    .setType(CypressNodeType.MAP)
                    .setRecursive(true)
                    .setIgnoreExisting(true)
                    .build()).join();
            jarPath = dllDir.child(originalName);
        } else {
            jarPath = jarsDir.child(calculateYPath(fileContent, originalName));
        }

        YPath tmpPath = jarsDir.child(GUID.create().toString());

        LOGGER.info("Uploading {} as {} using tmpPath {}", originalName, jarPath, tmpPath);

        int actualFileCacheReplicationFactor = isLocalMode ? 1 : fileCacheReplicationFactor;

        yt.createNode(CreateNode.builder()
                .setPath(tmpPath)
                .setType(CypressNodeType.FILE)
                .addAttribute("replication_factor", YTree.integerNode(actualFileCacheReplicationFactor))
                .setIgnoreExisting(true)
                .build()).join();

        writeFile(yt, tmpPath, fileContent.get());

        yt.moveNode(MoveNode.builder()
                .setSource(tmpPath.toString())
                .setDestination(jarPath.toString())
                .setPreserveAccount(true)
                .setRecursive(true)
                .setForce(true)
                .build()).join();

        return jarPath.plusAdditionalAttribute("file_name", originalName);
    }

    private static byte[] getClassPathDirJarBytes(File dir) {
        ByteArrayOutputStream bytes = new ByteArrayOutputStream();
        try {
            JarOutputStream jar = new JarOutputStream(bytes) {
                @Override
                public void putNextEntry(ZipEntry ze) throws IOException {
                    // makes resulting jar md5 predictable to allow jar hashing at yt side
                    // https://stackoverflow.com/questions/26525936
                    ze.setTime(-1);
                    super.putNextEntry(ze);
                }
            };
            walk(dir, elm -> {
                String name = elm.getAbsolutePath().substring(dir.getAbsolutePath().length());
                if (name.length() > 0) {
                    try {
                        JarEntry entry = new JarEntry(name.substring(1).replace("\\", "/"));
                        jar.putNextEntry(entry);
                        if (elm.isFile()) {
                            Files.copy(elm.toPath(), jar);
                        }
                    } catch (IOException ex) {
                        throw new UncheckedIOException(ex);
                    }
                }
            });
            jar.close();
        } catch (IOException ex) {
            throw new RuntimeException(ex);
        }
        return bytes.toByteArray();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy