All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.facebook.drift.idl.generator.ThriftIdlGenerator Maven / Gradle / Ivy

/*
 * Copyright (C) 2012 Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.drift.idl.generator;

import com.facebook.drift.annotations.ThriftService;
import com.facebook.drift.codec.ThriftCodecManager;
import com.facebook.drift.codec.ThriftProtocolType;
import com.facebook.drift.codec.metadata.FieldKind;
import com.facebook.drift.codec.metadata.MetadataErrorException;
import com.facebook.drift.codec.metadata.MetadataErrors.Monitor;
import com.facebook.drift.codec.metadata.MetadataWarningException;
import com.facebook.drift.codec.metadata.ThriftCatalog;
import com.facebook.drift.codec.metadata.ThriftFieldMetadata;
import com.facebook.drift.codec.metadata.ThriftMethodMetadata;
import com.facebook.drift.codec.metadata.ThriftServiceMetadata;
import com.facebook.drift.codec.metadata.ThriftStructMetadata;
import com.facebook.drift.codec.metadata.ThriftType;
import com.facebook.drift.codec.metadata.ThriftTypeReference;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;
import java.util.function.Predicate;

import static com.facebook.drift.codec.metadata.ReflectionHelper.getEffectiveClassAnnotations;
import static com.facebook.drift.idl.generator.ThriftIdlRenderer.renderThriftIdl;
import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.io.Files.getNameWithoutExtension;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class ThriftIdlGenerator
{
    private static final Set BUILT_IN_TYPES = ImmutableSet.builder()
            .add(ThriftType.BOOL)
            .add(ThriftType.BYTE)
            .add(ThriftType.I16)
            .add(ThriftType.I32)
            .add(ThriftType.I64)
            .add(ThriftType.DOUBLE)
            .add(ThriftType.STRING)
            .add(ThriftType.BINARY)
            .add(new ThriftType(ThriftType.BOOL, Boolean.class))
            .add(new ThriftType(ThriftType.BYTE, Byte.class))
            .add(new ThriftType(ThriftType.I16, Short.class))
            .add(new ThriftType(ThriftType.I32, Integer.class))
            .add(new ThriftType(ThriftType.I64, Long.class))
            .add(new ThriftType(ThriftType.DOUBLE, Double.class))
            .add(new ThriftType(ThriftType.STRING, String.class))
            .add(new ThriftType(ThriftType.BINARY, byte[].class))
            .build();

    private final ClassLoader classLoader;
    private final Consumer verboseLogger;
    private final ThriftCodecManager codecManager;
    private final String defaultPackage;
    private final Map includes = new HashMap<>();
    private final Set usedIncludedTypes = new HashSet<>();
    private final Map namespaces;

    private Set knownTypes = new HashSet<>(BUILT_IN_TYPES);
    private ThriftTypeRenderer typeRenderer = new ThriftTypeRenderer(ImmutableMap.of());
    private List thriftTypes = new CopyOnWriteArrayList<>();
    private List thriftServices = new CopyOnWriteArrayList<>();
    private boolean recursive;

    public ThriftIdlGenerator(ThriftIdlGeneratorConfig config)
    {
        this(config, firstNonNull(Thread.currentThread().getContextClassLoader(), ClassLoader.getSystemClassLoader()));
    }

    public ThriftIdlGenerator(ThriftIdlGeneratorConfig config, ClassLoader classLoader)
    {
        this.classLoader = requireNonNull(classLoader, "classLoader is null");

        Monitor monitor = createMonitor(config.getErrorLogger(), config.getWarningLogger());
        this.codecManager = new ThriftCodecManager(new ThriftCatalog(monitor));

        this.verboseLogger = config.getVerboseLogger();
        String defaultPackage = config.getDefaultPackage();

        if (defaultPackage.isEmpty()) {
            this.defaultPackage = "";
        }
        else {
            this.defaultPackage = defaultPackage + ".";
        }

        Map paramIncludeMap = config.getIncludes();
        for (Map.Entry entry : paramIncludeMap.entrySet()) {
            Class clazz = load(entry.getKey());
            if (clazz == null) {
                continue;
            }

            Object result = convertToThrift(clazz);
            this.includes.put(result, entry.getValue());
        }

        this.namespaces = config.getNamespaces();
        this.recursive = config.isRecursive();
    }

    public String generate(Iterable inputs)
    {
        for (String className : inputs) {
            Object result = convertToThrift(load(className));
            if (result instanceof ThriftType) {
                thriftTypes.add((ThriftType) result);
            }
            else {
                thriftServices.add((ThriftServiceMetadata) result);
            }
            // if the class we just loaded was also in the include map, remove it from there
            includes.remove(result);
        }

        if (!verify()) {
            throw new ThriftIdlGeneratorException("Errors found during verification.");
        }

        return generate();
    }

    private String getFullClassName(String className)
    {
        if (className.indexOf('.') == -1) {
            return defaultPackage + className;
        }
        return className;
    }

    @SuppressWarnings("NonShortCircuitBooleanExpression")
    @SuppressFBWarnings("NS_DANGEROUS_NON_SHORT_CIRCUIT")
    private boolean verify()
    {
        if (recursive) {
            // Call verifyStruct and verifyService until the lists of discovered types and services stop changing.
            // This populates the list with all transitive dependencies of the input types/services to be fed into the
            // topological sort of verifyTypes() and verifyServices().
            while (true) {
                int size = thriftTypes.size();
                for (ThriftType type : thriftTypes) {
                    verifyStruct(type, true);
                }
                if (size == thriftTypes.size()) {
                    break;
                }
            }

            while (true) {
                int size = thriftServices.size();
                for (ThriftServiceMetadata service : thriftServices) {
                    verifyService(service, true);
                }
                if (size == thriftServices.size()) {
                    break;
                }
            }

            recursive = false;
            usedIncludedTypes.clear();
            knownTypes = new HashSet<>(BUILT_IN_TYPES);
        }
        return verifyTypes() & verifyServices();
    }

    // verifies that all types are known (in thriftTypes or in include map)
    // and does a topological sort of thriftTypes in dependency order
    private boolean verifyTypes()
    {
        SuccessAndResult output = topologicalSort(thriftTypes, type -> {
            ThriftProtocolType proto = type.getProtocolType();
            if (proto == ThriftProtocolType.ENUM || proto == ThriftProtocolType.STRUCT) {
                return verifyStruct(type, true);
            }
            throw new VerifyException("Top-level non-enum and non-struct?");
        });
        if (output.isSuccess()) {
            thriftTypes = output.getResult();
            return true;
        }
        for (ThriftType type : output.getResult()) {
            // we know it's gonna fail, we just want the precise error message
            verifyStruct(type, false);
        }
        return false;
    }

    private boolean verifyServices()
    {
        SuccessAndResult output = topologicalSort(thriftServices, metadata -> verifyService(metadata, true));
        if (output.isSuccess()) {
            thriftServices = output.getResult();
            return true;
        }
        for (ThriftServiceMetadata s : output.getResult()) {
            // we know it's gonna fail, we just want the precise error message
            verifyService(s, false);
        }
        return false;
    }

    private static class SuccessAndResult
    {
        private final boolean success;
        private final List result;

        SuccessAndResult(boolean success, List result)
        {
            this.success = success;
            this.result = ImmutableList.copyOf(result);
        }

        public boolean isSuccess()
        {
            return success;
        }

        public List getResult()
        {
            return result;
        }
    }

    private static  SuccessAndResult topologicalSort(List list, Predicate isKnown)
    {
        List remaining = list;
        List newList = new ArrayList<>();
        int prevSize = 0;
        while (prevSize != remaining.size()) {
            prevSize = remaining.size();
            List bad = new ArrayList<>();
            for (T value : remaining) {
                if (isKnown.test(value)) {
                    newList.add(value);
                }
                else {
                    bad.add(value);
                }
            }
            remaining = bad;
        }
        if (prevSize == 0) {
            return new SuccessAndResult<>(true, newList);
        }
        return new SuccessAndResult<>(false, remaining);
    }

    private boolean verifyService(ThriftServiceMetadata service, boolean quiet)
    {
        boolean ok = true;

        for (ThriftMethodMetadata method : service.getMethods()) {
            for (ThriftFieldMetadata field : method.getParameters()) {
                if (!verifyField(field.getThriftType())) {
                    ok = false;
                    if (!quiet) {
                        throw new ThriftIdlGeneratorException(format("Unknown argument type %s in %s.%s", typeName(field.getThriftType()), service.getName(), method.getName()));
                    }
                }
            }

            for (ThriftType exception : method.getExceptions().values()) {
                if (!verifyField(exception)) {
                    ok = false;
                    if (!quiet) {
                        throw new ThriftIdlGeneratorException(format("Unknown exception type %s in %s.%s", typeName(exception), service.getName(), method.getName()));
                    }
                }
            }

            if (!method.getReturnType().equals(ThriftType.VOID) && !verifyField(method.getReturnType())) {
                ok = false;
                if (!quiet) {
                    throw new ThriftIdlGeneratorException(format("Unknown return type %s in %s.%s", typeName(method.getReturnType()), service.getName(), method.getName()));
                }
            }
        }

        return ok;
    }

    private boolean verifyElementType(ThriftTypeReference type)
    {
        if (!recursive && type.isRecursive()) {
            return true;
        }
        return verifyField(type.get());
    }

    @SuppressWarnings("NonShortCircuitBooleanExpression")
    @SuppressFBWarnings("NS_DANGEROUS_NON_SHORT_CIRCUIT")
    private boolean verifyField(ThriftType type)
    {
        ThriftProtocolType proto = type.getProtocolType();
        if (proto == ThriftProtocolType.SET || proto == ThriftProtocolType.LIST) {
            return verifyElementType(type.getValueTypeReference());
        }
        if (proto == ThriftProtocolType.MAP) {
            return verifyElementType(type.getKeyTypeReference()) & verifyElementType(type.getValueTypeReference());
        }

        if (knownTypes.contains(type)) {
            return true;
        }

        if (includes.containsKey(type)) {
            usedIncludedTypes.add(type);
            return true;
        }

        if (recursive) {
            // recursive but type is unknown - add it to the list and recurse
            thriftTypes.add(type);
            return verifyStruct(type, true);
        }
        return false;
    }

    private boolean verifyStruct(ThriftType type, boolean quiet)
    {
        if (type.getProtocolType() == ThriftProtocolType.ENUM) {
            knownTypes.add(type);
            return true;
        }
        ThriftStructMetadata metadata = type.getStructMetadata();
        boolean ok = true;

        knownTypes.add(type);

        for (ThriftFieldMetadata fieldMetadata : metadata.getFields(FieldKind.THRIFT_FIELD)) {
            if (!recursive && fieldMetadata.isTypeReferenceRecursive()) {
                continue;
            }

            boolean fieldOk = verifyField(fieldMetadata.getThriftType());
            if (!fieldOk) {
                ok = false;
                if (!quiet) {
                    throw new ThriftIdlGeneratorException(format("Unknown type %s in %s.%s", typeName(fieldMetadata.getThriftType()), metadata.getStructName(), fieldMetadata.getName()));
                }
            }
        }

        if (!ok) {
            knownTypes.remove(type);
        }
        return ok;
    }

    private Class load(String className)
    {
        className = getFullClassName(className);
        try {
            return classLoader.loadClass(className);
        }
        catch (ClassNotFoundException e) {
            throw new ThriftIdlGeneratorException("Class not found: " + className);
        }
    }

    private Object convertToThrift(Class clazz)
    {
        try {
            return getThriftType(clazz);
        }
        catch (MetadataErrorException e) {
            throw new ThriftIdlGeneratorException(e);
        }
    }

    private Object getThriftType(Class clazz)
    {
        Set serviceAnnotations = getEffectiveClassAnnotations(clazz, ThriftService.class);
        if (serviceAnnotations.isEmpty()) {
            ThriftType thriftType = codecManager.getCatalog().getThriftType(clazz);
            verboseLogger.accept("Found Thrift type: " + typeName(thriftType));
            return thriftType;
        }

        ThriftServiceMetadata serviceMetadata = new ThriftServiceMetadata(clazz, codecManager.getCatalog());
        verboseLogger.accept("Found Thrift service: " + clazz.getSimpleName());
        return serviceMetadata;
    }

    private String generate()
    {
        ImmutableMap.Builder typesBuilder = ImmutableMap.builder();
        ImmutableSet.Builder includesBuilder = ImmutableSet.builder();
        for (ThriftType type : usedIncludedTypes) {
            String filename = includes.get(type);
            includesBuilder.add(filename);
            typesBuilder.put(type, getNameWithoutExtension(filename));
        }
        typeRenderer = new ThriftTypeRenderer(typesBuilder.build());
        return renderThriftIdl(namespaces, includesBuilder.build(), thriftTypes, thriftServices, typeRenderer);
    }

    private String typeName(ThriftType type)
    {
        return typeRenderer.toString(type);
    }

    private static Monitor createMonitor(Consumer errorLogger, Consumer warningLogger)
    {
        return new Monitor()
        {
            @Override
            public void onError(MetadataErrorException e)
            {
                errorLogger.accept(e.getMessage());
            }

            @Override
            public void onWarning(MetadataWarningException e)
            {
                warningLogger.accept(e.getMessage());
            }
        };
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy