"""Manage dataset directories locally and in R2.

This master command provides subcommands to list, upload, download, and
delete dataset directories. A "dataset directory" is any folder that contains
``data.jsonl`` (see examples under ``./temp``). By default, the command works
under the local root ``./temp`` and stores remote datasets under the R2 prefix
``datasets/<name>/`` in the configured bucket.

Defaults
--------
- Local root: ``./temp`` (override with ``--root``)
- Remote base prefix: ``datasets/`` (override with ``--remote-prefix``)
- R2 connection defaults (bucket/endpoint) match the R2 CLI defaults
  and can be overridden with flags or environment variables.
- Upload/download use multiple threads and always show progress by default.

Examples
--------
- List local datasets: ``s6 dataset list``
- List remote datasets: ``s6 dataset list --remote``
- Upload local dataset ``./temp/diverse_2`` to R2: ``s6 dataset upload diverse_2``
- Download remote dataset into ``./temp/fail_case``: ``s6 dataset download fail_case``
- Delete a remote dataset: ``s6 dataset delete fail_case --yes``
"""

from __future__ import annotations

import os
import os.path as osp
import sys
import argparse
from typing import List

from s6.app.r2._common import (
    add_r2_client_args,
    make_client,
    require_bucket_endpoint,
)


DEFAULT_ROOT = "temp"
DEFAULT_REMOTE_BASE = "datasets/"


def _ensure_trailing_slash(p: str) -> str:
    return p if p.endswith("/") else p + "/"


def _dataset_local_path(root: str, name: str) -> str:
    return osp.join(root, name)


def _dataset_remote_prefix(base: str, name: str) -> str:
    base = _ensure_trailing_slash(base)
    return f"{base}{name}/"


def _find_local_datasets(root: str) -> List[str]:
    if not osp.isdir(root):
        return []
    names: List[str] = []
    for entry in os.listdir(root):
        path = osp.join(root, entry)
        if osp.isdir(path) and osp.isfile(osp.join(path, "data.jsonl")):
            names.append(entry)
    names.sort()
    return names


def _list_remote_datasets(client, base_prefix: str) -> List[str]:
    # One-level listing to get child prefixes under base
    _, prefixes = client.list(
        prefix=_ensure_trailing_slash(base_prefix), recursive=False
    )
    out = []
    for p in prefixes:
        # p is like 'datasets/name/' -> extract 'name'
        name = (
            p[len(base_prefix) :].strip("/")
            if p.startswith(base_prefix)
            else p.strip("/")
        )
        if name:
            out.append(name)
    return sorted(out)


def _add_global_args(parser: argparse.ArgumentParser) -> None:
    add_r2_client_args(parser)
    parser.add_argument(
        "--root",
        default=DEFAULT_ROOT,
        help=f"Local root directory (default: {DEFAULT_ROOT})",
    )
    parser.add_argument(
        "--remote-prefix",
        default=DEFAULT_REMOTE_BASE,
        help=f"Remote base prefix (default: {DEFAULT_REMOTE_BASE})",
    )
    parser.add_argument(
        "-w",
        "--workers",
        type=int,
        default=8,
        help="Parallel workers for up/download (default: 8)",
    )


def cmd_list(args: argparse.Namespace) -> None:
    local = _find_local_datasets(args.root)
    if args.remote:
        bucket, endpoint = require_bucket_endpoint(args)
        client = make_client(bucket=bucket, endpoint=endpoint, region=args.region)
        remote = _list_remote_datasets(client, args.remote_prefix)
    else:
        remote = []

    if args.remote and not args.local_only:
        print("Remote datasets:")
        for n in remote:
            print(n)
        if args.local:
            print()
    if args.local:
        print("Local datasets:")
        for n in local:
            print(n)


def cmd_upload(args: argparse.Namespace) -> None:
    dataset_path = _dataset_local_path(args.root, args.name)
    if not osp.isdir(dataset_path):
        raise SystemExit(f"Local dataset not found: {dataset_path}")
    if not osp.isfile(osp.join(dataset_path, "data.jsonl")):
        raise SystemExit(
            f"Not a dataset directory (missing data.jsonl): {dataset_path}"
        )

    bucket, endpoint = require_bucket_endpoint(args)
    client = make_client(bucket=bucket, endpoint=endpoint, region=args.region)
    remote_prefix = _dataset_remote_prefix(args.remote_prefix, args.name)

    client.upload_directory(
        dataset_path,
        remote_prefix,
        overwrite=args.overwrite,
        max_workers=args.workers,
        progress=True,
    )


def cmd_download(args: argparse.Namespace) -> None:
    bucket, endpoint = require_bucket_endpoint(args)
    client = make_client(bucket=bucket, endpoint=endpoint, region=args.region)
    remote_prefix = _dataset_remote_prefix(args.remote_prefix, args.name)

    dst_dir = _dataset_local_path(args.root, args.name)
    os.makedirs(dst_dir, exist_ok=True)
    client.download_directory(
        remote_prefix,
        dst_dir,
        max_workers=args.workers,
        progress=True,
        overwrite=args.overwrite,
    )

    # Post-check: ensure data.jsonl exists to conform to dataset structure
    jsonl = osp.join(dst_dir, "data.jsonl")
    if not osp.isfile(jsonl):
        print(f"Warning: downloaded content under '{dst_dir}' lacks data.jsonl")


def cmd_delete(args: argparse.Namespace) -> None:
    bucket, endpoint = require_bucket_endpoint(args)
    client = make_client(bucket=bucket, endpoint=endpoint, region=args.region)
    remote_prefix = _dataset_remote_prefix(args.remote_prefix, args.name)

    if not args.yes:
        raise SystemExit("Refusing to delete without --yes confirmation")

    objs, _ = client.list(prefix=remote_prefix, recursive=True)
    count = 0
    for o in objs:
        client.delete_object(o.key, missing_ok=True)
        count += 1
    print(f"Deleted {count} object(s) from '{remote_prefix}'")

    if args.local:
        # Optionally remove the local dataset directory as well
        import shutil

        local_dir = _dataset_local_path(args.root, args.name)
        try:
            shutil.rmtree(local_dir)
            print(f"Removed local dataset '{local_dir}'")
        except Exception as e:
            print(f"Warning: failed to remove local '{local_dir}': {e}")


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Manage dataset directories locally and on R2"
    )
    sub = parser.add_subparsers(dest="cmd", required=True)

    # list
    p_list = sub.add_parser("list", help="List local and/or remote datasets")
    _add_global_args(p_list)
    p_list.add_argument("--remote", action="store_true", help="List remote datasets")
    p_list.add_argument("--local", action="store_true", help="List local datasets")
    p_list.add_argument(
        "--local-only",
        action="store_true",
        help="Only list local datasets (ignore remote)",
    )
    p_list.set_defaults(func=cmd_list, local=True)

    # upload
    p_up = sub.add_parser("upload", help="Upload a local dataset directory to R2")
    _add_global_args(p_up)
    p_up.add_argument("name", help="Dataset directory name under --root")
    p_up.add_argument(
        "--overwrite", action="store_true", help="Allow overwriting existing objects"
    )
    p_up.set_defaults(func=cmd_upload)

    # download
    p_down = sub.add_parser(
        "download", help="Download a remote dataset into local --root"
    )
    _add_global_args(p_down)
    p_down.add_argument("name", help="Dataset name to download")
    p_down.add_argument(
        "--overwrite", action="store_true", help="Allow overwriting existing local files"
    )
    p_down.set_defaults(func=cmd_download)

    # delete
    p_del = sub.add_parser(
        "delete", help="Delete a remote dataset; optional local removal"
    )
    _add_global_args(p_del)
    p_del.add_argument("name", help="Dataset name to delete")
    p_del.add_argument(
        "--local", action="store_true", help="Also delete the local dataset directory"
    )
    p_del.add_argument(
        "-y", "--yes", action="store_true", help="Confirm deletion without prompt"
    )
    p_del.set_defaults(func=cmd_delete)

    args = parser.parse_args()
    # Normalize remote_prefix for internal use
    args.remote_prefix = _ensure_trailing_slash(args.remote_prefix)
    args.root = args.root or DEFAULT_ROOT
    args.workers = int(args.workers or 8)

    # Dispatch
    args.func(args)


if __name__ == "__main__":
    main()
