package machine.processor.builtins;

import matcher.Matcher;

import java.io.IOException;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import static parser.HighOrderCombinators.literal;

/**
 * Registry of built-in, pure functions that look like ordinary rule calls.
 * - Handlers take fully-resolved string arguments.
 * - TEXT builtins return a single string result.
 * - Boolean functions return "T" or "F".
 */
public final class Builtins {

    private enum Kind { TEXT, GENERATOR }

    private record Sig(String name, int arity) {}
    private record Entry(Kind kind, BuiltinTextFn textFn, BuiltinGenFn genFn) {}

    @FunctionalInterface
    public interface BuiltinTextFn extends Function<List<String>, String> {}

    @FunctionalInterface
    public interface BuiltinGenFn {
        List<String> apply(List<String> args);
    }

    private static final Map<Sig, Entry> REGISTRY = new HashMap<>();
    // Variadic registry: name -> (kind, minArity, handler)
    private record VariadicEntry(Kind kind, int minArity, BuiltinTextFn textFn, BuiltinGenFn genFn) {}
    private static final Map<String, VariadicEntry> VARIADIC = new HashMap<>();

    static {
        // ===== String/Matcher builtins =====

        // contains(haystack, needle) -> "T"/"F"
        registerText("contains", 2, args -> {
            var haystack = args.get(0);
            var needle = args.get(1);
            requireNonEmpty("needle", needle);
            int c = Matcher.count(literal(needle), haystack);
            return c > 0 ? "T" : "F";
        });

        // count(haystack, needle) -> non-overlapping count
        registerText("count", 2, args -> {
            var haystack = args.get(0);
            var needle = args.get(1);
            requireNonEmpty("needle", needle);
            return Integer.toString(Matcher.count(literal(needle), haystack));
        });

        // replaceFirst(haystack, needle, replacement)
        registerText("replaceFirst", 3, args -> {
            var haystack = args.get(0);
            var needle = args.get(1);
            var replacement = args.get(2);
            requireNonEmpty("needle", needle);
            return Matcher.replace(literal(needle), haystack, replacement, 0);
        });

        // replaceNth(haystack, needle, replacement, index) - 0-based index
        registerText("replaceNth", 4, args -> {
            var haystack = args.get(0);
            var needle = args.get(1);
            var replacement = args.get(2);
            int idx = parseNonNegativeInt("index", args.get(3));
            requireNonEmpty("needle", needle);
            return Matcher.replace(literal(needle), haystack, replacement, idx);
        });

        // replaceAll(haystack, needle, replacement) - non-overlapping, non-cascading
        registerText("replaceAll", 3, args -> {
            var haystack = args.get(0);
            var needle = args.get(1);
            var replacement = args.get(2);
            requireNonEmpty("needle", needle);
            if (needle.equals(replacement)) return haystack;
            // Safe split-join to avoid cascading when replacement contains needle
            var sb = new StringBuilder();
            int from = 0;
            int pos;
            while ((pos = haystack.indexOf(needle, from)) != -1) {
                sb.append(haystack, from, pos);
                sb.append(replacement);
                from = pos + needle.length();
            }
            sb.append(haystack, from, haystack.length());
            return sb.toString();
        });

        // deleteFirst(haystack, needle)
        registerText("deleteFirst", 2, args -> {
            var haystack = args.get(0);
            var needle = args.get(1);
            requireNonEmpty("needle", needle);
            return Matcher.replace(literal(needle), haystack, "", 0);
        });

        // deleteNth(haystack, needle, index)
        registerText("deleteNth", 3, args -> {
            var haystack = args.get(0);
            var needle = args.get(1);
            int idx = parseNonNegativeInt("index", args.get(2));
            requireNonEmpty("needle", needle);
            return Matcher.replace(literal(needle), haystack, "", idx);
        });

        // deleteAll(haystack, needle)
        registerText("deleteAll", 2, args -> {
            var haystack = args.get(0);
            var needle = args.get(1);
            requireNonEmpty("needle", needle);
            var sb = new StringBuilder();
            int from = 0;
            int pos;
            while ((pos = haystack.indexOf(needle, from)) != -1) {
                sb.append(haystack, from, pos);
                from = pos + needle.length();
            }
            sb.append(haystack, from, haystack.length());
            return sb.toString();
        });


        // length(s) -> number of characters in s
        registerText("length", 1, args -> Integer.toString(args.get(0).length()));

        // substring(s, start)
        registerText("substring", 2, args -> {
            var s = args.get(0);
            int start = parseNonNegativeInt("start", args.get(1));
            int len = s.length();
            if (start < 0 || start > len) {
                throw new IllegalArgumentException("substring: start out of bounds: " + start);
            }
            return s.substring(start);
        });

        // substring(s, start, end)
        registerText("substring", 3, args -> {
            var s = args.get(0);
            int start = parseNonNegativeInt("start", args.get(1));
            int end = parseNonNegativeInt("end", args.get(2));
            int len = s.length();
            if (start < 0 || end < 0 || start > end || end > len) {
                throw new IllegalArgumentException("substring: invalid range [" + start + "," + end + ") for length " + len);
            }
            return s.substring(start, end);
        });

        // deleteRange(s, start, end) -> remove substring in [start, end)
        registerText("deleteRange", 3, args -> {
            var s = args.get(0);
            int start = parseNonNegativeInt("start", args.get(1));
            int end = parseNonNegativeInt("end", args.get(2));
            int len = s.length();
            if (start < 0 || end < 0 || start > end || end > len) {
                throw new IllegalArgumentException("deleteRange: invalid range [" + start + "," + end + ") for length " + len);
            }
            return s.substring(0, start) + s.substring(end);
        });

        // insertAt(s, index, part) -> insert 'part' into s at index
        registerText("insertAt", 3, args -> {
            var s = args.get(0);
            int index = parseNonNegativeInt("index", args.get(1));
            var part = args.get(2);
            int len = s.length();
            if (index < 0 || index > len) {
                throw new IllegalArgumentException("insertAt: index out of bounds: " + index);
            }
            return s.substring(0, index) + part + s.substring(index);
        });

        // nthIndexOf(haystack, needle, index) -> start index of the index-th non-overlapping match, or "-1"
        registerText("nthIndexOf", 3, args -> {
            var haystack = args.get(0);
            var needle = args.get(1);
            int idx = parseNonNegativeInt("index", args.get(2));
            requireNonEmpty("needle", needle);
            var mr = Matcher.get(literal(needle), haystack, idx);
            if (mr instanceof matcher.Matcher.SuccessfulMatch sm) {
                return Integer.toString(sm.start());
            }
            return "-1";
        });

        // getMatch(haystack, needle, index) -> matched substring at the index-th non-overlapping match, or "" if none
        registerText("getMatch", 3, args -> {
            var haystack = args.get(0);
            var needle = args.get(1);
            int idx = parseNonNegativeInt("index", args.get(2));
            requireNonEmpty("needle", needle);
            var mr = Matcher.get(literal(needle), haystack, idx);
            if (mr instanceof matcher.Matcher.SuccessfulMatch sm) {
                return sm.match();
            }
            return "";
        });

        // matchAtStart(haystack, needle) -> matched prefix if present, else ""
        registerText("matchAtStart", 2, args -> {
            var haystack = args.get(0);
            var needle = args.get(1);
            requireNonEmpty("needle", needle);
            var m = Matcher.matchAtStartOrNull(literal(needle), haystack);
            return m == null ? "" : m;
        });


        // Variadic add(x, y, ...)
        registerTextVariadic("add", 2, args -> {
            var acc = parse(args.get(0));
            for (int i = 1; i < args.size(); i++) {
                acc = add(acc, parse(args.get(i)));
            }
            return formatNumber(acc);
        });

        // Binary sub(x, y)
        registerText("sub", 2, args -> formatNumber(sub(parse(args.get(0)), parse(args.get(1)))));

        // Variadic mul(x, y, ...)
        registerTextVariadic("mul", 2, args -> {
            var acc = parse(args.get(0));
            for (int i = 1; i < args.size(); i++) {
                acc = mul(acc, parse(args.get(i)));
            }
            return formatNumber(acc);
        });

        registerText("div", 2, args -> formatNumber(idiv(parse(args.get(0)), parse(args.get(1)))));
//        registerText("idiv", 2, args -> formatNumber(idiv(parse(args.get(0)), parse(args.get(1)))));
        registerText("mod", 2, args -> formatNumber(mod(parse(args.get(0)), parse(args.get(1)))));
        registerText("neg", 1, args -> formatNumber(neg(parse(args.get(0)))));
        registerText("abs", 1, args -> formatNumber(abs(parse(args.get(0)))));
        registerText("pow", 2, args -> formatNumber(pow(parse(args.get(0)), parseIntStrict(args.get(1), "exponent"))));

        // Variadic min/max
        registerTextVariadic("min", 2, args -> {
            var acc = parse(args.get(0));
            for (int i = 1; i < args.size(); i++) {
                acc = min(acc, parse(args.get(i)));
            }
            return formatNumber(acc);
        });

        registerTextVariadic("max", 2, args -> {
            var acc = parse(args.get(0));
            for (int i = 1; i < args.size(); i++) {
                acc = max(acc, parse(args.get(i)));
            }
            return formatNumber(acc);
        });

        registerText("clamp", 3, args -> {
            var x = parse(args.get(0));
            var lo = parse(args.get(1));
            var hi = parse(args.get(2));
            if (compare(lo, hi) > 0) throw new IllegalArgumentException("clamp: lo > hi");
            var y = x;
            if (compare(x, lo) < 0) y = lo;
            if (compare(y, hi) > 0) y = hi;
            return formatNumber(y);
        });

        registerText("round", 1, args -> formatNumber(round(parse(args.get(0)))));
        registerText("floor", 1, args -> formatNumber(floor(parse(args.get(0)))));
        registerText("ceil", 1, args -> formatNumber(ceil(parse(args.get(0)))));
        registerText("trunc", 1, args -> formatNumber(trunc(parse(args.get(0)))));

        // Comparisons -> "T"/"F"
        registerText("eq", 2, args -> compare(parse(args.get(0)), parse(args.get(1))) == 0 ? "T" : "F");
        registerText("ne", 2, args -> compare(parse(args.get(0)), parse(args.get(1))) != 0 ? "T" : "F");
        registerText("lt", 2, args -> compare(parse(args.get(0)), parse(args.get(1))) < 0 ? "T" : "F");
        registerText("le", 2, args -> compare(parse(args.get(0)), parse(args.get(1))) <= 0 ? "T" : "F");
        registerText("gt", 2, args -> compare(parse(args.get(0)), parse(args.get(1))) > 0 ? "T" : "F");
        registerText("ge", 2, args -> compare(parse(args.get(0)), parse(args.get(1))) >= 0 ? "T" : "F");

        // Predicates
        registerText("isInt", 1, args -> isIntegerString(args.get(0)) ? "T" : "F");
        registerText("isDec", 1, args -> isDecimalString(args.get(0)) ? "T" : "F");


        registerText("rand", 2, args -> {
            throw new IllegalStateException("rand(lo, hi) is only available in sampling mode");
        });
        registerText("randFloat01", 0, args -> {
            throw new IllegalStateException("randFloat01() is only available in sampling mode");
        });


        // load(path) -> entire file contents as UTF-8 string
        registerText("load", 1, args -> {
            var pathStr = args.get(0);
            try {
                return Files.readString(Paths.get(pathStr), StandardCharsets.UTF_8);
            } catch (IOException e) {
                throw new IllegalArgumentException("load: failed to read " + pathStr + ": " + e.getMessage(), e);
            }
        });

        // save(path, content) -> writes UTF-8 content to file, creating parent directories; returns the content
        registerText("save", 2, args -> {
            var pathStr = args.get(0);
            var content = args.get(1);
            try {
                var p = Paths.get(pathStr);
                if (p.getParent() != null) {
                    Files.createDirectories(p.getParent());
                }
                Files.writeString(p, content, StandardCharsets.UTF_8);
                return content;
            } catch (IOException e) {
                throw new IllegalArgumentException("save: failed to write " + pathStr + ": " + e.getMessage(), e);
            }
        });
    }


    private static void registerText(String name, int arity, BuiltinTextFn fn) {
        REGISTRY.put(new Sig(name, arity), new Entry(Kind.TEXT, fn, null));
    }

    private static void registerGen(String name, int arity, BuiltinGenFn fn) {
        REGISTRY.put(new Sig(name, arity), new Entry(Kind.GENERATOR, null, fn));
    }

    private static void registerTextVariadic(String name, int minArity, BuiltinTextFn fn) {
        VARIADIC.put(name, new VariadicEntry(Kind.TEXT, minArity, fn, null));
    }

    private static void registerGenVariadic(String name, int minArity, BuiltinGenFn fn) {
        VARIADIC.put(name, new VariadicEntry(Kind.GENERATOR, minArity, null, fn));
    }

    public static boolean has(String name, int arity) {
        if (REGISTRY.containsKey(new Sig(name, arity))) return true;
        var v = VARIADIC.get(name);
        return v != null && arity >= v.minArity;
    }

    public static boolean isGenerator(String name, int arity) {
        var e = REGISTRY.get(new Sig(name, arity));
        if (e != null) return e.kind == Kind.GENERATOR;
        var v = VARIADIC.get(name);
        return v != null && arity >= v.minArity && v.kind == Kind.GENERATOR;
    }

    // Returns a list of outputs; TEXT builtins return a singleton list.
    public static List<String> applyList(String name, List<String> args) {
        var e = REGISTRY.get(new Sig(name, args.size()));
        if (e != null) {
            if (e.kind == Kind.TEXT) {
                return List.of(e.textFn.apply(args));
            } else {
                return e.genFn.apply(args);
            }
        }
        var v = VARIADIC.get(name);
        if (v != null && args.size() >= v.minArity) {
            if (v.kind == Kind.TEXT) {
                return List.of(v.textFn.apply(args));
            } else {
                return v.genFn.apply(args);
            }
        }
        throw new IllegalArgumentException("No builtin found for " + name + "/" + args.size());
    }

