#!/usr/bin/env python3
"""
nukez_context.py — Cross-Agent Context Bridge

Uses the NukezAgent (backed by the real Nukez gateway) for persistent,
verifiable context management across any AI agent — Claude Code, Codex,
Cursor, custom agents. Same identity → same context → seamless handoff.

Usage:
    python nukez_context.py bootstrap          # First time: init storage
    python nukez_context.py persist <key> <val> [--tags t1,t2]
    python nukez_context.py recall [--key K] [--tag T]
    python nukez_context.py verify <receipt_id>
    python nukez_context.py status
    python nukez_context.py snapshot --source claude-code
    python nukez_context.py hydrate            # Load latest snapshot into any agent
"""
from __future__ import annotations

import argparse
import asyncio
import base64
import json
import os
import sys
import time
from pathlib import Path

# Add project root for imports
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT / "bundles" / "pynukez"))
sys.path.insert(0, str(PROJECT_ROOT / "src"))

from nukez_aaap.agent import NukezAgent
from nukez_aaap.config import AAAPConfig

# ── Configuration ────────────────────────────────────────────────────────

STATE_FILE = Path(__file__).parent / ".nukez_context_state.json"


def load_state() -> dict:
    if STATE_FILE.exists():
        return json.loads(STATE_FILE.read_text())
    return {}


def save_state(state: dict):
    STATE_FILE.parent.mkdir(parents=True, exist_ok=True)
    STATE_FILE.write_text(json.dumps(state, indent=2))


def pp(data):
    print(json.dumps(data, indent=2))


# ── Agent singleton ─────────────────────────────────────────────────────

_agent: NukezAgent | None = None


async def get_agent() -> NukezAgent:
    global _agent
    if _agent is None:
        config = AAAPConfig.from_env()
        _agent = NukezAgent(config)
        boot = await _agent.bootstrap()
        print(f"Agent ready: pubkey={boot.wallet_pubkey[:12]}... "
              f"network={boot.network} locker={boot.system_locker_id[:16]}... "
              f"entries={boot.index_entries}")
    return _agent


async def close_agent():
    global _agent
    if _agent:
        await _agent.close()
        _agent = None


# ── Commands ─────────────────────────────────────────────────────────────


async def cmd_bootstrap():
    """Bootstrap the agent — find or create system locker, load index."""
    agent = await get_agent()
    status = await agent.status()
    print()
    print(f"  wallet: {status.wallet_pubkey}")
    print(f"  balance: {status.balance_sol:.4f} SOL")
    print(f"  locker: {status.system_locker_id}")
    print(f"  entries: {status.index_entries}")
    print(f"  stage: {status.stage}")
    return 0


async def cmd_persist(key: str, value: str, tags: list[str] | None = None):
    """Persist a context entry."""
    agent = await get_agent()
    all_tags = ["context"] + (tags or [])
    filename = f"context/{key}.json"

    entry = {
        "key": key,
        "value": value,
        "tags": all_tags,
        "source": os.environ.get("NUKEZ_CONTEXT_SOURCE", "cli"),
        "timestamp": int(time.time()),
    }
    data = json.dumps(entry).encode()

    result = await agent.persist(data, filename=filename, tags=all_tags,
                                  metadata={"context_key": key})
    print(f"✓ Persisted: {key}")
    print(f"  receipt_id: {result.receipt_id}")
    print(f"  verified: {result.verified}")
    if result.merkle_root:
        print(f"  merkle_root: {result.merkle_root[:24]}...")

    # Track locally
    state = load_state()
    state.setdefault("entries", {})[key] = {
        "receipt_id": result.receipt_id,
        "filename": filename,
        "tags": all_tags,
        "timestamp": entry["timestamp"],
    }
    save_state(state)
    return 0


async def cmd_recall(key: str | None = None, tag: str | None = None):
    """Recall context entries."""
    agent = await get_agent()

    if key:
        result = await agent.recall(filename=f"context/{key}.json", include_data=True)
    elif tag:
        result = await agent.recall(tag=tag, include_data=True, limit=50)
    else:
        result = await agent.recall(tag="context", include_data=True, limit=50)

    if not result.items:
        print("(no context entries found)")
        return 0

    print(f"Found {len(result.items)} entries:\n")
    for item in result.items:
        if item.data:
            try:
                raw = item.data if isinstance(item.data, bytes) else item.data.encode()
                entry = json.loads(raw)
                print(f"  [{entry.get('key', item.filename)}]")
                print(f"    value: {entry.get('value', '(raw)')}")
                print(f"    source: {entry.get('source', '?')}")
                ts = entry.get("timestamp")
                if ts:
                    print(f"    when: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(ts))}")
                if item.verified:
                    print(f"    verified: ✓")
            except Exception:
                print(f"  [{item.filename}] (raw data, {len(item.data)} bytes)")
        else:
            print(f"  [{item.filename}] (metadata only)")
        print()
    return 0


async def cmd_verify(receipt_id: str):
    """Verify a context entry's integrity."""
    agent = await get_agent()
    result = await agent.verify(receipt_id)
    print(f"  valid: {result.valid}")
    if result.merkle_root:
        print(f"  merkle_root: {result.merkle_root}")
    if result.attestation:
        print(f"  attestation: {result.attestation}")
    if result.chain:
        print(f"  chain: {result.chain}")
    return 0


async def cmd_status():
    """Check agent status."""
    agent = await get_agent()
    status = await agent.status()
    print(f"  pubkey: {status.wallet_pubkey}")
    print(f"  balance: {status.balance_sol:.4f} SOL")
    print(f"  locker: {status.system_locker_id}")
    print(f"  entries: {status.index_entries}")
    print(f"  stage: {status.stage}")
    if status.low_balance:
        print(f"  ⚠ LOW BALANCE")
    return 0


async def cmd_snapshot(source: str = "unknown"):
    """Persist a full context snapshot for cross-agent handoff."""
    agent = await get_agent()
    state = load_state()
    entries = state.get("entries", {})

    # Also pull live index entries
    try:
        live = await agent.recall(tag="context", include_data=True, limit=100)
        live_entries = {}
        for item in live.items:
            if item.data:
                try:
                    raw = item.data if isinstance(item.data, bytes) else item.data.encode()
                    entry = json.loads(raw)
                    live_entries[entry.get("key", item.filename)] = {
                        "value": entry.get("value"),
                        "tags": entry.get("tags", []),
                        "source": entry.get("source"),
                        "timestamp": entry.get("timestamp"),
                    }
                except Exception:
                    pass
    except Exception:
        live_entries = {}

    snapshot = {
        "type": "context_snapshot",
        "source": source,
        "timestamp": int(time.time()),
        "entries": {**entries, **{k: v for k, v in live_entries.items()}},
        "live_context": live_entries,
        "summary": {
            "total_entries": len(entries) + len(live_entries),
            "sources": list({e.get("source", "?") for e in live_entries.values()}),
        },
    }

    filename = f"snapshots/snapshot_{source}_{int(time.time())}.json"
    data = json.dumps(snapshot).encode()

    result = await agent.persist(data, filename=filename,
                                  tags=["snapshot", f"source:{source}"],
                                  metadata={"snapshot_source": source})
    print(f"✓ Snapshot persisted from {source}")
    print(f"  receipt_id: {result.receipt_id}")
    print(f"  entries: {snapshot['summary']['total_entries']}")
    print(f"  verified: {result.verified}")
    return 0


async def cmd_hydrate():
    """Recall the latest context snapshot — used on agent startup."""
    agent = await get_agent()
    result = await agent.recall(tag="snapshot", include_data=True, limit=1)

    if not result.items:
        print("No snapshots found — starting fresh.")
        return 0

    item = result.items[0]
    try:
        raw = item.data if isinstance(item.data, bytes) else item.data.encode()
        snapshot = json.loads(raw)
    except Exception:
        print("✗ Could not decode snapshot.")
        return 1

    print(f"✓ Hydrated from snapshot")
    print(f"  source: {snapshot.get('source', '?')}")
    ts = snapshot.get("timestamp")
    if ts:
        print(f"  taken: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(ts))}")

    live = snapshot.get("live_context", {})
    print(f"  entries: {len(live)}")
    print()

    if live:
        print("Context entries:")
        for key, entry in live.items():
            print(f"  [{key}] = {entry.get('value', '?')}")
            print(f"    tags: {entry.get('tags', [])} | source: {entry.get('source', '?')}")

    # Update local state
    state = load_state()
    state["entries"] = snapshot.get("entries", {})
    state["hydrated_from"] = snapshot.get("source")
    state["hydrated_at"] = int(time.time())
    save_state(state)
    return 0


async def cmd_delegate(task: str):
    """Natural-language delegation."""
    agent = await get_agent()
    result = await agent.delegate(task)
    print(f"  status: {result.status}")
    print(f"  summary: {result.summary}")
    if result.receipts:
        print(f"  receipts: {result.receipts}")
    return 0


# ── CLI ──────────────────────────────────────────────────────────────────


async def async_main():
    parser = argparse.ArgumentParser(description="Nukez Cross-Agent Context Bridge")
    sub = parser.add_subparsers(dest="command")

    sub.add_parser("bootstrap", help="Init agent + storage")

    p = sub.add_parser("persist", help="Persist a context entry")
    p.add_argument("key", help="Context key")
    p.add_argument("value", help="Context value")
    p.add_argument("--tags", default="", help="Comma-separated tags")

    p = sub.add_parser("recall", help="Recall context")
    p.add_argument("--key", help="Recall specific key")
    p.add_argument("--tag", help="Recall by tag")

    p = sub.add_parser("verify", help="Verify entry integrity")
    p.add_argument("receipt_id", help="Receipt ID")

    sub.add_parser("status", help="Agent status")

    p = sub.add_parser("snapshot", help="Persist context snapshot")
    p.add_argument("--source", default="cli", help="Source agent name")

    sub.add_parser("hydrate", help="Recall latest snapshot")

    p = sub.add_parser("delegate", help="Natural-language delegation")
    p.add_argument("task", help="Task description")

    args = parser.parse_args()
    if not args.command:
        parser.print_help()
        return 1

    try:
        if args.command == "bootstrap":
            return await cmd_bootstrap()
        elif args.command == "persist":
            tags = [t.strip() for t in args.tags.split(",") if t.strip()] if args.tags else []
            return await cmd_persist(args.key, args.value, tags)
        elif args.command == "recall":
            return await cmd_recall(args.key, args.tag)
        elif args.command == "verify":
            return await cmd_verify(args.receipt_id)
        elif args.command == "status":
            return await cmd_status()
        elif args.command == "snapshot":
            return await cmd_snapshot(args.source)
        elif args.command == "hydrate":
            return await cmd_hydrate()
        elif args.command == "delegate":
            return await cmd_delegate(args.task)
    finally:
        await close_agent()


def main():
    return asyncio.run(async_main())


if __name__ == "__main__":
    sys.exit(main() or 0)
