All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mysema.codegen.ScalaWriter Maven / Gradle / Ivy

package com.mysema.codegen;

import static com.mysema.codegen.Symbols.ASSIGN;
import static com.mysema.codegen.Symbols.COMMA;
import static com.mysema.codegen.Symbols.DOT;
import static com.mysema.codegen.Symbols.QUOTE;
import static com.mysema.codegen.Symbols.SEMICOLON;

import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;

import org.apache.commons.collections15.Transformer;
import org.apache.commons.lang.StringEscapeUtils;

import com.mysema.codegen.model.Parameter;
import com.mysema.codegen.model.Type;
import com.mysema.codegen.model.Types;

/**
 * @author tiwe
 *
 */
public class ScalaWriter extends AbstractCodeWriter{
    
    private static final String DEF = "def ";
    
    private static final String EXTENDS = " extends ";
    
    private static final String IMPLEMENTS = " implements ";

    private static final String IMPORT = "import ";

    private static final String IMPORT_STATIC = "import static ";

    private static final String PACKAGE = "package ";

    private static final String PRIVATE = "private ";
    
    private static final String PRIVATE_VAL = "private val ";
    
    private static final String PROTECTED = "protected ";
    
    private static final String PROTECTED_VAL = "protected val ";

    private static final String PUBLIC = "public ";

    private static final String PUBLIC_CLASS = "class ";
    
    private static final String PUBLIC_OBJECT = "object ";

    private static final String VAR = "var ";
    
    private static final String VAL = "val ";

    private static final String TRAIT = "trait ";
    
    private final Set classes = new HashSet();
    
    private final Set packages = new HashSet();
    
    private Type type;
    
    private final boolean compact;
    
    public ScalaWriter(Appendable appendable){
        this(appendable, false);
    }

    public ScalaWriter(Appendable appendable, boolean compact){
        super(appendable);
        this.classes.add("java.lang.String");
        this.classes.add("java.lang.Object");
        this.classes.add("java.lang.Integer");
        this.classes.add("java.lang.Comparable");
        this.compact = compact;
    }
    
    @Override
    public ScalaWriter annotation(Annotation annotation) throws IOException {
        beginLine().append("@").appendType(annotation.annotationType());
        Method[] methods = annotation.annotationType().getDeclaredMethods();
        if (methods.length == 1 && methods[0].getName().equals("value")){
            try {
                Object value = methods[0].invoke(annotation);
                append("(");
                annotationConstant(value);
                append(")");
            } catch (IllegalArgumentException e) {
                throw new CodegenException(e);
            } catch (IllegalAccessException e) {
                throw new CodegenException(e);
            } catch (InvocationTargetException e) {
                throw new CodegenException(e);
            }            
        }else{
            boolean first = true;        
            for (Method method : methods){            
                try {
                    Object value = method.invoke(annotation);
                    if (value == null 
                     || value.equals(method.getDefaultValue())
                     || (value.getClass().isArray() 
                     && Arrays.equals((Object[])value, (Object[])method.getDefaultValue()))){
                        continue;
                    }else if (!first){
                        append(COMMA);
                    }else{
                        append("(");
                    }
                    append(method.getName()+"=");                
                    annotationConstant(value);
                } catch (IllegalArgumentException e) {
                    throw new CodegenException(e);
                } catch (IllegalAccessException e) {
                    throw new CodegenException(e);
                } catch (InvocationTargetException e) {
                    throw new CodegenException(e);
                }
                first = false;
            }               
            if (!first){
                append(")");    
            }        
        }        
        return nl();
    }

    @Override
    public ScalaWriter annotation(Class annotation) throws IOException {
        return beginLine().append("@").appendType(annotation).nl();
    }
    
    @SuppressWarnings("unchecked")
    private void annotationConstant(Object value) throws IOException{
        if (value.getClass().isArray()){
            append("Array(");
            boolean first = true;
            for (Object o : (Object[])value){
                if (!first){
                    append(", ");
                }
                annotationConstant(o);
                first = false;
            }
            append(")");        
        }else if (value instanceof Class){
            append("classOf[");
            appendType((Class)value);
            append("]");
        }else if (value instanceof Number || value instanceof Boolean){
            append(value.toString());
        }else if (value instanceof Enum){
            Enum enumValue = (Enum)value;
            append(enumValue.getDeclaringClass().getName()+DOT+enumValue.name());
        }else if (value instanceof String){
            append(QUOTE + StringEscapeUtils.escapeJava(value.toString()) + QUOTE);
        }else{
            throw new IllegalArgumentException("Unsupported annotation value : " + value);
        }
    }
    
    private ScalaWriter appendType(Class type) throws IOException{
        if (classes.contains(type.getName()) || packages.contains(type.getPackage().getName())){
            append(type.getSimpleName());
        }else{
            append(type.getName());
        }
        return this;
    }
    
    public ScalaWriter beginObject(String header) throws IOException {
        line(PUBLIC_OBJECT, header, " {");
        goIn();
        return this;
    }
    
    public ScalaWriter beginClass(String header) throws IOException {
        line(PUBLIC_CLASS, header, " {");
        goIn();
        return this;
    }

    @Override
    public ScalaWriter beginClass(Type type) throws IOException {
        return beginClass(type, null);
    }

    @Override
    public ScalaWriter beginClass(Type type, Type superClass, Type... interfaces) throws IOException {
        packages.add(type.getPackageName());
        beginLine(PUBLIC_CLASS, type.getSimpleName());
        if (superClass != null){
            append(EXTENDS + getGenericName(false, superClass));
        }
        if (interfaces.length > 0){
            append(IMPLEMENTS);
            for (int i = 0; i < interfaces.length; i++){
                if (i > 0){
                    append(COMMA);
                }
                append(getGenericName(false, interfaces[i]));
            }
        }
        append(" {").nl().nl();
        goIn();
        this.type = type;
        return this;
    }

    @Override
    public  ScalaWriter beginConstructor(Collection parameters, Transformer transformer) throws IOException {
        beginLine(PUBLIC + type.getSimpleName()).params(parameters, transformer).append(" {").nl();
        return goIn();     
    }

    @Override
    public ScalaWriter beginConstructor(Parameter... params) throws IOException {
        beginLine(PUBLIC + type.getSimpleName()).params(params).append(" {").nl();
        return goIn();
    }

    @Override
    public ScalaWriter beginInterface(Type type, Type... interfaces)
            throws IOException {
        packages.add(type.getPackageName());
        beginLine(TRAIT + getGenericName(false, type));
        if (interfaces.length > 0){
            append(EXTENDS);
            for (int i = 0; i < interfaces.length; i++){
                if (i > 0){
                    append(COMMA);
                }
                append(getGenericName(false, interfaces[i]));
            }
        }
        append(" {").nl().nl();
        goIn();
        this.type = type;
        return this;   
    }

    private ScalaWriter beginMethod(String modifiers, Type returnType, String methodName, Parameter... args) throws IOException{
        if (returnType.equals(Types.VOID)){
            beginLine(modifiers + methodName).params(args).append(" {").nl();    
        }else{
            beginLine(modifiers + methodName).params(args).append(": " + getGenericName(true, returnType)).append(" {").nl();
        }
        
        return goIn();
    }
    
    @Override
    public  ScalaWriter beginPublicMethod(Type returnType, String methodName, Collection parameters, Transformer transformer) throws IOException {
        return beginMethod(DEF, returnType, methodName, transform(parameters, transformer));
    }

    @Override
    public ScalaWriter beginPublicMethod(Type returnType, String methodName, Parameter... args) throws IOException{
        return beginMethod(DEF, returnType, methodName, args);
    }
    
    @Override
    public  ScalaWriter beginStaticMethod(Type returnType, String methodName, Collection parameters, Transformer transformer) throws IOException {
        return beginMethod(DEF, returnType, methodName, transform(parameters, transformer));
    }

    @Override
    public ScalaWriter beginStaticMethod(Type returnType, String methodName, Parameter... args) throws IOException{
        return beginMethod(DEF, returnType, methodName, args);
    }

    @Override
    public ScalaWriter end() throws IOException {
        goOut();
        return line("}").nl();
    }
    
    public ScalaWriter field(Type type, String name) throws IOException {
        line(VAR + name + ": " + getGenericName(true, type) + SEMICOLON);
        return compact ? this : nl();
    }

    private ScalaWriter field(String modifier, Type type, String name) throws IOException{
        line(modifier + name + ": " + getGenericName(true, type) + SEMICOLON);
        return compact ? this : nl();
    }
    
    private ScalaWriter field(String modifier, Type type, String name, String value) throws IOException{
        line(modifier + name + ": " + getGenericName(true, type) + ASSIGN + value + SEMICOLON);
        return compact ? this : nl();
    }
    
    @Override
    public String getGenericName(boolean asArgType, Type type) {
        if (type.getParameters().isEmpty()){
            return getRawName(type);
        }else{
            StringBuilder builder = new StringBuilder();
            builder.append(getRawName(type));
            builder.append("[");
            boolean first = true;
            String fullName = type.getFullName();
            for (Type parameter : type.getParameters()){                
                if (!first){
                    builder.append(", ");
                }
                if (parameter == null || parameter.getFullName().equals(fullName)){
                    builder.append("?");
                }else{
                    builder.append(getGenericName(false, parameter));    
                }                
                first = false;
            }
            builder.append("]");
            return builder.toString();
        }
    }

    @Override
    public String getRawName(Type type) {
        String fullName = type.getFullName();
        String packageName = type.getPackageName();
        String rv = fullName;
        if (packages.contains(packageName) || classes.contains(fullName)){
            if (packageName.length() > 0){
                rv = fullName.substring(packageName.length()+1);    
            }  
        }
        if (rv.endsWith("[]")){
            rv = rv.substring(0, rv.length()-2);
            if (classes.contains(rv)){
                rv = rv.substring(packageName.length()+1);
            }            
            return "Array[" + rv + "]";
        }else{
            return rv;
        }
    }

    @Override
    public ScalaWriter imports(Class... imports) throws IOException{
        for (Class cl : imports){
            classes.add(cl.getName());
            line(IMPORT + cl.getName() + SEMICOLON);
        }
        nl();
        return this;
    }

    @Override
    public ScalaWriter imports(Package... imports) throws IOException {
        for (Package p : imports){
            packages.add(p.getName());
            line(IMPORT + p.getName() + "._;");
        }
        nl();
        return this;
    }
    
    @Override
    public ScalaWriter importClasses(String... imports) throws IOException{
        for (String cl : imports){
            classes.add(cl);
            line(IMPORT + cl + SEMICOLON);
        }
        nl();
        return this;
    }

    @Override
    public ScalaWriter importPackages(String... imports) throws IOException {
        for (String p : imports){
            packages.add(p);
            line(IMPORT + p + "._;");
        }
        nl();
        return this;
    }

    @Override
    public ScalaWriter javadoc(String... lines) throws IOException {
        line("/**");
        for (String line : lines){
            line(" * " + line);
        }
        return line(" */");
    }

    @Override
    public ScalaWriter packageDecl(String packageName) throws IOException {
        packages.add(packageName);
        return line(PACKAGE + packageName + SEMICOLON).nl();
    }
    
    private  ScalaWriter params(Collection parameters, Transformer transformer) throws IOException{
        append("(");
        boolean first = true;
        for (T param : parameters){
            if (!first){
                append(COMMA);
            }
            param(transformer.transform(param));
            first = false;
        }
        append(")");
        return this;
    }
    
    private ScalaWriter params(Parameter... params) throws IOException{
        append("(");
        for (int i = 0; i < params.length; i++){
            if (i > 0){
                append(COMMA);
            }
            param(params[i]);
        }
        append(")");
        return this;
    }
    
    
    private ScalaWriter param(Parameter parameter) throws IOException{
        append(parameter.getName());
        append(": ");
        append(getGenericName(true, parameter.getType()));
        return this;
    }

    @Override
    public ScalaWriter privateField(Type type, String name) throws IOException {
        return field(PRIVATE, type, name);
    }
    
    @Override
    public ScalaWriter privateFinal(Type type, String name) throws IOException {
        return field(PRIVATE_VAL, type, name);        
    }
    
    @Override
    public ScalaWriter privateFinal(Type type, String name, String value) throws IOException {
        return field(PRIVATE_VAL, type, name, value);
    }

    @Override
    public ScalaWriter privateStaticFinal(Type type, String name, String value) throws IOException {
        return field(PRIVATE_VAL, type, name, value);
    }
        
    @Override
    public ScalaWriter protectedField(Type type, String name) throws IOException {
        return field(PROTECTED, type, name);        
    }
    
    @Override
    public ScalaWriter protectedFinal(Type type, String name) throws IOException {
        return field(PROTECTED_VAL, type, name);        
    }

    @Override
    public ScalaWriter protectedFinal(Type type, String name, String value) throws IOException {
        return field(PROTECTED_VAL, type, name, value);
    }

    @Override
    public ScalaWriter publicField(Type type, String name) throws IOException {
        return field(VAR, type, name);
    }
    
    @Override
    public ScalaWriter publicFinal(Type type, String name) throws IOException {
        return field(VAL, type, name);        
    }
    
    @Override
    public ScalaWriter publicFinal(Type type, String name, String value) throws IOException {
        return field(VAL, type, name, value);
    }
    
    @Override
    public ScalaWriter publicStaticFinal(Type type, String name, String value) throws IOException {
        return field(VAL, type, name, value);
    }

    @Override
    public ScalaWriter staticimports(Class... imports) throws IOException {
        for (Class cl : imports){
            line(IMPORT_STATIC + cl.getName() + "._;");
        }
        return this;
    }

    @Override
    public ScalaWriter suppressWarnings(String type) throws IOException {
        return line("@SuppressWarnings(\"" + type +"\")");
    }
    
    private  Parameter[] transform(Collection parameters, Transformer transformer){
        Parameter[] rv = new Parameter[parameters.size()];
        int i = 0; 
        for (T value : parameters){
            rv[i++] = transformer.transform(value);
        }
        return rv;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy