All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mysema.codegen.ScalaWriter Maven / Gradle / Ivy

package com.mysema.codegen;

import static com.mysema.codegen.Symbols.ASSIGN;
import static com.mysema.codegen.Symbols.COMMA;
import static com.mysema.codegen.Symbols.DOT;
import static com.mysema.codegen.Symbols.QUOTE;

import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;

import com.google.common.base.Function;
import com.mysema.codegen.model.Parameter;
import com.mysema.codegen.model.Type;
import com.mysema.codegen.model.Types;
import com.mysema.codegen.support.ScalaSyntaxUtils;

/**
 * @author tiwe
 * 
 */
public class ScalaWriter extends AbstractCodeWriter {

    private static final Set PRIMITIVE_TYPES = new HashSet(Arrays.asList("boolean",
            "byte", "char", "int", "long", "short", "double", "float"));

    private static final String DEF = "def ";

    private static final String EXTENDS = " extends ";

    private static final String WITH = " with ";

    private static final String IMPORT = "import ";

    private static final String IMPORT_STATIC = "import ";

    private static final String PACKAGE = "package ";

    private static final String PRIVATE = "private ";

    private static final String PRIVATE_VAL = "private val ";

    private static final String PROTECTED = "protected ";

    private static final String PROTECTED_VAL = "protected val ";

    private static final String PUBLIC = "public ";

    private static final String PUBLIC_CLASS = "class ";

    private static final String PUBLIC_OBJECT = "object ";

    private static final String CASE_CLASS = "case class ";

    private static final String VAR = "var ";

    private static final String VAL = "val ";

    private static final String TRAIT = "trait ";

    private final Set classes = new HashSet();

    private final Set packages = new HashSet();

    private Type type;

    private final boolean compact;

    public ScalaWriter(Appendable appendable) {
        this(appendable, false);
    }

    public ScalaWriter(Appendable appendable, boolean compact) {
        super(appendable, 2);
        this.classes.add("java.lang.String");
        this.classes.add("java.lang.Object");
        this.classes.add("java.lang.Integer");
        this.classes.add("java.lang.Comparable");
        this.compact = compact;
    }

    @Override
    public ScalaWriter annotation(Annotation annotation) throws IOException {
        beginLine().append("@").appendType(annotation.annotationType());
        Method[] methods = annotation.annotationType().getDeclaredMethods();
        if (methods.length == 1 && methods[0].getName().equals("value")) {
            try {
                Object value = methods[0].invoke(annotation);
                append("(");
                annotationConstant(value);
                append(")");
            } catch (IllegalArgumentException e) {
                throw new CodegenException(e);
            } catch (IllegalAccessException e) {
                throw new CodegenException(e);
            } catch (InvocationTargetException e) {
                throw new CodegenException(e);
            }
        } else {
            boolean first = true;
            for (Method method : methods) {
                try {
                    Object value = method.invoke(annotation);
                    if (value == null
                            || value.equals(method.getDefaultValue())
                            || (value.getClass().isArray() && Arrays.equals((Object[]) value,
                                    (Object[]) method.getDefaultValue()))) {
                        continue;
                    } else if (!first) {
                        append(COMMA);
                    } else {
                        append("(");
                    }
                    append(escape(method.getName()) + "=");
                    annotationConstant(value);
                } catch (IllegalArgumentException e) {
                    throw new CodegenException(e);
                } catch (IllegalAccessException e) {
                    throw new CodegenException(e);
                } catch (InvocationTargetException e) {
                    throw new CodegenException(e);
                }
                first = false;
            }
            if (!first) {
                append(")");
            }
        }
        return nl();
    }

    @Override
    public ScalaWriter annotation(Class annotation) throws IOException {
        return beginLine().append("@").appendType(annotation).nl();
    }

    @SuppressWarnings("unchecked")
    private void annotationConstant(Object value) throws IOException {
        if (value.getClass().isArray()) {
            append("Array(");
            boolean first = true;
            for (Object o : (Object[]) value) {
                if (!first) {
                    append(", ");
                }
                annotationConstant(o);
                first = false;
            }
            append(")");
        } else if (value instanceof Class) {
            append("classOf[");
            appendType((Class) value);
            append("]");
        } else if (value instanceof Number || value instanceof Boolean) {
            append(value.toString());
        } else if (value instanceof Enum) {
            Enum enumValue = (Enum) value;
            if (classes.contains(enumValue.getClass().getName())
                    || packages.contains(enumValue.getClass().getPackage().getName())) {
                append(enumValue.name());
            } else {
                append(enumValue.getDeclaringClass().getName() + DOT + enumValue.name());
            }
        } else if (value instanceof String) {
            append(QUOTE + StringUtils.escapeJava(value.toString()) + QUOTE);
        } else {
            throw new IllegalArgumentException("Unsupported annotation value : " + value);
        }
    }

    private ScalaWriter appendType(Class type) throws IOException {
        if (type.isPrimitive()) {
            append(StringUtils.capitalize(type.getName()));
        } else if (type.getPackage() == null || classes.contains(type.getName())
                || packages.contains(type.getPackage().getName())) {
            append(type.getSimpleName());
        } else {
            append(type.getName());
        }
        return this;
    }

    public ScalaWriter beginObject(String header) throws IOException {
        line(PUBLIC_OBJECT, header, " {");
        goIn();
        return this;
    }

    public ScalaWriter beginClass(String header) throws IOException {
        line(PUBLIC_CLASS, header, " {");
        goIn();
        return this;
    }

    @Override
    public ScalaWriter beginClass(Type type) throws IOException {
        return beginClass(type, null);
    }

    @Override
    public ScalaWriter beginClass(Type type, Type superClass, Type... interfaces)
            throws IOException {
        packages.add(type.getPackageName());
        beginLine(PUBLIC_CLASS, getGenericName(false, type));
        if (superClass != null) {
            append(EXTENDS + getGenericName(false, superClass));
        }
        if (interfaces.length > 0) {
            if (superClass == null) {
                append(EXTENDS);
                append(getGenericName(false, interfaces[0]));
                append(WITH);
                for (int i = 1; i < interfaces.length; i++) {
                    if (i > 1) {
                        append(COMMA);
                    }
                    append(getGenericName(false, interfaces[i]));
                }
            } else {
                append(WITH);
                for (int i = 0; i < interfaces.length; i++) {
                    if (i > 0) {
                        append(COMMA);
                    }
                    append(getGenericName(false, interfaces[i]));
                }
            }
        }
        append(" {").nl().nl();
        goIn();
        this.type = type;
        return this;
    }

    @Override
    public  ScalaWriter beginConstructor(Collection parameters,
            Function transformer) throws IOException {
        beginLine(PUBLIC + type.getSimpleName()).params(parameters, transformer).append(" {").nl();
        return goIn();
    }

    @Override
    public ScalaWriter beginConstructor(Parameter... params) throws IOException {
        beginLine(PUBLIC + type.getSimpleName()).params(params).append(" {").nl();
        return goIn();
    }

    @Override
    public ScalaWriter beginInterface(Type type, Type... interfaces) throws IOException {
        packages.add(type.getPackageName());
        beginLine(TRAIT + getGenericName(false, type));
        if (interfaces.length > 0) {
            append(EXTENDS);
            append(getGenericName(false, interfaces[0]));
            if (interfaces.length > 1) {
                append(WITH);
                for (int i = 1; i < interfaces.length; i++) {
                    if (i > 1) {
                        append(COMMA);
                    }
                    append(getGenericName(false, interfaces[i]));
                }
            }

        }
        append(" {").nl().nl();
        goIn();
        this.type = type;
        return this;
    }

    private ScalaWriter beginMethod(String modifiers, Type returnType, String methodName,
            Parameter... args) throws IOException {
        if (returnType.equals(Types.VOID)) {
            beginLine(modifiers + escape(methodName)).params(args).append(" {").nl();
        } else {
            beginLine(modifiers + escape(methodName)).params(args)
                    .append(": " + getGenericName(true, returnType)).append(" {").nl();
        }

        return goIn();
    }

    @Override
    public  ScalaWriter beginPublicMethod(Type returnType, String methodName,
            Collection parameters, Function transformer) throws IOException {
        return beginMethod(DEF, returnType, methodName, transform(parameters, transformer));
    }

    @Override
    public ScalaWriter beginPublicMethod(Type returnType, String methodName, Parameter... args)
            throws IOException {
        return beginMethod(DEF, returnType, methodName, args);
    }

    @Override
    public  ScalaWriter beginStaticMethod(Type returnType, String methodName,
            Collection parameters, Function transformer) throws IOException {
        return beginMethod(DEF, returnType, methodName, transform(parameters, transformer));
    }

    @Override
    public ScalaWriter beginStaticMethod(Type returnType, String methodName, Parameter... args)
            throws IOException {
        return beginMethod(DEF, returnType, methodName, args);
    }

    public ScalaWriter caseClass(String header, Parameter... parameters) throws IOException {
        beginLine(CASE_CLASS, header).params(parameters).nl();
        return this;
    }

    @Override
    public ScalaWriter end() throws IOException {
        goOut();
        return line("}").nl();
    }

    public ScalaWriter field(Type type, String name) throws IOException {
        line(VAR + escape(name) + ": " + getGenericName(true, type));
        return compact ? this : nl();
    }

    private ScalaWriter field(String modifier, Type type, String name) throws IOException {
        line(modifier + escape(name) + ": " + getGenericName(true, type));
        return compact ? this : nl();
    }

    private ScalaWriter field(String modifier, Type type, String name, String value)
            throws IOException {
        line(modifier + escape(name) + ": " + getGenericName(true, type) + ASSIGN + value);
        return compact ? this : nl();
    }

    @Override
    public String getGenericName(boolean asArgType, Type type) {
        if (type.getParameters().isEmpty()) {
            return getRawName(type);
        } else {
            StringBuilder builder = new StringBuilder();
            builder.append(getRawName(type));
            builder.append("[");
            boolean first = true;
            String fullName = type.getFullName();
            for (Type parameter : type.getParameters()) {
                if (!first) {
                    builder.append(", ");
                }
                if (parameter == null || parameter.getFullName().equals(fullName)) {
                    builder.append("_");
                } else {
                    builder.append(getGenericName(false, parameter));
                }
                first = false;
            }
            builder.append("]");
            return builder.toString();
        }
    }

    @Override
    public String getRawName(Type type) {
        String fullName = type.getFullName();
        if (PRIMITIVE_TYPES.contains(fullName)) {
            fullName = StringUtils.capitalize(fullName);
        }
        String packageName = type.getPackageName();
        String rv = fullName;
        if (type.isPrimitive() && packageName.isEmpty()) {
            rv = Character.toUpperCase(rv.charAt(0)) + rv.substring(1);
        }
        if (packages.contains(packageName) || classes.contains(fullName)) {
            if (packageName.length() > 0) {
                rv = fullName.substring(packageName.length() + 1);
            }
        }
        if (rv.endsWith("[]")) {
            rv = rv.substring(0, rv.length() - 2);
            if (PRIMITIVE_TYPES.contains(rv)) {
                rv = StringUtils.capitalize(rv);
            } else if (classes.contains(rv)) {
                rv = rv.substring(packageName.length() + 1);
            }
            return "Array[" + rv + "]";
        } else {
            return rv;
        }
    }

    @Override
    public ScalaWriter imports(Class... imports) throws IOException {
        for (Class cl : imports) {
            classes.add(cl.getName());
            line(IMPORT + cl.getName());
        }
        nl();
        return this;
    }

    @Override
    public ScalaWriter imports(Package... imports) throws IOException {
        for (Package p : imports) {
            packages.add(p.getName());
            line(IMPORT + p.getName() + "._");
        }
        nl();
        return this;
    }

