#!/usr/bin/env python3
"""
WML OS post-login console - OPNsense-style numbered menu.

Replaces /bin/bash as the root tty1 login shell on the installed system.
Every option has a real working implementation (no stubs except items 12/13
which point to the web UI, mirroring pfSense/OPNsense convention).

Strict input validation throughout.  User can NEVER drop to a shell except
via menu option 8 (which asks for confirmation).
"""
import os
import sys
import signal
import subprocess

sys.path.insert(0, "/usr/local/lib/wml")
from tui import (
    header, hr, section, menu, prompt, prompt_int, prompt_yn,
    prompt_password, prompt_ip_cidr, prompt_ip, pause, transition,
    Spinner, run_status, get_hostname, get_os_release_name,
    get_interfaces, BOLD, DIM, PINK, WHITE, YELLOW, RED, GREEN, PURPLE, RESET
)


MENU_OPTIONS = [
    (0,  "Logout"),
    (1,  "Assign interfaces"),
    (2,  "Set interface IP"),
    (3,  "Reset root password"),
    (4,  "Reset to factory defaults"),
    (5,  "Power off"),
    (6,  "Reboot"),
    (7,  "Ping host"),
    (8,  "Shell"),
    (9,  "Show top processes"),
    (10, "Firewall log"),
    (11, "Reload all services"),
    (12, "Update from console"),
    (13, "Restore a backup"),
    (14, "Re-detect interfaces (after cable swap)"),
]


# ---------------------------------------------------------------------
# Header
# ---------------------------------------------------------------------
def show_header():
    hostname = get_hostname()
    osname = get_os_release_name()
    arch = subprocess.check_output(["uname", "-m"], text=True).strip()
    header(f"{osname} ({arch})")

    # Interface summary section
    ifaces = get_interfaces()
    if ifaces:
        body = []
        for name, role, ipv4, up in ifaces:
            role_str = f"{PINK}{role:<4}{RESET}" if role else f"{DIM}{'':<4}{RESET}"
            up_str = f"{GREEN}up{RESET}" if up else f"{RED}down{RESET}"
            ip_str = ipv4 if ipv4 else f"{DIM}(no IP){RESET}"
            body.append(f"{role_str} {BOLD}{name:<7}{RESET} {up_str:<13} {ip_str}")
        section("Interfaces", body)
    else:
        section("Interfaces", [f"{DIM}No interfaces detected{RESET}"])


# ---------------------------------------------------------------------
# Option handlers
# ---------------------------------------------------------------------
def opt_logout():
    print(f"\n  {DIM}Logging out...{RESET}")
    transition(0.3)
    # SSH session (pts/*) -> exit cleanly so client closes.
    # Physical tty1 -> re-run login.
    try:
        tty = os.ttyname(0)
    except OSError:
        tty = ""
    if tty.startswith("/dev/pts/"):
        sys.exit(0)
    os.execv("/bin/login", ["login"])


def opt_assign_interfaces():
    """Swap which physical NIC is WAN vs LAN."""
    transition()
    header("Assign Interfaces")

    ifaces = get_interfaces()
    if len(ifaces) < 2:
        print(f"  {RED}Need at least 2 interfaces to swap.  Found {len(ifaces)}.{RESET}")
        pause()
        return

    section("Current assignment",
            [f"{BOLD}{role or '-':<4}{RESET}  {name:<8}  {ipv4 or '(no IP)'}"
             for name, role, ipv4, _ in ifaces])

    print(f"\n  {YELLOW}Pick the NIC that should be {BOLD}WAN{RESET}{YELLOW} (internet side):{RESET}\n")
    for i, (name, role, ipv4, _) in enumerate(ifaces, 1):
        print(f"    {BOLD}{i}){RESET} {name}  current role: {role or '-'}")
    choice = prompt_int("Selection", 1, len(ifaces))
    wan_iface = ifaces[choice - 1][0]

    # First non-WAN becomes LAN
    lan_iface = next((n for n, _, _, _ in ifaces if n != wan_iface), None)

    print()
    print(f"  {BOLD}WAN:{RESET} {wan_iface}    {BOLD}LAN:{RESET} {lan_iface}")
    if not prompt_yn("Apply this assignment", default=False):
        print(f"  {DIM}Cancelled.{RESET}")
        pause()
        return

    # --- Pre-apply validation: probe each port for DHCP ---
    # If the user's chosen WAN doesn't get an upstream lease (or worse, gets a
    # lease from our own LAN subnet), the cables are likely reversed. Detect
    # and offer to auto-swap instead of bricking their network for 2 min.
    print()
    print(f"  {DIM}Validating cable placement (this takes ~10s)...{RESET}")
    LAN_FAMILY = "172.30"

    def _probe(iface):
        """Return DHCP lease IP or None within 6s, non-destructive."""
        try:
            subprocess.run(["dhclient", "-r", iface],
                           capture_output=True, timeout=3)
        except subprocess.TimeoutExpired:
            pass
        r = subprocess.run(
            ["timeout", "6", "dhclient", "-1", "-v",
             "-lf", f"/tmp/wml-probe-{iface}.lease",
             "-pf", f"/tmp/wml-probe-{iface}.pid", iface],
            capture_output=True, text=True)
        lease = None
        out = (r.stdout + r.stderr).split("\n")
        for line in out:
            if "bound to" in line:
                parts = line.split("bound to ")
                if len(parts) > 1:
                    lease = parts[1].split()[0].rstrip(",;")
                    break
        # Always release after probe
        subprocess.run(["dhclient", "-r",
                        "-lf", f"/tmp/wml-probe-{iface}.lease",
                        "-pf", f"/tmp/wml-probe-{iface}.pid", iface],
                       capture_output=True, timeout=3)
        for f in (f"/tmp/wml-probe-{iface}.lease", f"/tmp/wml-probe-{iface}.pid"):
            try:
                os.remove(f)
            except OSError:
                pass
        return lease

    wan_lease = _probe(wan_iface)
    suggested_swap = False
    if wan_lease and wan_lease.startswith(LAN_FAMILY + "."):
        # User picked the LAN side as WAN
        print(f"  {RED}✗{RESET} {wan_iface} got DHCP {wan_lease} from our OWN LAN subnet.")
        print(f"     That means {wan_iface} is the LAN side, not WAN.")
        suggested_swap = True
    elif not wan_lease:
        # No upstream DHCP. Check if the OTHER NIC has one (cables reversed).
        print(f"  {YELLOW}!{RESET} {wan_iface} got no DHCP response in 6s.")
        if lan_iface:
            print(f"  {DIM}Checking {lan_iface}...{RESET}")
            lan_lease = _probe(lan_iface)
            if lan_lease and not lan_lease.startswith(LAN_FAMILY + "."):
                print(f"  {GREEN}✓{RESET} {lan_iface} got upstream lease {lan_lease}.")
                suggested_swap = True
            elif lan_lease:
                print(f"  {DIM}{lan_iface} also got LAN-family lease {lan_lease}.{RESET}")
            else:
                print(f"  {YELLOW}!{RESET} Neither port has internet. Check cables.")
    else:
        print(f"  {GREEN}✓{RESET} {wan_iface} got upstream lease {wan_lease}.")

    if suggested_swap:
        print()
        print(f"  {BOLD}Recommended:{RESET} WAN={lan_iface}, LAN={wan_iface} (cables are reversed)")
        if prompt_yn("Apply recommended assignment instead", default=True):
            wan_iface, lan_iface = lan_iface, wan_iface
            print(f"  {GREEN}✓{RESET} Will apply: WAN={wan_iface}, LAN={lan_iface}")
        else:
            if not prompt_yn(f"Apply YOUR choice WAN={wan_iface} anyway (no internet expected)",
                             default=False):
                print(f"  {DIM}Cancelled. No changes.{RESET}")
                pause()
                return

    # Write /etc/wml/interfaces.conf
    os.makedirs("/etc/wml", exist_ok=True)
    with open("/etc/wml/interfaces.conf", "w") as f:
        f.write(f"WAN_IF={wan_iface}\n")
        f.write(f"LAN_IF={lan_iface}\n")
        f.write("LAN_IP=172.30.30.1\n")
        f.write("LAN_PREFIX=24\n")
        f.write("LAN_NETWORK=172.30.30.0/24\n")

    # FAST path: write networkd configs DIRECTLY (don't re-run wml-firstboot,
    # which would re-probe DHCP for 20s+). Just apply the user's explicit
    # choice. OPNsense does this in <2 seconds.
    print()
    with Spinner("Applying configuration"):
        # 1) Remove old WML network configs
        net_dir = "/etc/systemd/network"
        os.makedirs(net_dir, exist_ok=True)
        for old in os.listdir(net_dir):
            if "wml-" in old:
                try:
                    os.remove(f"{net_dir}/{old}")
                except OSError:
                    pass

        # 2) Write new WAN config (DHCP)
        with open(f"{net_dir}/10-wml-wan.network", "w") as f:
            f.write(f"[Match]\nName={wan_iface}\n\n"
                    "[Network]\nDHCP=ipv4\nIPv6AcceptRA=true\n\n"
                    "[DHCP]\nUseDNS=false\nUseDomains=false\nUseHostname=false\n")

        # 3) Write new LAN config (static 172.30.30.1/24)
        if lan_iface:
            with open(f"{net_dir}/20-wml-lan.network", "w") as f:
                f.write(f"[Match]\nName={lan_iface}\n\n"
                        "[Network]\nAddress=172.30.30.1/24\n"
                        "DHCPServer=no\nIPMasquerade=no\n")

        # 4) Restart networkd to apply (non-blocking; we don't wait for it)
        try:
            subprocess.run(["systemctl", "restart", "--no-block", "systemd-networkd"],
                           capture_output=True, timeout=5)
        except subprocess.TimeoutExpired:
            pass

        # 5) Reapply NAT for the new WAN interface
        subprocess.run(["nft", "flush", "ruleset"], capture_output=True)
        subprocess.run(["nft", "add", "table", "ip", "wml-nat"],
                       capture_output=True)
        subprocess.run(["nft", "add", "chain", "ip", "wml-nat", "postrouting",
                        "{ type nat hook postrouting priority 100; }"],
                       capture_output=True)
        subprocess.run(["nft", "add", "rule", "ip", "wml-nat", "postrouting",
                        "oifname", f'"{wan_iface}"', "masquerade"],
                       capture_output=True)

        # 6) Restart kea so it binds to the new LAN interface (non-blocking)
        if lan_iface:
            try:
                subprocess.run(["systemctl", "restart", "--no-block", "kea-dhcp4-server"],
                               capture_output=True, timeout=5)
            except subprocess.TimeoutExpired:
                pass  # kea will catch up; user doesn't need to wait

    print(f"\n  {GREEN}✓{RESET} Interfaces assigned: WAN={wan_iface}, LAN={lan_iface}")
    print(f"  {DIM}(applied without restart - networkd reconfigured in place){RESET}")
    pause()


def _update_kea_for_new_lan(iface, addr_iface):
    """Rewrite Kea DHCP4 config for a changed LAN subnet."""
    import ipaddress
    import json as _json
    net = addr_iface.network
    gw = str(addr_iface.ip)
    # Pool: .100 - .200 in the new subnet (matches wml-firstboot defaults)
    hosts = list(net.hosts())
    if len(hosts) < 50:
        # tiny subnet; pick middle quarter
        start = hosts[len(hosts) // 4]
        end = hosts[3 * len(hosts) // 4]
    else:
        start = ipaddress.ip_address(int(net.network_address) + 100)
        end = ipaddress.ip_address(int(net.network_address) + 200)
    cfg = {
        "Dhcp4": {
            "interfaces-config": {"interfaces": [iface]},
            "valid-lifetime": 86400,
            "lease-database": {
                "type": "memfile", "persist": True,
                "name": "/var/lib/kea/dhcp4.leases"},
            "subnet4": [{
                "id": 1,
                "subnet": str(net),
                "pools": [{"pool": f"{start} - {end}"}],
                "option-data": [
                    {"name": "routers", "data": gw},
                    {"name": "domain-name-servers", "data": gw},
                ],
            }],
        }
    }
    with open("/etc/kea/kea-dhcp4.conf", "w") as f:
        _json.dump(cfg, f, indent=2)
    subprocess.run(["systemctl", "restart", "--no-block", "kea-dhcp4-server"],
                   capture_output=True, timeout=5)


def opt_set_interface_ip():
    """Reconfigure WAN or LAN IP (DHCP / Static / PPPoE)."""
    transition()
    header("Set Interface IP")

    ifaces = get_interfaces()
    if not ifaces:
        print(f"  {RED}No interfaces detected.{RESET}")
        pause()
        return

    print(f"  {YELLOW}Which interface to reconfigure?{RESET}\n")
    for i, (name, role, ipv4, _) in enumerate(ifaces, 1):
        print(f"    {BOLD}{i}){RESET} {name}  ({role or 'unassigned'})  {ipv4 or '(no IP)'}")
    pick = prompt_int("Selection", 1, len(ifaces))
    iface = ifaces[pick - 1][0]
    role = ifaces[pick - 1][1] or "ROLE"

    print(f"\n  Interface: {BOLD}{iface}{RESET}  Role: {BOLD}{role}{RESET}\n")
    print(f"  {YELLOW}Configuration method:{RESET}")
    print(f"    {BOLD}1){RESET} DHCP (automatic)")
    print(f"    {BOLD}2){RESET} Static IP")
    print(f"    {BOLD}3){RESET} PPPoE")
    method = prompt_int("Selection", 1, 3, default=1)

    network_dir = "/etc/systemd/network"
    os.makedirs(network_dir, exist_ok=True)
    config_file = f"{network_dir}/10-wml-{iface}.network"

    import ipaddress

    def _read_wan_lan():
        wan, lan, lan_net = "", "", "172.30.30.0/24"
        try:
            with open("/etc/wml/interfaces.conf") as f:
                for line in f:
                    if line.startswith("WAN_IF="):
                        wan = line.split("=", 1)[1].strip().strip('"')
                    elif line.startswith("LAN_IF="):
                        lan = line.split("=", 1)[1].strip().strip('"')
                    elif line.startswith("LAN_NETWORK="):
                        lan_net = line.split("=", 1)[1].strip().strip('"')
        except OSError:
            pass
        return wan, lan, lan_net

    cur_wan, cur_lan, cur_lan_net = _read_wan_lan()
    is_wan = (iface == cur_wan)
    is_lan = (iface == cur_lan)

    if method == 1:
        # DHCP — only meaningful for WAN. Warn if user picks DHCP on LAN.
        if is_lan:
            print(f"\n  {YELLOW}!{RESET} {iface} is the LAN side. DHCP here will lose"
                  f" the static gateway and break clients.")
            if not prompt_yn("Apply DHCP on LAN anyway", default=False):
                print(f"  {DIM}Cancelled.{RESET}")
                pause()
                return
        with open(config_file, "w") as f:
            f.write(f"[Match]\nName={iface}\n\n[Network]\nDHCP=ipv4\n")
    elif method == 2:
        ip_cidr = prompt_ip_cidr("Static IP (e.g. 192.168.1.1/24)")
        gw = prompt_ip("Gateway IP")
        dns = prompt_ip("DNS server")

        # --- Validation ---
        errors, warnings = [], []
        try:
            addr_iface = ipaddress.ip_interface(ip_cidr)
            ip_obj = addr_iface.ip
            net = addr_iface.network
            if ip_obj.is_loopback or ip_obj.is_multicast or ip_obj.is_unspecified:
                errors.append(f"IP {ip_obj} is reserved/loopback/multicast.")
            if net.prefixlen >= 31:
                errors.append(f"Prefix /{net.prefixlen} too small for a usable subnet.")

            gw_obj = ipaddress.ip_address(gw)
            if gw_obj not in net:
                errors.append(f"Gateway {gw} is not in subnet {net}.")
            # On WAN, gateway should be DIFFERENT from interface IP (upstream router).
            # On LAN, gateway IS the interface IP (we are the gateway).
            if is_wan and gw_obj == ip_obj:
                errors.append("WAN gateway equals interface IP — the firewall can't"
                              " be its own upstream router.")

            # WAN-specific: if there's an active DHCP lease, mention it.
            if is_wan:
                cur_ip = next((ipv4 for n, _, ipv4, _ in ifaces if n == iface), None)
                if cur_ip and cur_ip.startswith(("192.168.", "10.", "172.")):
                    warnings.append(f"WAN currently has DHCP lease {cur_ip}."
                                    " Static will override it.")
            # LAN-specific: subnet must not overlap any current WAN net.
            if is_lan:
                for n, _, ipv4, _ in ifaces:
                    if n == iface or not ipv4:
                        continue
                    try:
                        other_net = ipaddress.ip_interface(ipv4).network
                        if net.overlaps(other_net):
                            errors.append(f"Subnet {net} overlaps {n} network "
                                          f"{other_net} — would break routing.")
                    except ValueError:
                        pass
                # On LAN, gateway should equal the iface IP (firewall is the gateway).
                if gw_obj != ip_obj:
                    warnings.append(f"LAN gateway is usually the firewall itself"
                                    f" ({ip_obj}), but you entered {gw}.")
        except (ValueError, ipaddress.AddressValueError) as e:
            errors.append(str(e))

        if errors:
            print()
            for e in errors:
                print(f"  {RED}✗{RESET} {e}")
            print()
            print(f"  {DIM}Fix the above and try again.{RESET}")
            pause()
            return
        if warnings:
            print()
            for w in warnings:
                print(f"  {YELLOW}!{RESET} {w}")
            print()
            if not prompt_yn("Apply anyway", default=False):
                print(f"  {DIM}Cancelled.{RESET}")
                pause()
                return

        with open(config_file, "w") as f:
            f.write(f"[Match]\nName={iface}\n\n[Network]\n")
            f.write(f"Address={ip_cidr}\n")
            # Only emit Gateway= for WAN-style configs (when gw != iface IP).
            # On LAN, the firewall IS the gateway; no upstream route needed.
            if gw_obj != ip_obj:
                f.write(f"Gateway={gw}\n")
            f.write(f"DNS={dns}\n")

        # LAN change: also update Kea pool + unbound listen to match new subnet
        if is_lan:
            try:
                _update_kea_for_new_lan(iface, addr_iface)
            except Exception as e:
                print(f"  {YELLOW}!{RESET} Couldn't update Kea: {e}")
    elif method == 3:
        print(f"  {DIM}PPPoE: enter your ISP-provided credentials.{RESET}")
        user = prompt("PPPoE username")
        password = prompt_password("PPPoE password", min_len=1)
        os.makedirs("/etc/ppp/peers", exist_ok=True)
        with open(f"/etc/ppp/peers/wml-{iface}", "w") as f:
            f.write(f"plugin rp-pppoe.so {iface}\nuser \"{user}\"\nnoauth\n"
                    "defaultroute\npersist\nnoipdefault\n")
        with open("/etc/ppp/pap-secrets", "a") as f:
            f.write(f'"{user}" * "{password}" *\n')

    print()
    with Spinner("Restarting networkd"):
        subprocess.run(["systemctl", "restart", "systemd-networkd"], capture_output=True)
    print(f"\n  {GREEN}✓{RESET} Configuration applied to {iface}")
    pause()


def opt_reset_root_password():
    """chpasswd + update /etc/wml/admin.json so web UI uses new password too."""
    transition()
    header("Reset Root Password")

    section("This will change the password for:",
            [f"• {BOLD}root{RESET} console login",
             f"• {BOLD}root{RESET} SSH login",
             f"• {BOLD}root{RESET} Web UI login at https://172.30.30.1/"])
    print()

    if not prompt_yn("Continue", default=False):
        print(f"  {DIM}Cancelled.{RESET}")
        pause()
        return

    new_pw = prompt_password("New password", min_len=8)
    print()
    with Spinner("Updating credentials"):
        # /etc/shadow
        subprocess.run(["chpasswd"], input=f"root:{new_pw}\n",
                       text=True, capture_output=True)
        # /etc/wml/admin.json for web UI
        os.makedirs("/etc/wml", exist_ok=True)
        h = subprocess.check_output(
            ["openssl", "passwd", "-6", new_pw], text=True).strip()
        import json, time as _t
        with open("/etc/wml/admin.json", "w") as f:
            json.dump({
                "username": "root",
                "password_hash": h,
                "role": "superadmin",
                "must_change_password": False,
                "created_at": _t.strftime("%Y-%m-%dT%H:%M:%SZ"),
            }, f, indent=2)
        os.chmod("/etc/wml/admin.json", 0o600)
    print(f"\n  {GREEN}✓{RESET} Password updated for console, SSH, and web UI.")
    pause()


def opt_factory_reset():
    transition()
    header("Factory Reset")
    section("WARNING",
            [f"{RED}This will erase ALL WML OS configuration:{RESET}",
             "  • Network settings (WAN/LAN/DHCP)",
             "  • Firewall rules and aliases",
             "  • Admin credentials (back to install default)",
             "  • Setup wizard completion state",
             "",
             f"{BOLD}The system will reboot.{RESET}  Disk data is NOT erased."])
    print()
    if not prompt_yn(f"{RED}Proceed with factory reset?{RESET}", default=False):
        print(f"  {DIM}Cancelled.{RESET}")
        pause()
        return
    print()
    with Spinner("Resetting"):
        subprocess.run(["rm", "-rf", "/etc/wml/config.json",
                        "/etc/wml/interfaces.conf",
                        "/etc/wml/admin.json",
                        "/etc/wml/setup-done",
                        "/var/lib/wml/firstboot-done",
                        "/etc/systemd/network/10-wml-wan.network",
                        "/etc/systemd/network/20-wml-lan.network",
                        "/etc/nftables.conf"],
                       capture_output=True)
    print(f"  {GREEN}✓{RESET} Reset complete.  Rebooting in 3 seconds...")
    import time
    time.sleep(3)
    subprocess.run(["systemctl", "reboot"])


def opt_poweroff():
    transition()
    header("Power Off")
    section("Confirm",
            [f"This will shut down the {BOLD}WML OS{RESET} firewall.",
             "Network traffic will stop immediately.",
             "Use physical power button to power on again."])
    print()
    if prompt_yn(f"{RED}Power off now?{RESET}", default=False):
        with Spinner("Powering off"):
            import time
            time.sleep(1)
        subprocess.run(["systemctl", "poweroff"])
    else:
        print(f"  {DIM}Cancelled.{RESET}")
        pause()


def opt_reboot():
    transition()
    header("Reboot")
    section("Confirm",
            ["This will restart the WML OS firewall.",
             "Network traffic will stop for ~30 seconds during reboot."])
    print()
    if prompt_yn("Reboot now?", default=False):
        with Spinner("Rebooting"):
            import time
            time.sleep(1)
        subprocess.run(["systemctl", "reboot"])
    else:
        print(f"  {DIM}Cancelled.{RESET}")
        pause()


def opt_ping():
    transition()
    header("Ping Host")
    host = prompt("Hostname or IP to ping", default="8.8.8.8")
    print()
    print(f"  {DIM}--- ping {host} ---{RESET}")
    print()
    subprocess.run(["ping", "-c", "4", "-W", "2", host])
    pause()


def opt_shell():
    transition()
    header("Shell")
    section("Drop to bash",
            [f"{YELLOW}Advanced users only.{RESET}",
             "Type 'exit' or Ctrl-D to return to this menu.",
             "Random commands here can break the firewall."])
    print()
    if not prompt_yn("Open shell", default=False):
        return
    # trap so Ctrl-D / exit returns to menu rather than logging out
    print(f"\n  {DIM}Type 'exit' to return to WML console.{RESET}\n")
    subprocess.run(["/bin/bash", "--norc", "--noprofile", "-i"], env={**os.environ, "PS1": r"root@\h:\w# ", "HOME": "/root"})
    print(f"\n  {DIM}Returned to WML console.{RESET}")


def opt_top():
    """System status: WML services, traffic, activity counters."""
    transition()
    header("System Status")

    # WML services state + RAM
    svcs = [
        ("wml-api",            "WML API"),
        ("kea-dhcp4-server",   "DHCP server"),
        ("unbound",            "DNS resolver"),
        ("suricata",           "IDS/IPS"),
        ("clamav-daemon",      "Antivirus engine"),
        ("c-icap",             "ICAP gateway"),
        ("squid",              "Web proxy"),
        ("nginx",              "Admin UI server"),
        ("ssh",                "SSH"),
        ("wml-wan-monitor",    "WAN health monitor"),
        ("wml-carrier-watcher","Carrier watchdog"),
    ]
    body = []
    for unit, label in svcs:
        r = subprocess.run(["systemctl", "is-active", unit],
                           capture_output=True, text=True)
        state = r.stdout.strip()
        color = GREEN if state == "active" else (YELLOW if state == "activating" else RED)
        body.append(f"  {label:<24} {color}{state:<10}{RESET} {DIM}{unit}{RESET}")
    section("WML services", body)

    # System resources
    try:
        with open("/proc/meminfo") as f:
            mem_total = mem_avail = 0
            for line in f:
                if line.startswith("MemTotal:"):
                    mem_total = int(line.split()[1])
                elif line.startswith("MemAvailable:"):
                    mem_avail = int(line.split()[1])
        load = open("/proc/loadavg").read().split()[:3]
        try:
            with open("/proc/uptime") as f:
                up_sec = int(float(f.read().split()[0]))
            up_h, up_m = divmod(up_sec // 60, 60)
            up_d, up_h = divmod(up_h, 24)
            uptime_str = f"{up_d}d {up_h}h {up_m}m"
        except Exception:
            uptime_str = "?"
        section("System", [
            f"  Uptime:      {uptime_str}",
            f"  Load avg:    {load[0]} / {load[1]} / {load[2]}",
            f"  RAM used:    {(mem_total - mem_avail) // 1024} MB / {mem_total // 1024} MB",
        ])
    except Exception as e:
        section("System", [f"  (couldn't read: {e})"])

    # Network activity
    try:
        # Conntrack count
        try:
            with open("/proc/sys/net/netfilter/nf_conntrack_count") as f:
                ct = f.read().strip()
        except OSError:
            ct = "?"
        # Kea active leases (count lines in lease file minus header)
        leases = "?"
        try:
            with open("/var/lib/kea/dhcp4.leases") as f:
                lines = [l for l in f.read().splitlines()
                         if l and not l.startswith("#") and "," in l]
                leases = str(len(lines) - 1) if lines else "0"
        except OSError:
            pass
        section("Activity", [
            f"  Active connections (conntrack):  {ct}",
            f"  DHCP leases:                     {leases}",
        ])
    except Exception:
        pass

    pause()


def opt_firewall_log():
    """Real packet-level firewall events: nftables drops + Suricata IDS alerts."""
    transition()
    header("Firewall Log")
    print(f"  {DIM}Last 50 packet drops + IDS alerts (kernel + Suricata){RESET}\n")

    # 1) nftables drops from kernel log (we log with prefix "wml-fw input drop:")
    print(f"  {BOLD}Recent packet drops:{RESET}")
    r = subprocess.run(
        "dmesg -T 2>/dev/null | grep -E 'wml-fw .* drop:' | tail -25",
        shell=True, capture_output=True, text=True)
    drops = r.stdout.strip()
    if drops:
        for line in drops.split("\n"):
            # Extract just src/dst/dport for readability
            src = dst = dpt = proto = ""
            for part in line.split():
                if part.startswith("SRC="): src = part[4:]
                elif part.startswith("DST="): dst = part[4:]
                elif part.startswith("DPT="): dpt = part[4:]
                elif part.startswith("PROTO="): proto = part[6:]
            ts = "?"
            if "[" in line and "]" in line:
                ts = line[line.index("[")+1:line.index("]")]
            if src and dst:
                print(f"    {DIM}{ts}{RESET} {RED}DROP{RESET} {src} -> {dst}:{dpt} {DIM}({proto}){RESET}")
            else:
                print(f"    {DIM}{line[:120]}{RESET}")
    else:
        print(f"    {DIM}(no recent drops in kernel log){RESET}")

    # 2) Suricata IDS alerts
    print(f"\n  {BOLD}IDS alerts:{RESET}")
    eve = "/var/log/suricata/eve.json"
    if os.path.exists(eve):
        r = subprocess.run(
            f"tail -200 {eve} 2>/dev/null | grep -E '\"event_type\":\"alert\"' | tail -15",
            shell=True, capture_output=True, text=True)
        alerts = r.stdout.strip()
        if alerts:
            import json as _json
            for line in alerts.split("\n"):
                try:
                    e = _json.loads(line)
                    ts = e.get("timestamp", "")[:19].replace("T", " ")
                    src = e.get("src_ip", "?")
                    dst = e.get("dest_ip", "?")
                    sig = e.get("alert", {}).get("signature", "?")
                    print(f"    {DIM}{ts}{RESET} {YELLOW}ALERT{RESET} {src} -> {dst}: {sig[:80]}")
                except Exception:
                    pass
        else:
            print(f"    {DIM}(no alerts in {eve}){RESET}")
    else:
        print(f"    {DIM}(Suricata log not yet available — IDS may still be starting){RESET}")

    pause()


def opt_reload_services():
    transition()
    header("Reload All Services")
    print()
    if not prompt_yn("Reload all WML services now?", default=True):
        return
    print()
    subprocess.run(["/usr/local/sbin/wml-fix"])
    pause()


def opt_update():
    transition()
    header("Update from Console")
    section("System updates",
            ["Console updates aren't supported in this version.",
             "",
             f"To update WML OS:",
             f"  1. Open {BOLD}https://172.30.30.1/{RESET} in a browser",
             f"  2. Sign in as {BOLD}root{RESET}",
             f"  3. Navigate to {BOLD}System -> Updates{RESET}",
             f"  4. Click {BOLD}Check for Updates{RESET}"])
    pause()


def opt_restore_backup():
    transition()
    header("Restore a Backup")
    section("Backup restore",
            ["Console restore isn't supported in this version.",
             "",
             f"To restore a backup:",
             f"  1. Open {BOLD}https://172.30.30.1/{RESET}",
             f"  2. Sign in as {BOLD}root{RESET}",
             f"  3. Go to {BOLD}System -> Backup/Restore{RESET}",
             f"  4. Click {BOLD}Restore from File{RESET}"])
    pause()


def opt_redetect_interfaces():
    """Re-run WAN/LAN auto-detection (use after physically swapping cables)."""
    transition()
    header("Re-detect Interfaces")
    section("Information", [
        "This will re-probe every NIC and reassign WAN/LAN roles.",
        "Use this after physically swapping which cable goes to which port.",
        "Existing connections may briefly drop while networkd reapplies.",
    ])
    print()
    if not prompt_yn("Re-detect now", default=False):
        print(f"  {DIM}Cancelled.{RESET}")
        pause()
        return

    print()
    with Spinner("Re-detecting"):
        try:
            os.remove("/var/lib/wml/firstboot-done")
        except OSError:
            pass
        r = subprocess.run(["/usr/local/sbin/wml-firstboot"],
                           capture_output=True, text=True, timeout=60)

    if r.returncode == 0:
        print(f"\n  {GREEN}✓{RESET} Interfaces re-detected.")
        try:
            with open("/etc/wml/interfaces.conf") as f:
                for line in f:
                    if line.startswith(("WAN_IFS=", "LAN_IFS=")):
                        print(f"  {DIM}{line.rstrip()}{RESET}")
        except OSError:
            pass
    else:
        print(f"\n  {RED}✗{RESET} Re-detect failed (exit {r.returncode})")
        if r.stderr:
            for line in r.stderr.strip().split("\n")[-5:]:
                print(f"  {DIM}{line}{RESET}")
    pause()


# ---------------------------------------------------------------------
# Main loop
# ---------------------------------------------------------------------
HANDLERS = {
    0: opt_logout,
    1: opt_assign_interfaces,
    2: opt_set_interface_ip,
    3: opt_reset_root_password,
    4: opt_factory_reset,
    5: opt_poweroff,
    6: opt_reboot,
    7: opt_ping,
    8: opt_shell,
    9: opt_top,
    10: opt_firewall_log,
    11: opt_reload_services,
    12: opt_update,
    13: opt_restore_backup,
    14: opt_redetect_interfaces,
}


def wait_for_boot_ready():
    """Block until WML OS firewall services are all up.

    Shows a splash with a spinner on the FIRST console launch after install.
    Subsequent logins (logout/login, SSH) skip the splash because firstboot
    has already run and services are persistently enabled.
    """
    import time
    # Short-circuit on every login after the first: marker is written by
    # wml-firstboot at step 3a once interface assignment is complete.
    if os.path.exists("/var/lib/wml/firstboot-done"):
        # Still wait briefly for interfaces to have IPs (networkd race after
        # boot). Max 8s, exit early as soon as any iface has an IPv4 address.
        deadline = time.time() + 8
        while time.time() < deadline:
            try:
                if any(ipv4 for _, _, ipv4, _ in get_interfaces() if ipv4):
                    break
            except Exception:
                pass
            time.sleep(0.5)
        return
    critical_services = [
        ("systemd-networkd", "Network interfaces"),
        ("nginx",            "Web admin server"),
        ("wml-api",          "API backend"),
    ]
    # Also run wml-network-ensure synchronously to refresh NAT + DHCP server
    # on every login (cheap; idempotent; ensures internet path is up).
    header(f"Starting {get_hostname()}")
    section("Initialization",
            ["Bringing up firewall services...",
             "This takes ~10 seconds on first boot."])
    print()

    # Wait for each service in turn, with a spinner
    for unit, desc in critical_services:
        deadline = time.time() + 25
        was_active = False
        sys.stdout.write(f"  {DIM}…{RESET} {desc}")
        sys.stdout.flush()
        while time.time() < deadline:
            r = subprocess.run(["systemctl", "is-active", unit],
                               capture_output=True, text=True)
            if r.stdout.strip() in ("active", "activating"):
                was_active = True
                break
            time.sleep(0.5)
        if was_active:
            sys.stdout.write(f"\r  {GREEN}✓{RESET} {desc}{' ' * 30}\n")
        else:
            sys.stdout.write(f"\r  {YELLOW}!{RESET} {desc} (not ready, continuing){' ' * 10}\n")
        sys.stdout.flush()

    # Best-effort: re-apply NAT + restart services that depend on networkd.
    # This catches the case where wml-firstboot fired before cables were
    # plugged in, OR where cables got swapped between boots.
    sys.stdout.write(f"  {DIM}…{RESET} Applying firewall state")
    sys.stdout.flush()
    subprocess.run(["/usr/local/sbin/wml-fix"],
                   capture_output=True, text=True, timeout=30)
    sys.stdout.write(f"\r  {GREEN}✓{RESET} Firewall state applied{' ' * 20}\n")
    sys.stdout.flush()

    time.sleep(1)



def main():
    # Ignore Ctrl-C, Ctrl-Z so users can't accidentally drop out of the menu.
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    signal.signal(signal.SIGTSTP, signal.SIG_IGN)

    # On first console launch after boot, wait for the firewall to be fully up
    # and apply NAT/services so the user sees a healthy system. Skip the wait
    # if the marker file shows we've already done it this session.
    SESSION_MARKER = "/run/wml-console-ready"
    if not os.path.exists(SESSION_MARKER):
        try:
            wait_for_boot_ready()
            os.makedirs("/run", exist_ok=True)
            open(SESSION_MARKER, "w").close()
        except Exception as e:
            # Don't block the menu even if startup checks fail
            print(f"\n  {YELLOW}WARN:{RESET} startup check failed: {e}")
            import time
            time.sleep(1)

    import select
    import time
    import termios
    import tty
    REFRESH_SEC = 2.0

    def read_choice_with_live_refresh():
        """Read a digit 0-14 from stdin while periodically refreshing the
        interfaces display. Returns int 0-14, or None to redraw and retry."""
        fd = sys.stdin.fileno()
        last_snap = get_interfaces()
        try:
            old_attrs = termios.tcgetattr(fd)
        except termios.error:
            # Not a real tty (e.g., piped stdin) - fall back to blocking input.
            return prompt_int("Enter an option", 0, 14)
        tty.setcbreak(fd)
        try:
            sys.stdout.write(f"\n  {YELLOW}Enter an option (0-14):{RESET} ")
            sys.stdout.flush()
            buf = ""
            while True:
                rlist, _, _ = select.select([sys.stdin], [], [], REFRESH_SEC)
                if rlist:
                    ch = sys.stdin.read(1)
                    if not ch:
                        return None
                    if ch in ("\r", "\n"):
                        sys.stdout.write("\n")
                        sys.stdout.flush()
                        if buf.strip().isdigit():
                            n = int(buf.strip())
                            if 0 <= n <= 14:
                                return n
                        sys.stdout.write(f"  {RED}Invalid option.{RESET}\n")
                        sys.stdout.write(f"  {YELLOW}Enter an option (0-14):{RESET} ")
                        sys.stdout.flush()
                        buf = ""
                    elif ch in ("\x7f", "\b"):
                        if buf:
                            buf = buf[:-1]
                            sys.stdout.write("\b \b")
                            sys.stdout.flush()
                    elif ch == "\x03":  # Ctrl-C
                        raise KeyboardInterrupt
                    elif ch == "\x04":  # Ctrl-D
                        raise EOFError
                    elif ch.isprintable():
                        buf += ch
                        sys.stdout.write(ch)
                        sys.stdout.flush()
                else:
                    cur = get_interfaces()
                    if cur != last_snap:
                        last_snap = cur
                        return None
                    # Even if list unchanged, surface partial typed buffer
                    # so user knows their keystrokes were received.
        finally:
            termios.tcsetattr(fd, termios.TCSADRAIN, old_attrs)

    while True:
        show_header()
        menu(MENU_OPTIONS)
        try:
            choice = read_choice_with_live_refresh()
        except (EOFError, KeyboardInterrupt):
            print(f"\n  {DIM}Use option 0 to logout, option 8 to open shell.{RESET}")
            time.sleep(1)
            continue
        if choice is None:
            continue

        handler = HANDLERS.get(choice)
        if handler:
            try:
                handler()
            except (EOFError, KeyboardInterrupt):
                print(f"\n  {DIM}Cancelled.  Returning to menu.{RESET}")
                import time
                time.sleep(0.5)
            except Exception as e:
                import traceback
                print(f"\n  {RED}Error:{RESET} {e}")
                traceback.print_exc()
                pause()


if __name__ == "__main__":
    if os.geteuid() != 0:
        print(f"{RED}wml-console must run as root.{RESET}")
        sys.exit(1)
    main()