    private static void requireNonEmpty(String param, String s) {
        if (s == null || s.isEmpty()) {
            throw new IllegalArgumentException(param + " must not be empty");
        }
    }

    private static int parseNonNegativeInt(String param, String s) {
        try {
            int v = Integer.parseInt(s.trim());
            if (v < 0) throw new NumberFormatException("negative");
            return v;
        } catch (NumberFormatException e) {
            throw new IllegalArgumentException(param + " must be a non-negative integer: " + s);
        }
    }

    private record Num(boolean isInt, BigInteger i, BigDecimal d) {}

    private static final BigDecimal BD_ZERO = new BigDecimal("0");
    private static final BigDecimal BD_ONE = new BigDecimal("1");

    private static Num parse(String s) {
        s = s.trim();
        if (isIntegerString(s)) {
            return new Num(true, new BigInteger(s), null);
        }
        if (isDecimalString(s)) {
            return new Num(false, null, new BigDecimal(s));
        }
        throw new IllegalArgumentException("Not a number: " + s);
    }

    private static boolean isIntegerString(String s) {
        try {
            new BigInteger(s.trim());
            return true;
        } catch (Exception e) {
            return false;
        }
    }

    private static boolean isDecimalString(String s) {
        try {
            new BigDecimal(s.trim());
            return true;
        } catch (Exception e) {
            return false;
        }
    }

    private static BigInteger parseIntStrict(String s, String name) {
        try {
            return new BigInteger(s.trim());
        } catch (Exception e) {
            throw new IllegalArgumentException(name + " is not an integer: " + s);
        }
    }

    private static String formatNumber(Num n) {
        if (n.isInt) {
            return n.i.toString();
        } else {
            var x = n.d.stripTrailingZeros();
            return x.scale() <= 0 ? x.toPlainString() : x.toPlainString();
        }
    }

    private static int compare(Num a, Num b) {
        if (a.isInt && b.isInt) {
            return a.i.compareTo(b.i);
        }
        BigDecimal da = a.isInt ? new BigDecimal(a.i) : a.d;
        BigDecimal db = b.isInt ? new BigDecimal(b.i) : b.d;
        return da.compareTo(db);
        }

    private static Num add(Num a, Num b) {
        if (a.isInt && b.isInt) return new Num(true, a.i.add(b.i), null);
        BigDecimal da = a.isInt ? new BigDecimal(a.i) : a.d;
        BigDecimal db = b.isInt ? new BigDecimal(b.i) : b.d;
        return new Num(false, null, da.add(db));
    }

    private static Num sub(Num a, Num b) {
        if (a.isInt && b.isInt) return new Num(true, a.i.subtract(b.i), null);
        BigDecimal da = a.isInt ? new BigDecimal(a.i) : a.d;
        BigDecimal db = b.isInt ? new BigDecimal(b.i) : b.d;
        return new Num(false, null, da.subtract(db));
    }

    private static Num mul(Num a, Num b) {
        if (a.isInt && b.isInt) return new Num(true, a.i.multiply(b.i), null);
        BigDecimal da = a.isInt ? new BigDecimal(a.i) : a.d;
        BigDecimal db = b.isInt ? new BigDecimal(b.i) : b.d;
        return new Num(false, null, da.multiply(db));
    }

    private static Num div(Num a, Num b) {
        BigDecimal db = b.isInt ? new BigDecimal(b.i) : b.d;
        if (db.compareTo(BD_ZERO) == 0) throw new IllegalArgumentException("Division by zero");
        BigDecimal da = a.isInt ? new BigDecimal(a.i) : a.d;
        // Default scale/rounding: 16 digits, HALF_UP
        return new Num(false, null, da.divide(db, 16, RoundingMode.HALF_UP));
    }

    private static Num idiv(Num a, Num b) {
        if (!a.isInt || !b.isInt) {
            // Truncate toward zero after decimal division
            BigDecimal q = div(a, b).d;
            q = q.setScale(0, RoundingMode.DOWN);
            return new Num(true, q.toBigInteger(), null);
        }
        if (b.i.equals(BigInteger.ZERO)) throw new IllegalArgumentException("Division by zero");
        return new Num(true, a.i.divide(b.i), null);
    }

    private static Num mod(Num a, Num b) {
        if (!a.isInt || !b.isInt) throw new IllegalArgumentException("mod expects integers");
        if (b.i.equals(BigInteger.ZERO)) throw new IllegalArgumentException("Division by zero");
        return new Num(true, a.i.mod(b.i), null);
    }

    private static Num neg(Num a) {
        if (a.isInt) return new Num(true, a.i.negate(), null);
        return new Num(false, null, a.d.negate());
    }

    private static Num abs(Num a) {
        if (a.isInt) return new Num(true, a.i.abs(), null);
        return new Num(false, null, a.d.abs());
    }

    private static Num pow(Num base, BigInteger exponent) {
        if (exponent.signum() < 0) throw new IllegalArgumentException("pow exponent must be >= 0");
        int n;
        try {
            n = exponent.intValueExact();
        } catch (ArithmeticException ex) {
            throw new IllegalArgumentException("pow exponent too large");
        }
        if (base.isInt) {
            return new Num(true, base.i.pow(n), null);
        } else {
            BigDecimal r = BD_ONE;
            for (int i = 0; i < n; i++) r = r.multiply(base.d);
            return new Num(false, null, r);
        }
    }

    private static Num min(Num a, Num b) {
        return compare(a, b) <= 0 ? a : b;
    }

    private static Num max(Num a, Num b) {
        return compare(a, b) >= 0 ? a : b;
    }

    private static Num round(Num a) {
        if (a.isInt) return a;
        BigDecimal r = a.d.setScale(0, RoundingMode.HALF_UP);
        return new Num(true, r.toBigInteger(), null);
    }

    private static Num floor(Num a) {
        if (a.isInt) return a;
        BigDecimal r = a.d.setScale(0, RoundingMode.FLOOR);
        return new Num(true, r.toBigInteger(), null);
    }

    private static Num ceil(Num a) {
        if (a.isInt) return a;
        BigDecimal r = a.d.setScale(0, RoundingMode.CEILING);
        return new Num(true, r.toBigInteger(), null);
    }

    private static Num trunc(Num a) {
        if (a.isInt) return a;
        BigDecimal r = a.d.setScale(0, RoundingMode.DOWN);
        return new Num(true, r.toBigInteger(), null);
    }

    private Builtins() {}
}