    @Override
    public ScalaWriter importClasses(String... imports) throws IOException {
        for (String cl : imports) {
            classes.add(cl);
            line(IMPORT + cl);
        }
        nl();
        return this;
    }

    @Override
    public ScalaWriter importPackages(String... imports) throws IOException {
        for (String p : imports) {
            packages.add(p);
            line(IMPORT + p + "._");
        }
        nl();
        return this;
    }

    @Override
    public ScalaWriter javadoc(String... lines) throws IOException {
        line("/**");
        for (String line : lines) {
            line(" * " + line);
        }
        return line(" */");
    }

    @Override
    public ScalaWriter packageDecl(String packageName) throws IOException {
        packages.add(packageName);
        return line(PACKAGE + packageName).nl();
    }

    private  ScalaWriter params(Collection parameters, Function transformer)
            throws IOException {
        append("(");
        boolean first = true;
        for (T param : parameters) {
            if (!first) {
                append(COMMA);
            }
            param(transformer.apply(param));
            first = false;
        }
        append(")");
        return this;
    }

    private ScalaWriter params(Parameter... params) throws IOException {
        append("(");
        for (int i = 0; i < params.length; i++) {
            if (i > 0) {
                append(COMMA);
            }
            param(params[i]);
        }
        append(")");
        return this;
    }

    private ScalaWriter param(Parameter parameter) throws IOException {
        append(escape(parameter.getName()));
        append(": ");
        append(getGenericName(true, parameter.getType()));
        return this;
    }

    @Override
    public ScalaWriter privateField(Type type, String name) throws IOException {
        return field(PRIVATE, type, name);
    }

    @Override
    public ScalaWriter privateFinal(Type type, String name) throws IOException {
        return field(PRIVATE_VAL, type, name);
    }

    @Override
    public ScalaWriter privateFinal(Type type, String name, String value) throws IOException {
        return field(PRIVATE_VAL, type, name, value);
    }

    @Override
    public ScalaWriter privateStaticFinal(Type type, String name, String value) throws IOException {
        return field(PRIVATE_VAL, type, name, value);
    }

    @Override
    public ScalaWriter protectedField(Type type, String name) throws IOException {
        return field(PROTECTED, type, name);
    }

    @Override
    public ScalaWriter protectedFinal(Type type, String name) throws IOException {
        return field(PROTECTED_VAL, type, name);
    }

    @Override
    public ScalaWriter protectedFinal(Type type, String name, String value) throws IOException {
        return field(PROTECTED_VAL, type, name, value);
    }

    @Override
    public ScalaWriter publicField(Type type, String name) throws IOException {
        return field(VAR, type, name);
    }

    @Override
    public ScalaWriter publicField(Type type, String name, String value) throws IOException {
        return field(VAR, type, name, value);
    }

    @Override
    public ScalaWriter publicFinal(Type type, String name) throws IOException {
        return field(VAL, type, name);
    }

    @Override
    public ScalaWriter publicFinal(Type type, String name, String value) throws IOException {
        return field(VAL, type, name, value);
    }

    @Override
    public ScalaWriter publicStaticFinal(Type type, String name, String value) throws IOException {
        return field(VAL, type, name, value);
    }

    @Override
    public ScalaWriter staticimports(Class... imports) throws IOException {
        for (Class cl : imports) {
            line(IMPORT_STATIC + cl.getName() + "._;");
        }
        return this;
    }

    @Override
    public ScalaWriter suppressWarnings(String type) throws IOException {
        return line("@SuppressWarnings(\"" + type + "\")");
    }

    private  Parameter[] transform(Collection parameters,
            Function transformer) {
        Parameter[] rv = new Parameter[parameters.size()];
        int i = 0;
        for (T value : parameters) {
            rv[i++] = transformer.apply(value);
        }
        return rv;
    }

    private String escape(String token) {
        if (ScalaSyntaxUtils.isReserved(token)) {
            return "`" + token + "`";
        } else {
            return token;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy