All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.jfirer.jsql.mapper.MapperGenerator Maven / Gradle / Ivy

package com.jfirer.jsql.mapper;

import com.jfirer.jsql.analyse.template.Template;
import com.jfirer.jsql.analyse.token.SqlLexer;
import com.jfirer.jsql.annotation.Sql;
import com.jfirer.jsql.metadata.Page;
import com.jfirer.jsql.metadata.TableEntityInfo;
import com.jfirer.jsql.session.SqlSession;
import com.jfirer.jsql.transfer.resultset.ResultMap;
import com.jfirer.jsql.transfer.resultset.ResultSetTransfer;
import com.jfirer.jsql.transfer.resultset.impl.*;
import com.jfirer.baseutil.reflect.ReflectUtil;
import com.jfirer.baseutil.smc.SmcHelper;
import com.jfirer.baseutil.smc.compiler.CompileHelper;
import com.jfirer.baseutil.smc.model.ClassModel;
import com.jfirer.baseutil.smc.model.FieldModel;
import com.jfirer.baseutil.smc.model.MethodModel;
import com.jfirer.jsql.transfer.resultset.impl.*;

import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

public class MapperGenerator
{
    private static final AtomicInteger count = new AtomicInteger(0);

    public static Class generate(Class ckass, Map tableEntityInfos, CompileHelper compiler)
    {
        Method[] methods = ckass.getMethods();
        for (Method method : methods)
        {
            if ( method.isAnnotationPresent(Sql.class) == false )
            {
                throw new IllegalArgumentException("类:" + method.getDeclaringClass().getName() + "有方法没有打@Sql注解");
            }
        }
        ClassModel classModel = new ClassModel(ckass.getSimpleName() + "$Mapper$" + count.getAndIncrement(), Mapper.class, ckass);
        classModel.addImport(Mapper.class);
        classModel.addImport(Template.class);
        classModel.addImport(Map.class);
        classModel.addImport(HashMap.class);
        classModel.addImport(String.class);
        classModel.addImport(BeanTransfer.class);
        classModel.addImport(SqlSession.class);
        classModel.addImport(List.class);
        AtomicInteger fieldNameCount = new AtomicInteger(0);
        for (Method method : methods)
        {
            StringBuilder cache = new StringBuilder();
            cache.append("if(session==null){throw new NullPointerException(\"当前没有session\");}");
            cache.append("Map variables = cachedVariables.get();\r\n");
            cache.append("List params = cachedParams.get();\r\n");
            MethodModel methodModel = new MethodModel(method, classModel);
            Sql         annotation  = method.getAnnotation(Sql.class);
            String formatSql = generateSqlAndTemplateField(tableEntityInfos, classModel, fieldNameCount, method, cache, annotation);
            if ( formatSql.startsWith("SELECT") )
            {
                String transferFieldName = "transfer_" + (fieldNameCount.getAndIncrement());
                if ( List.class.isAssignableFrom(method.getReturnType()) )
                {
                    Class componentClass = (Class) ((ParameterizedType) method.getGenericReturnType()).getActualTypeArguments()[0];
                    addResultSetTransferField(classModel, method, transferFieldName, componentClass);
                    cache.append("List result = session.queryList(").append(transferFieldName).append(",sql,params);\r\n");
                }
                else
                {
                    addResultSetTransferField(classModel, method, transferFieldName, method.getReturnType());
                    String returnTypeName = method.getReturnType().isPrimitive() ? ReflectUtil.wrapPrimitive(method.getReturnType()).getName() : SmcHelper.getReferenceName(method.getReturnType(), classModel);
                    cache.append(returnTypeName).append(" result = session.query(").append(transferFieldName).append(",sql,params);\r\n");
                }
            }
            else
            {
                cache.append("int result = session.update(sql,params);\r\n");
            }
            cache.append("params.clear();\r\n");
            cache.append("variables.clear();\r\n");
            cache.append("return result;\r\n");
            methodModel.setBody(cache.toString());
            classModel.putMethodModel(methodModel);
        }
        Thread.currentThread().getContextClassLoader();
        try
        {
            return compiler.compile(classModel);
        } catch (Exception e)
        {
            ReflectUtil.throwException(e);
            return null;
        }
    }

    /**
     * 生成ResultSetTransferField字段,并且添加到ClassModel中。
     *
     * @param classModel
     * @param method
     * @param transferFieldName
     * @param itemType          返回参数的类型。如果方法返回是List,则取其泛型参数的类型
     */
    private static void addResultSetTransferField(ClassModel classModel, Method method, String transferFieldName, Class itemType)
    {
        Class ckass = null;
        if ( method.isAnnotationPresent(ResultMap.class) )
        {
            ckass = method.getAnnotation(ResultMap.class).value();
        }
        else if ( itemType == String.class )
        {
            ckass = StringTransfer.class;
        }
        else if ( Enum.class.isAssignableFrom(itemType) )
        {
            ckass = EnumNameTransfer.class;
        }
        else if ( itemType == Date.class )
        {
            ckass = SqlDateTransfer.class;
        }
        else if ( itemType == java.util.Date.class )
        {
            ckass = UtilDateTransfer.class;
        }
        else if ( itemType == Timestamp.class )
        {
            ckass = TimeStampTransfer.class;
        }
        else if ( itemType == Time.class )
        {
            ckass = TimeTransfer.class;
        }
        else if ( itemType.isPrimitive() )
        {
            itemType = ReflectUtil.wrapPrimitive(itemType);
            if ( itemType == Integer.class )
            {
                ckass = IntegerTransfer.class;
            }
            else if ( itemType == Long.class )
            {
                ckass = LongTransfer.class;
            }
            else if ( itemType == Short.class )
            {
                ckass = ShortTransfer.class;
            }
            else if ( itemType == Float.class )
            {
                ckass = FloatTransfer.class;
            }
            else if ( itemType == Double.class )
            {
                ckass = DoubleTransfer.class;
            }
            else if ( itemType == Boolean.class )
            {
                ckass = BooleanTransfer.class;
            }
            else
            {
                throw new UnsupportedOperationException("不支持的单类型转换:" + itemType.getName());
            }
        }
        else
        {
            ckass = BeanTransfer.class;
        }
        classModel.addImport(ckass);
        FieldModel fieldModel = new FieldModel(transferFieldName, ResultSetTransfer.class, "new " + SmcHelper.getReferenceName(ckass, classModel) + "().initialize(" + SmcHelper.getReferenceName(itemType, classModel) + ".class)", classModel);
        classModel.addField(fieldModel);
    }

    /**
     * 生成并添加模板字段,并且生成解析格式化Sql的代码。最终返回格式化的sql
     *
     * @param tableEntityInfos
     * @param classModel
     * @param fieldNameCount
     * @param method
     * @param cache
     * @param annotation
     * @return
     */
    private static String generateSqlAndTemplateField(Map tableEntityInfos, ClassModel classModel, AtomicInteger fieldNameCount, Method method, StringBuilder cache, Sql annotation)
    {
        String formatSql = SqlLexer.parse(annotation.sql()).transfer(tableEntityInfos).format();
        String templateFieldName = "template_" + (fieldNameCount.getAndIncrement());
        FieldModel fieldModel = new FieldModel(templateFieldName, Template.class, "Template.parse(\"" + formatSql + "\")", classModel);
        classModel.addField(fieldModel);
        Class[] parameterTypes = method.getParameterTypes();
        String paramNames = annotation.paramNames();
        if ( parameterTypes.length != 0 )
        {
            String[] names = paramNames.split(",");
            int index = 0;
            for (String each : names)
            {
                cache.append("variables.put(\"").append(each).append("\",$").append(index).append(");\r\n");
                index++;
            }
        }
        cache.append("String sql =").append(templateFieldName).append(".render(variables,params);\r\n");
        if ( parameterTypes.length != 0 && parameterTypes[parameterTypes.length - 1] == Page.class )
        {
            cache.append("params.add($").append(parameterTypes.length - 1).append(");\r\n");
        }
        return formatSql;
    }
}