import edu.rice.cs.mint.runtime.Code;
//import edu.rice.cs.mint.runtime.SafeCode;

public class StagedLint {
    public static interface Exp {
        public separable Code<Integer> eval(Env e, FEnv f);
    }
    
    public static class Int implements Exp {
        private Code<Integer> _value;
        public separable Int(final int value) {
            _value = <|value|>;
        }
        public separable Int(Code<Integer> value) {
            _value = value;
        }
        public separable Code<Integer> eval(Env e, FEnv f) {
            final Code<Integer> value = _value;
            return value;
        }
    }
    
    public static class Var implements Exp {
        private String _s;
        public separable Var(String s) {
            _s = s;
        }
        public separable Code<Integer> eval(Env e, FEnv f) {
            return e.get(_s);
        }
    }
    
    public static class App implements Exp {
        private String _s;
        private Exp _body;
        public separable App(String s, Exp body) {
            _s = s;
            _body = body;
        }
        public separable Code<Integer> eval(Env e, FEnv f) {
            return <| `(f.get(_s)).apply(`(_body.eval(e,f))) |>;
        }
    }
    
    public static abstract class BinOp implements Exp {
        protected Exp _left, _right;
        public separable BinOp(Exp left, Exp right) {
            _left = left;
            _right = right;
        }
    }
    
    public static class Add extends BinOp {
        public separable Add(Exp left, Exp right) { super(left, right); }
        public separable Code<Integer> eval(Env e, FEnv f) {
            return <| `(_left.eval(e,f)) + `(_right.eval(e,f)) |>;
        }
    }
    
    public static class Sub extends BinOp {
        public separable Sub(Exp left, Exp right) { super(left, right); }
        public separable Code<Integer> eval(Env e, FEnv f) {
            return <| `(_left.eval(e,f)) - `(_right.eval(e,f)) |>;
        }
    }
    
    public static class Mul extends BinOp {
        public separable Mul(Exp left, Exp right) { super(left, right); }
        public separable Code<Integer> eval(Env e, FEnv f) {
            return <| `(_left.eval(e,f)) * `(_right.eval(e,f)) |>;
        }
    }
    
    public static class Div extends BinOp {
        public separable Div(Exp left, Exp right) { super(left, right); }
        public separable Code<Integer> eval(Env e, FEnv f) {
            return <| `(_left.eval(e,f)) / `(_right.eval(e,f)) |>;
        }
    }
    
    public static class Ifz implements Exp {
        private Exp _test, _conseq, _alt;
        public separable Ifz(Exp test, Exp conseq, Exp alt) {
            _test = test;
            _conseq = conseq;
            _alt = alt;
        }
        public separable Code<Integer> eval(Env e, FEnv f) {
            return <| ((`(_test.eval(e,f))==0)?
                           `(_conseq.eval(e,f)):
                           `(_alt.eval(e,f))) |>;
        }
    }
    
    public static class Yikes extends RuntimeException { }
    
    // interface to represent int -> int functions
    public static interface IntFun {
        public int apply(int param);
    }
    
    // environment; implemented as a function object from String to Code<Integer>
    public static interface Env {
        public separable Code<Integer> get(String y);
    }
    public static final Env env0 = new Env() {
        public separable Code<Integer> get(String s) {
            throw new Yikes();
        }
    };
    
    public static separable Env ext(final Env env,
                                    final String x,
                                    final Code<Integer> v) {
        return new Env() {
            public separable Code<Integer> get(String y) {
                // error: if (x.equals(y))
                if (x==y)
                    return v;
                else
                    return env.get(y);
            }
        };
    }     
    
    // function environment; implemented as a function object from String to Code<IntFun>
    public static interface FEnv {
        public separable Code<? extends IntFun> get(String s);
    }
    public static final FEnv fenv0 = new FEnv() {
        public separable Code<? extends IntFun> get(String s) {
            throw new Yikes();
        }
    };
    
    public static separable FEnv fext(final FEnv fenv, final String x, final Code<? extends IntFun> v) {
        return new FEnv() {
            public separable Code<? extends IntFun> get(String y) {
                // error: if (x.equals(y))
                if (x==y)
                    return v;
                else
                    return fenv.get(y);
            }
        };
    }
    
    public static class Declaration {
        private String _fun, _param;
        private Exp _body;
        public separable Declaration(String fun, String param, Exp body) {
            _fun = fun;
            _param = param;
            _body = body;
        }
        public String fun() { return _fun; }
        public String param() { return _param; }
        public Exp body() { return _body; }
    }
    
    public static class Program {
        private Declaration[] _defs;
        private Exp _body;
        public separable Program(Exp body, Declaration ... defs) {
            _defs = defs;
            _body = body;
        }
        public Declaration[] defs() { return _defs; }
        public Exp body() { return _body; }
        
        public separable Code<Integer> peval(Env env, FEnv fenv) {
            return peval(env,fenv,0);
        }
        
        private separable Code<Integer> peval(final Env env, final FEnv fenv, int defIndex) {
            // match p with
            if (_defs.length<=defIndex) {
                //     Program ([],e) -> eval e env fenv
                return _body.eval(env,fenv);
            }
            else {
                //    |Program (Declaration (s1,s2,e1)::tl,e) ->
                //        let rec f x = eval e1 (ext env s2 x) (ext fenv s1 f)
                //        in peval (Program(tl,e)) env (ext fenv s1 f)
                final Declaration d = _defs[defIndex];
                final String dParam = d._param;
                final String dFun = d._fun;
                final Exp dBody = d._body;
                return peval(env,fext(fenv, dFun, <|new IntFun() {
                    public int apply(final int x) {
                        final IntFun fthis = this;
                        return `(dBody.eval(ext(env, dParam, <| x |>),
                                            fext(fenv, dFun, <| fthis |>)));
                    }
                }|>),defIndex+1);
            }
        }
    }
    
    public static Program termFact = new Program
        (new App("f", new Int(10)),
         new Declaration
             ("f", "x", new Ifz
                  (new Var("x"),
                   new Int(1),
                   new Mul(new Var("x"),
                           new App
                               ("f",
                                new Sub
                                    (new Var("x"),
                                     new Int(1)))))));
    
    public static Program termFib = new Program
        (new App("f", new Int(10)),
         new Declaration("f", "x", new Ifz
                             (new Var("x"),
                              new Int(0),
                              new Ifz
                                  (new Sub
                                       (new Var("x"),
                                        new Int(1)),
                                   new Int(1),
                                   new Add
                                       (new App
                                            ("f",
                                             new Sub
                                                 (new Var("x"),
                                                  new Int(1))),
                                        new App
                                            ("f",
                                             new Sub
                                                 (new Var("x"),
                                                  new Int(2))))))));
    
    public static void main(String[] args) {
        Code<Integer> c = termFact.peval(env0, fenv0);
        System.out.println(c);
        System.out.println("fact(10) = "+c.run());
        c = termFib.peval(env0, fenv0);
        System.out.println("fib(10) = "+c.run());
    }
}
