/*
 * Decompiled with CFR 0.152.
 */
package polyglot.objinl.hj.visit;

import java.util.Collection;
import java.util.LinkedList;
import polyglot.ast.Assign;
import polyglot.ast.Binary;
import polyglot.ast.Block;
import polyglot.ast.Call;
import polyglot.ast.CanonicalTypeNode;
import polyglot.ast.Conditional;
import polyglot.ast.Eval;
import polyglot.ast.Expr;
import polyglot.ast.Field;
import polyglot.ast.For;
import polyglot.ast.Formal;
import polyglot.ast.Id;
import polyglot.ast.IntLit;
import polyglot.ast.Local;
import polyglot.ast.LocalDecl;
import polyglot.ast.Node;
import polyglot.ast.Receiver;
import polyglot.ast.Stmt;
import polyglot.ast.TypeNode;
import polyglot.ext.hj.ExtensionInfo;
import polyglot.ext.hj.ast.ConstantDistMaker;
import polyglot.ext.hj.ast.ForLoop;
import polyglot.ext.hj.ast.HjFormal;
import polyglot.ext.hj.ast.HjLoop;
import polyglot.ext.hj.ast.HjNodeFactory;
import polyglot.ext.hj.ast.Range;
import polyglot.ext.hj.ast.Range_c;
import polyglot.ext.hj.types.HjTypeSystem;
import polyglot.frontend.Job;
import polyglot.objinl.hj.visit.HjPointInlineSafetyAnalyzer;
import polyglot.objinl.hj.visit.HjPointRankTypeAnalyzer;
import polyglot.objinl.util.TreeUtils;
import polyglot.types.ClassType;
import polyglot.types.Flags;
import polyglot.types.LocalInstance;
import polyglot.types.MethodInstance;
import polyglot.types.PrimitiveType;
import polyglot.types.ReferenceType;
import polyglot.types.Type;
import polyglot.types.VarInstance;
import polyglot.util.Position;
import polyglot.visit.NodeVisitor;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class HjPointForLoopConverter
extends NodeVisitor {
    private static int cntr = 1;
    Job job;
    Collection jobs;
    HjTypeSystem ts;
    HjNodeFactory nf;
    ExtensionInfo.HjScheduler scheduler;

    public HjPointForLoopConverter(Job job, HjTypeSystem ts, HjNodeFactory nf, ExtensionInfo.HjScheduler scheduler, Collection jobs) {
        this.job = job;
        this.jobs = jobs;
        this.ts = ts;
        this.nf = nf;
        this.scheduler = scheduler;
    }

    public NodeVisitor enter(Node n) {
        return this;
    }

    public Node leave(Node old, Node n, NodeVisitor v) {
        if (n instanceof HjLoop) {
            HjLoop hjLoop = (HjLoop)n;
            if (hjLoop instanceof ForLoop) {
                n = this.hjLoopHelper(hjLoop);
            }
            return n;
        }
        return n;
    }

    private Node hjLoopHelper(HjLoop hjLoop) {
        HjFormal hjFormal;
        Formal formal = hjLoop.formal();
        if (formal instanceof HjFormal && this.ts.isPoint((hjFormal = (HjFormal)formal).declType())) {
            LocalInstance[] indiceInsts = hjFormal.localInstances();
            LinkedList<Range> rangeList = null;
            VarInstance regionInst = TreeUtils.getVarInstance((Node)hjLoop.domain());
            if (regionInst != null) {
                rangeList = HjPointRankTypeAnalyzer.regionToRangesMap.get(regionInst);
            }
            if (rangeList == null) {
                if (hjLoop.domain() instanceof Field) {
                    rangeList = HjPointForLoopConverter.getRangesFromArrayDistr(((Field)hjLoop.domain()).target(), this.ts);
                } else if (hjLoop.domain() instanceof Call) {
                    rangeList = HjPointForLoopConverter.getRangesFromArrayDistr(((Call)hjLoop.domain()).target(), this.ts);
                }
            }
            return this.transformHjForLoopWithPossibleArrayDomain(hjLoop, hjFormal, rangeList, indiceInsts);
        }
        return hjLoop;
    }

    private Node transformHjForLoopWithPossibleArrayDomain(HjLoop hjLoop, HjFormal hjFormal, LinkedList<Range> rangeList, LocalInstance[] indiceInsts) {
        if (!HjPointInlineSafetyAnalyzer.unSafePointsToInline.contains(hjFormal.localInstance())) {
            if (rangeList != null) {
                return this.transformHjForLoop(indiceInsts, rangeList, 0, hjLoop.body(), hjLoop.domain(), hjFormal.localInstance());
            }
            Expr domain = hjLoop.domain();
            Receiver hjArray = this.findHjArrayInLoopDomain((Receiver)domain);
            if (hjArray != null && (this.ts.isRegion(domain.type()) || this.ts.isDistribution(domain.type()))) {
                VarInstance arrayInst = TreeUtils.getVarInstance((Node)hjArray);
                Integer numDims = null;
                if (arrayInst != null) {
                    numDims = HjPointRankTypeAnalyzer.arrayToRankMap.get(arrayInst);
                }
                if (numDims != null) {
                    rangeList = this.transformHjLoopHeaderUsingHjArrayLength(hjLoop.position(), numDims, domain);
                    return this.transformHjForLoop(indiceInsts, rangeList, 0, hjLoop.body(), hjLoop.domain(), hjFormal.localInstance());
                }
            }
        }
        return hjLoop;
    }

    private LinkedList<Range> transformHjLoopHeaderUsingHjArrayLength(Position pos, Integer numDims, Expr domain) {
        LinkedList<Range> rangeList = new LinkedList<Range>();
        for (int i = 0; i < numDims; ++i) {
            IntLit lowerBound = this.nf.IntLit(pos, IntLit.INT, 0L);
            lowerBound = lowerBound.type((Type)this.ts.Int());
            IntLit one = this.nf.IntLit(pos, IntLit.INT, 1L);
            one = one.type(lowerBound.type());
            Expr arraySize = this.createRegionDimSize((Receiver)domain, i);
            Binary upperBound = this.nf.Binary(pos, arraySize, Binary.SUB, (Expr)one);
            upperBound = upperBound.type(lowerBound.type());
            Range_c range = new Range_c(pos, (Expr)lowerBound, (Expr)upperBound, (Expr)one);
            rangeList.add(range);
        }
        return rangeList;
    }

    private Expr createRegionDimSize(Receiver receiver, int dim) {
        Type type = receiver.type();
        if (this.ts.isRegion(type)) {
            Call call = HjPointForLoopConverter.createRankCall(receiver, dim, this.nf, this.ts);
            call = HjPointForLoopConverter.createCallHelper((Receiver)call, "size", (Type)this.ts.Int(), this.nf, this.ts);
            return call;
        }
        if (this.ts.isDistribution(type)) {
            Call call = HjPointForLoopConverter.createCallHelper(receiver, "region", (Type)this.ts.region(), this.nf, this.ts);
            return this.createRegionDimSize((Receiver)call, dim);
        }
        return null;
    }

    private Expr createRegionSize(Receiver receiver) {
        Type type = receiver.type();
        if (this.ts.isRegion(type)) {
            Call call = HjPointForLoopConverter.createCallHelper(receiver, "size", (Type)this.ts.Int(), this.nf, this.ts);
            return call;
        }
        if (this.ts.isDistribution(type)) {
            Call call = HjPointForLoopConverter.createCallHelper(receiver, "region", (Type)this.ts.region(), this.nf, this.ts);
            return this.createRegionSize((Receiver)call);
        }
        return null;
    }

    public static Call createRankCall(Receiver receiver, int dim, HjNodeFactory nf, HjTypeSystem ts) {
        Type containerType = receiver.type();
        ClassType regionType = ts.region();
        LinkedList<IntLit> args = new LinkedList<IntLit>();
        LinkedList<PrimitiveType> argTypes = new LinkedList<PrimitiveType>();
        LinkedList excTypes = new LinkedList();
        Id id = nf.Id(receiver.position(), "rank");
        IntLit lit = nf.IntLit(id.position(), IntLit.INT, dim);
        lit = (IntLit)lit.type((Type)ts.Int());
        args.add(lit);
        argTypes.add(ts.Int());
        Call call = nf.Call(receiver.position(), receiver, id, args);
        MethodInstance mi = ts.methodInstance(receiver.position(), (ReferenceType)containerType, Flags.NONE, (Type)regionType, "rank", argTypes, excTypes);
        call = call.methodInstance(mi);
        call = (Call)call.type(mi.returnType());
        return call;
    }

    public static Call createCallHelper(Receiver receiver, String name, Type retType, HjNodeFactory nf, HjTypeSystem ts) {
        Type type = receiver.type();
        LinkedList args = new LinkedList();
        LinkedList argTypes = new LinkedList();
        LinkedList excTypes = new LinkedList();
        Id id = nf.Id(receiver.position(), name);
        Call call = nf.Call(receiver.position(), receiver, id, args);
        MethodInstance mi = ts.methodInstance(receiver.position(), (ReferenceType)type, Flags.NONE, retType, name, argTypes, excTypes);
        call = call.methodInstance(mi);
        call = (Call)call.type(mi.returnType());
        return call;
    }

    private Receiver findHjArrayInLoopDomain(Receiver domainExpr) {
        if (this.ts.isHjArray(domainExpr.type())) {
            return domainExpr;
        }
        if (domainExpr instanceof Field) {
            return this.findHjArrayInLoopDomain(((Field)domainExpr).target());
        }
        if (domainExpr instanceof Call) {
            return this.findHjArrayInLoopDomain(((Call)domainExpr).target());
        }
        return null;
    }

    private Stmt transformHjForLoop(LocalInstance[] indiceInsts, LinkedList<Range> rangeList, int index, Stmt loopbody, Expr domain, LocalInstance pointInst) {
        int rangeListSize = rangeList.size();
        if (rangeListSize > index) {
            Conditional init;
            IntLit zero;
            Local loopRegionSizeVarRef;
            Range range = rangeList.get(index);
            Position pos = loopbody.position();
            if (indiceInsts == null || indiceInsts.length == 0) {
                indiceInsts = (LocalInstance[])HjPointRankTypeAnalyzer.pointToIndicesMap.get(pointInst);
            }
            if (indiceInsts == null || indiceInsts.length == 0) {
                indiceInsts = HjPointForLoopConverter.createInlinedIndices(rangeListSize, pointInst, this.ts);
                HjPointRankTypeAnalyzer.pointToIndicesMap.put((VarInstance)pointInst, (VarInstance[])indiceInsts);
            }
            LocalInstance locIndiceInst = indiceInsts[index];
            String syncName = this.createUniqueName();
            LocalInstance syncInst = this.ts.localInstance(pos, Flags.NONE, (Type)this.ts.Int(), syncName);
            LinkedList<LocalDecl> inits = new LinkedList<LocalDecl>();
            LocalDecl loopRegionSizeVarDecl = null;
            if (index == 0) {
                Id nonEmptyRegionCheck = this.nf.Id(pos, this.createUniqueName());
                PrimitiveType bool = this.ts.Boolean();
                CanonicalTypeNode tn = this.nf.CanonicalTypeNode(pos, (Type)bool);
                Expr regionSize = this.createRegionSize((Receiver)domain);
                loopRegionSizeVarDecl = this.createLocalDecl(pos, nonEmptyRegionCheck.id(), (Type)this.ts.Int(), Flags.NONE, regionSize);
                inits.add(loopRegionSizeVarDecl);
                loopRegionSizeVarRef = this.createLocalRef(pos, loopRegionSizeVarDecl.localInstance());
                zero = this.nf.IntLit(pos, IntLit.INT, 0L);
                IntLit m_one = this.nf.IntLit(pos, IntLit.INT, -1L);
                Binary cond = this.createBinaryExpr(pos, (Expr)loopRegionSizeVarRef, Binary.NE, (Expr)zero, (Type)bool);
                Expr lowerBound = (Expr)range.lowerBound().visit((NodeVisitor)this);
                init = this.nf.Conditional(pos, (Expr)cond, lowerBound, (Expr)m_one);
                LocalDecl loopVarDecl = this.createLocalDecl(pos, syncInst, (Expr)init);
                inits.add(loopVarDecl);
            } else {
                Expr lowerBound = (Expr)range.lowerBound().visit((NodeVisitor)this);
                LocalDecl loopVarDecl = this.createLocalDecl(pos, syncInst, lowerBound);
                inits.add(loopVarDecl);
            }
            Expr upperBound = (Expr)range.upperBound().visit((NodeVisitor)this);
            upperBound = (Expr)this.rePosition(domain, (Receiver)upperBound);
            Local indice = this.createLocalRef(pos, syncInst);
            PrimitiveType bool = this.ts.Boolean();
            Binary cond = this.createBinaryExpr(pos, (Expr)indice, Binary.LE, upperBound, (Type)bool);
            if (index == 0) {
                loopRegionSizeVarRef = this.createLocalRef(pos, loopRegionSizeVarDecl.localInstance());
                zero = this.nf.IntLit(pos, IntLit.INT, 0L);
                Binary emptyRegionCond = this.createBinaryExpr(pos, (Expr)loopRegionSizeVarRef, Binary.GT, (Expr)zero, (Type)bool);
                cond = this.createBinaryExpr(pos, (Expr)emptyRegionCond, Binary.COND_AND, (Expr)cond, (Type)bool);
            }
            LinkedList<Eval> incrs = new LinkedList<Eval>();
            Expr stride = (Expr)range.stride().visit((NodeVisitor)this);
            stride = stride.type((Type)this.ts.Int());
            indice = this.createLocalRef(pos, syncInst);
            Assign assign = this.nf.Assign(loopbody.position(), (Expr)indice, Assign.ADD_ASSIGN, stride);
            assign = (Assign)assign.type(syncInst.type());
            Eval eval = this.nf.Eval(loopbody.position(), (Expr)assign);
            incrs.add(eval);
            Stmt newLoopBody = this.transformHjForLoop(indiceInsts, rangeList, ++index, loopbody, domain, pointInst);
            init = this.createLocalRef(pos, syncInst);
            LocalDecl ld = this.createLocalDecl(pos, locIndiceInst.name(), (Type)this.ts.Int(), Flags.FINAL, (Expr)init);
            newLoopBody = this.addStmtToLoopBody(newLoopBody, (Stmt)ld);
            For forLoop = this.nf.For(loopbody.position(), inits, (Expr)cond, incrs, newLoopBody);
            return forLoop;
        }
        return loopbody;
    }

    private Block addStmtToLoopBody(Stmt loopBody, Stmt stmt) {
        LinkedList<Stmt> newLoopBodyList = new LinkedList<Stmt>();
        newLoopBodyList.add(stmt);
        if (loopBody instanceof Block) {
            newLoopBodyList.addAll(((Block)loopBody).statements());
        } else {
            newLoopBodyList.add(loopBody);
        }
        Block block = this.nf.Block(loopBody.position(), newLoopBodyList);
        return block;
    }

    private Local createLocalRef(Position pos, LocalInstance li) {
        Id id_name = this.nf.Id(pos, li.name());
        Local local = this.nf.Local(pos, id_name);
        local = local.localInstance(li);
        local = (Local)local.type(li.type());
        return local;
    }

    private LocalDecl createLocalDecl(Position pos, LocalInstance li, Expr init) {
        Id loc_name = this.nf.Id(pos, li.name());
        CanonicalTypeNode typeNode = this.nf.CanonicalTypeNode(pos, li.type());
        LocalDecl ld = this.nf.LocalDecl(pos, Flags.NONE, (TypeNode)typeNode, loc_name, init);
        ld = ld.localInstance(li);
        return ld;
    }

    private LocalDecl createLocalDecl(Position pos, String name, Type type, Flags flags, Expr init) {
        LocalInstance li = this.ts.localInstance(pos, flags, (Type)this.ts.Int(), name);
        Id loc_name = this.nf.Id(pos, name);
        CanonicalTypeNode typeNode = this.nf.CanonicalTypeNode(pos, type);
        LocalDecl ld = this.nf.LocalDecl(pos, flags, (TypeNode)typeNode, loc_name, init);
        ld = ld.localInstance(li);
        return ld;
    }

    private Binary createBinaryExpr(Position pos, Expr left, Binary.Operator op, Expr right, Type type) {
        Binary binary = this.nf.Binary(pos, left, op, right);
        binary = (Binary)binary.type(type);
        return binary;
    }

    private Receiver rePosition(Expr domain, Receiver cond) {
        if (cond instanceof Call) {
            Call call = (Call)cond;
            call = call.target(this.rePosition(domain, call.target()));
            return call;
        }
        if (cond instanceof Field || cond instanceof Local) {
            VarInstance domInst = TreeUtils.getVarInstance((Node)domain);
            VarInstance condInst = TreeUtils.getVarInstance((Node)cond);
            if (domInst != null && condInst != null && domInst.equals(condInst)) {
                cond = (Receiver)cond.position(domain.position());
            }
        }
        return cond;
    }

    public static LinkedList<Range> getRangesFromArrayDistr(Receiver dist2, HjTypeSystem typeSystem) {
        VarInstance distInst = null;
        distInst = TreeUtils.getVarInstance((Node)dist2);
        if (distInst == null || !typeSystem.isDistribution(distInst.type())) {
            VarInstance regInst = HjPointForLoopConverter.getDistRegion(dist2, typeSystem);
            if (regInst != null) {
                return HjPointRankTypeAnalyzer.regionToRangesMap.get(regInst);
            }
        } else {
            VarInstance regInst = HjPointRankTypeAnalyzer.distToRegionMap.get(distInst);
            if (regInst != null) {
                return HjPointRankTypeAnalyzer.regionToRangesMap.get(regInst);
            }
        }
        return null;
    }

    public static LinkedList<Range> getRangesFromArrayDistr(VarInstance distInst, HjTypeSystem typeSystem) {
        if (distInst == null || !typeSystem.isDistribution(distInst.type())) {
            return null;
        }
        VarInstance regInst = HjPointRankTypeAnalyzer.distToRegionMap.get(distInst);
        if (regInst != null) {
            return HjPointRankTypeAnalyzer.regionToRangesMap.get(regInst);
        }
        return null;
    }

    private static VarInstance getDistRegion(Receiver right, HjTypeSystem ts) {
        if (right instanceof ConstantDistMaker) {
            for (Expr currArg : ((ConstantDistMaker)right).arguments()) {
                if (!ts.isRegion(currArg.type())) continue;
                return TreeUtils.getVarInstance((Node)currArg);
            }
        }
        return null;
    }

    public static LocalInstance[] createInlinedIndices(int size, LocalInstance p, HjTypeSystem ts) {
        LocalInstance[] indices = new LocalInstance[size];
        for (int i = 0; i < size; ++i) {
            LocalInstance li;
            String name = p.name() + "_i" + cntr++;
            indices[i] = li = ts.localInstance(p.position(), p.flags(), (Type)ts.Int(), name);
        }
        return indices;
    }

    private String createUniqueName() {
        return "syncForConvert_" + cntr++;
    }
}

