com.github.tonivade.purefun.HigherKindTransformer Maven / Gradle / Ivy
/*
* Copyright (c) 2018-2020, Antonio Gabriel Muñoz Conejo
* Distributed under the terms of the MIT License
*/
package com.github.tonivade.purefun;
import com.sun.tools.javac.code.Flags;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.JCTree.JCAnnotation;
import com.sun.tools.javac.tree.JCTree.JCAssign;
import com.sun.tools.javac.tree.JCTree.JCBlock;
import com.sun.tools.javac.tree.JCTree.JCClassDecl;
import com.sun.tools.javac.tree.JCTree.JCCompilationUnit;
import com.sun.tools.javac.tree.JCTree.JCExpression;
import com.sun.tools.javac.tree.JCTree.JCFieldAccess;
import com.sun.tools.javac.tree.JCTree.JCIdent;
import com.sun.tools.javac.tree.JCTree.JCLiteral;
import com.sun.tools.javac.tree.JCTree.JCMethodDecl;
import com.sun.tools.javac.tree.JCTree.JCReturn;
import com.sun.tools.javac.tree.JCTree.JCTypeApply;
import com.sun.tools.javac.tree.JCTree.JCTypeCast;
import com.sun.tools.javac.tree.JCTree.JCTypeParameter;
import com.sun.tools.javac.tree.JCTree.JCVariableDecl;
import com.sun.tools.javac.util.Context;
import com.sun.tools.javac.util.List;
import com.sun.tools.javac.util.Name;
import java.util.Optional;
public class HigherKindTransformer extends AbstractClassTransformer {
private static final String WITNESS = "µ";
private static final String NARROWK = "narrowK";
private static final String KIND = "com.github.tonivade.purefun.Kind";
private static final String HIGHER1 = "com.github.tonivade.purefun.Higher1";
private static final String HIGHER2 = "com.github.tonivade.purefun.Higher2";
private static final String HIGHER3 = "com.github.tonivade.purefun.Higher3";
public HigherKindTransformer(Context context) {
super(context);
}
@Override
public Optional transform(JCCompilationUnit unit) {
JCExpression kind = maker.QualIdent(elements.getTypeElement(KIND));
JCExpression higher1 = maker.QualIdent(elements.getTypeElement(HIGHER1));
JCExpression higher2 = maker.QualIdent(elements.getTypeElement(HIGHER2));
JCExpression higher3 = maker.QualIdent(elements.getTypeElement(HIGHER3));
JCTree head = unit.defs.head;
unit.defs = unit.defs.tail
.prepend(maker.Import(kind, false))
.prepend(maker.Import(higher1, false))
.prepend(maker.Import(higher2, false))
.prepend(maker.Import(higher3, false))
.prepend(head);
return Optional.of(unit);
}
@Override
public Optional transform(JCClassDecl clazz) {
Name kindName = kindName(clazz);
Name varName = elements.getName("hkt");
JCTree.JCClassDecl result = null;
if (clazz.typarams.length() == 1) {
result = generateHigher1Kind(clazz, kindName, varName);
} else if (clazz.typarams.length() == 2) {
result = generateHigher2Kind(clazz, kindName, varName);
} else if (clazz.typarams.length() == 3) {
result = generateHigher3Kind(clazz, kindName, varName);
}
return Optional.ofNullable(result);
}
private Name kindName(JCClassDecl clazz) {
Optional name = findAnnotation(clazz).args.stream()
.map(expr -> (JCAssign) expr)
.filter(assign -> ((JCIdent) assign.lhs).name.contentEquals("name"))
.findFirst()
.map(assign -> ((JCLiteral) assign.rhs))
.map(literal -> (String) literal.value);
return elements.getName(name.orElse(WITNESS));
}
private JCAnnotation findAnnotation(JCClassDecl clazz) {
return clazz.mods.annotations.stream()
.filter(annotation -> annotation.annotationType.type.toString().equals(HigherKind.class.getName()))
.findFirst()
.orElseThrow(IllegalStateException::new);
}
private JCClassDecl generateHigher1Kind(JCClassDecl clazz, Name kindName, Name varName) {
JCTypeParameter typeParam = clazz.typarams.head;
JCClassDecl witness = kindWitness(kindName);
JCTypeApply higher1 = higher1Kind(higher1(select(clazz.name, kindName), typeParam));
JCMethodDecl narrowKOf1 = narrowKindOf1(higher1, clazz.name, varName, typeParam);
fixPos(witness, clazz.pos);
fixPos(narrowKOf1, clazz.pos + witness.pos);
clazz.implementing = clazz.implementing.append(higher1Kind(clazz.name, kindName, typeParam));
clazz.defs = clazz.defs.append(witness).append(narrowKOf1);
return clazz;
}
private JCTree.JCClassDecl generateHigher2Kind(JCClassDecl clazz, Name kindName, Name varName) {
JCTypeParameter typeParam1 = clazz.typarams.head;
JCTypeParameter typeParam2 = clazz.typarams.tail.head;
JCClassDecl witness = kindWitness(kindName);
JCTypeApply higher1 = nestedHigher1(clazz.name, kindName, typeParam1, typeParam2);
JCTypeApply higher2 = higher2Kind(higher2(select(clazz.name, kindName), typeParam1, typeParam2));
JCMethodDecl narrowKOf1 = narrowKindOf1(higher1, clazz.name, varName, typeParam1, typeParam2);
JCMethodDecl narrowKOf2 = narrowKindOf2(higher2, clazz.name, varName, typeParam1, typeParam2);
fixPos(witness, clazz.pos);
fixPos(narrowKOf1, clazz.pos + witness.pos);
fixPos(narrowKOf2, clazz.pos + witness.pos + narrowKOf1.pos);
clazz.implementing = clazz.implementing.append(higher2Kind(clazz.name, kindName, typeParam1, typeParam2));
clazz.defs = clazz.defs.append(witness).append(narrowKOf1).append(narrowKOf2);
return clazz;
}
private JCTree.JCClassDecl generateHigher3Kind(JCClassDecl clazz, Name kindName, Name varName) {
JCTypeParameter typeParam1 = clazz.typarams.head;
JCTypeParameter typeParam2 = clazz.typarams.tail.head;
JCTypeParameter typeParam3 = clazz.typarams.tail.tail.head;
JCClassDecl witness = kindWitness(kindName);
JCTypeApply higher1 = nestedHigher1(clazz.name, kindName, typeParam1, typeParam2, typeParam3);
JCTypeApply higher2 = nestedHigher2(clazz.name, kindName, typeParam1, typeParam2, typeParam3);
JCTypeApply higher3 = higher3Kind(higher3(select(clazz.name, kindName), typeParam1, typeParam2, typeParam3));
JCMethodDecl narrowKOf1 = narrowKindOf1(higher1, clazz.name, varName, typeParam1, typeParam2, typeParam3);
JCMethodDecl narrowKOf2 = narrowKindOf2(higher2, clazz.name, varName, typeParam1, typeParam2, typeParam3);
JCMethodDecl narrowKOf3 = narrowKindOf3(higher3, clazz.name, varName, typeParam1, typeParam2, typeParam3);
fixPos(witness, clazz.pos);
fixPos(narrowKOf1, clazz.pos + witness.pos);
fixPos(narrowKOf2, clazz.pos + witness.pos + narrowKOf1.pos);
fixPos(narrowKOf3, clazz.pos + witness.pos + narrowKOf1.pos + narrowKOf2.pos);
clazz.implementing = clazz.implementing.append(higher3Kind(clazz.name, kindName, typeParam1, typeParam2, typeParam3));
clazz.defs = clazz.defs.append(witness).append(narrowKOf1).append(narrowKOf2).append(narrowKOf3);
return clazz;
}
private JCClassDecl kindWitness(Name name) {
return maker.ClassDef(
maker.Modifiers(Flags.PUBLIC | Flags.STATIC | Flags.FINAL),
name,
List.nil(),
null,
List.of(implementsKind()),
List.nil());
}
private JCIdent implementsKind() {
return maker.Ident(elements.getName("Kind"));
}
private JCMethodDecl narrowKindOf1(JCExpression param,
Name className,
Name varName,
JCTypeParameter... typeParams) {
return maker.MethodDef(
maker.Modifiers(Flags.PUBLIC | Flags.STATIC),
elements.getName(NARROWK),
returnType(className, typeParams),
params2Type(typeParams),
List.of(variable(varName, param)),
List.nil(),
block(returns(typeCast(className, varName, typeParams))),
null);
}
private JCMethodDecl narrowKindOf2(JCExpression param,
Name className,
Name varName,
JCTypeParameter... typeParams) {
return maker.MethodDef(
maker.Modifiers(Flags.PUBLIC | Flags.STATIC),
elements.getName(NARROWK),
returnType(className, typeParams),
params2Type(typeParams),
List.of(variable(varName, param)),
List.nil(),
block(returns(typeCast(className, varName, typeParams))),
null);
}
private JCMethodDecl narrowKindOf3(JCExpression param,
Name className,
Name varName,
JCTypeParameter... typeParams) {
return maker.MethodDef(
maker.Modifiers(Flags.PUBLIC | Flags.STATIC),
elements.getName(NARROWK),
returnType(className, typeParams),
params2Type(typeParams),
List.of(variable(varName, param)),
List.nil(),
block(returns(typeCast(className, varName, typeParams))),
null);
}
private JCBlock block(JCReturn returnValue) {
return maker.Block(0, List.of(returnValue));
}
private JCReturn returns(JCExpression expression) {
return maker.Return(expression);
}
private JCTypeCast typeCast(Name className, Name varName, JCTypeParameter... typeParams) {
return maker.TypeCast(
returnType(className, typeParams),
maker.Ident(varName));
}
private JCVariableDecl variable(Name varName, JCExpression typeDef) {
return maker.VarDef(
maker.Modifiers(Flags.ReceiverParamFlags),
varName,
typeDef,
null);
}
private JCTypeApply higher1Kind(Name className,
Name kindName,
JCTypeParameter typeParam) {
return higher1Kind(higher1(select(className, kindName), typeParam));
}
private JCTypeApply higher2Kind(Name className,
Name kindName,
JCTypeParameter typeParam1,
JCTypeParameter typeParam2) {
return higher2Kind(higher2(select(className, kindName), typeParam1, typeParam2));
}
private JCTypeApply higher3Kind(Name className,
Name kindName,
JCTypeParameter typeParam1,
JCTypeParameter typeParam2,
JCTypeParameter typeParam3) {
return higher3Kind(higher3(select(className, kindName), typeParam1, typeParam2, typeParam3));
}
private JCTypeApply higher1Kind(List type) {
return maker.TypeApply(maker.Ident(elements.getName("Higher1")), type);
}
private JCTypeApply higher2Kind(List type) {
return maker.TypeApply(maker.Ident(elements.getName("Higher2")), type);
}
private JCTypeApply higher3Kind(List type) {
return maker.TypeApply(maker.Ident(elements.getName("Higher3")), type);
}
private JCTypeApply nestedHigher1(Name className,
Name kindName,
JCTypeParameter typeParam1,
JCTypeParameter typeParam2)
{
return higher1Kind(higher1(higher1Kind(higher1(select(className, kindName), typeParam1)), typeParam2));
}
private JCTypeApply nestedHigher1(Name className,
Name kindName,
JCTypeParameter typeParam1,
JCTypeParameter typeParam2,
JCTypeParameter typeParam3) {
return higher1Kind(higher1(nestedHigher1(className, kindName, typeParam1, typeParam2), typeParam3));
}
private JCTypeApply nestedHigher2(Name className,
Name kindName,
JCTypeParameter typeParam1,
JCTypeParameter typeParam2,
JCTypeParameter typeParam3) {
return higher2Kind(higher2(higher1Kind(higher1(select(className, kindName), typeParam1)), typeParam2, typeParam3));
}
private List higher1(JCExpression nested, JCTypeParameter typeParam) {
return List.of(nested, maker.Ident(typeParam.name));
}
private List higher2(JCExpression nested,
JCTypeParameter typeParam1,
JCTypeParameter typeParam2) {
return List.of(nested, maker.Ident(typeParam1.name), maker.Ident(typeParam2.name));
}
private List higher3(JCExpression nested,
JCTypeParameter typeParam1,
JCTypeParameter typeParam2,
JCTypeParameter typeParam3) {
return List.of(nested, maker.Ident(typeParam1.name), maker.Ident(typeParam2.name), maker.Ident(typeParam3.name));
}
private JCTypeApply returnType(Name className, JCTypeParameter... typeParams) {
return maker.TypeApply(
maker.Ident(className),
params2Ident(typeParams));
}
private JCFieldAccess select(Name className, Name kindName) {
return maker.Select(maker.Ident(className), kindName);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy