"""Thin Cloudflare R2 client (S3‑compatible).

Provides a minimal wrapper around ``boto3`` for listing, uploading,
downloading, and deleting objects in Cloudflare R2 using the S3 API. It is
deliberately small and explicit, with sensible defaults and safeguards
against accidental overwrites.

Key concepts
------------
- Uses standard S3 credentials and custom ``endpoint_url`` for R2.
- Environment variables can supply credentials/region when arguments are
  omitted: ``R2_ACCESS_KEY_ID``/``AWS_ACCESS_KEY_ID``,
  ``R2_SECRET_ACCESS_KEY``/``AWS_SECRET_ACCESS_KEY``, and optionally
  ``R2_REGION_NAME``/``AWS_REGION``/``AWS_DEFAULT_REGION``.
- Upload helpers default to no‑overwrite behaviour; pass ``overwrite=True``
  where supported to replace existing objects.

Motivation
----------
Experiments often produce large datasets and intermediate artifacts that need
to be shared between local development machines, remote teammates, and lab
hardware. Cloudflare R2 provides inexpensive S3‑compatible object storage
without egress fees. Wrapping it behind a tiny, documented interface lets the
pipeline move data in and out of a relatively abstracted storage backend
without depending on any vendor‑specific SDK beyond the ubiquitous S3 API.

This module aims to be:
- Minimal: only the operations we actually use (list, get/put, delete).
- Predictable: no silent overwrites by default; explicit ``overwrite=True``
  when you intend to replace data.
- Portable: credentials via env or args; works with R2 and any S3‑compatible
  endpoint (e.g., MinIO) by changing ``endpoint_url``.

Design
------
The :class:`R2Client` is a thin façade over a ``boto3`` S3 client:

- Construction resolves credentials/region from arguments or environment
  variables and creates a namespaced S3 client using a custom
  ``endpoint_url``.
- Listings use ``ListObjectsV2`` with pagination (``ContinuationToken``) and
  optional delimiter semantics when ``recursive=False``. Results are returned
  as a list of :class:`R2Object` plus any child prefixes.
- Upload helpers intentionally refuse to overwrite existing keys unless you
  pass ``overwrite=True`` (for directory uploads). Byte and file uploads
  raise ``FileExistsError`` if the key is present.
- ``download_directory`` recreates the relative structure under a prefix.
- ``_exists`` issues a ``HEAD`` request and interprets common not‑found error
  codes; other errors are re‑raised to the caller.
- To keep documentation builds importable without cloud dependencies, ``boto3``
  imports are guarded. Instantiating :class:`R2Client` fails fast with a clear
  ``ImportError`` if ``boto3`` is unavailable.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Dict, Any, Tuple
import os
import os.path as osp
import threading
import time
import sys

try:  # optional at import time to keep docs importable
    import boto3  # type: ignore
    from botocore.exceptions import ClientError  # type: ignore
except Exception:  # pragma: no cover
    boto3 = None  # type: ignore

    class ClientError(Exception):  # type: ignore
        pass


@dataclass
class R2Object:
    """Lightweight object metadata returned by listings.

    Attributes
    ----------
    key : str
        Object key (path) within the bucket.
    size : int
        Object size in bytes.
    last_modified : Any
        Timestamp of last modification (``datetime`` from ``boto3``). Typing
        kept generic to avoid a hard dependency in the public API.
    etag : str
        Entity tag (usually an MD5 for non‑multipart uploads).
    """

    key: str
    size: int
    last_modified: Any  # datetime, but we don't need strict typing here
    etag: str


class R2Client:
    """Minimal S3‑compatible client for Cloudflare R2.

    Parameters
    ----------
    access_key_id : str, optional
        Access key ID. When ``None``, resolved from environment variables
        ``R2_ACCESS_KEY_ID`` or ``AWS_ACCESS_KEY_ID``.
    secret_access_key : str, optional
        Secret access key. When ``None``, resolved from ``R2_SECRET_ACCESS_KEY``
        or ``AWS_SECRET_ACCESS_KEY``.
    bucket_name : str, optional
        Target bucket name.
    endpoint_url : str, optional
        R2 endpoint URL, e.g. ``"https://<account>.r2.cloudflarestorage.com"``.
    region_name : str, optional
        Region name. If set to ``"auto"`` (default), resolved from
        ``R2_REGION_NAME``/``AWS_REGION``/``AWS_DEFAULT_REGION`` when present.
    session : boto3.session.Session, optional
        Existing boto3 session to reuse. A new session is created by default.

    Examples
    --------
    Create a client and perform common operations::

        client = R2Client(
            access_key_id="...",
            secret_access_key="...",
            bucket_name="my-bucket",
            endpoint_url="https://<account>.r2.cloudflarestorage.com",
        )

        client.upload_file("local.txt", "folder/remote.txt")
        client.download_file("folder/remote.txt", "downloaded.txt")
        client.delete_object("folder/remote.txt")
        objects, prefixes = client.list(prefix="folder/", recursive=False)
    """

    def __init__(
        self,
        access_key_id: Optional[str] = None,
        secret_access_key: Optional[str] = None,
        bucket_name: str = "",
        endpoint_url: str = "",
        region_name: str = "auto",
        session: Optional[boto3.session.Session] = None,
    ) -> None:
        """Initialize the client and underlying S3 client.

        See class docstring for environment variable resolution rules.
        """
        if boto3 is None:  # pragma: no cover - clearer error when missing dependency
            raise ImportError(
                "boto3 and botocore are required to use R2Client; install with 'pip install boto3'"
            )
        self.bucket_name = bucket_name

        if access_key_id is None:
            access_key_id = (
                os.getenv("R2_ACCESS_KEY_ID")
                or os.getenv("AWS_ACCESS_KEY_ID")
            )
        if secret_access_key is None:
            secret_access_key = (
                os.getenv("R2_SECRET_ACCESS_KEY")
                or os.getenv("AWS_SECRET_ACCESS_KEY")
            )

        # If region_name not explicitly provided (default is "auto"), allow env override
        env_region = (
            os.getenv("R2_REGION_NAME")
            or os.getenv("AWS_REGION")
            or os.getenv("AWS_DEFAULT_REGION")
        )
        if region_name == "auto" and env_region:
            region_name = env_region

        if session is None:
            session = boto3.session.Session()

        self._s3 = session.client(
            "s3",
            aws_access_key_id=access_key_id,
            aws_secret_access_key=secret_access_key,
            endpoint_url=endpoint_url,
            region_name=region_name,
        )

    # -----------------------------
    # Listing
    # -----------------------------
    def list(
        self,
        prefix: str = "",
        recursive: bool = True,
    ) -> tuple[List[R2Object], List[str]]:
        """List objects and pseudo‑directories under a prefix.

        Parameters
        ----------
        prefix : str, optional
            Key prefix to list under, e.g., ``"folder/sub/"``.
        recursive : bool, optional
            If ``True`` (default), returns all objects under ``prefix``.
            If ``False``, treats ``"/"`` as a delimiter and returns only
            one level, also returning sub‑prefixes.

        Returns
        -------
        (list[R2Object], list[str])
            A pair ``(objects, prefixes)``. ``objects`` contains metadata for
            each object; ``prefixes`` lists child "directories" (only when
            ``recursive=False``).
        """
        delimiter = None if recursive else "/"

        objects: List[R2Object] = []
        prefixes: List[str] = []

        continuation_token: Optional[str] = None

        while True:
            params: Dict[str, Any] = {
                "Bucket": self.bucket_name,
                "Prefix": prefix,
            }
            if delimiter:
                params["Delimiter"] = delimiter
            if continuation_token:
                params["ContinuationToken"] = continuation_token

            response = self._s3.list_objects_v2(**params)

            # Object contents
            for obj in response.get("Contents", []):
                # Skip the prefix "folder" markers if any
                if obj["Key"] == prefix and obj["Key"].endswith("/"):
                    continue

                objects.append(
                    R2Object(
                        key=obj["Key"],
                        size=obj["Size"],
                        last_modified=obj["LastModified"],
                        etag=obj["ETag"].strip('"'),
                    )
                )

            # "Directories" (only when delimiter is set)
            for cp in response.get("CommonPrefixes", []):
                prefixes.append(cp["Prefix"])

            if response.get("IsTruncated"):
                continuation_token = response.get("NextContinuationToken")
            else:
                break

        return objects, prefixes

    # -----------------------------
    # Download
    # -----------------------------
    def download_file(self, key: str, local_path: str, *, progress: bool = False) -> None:
        """Download an object to a local file.

        Parameters
        ----------
        key : str
            Object key in the bucket.
        local_path : str
            Local filesystem path to write to.
        """
        # Ensure parent directory exists
        os.makedirs(osp.dirname(local_path) or ".", exist_ok=True)

        # Optional progress: determine total bytes via HEAD
        callback = None
        config = _make_transfer_config(max_concurrency=4)
        if progress:
            try:
                head = self._s3.head_object(Bucket=self.bucket_name, Key=key)
                total = int(head.get("ContentLength", 0) or 0)
            except Exception:
                total = 0
            callback = _ProgressPrinter(total, enabled=True, label="Downloading")

        if callback is not None:
            self._s3.download_file(self.bucket_name, key, local_path, Callback=callback, Config=config)
        else:
            self._s3.download_file(self.bucket_name, key, local_path, Config=config)

    def download_directory(
        self,
        prefix: str,
        local_dir: str,
        *,
        max_workers: int = 8,
        progress: bool = False,
        overwrite: bool = False,
    ) -> None:
        """Download all objects under a prefix into a local directory, optionally in parallel.

        Preserves the relative structure underneath ``prefix``.

        Parameters
        ----------
        prefix : str
            Key prefix to download (e.g., ``"folder/sub/"``).
        local_dir : str
            Destination local directory. Created if missing.
        """
        # Normalize prefix to allow clean relative path computation
        norm_prefix = prefix or ""
        # List all objects recursively under the prefix
        objects, _ = self.list(prefix=norm_prefix, recursive=True)

        os.makedirs(local_dir, exist_ok=True)

        # Build download tasks and compute aggregate size
        tasks: List[Tuple[str, str, int]] = []  # (key, dst_path, size)
        total_bytes = 0
        for obj in objects:
            key = obj.key
            rel = key[len(norm_prefix):] if norm_prefix and key.startswith(norm_prefix) else key
            rel = rel.lstrip("/")  # guard against accidental leading slash
            dst_path = osp.join(local_dir, rel)
            tasks.append((key, dst_path, int(obj.size)))
            total_bytes += int(obj.size)

        # Pre-flight existence check to avoid partial downloads when overwrite=False
        if not overwrite:
            collisions = [p for _, p, _ in tasks if osp.exists(p)]
            if collisions:
                # Show first few collisions for clarity
                sample = ", ".join(collisions[:5])
                more = f" and {len(collisions)-5} more" if len(collisions) > 5 else ""
                raise FileExistsError(
                    f"Local files already exist for download targets (e.g., {sample}{more}). "
                    "Use overwrite=True to replace."
                )

        # Aggregate progress across files
        progress_cb = _ProgressPrinter(total_bytes, enabled=progress, label="Downloading")
        per_transfer_conc = max(4, min(16, int(max_workers)))
        config = _make_transfer_config(max_concurrency=per_transfer_conc)

        def _download_one(key: str, dst_path: str) -> None:
            os.makedirs(osp.dirname(dst_path), exist_ok=True)
            if progress_cb.enabled:
                self._s3.download_file(
                    self.bucket_name,
                    key,
                    dst_path,
                    Callback=progress_cb,
                    Config=config,
                )
            else:
                self._s3.download_file(self.bucket_name, key, dst_path, Config=config)

        if max_workers and max_workers > 1 and len(tasks) > 1:
            from concurrent.futures import ThreadPoolExecutor, as_completed

            with ThreadPoolExecutor(max_workers=int(max_workers)) as ex:
                futs = [ex.submit(_download_one, k, p) for k, p, _ in tasks]
                for f in as_completed(futs):
                    f.result()
        else:
            for k, p, _ in tasks:
                _download_one(k, p)

    def get_object_bytes(self, key: str) -> bytes:
        """Download an object into memory (as bytes).

        Parameters
        ----------
        key : str
            Object key.

        Returns
        -------
        bytes
            Object data.
        """
        resp = self._s3.get_object(Bucket=self.bucket_name, Key=key)
        return resp["Body"].read()

    # -----------------------------
    # Upload (no overwrite)
    # -----------------------------
    def upload_file(
        self,
        local_path: str,
        key: str,
        *,
        progress: bool = False,
    ) -> None:
        """Upload a local file without overwriting existing objects.

        Parameters
        ----------
        local_path : str
            Path to the local file.
        key : str
            Destination object key in the bucket.

        Raises
        ------
        FileExistsError
            If an object already exists at ``key``.
        """
        if self._exists(key):
            raise FileExistsError(
                f"Object already exists at key '{key}' in bucket '{self.bucket_name}'"
            )
        # Optional per-file progress (bytes callback)
        callback = None
        if progress:
            total = os.path.getsize(local_path)
            callback = _ProgressPrinter(total, enabled=True)

        # Transfer config (optional) for reasonable concurrency within a single transfer
        config = _make_transfer_config(max_concurrency=4)

        if callback is not None:
            self._s3.upload_file(local_path, self.bucket_name, key, Callback=callback, Config=config)
        else:
            self._s3.upload_file(local_path, self.bucket_name, key, Config=config)

    def upload_directory(
        self,
        local_dir: str,
        prefix: str,
        *,
        overwrite: bool = False,
        max_workers: int = 8,
        progress: bool = False,
    ) -> None:
        """Upload a directory tree under a prefix, optionally in parallel.

        Preserves the relative path structure under ``local_dir``. Keys are
        normalised to use ``/`` separators.

        Parameters
        ----------
        local_dir : str
            Path to the local directory whose contents to upload.
        prefix : str
            Destination key prefix (e.g., ``"folder/sub/"``).
        overwrite : bool, optional
            If ``False`` (default), raise ``FileExistsError`` when a destination
            key already exists. If ``True``, existing objects are replaced.

        Additional Parameters
        ---------------------
        max_workers : int, optional
            Number of worker threads for parallel uploads (default ``8``).
        progress : bool, optional
            If ``True``, prints a simple aggregate progress indicator to stderr.

        Raises
        ------
        NotADirectoryError
            If ``local_dir`` is not a directory.
        FileExistsError
            If ``overwrite`` is ``False`` and any destination key exists.
        """
        if not osp.isdir(local_dir):
            raise NotADirectoryError(f"Not a directory: {local_dir}")
        norm_prefix = prefix.rstrip("/")  # avoid accidental double slashes

        # Collect file list and destination keys
        file_entries: List[Tuple[str, str, int]] = []  # (local_path, key, size)
        for root, _, files in os.walk(local_dir):
            for fname in files:
                local_path = osp.join(root, fname)
                rel_path = osp.relpath(local_path, start=local_dir)
                rel_key = rel_path.replace(os.sep, "/")  # Keys must use '/'
                key = f"{norm_prefix}/{rel_key}" if norm_prefix else rel_key
                size = osp.getsize(local_path)
                file_entries.append((local_path, key, size))

        # Pre-flight existence check to preserve no-overwrite semantics.
        # Use a single prefix listing to avoid many HEAD requests.
        if not overwrite:
            existing_objects, _ = self.list(prefix=(norm_prefix + "/") if norm_prefix and not norm_prefix.endswith("/") else norm_prefix, recursive=True)
            existing_keys = {obj.key for obj in existing_objects}
            collisions = [key for _, key, _ in file_entries if key in existing_keys]
            if collisions:
                sample = ", ".join(collisions[:5])
                more = f" and {len(collisions)-5} more" if len(collisions) > 5 else ""
                raise FileExistsError(
                    f"Remote objects already exist at destination (e.g., {sample}{more}). "
                    "Use overwrite=True to replace."
                )

        # Aggregate progress across all files
        total_bytes = sum(sz for _, __, sz in file_entries)
        progress_cb = _ProgressPrinter(total_bytes, enabled=progress)

        # Create a TransferConfig with reasonable concurrency per transfer
        # (actual directory-level parallelism is controlled by ThreadPoolExecutor)
        per_transfer_conc = max(4, min(16, int(max_workers)))
        config = _make_transfer_config(max_concurrency=per_transfer_conc)

        # Worker function for each file
        def _upload_one(local_path: str, key: str) -> None:
            if progress_cb.enabled:
                self._s3.upload_file(
                    local_path,
                    self.bucket_name,
                    key,
                    Callback=progress_cb,
                    Config=config,
                )
            else:
                self._s3.upload_file(local_path, self.bucket_name, key, Config=config)

        # Run uploads, possibly in parallel
        if max_workers and max_workers > 1 and len(file_entries) > 1:
            from concurrent.futures import ThreadPoolExecutor, as_completed

            with ThreadPoolExecutor(max_workers=int(max_workers)) as ex:
                futs = [ex.submit(_upload_one, lp, k) for lp, k, _ in file_entries]
                # Propagate exceptions early
                for f in as_completed(futs):
                    f.result()
        else:
            for lp, k, _ in file_entries:
                _upload_one(lp, k)

    def upload_bytes(
        self,
        data: bytes,
        key: str,
    ) -> None:
        """Upload raw bytes as an object without overwriting.

        Parameters
        ----------
        data : bytes
            Bytes to upload.
        key : str
            Destination object key.

        Raises
        ------
        FileExistsError
            If an object already exists at ``key``.
        """
        if self._exists(key):
            raise FileExistsError(
                f"Object already exists at key '{key}' in bucket '{self.bucket_name}'"
            )

        self._s3.put_object(Bucket=self.bucket_name, Key=key, Body=data)

    # -----------------------------
    # Delete
    # -----------------------------
    def delete_object(self, key: str, missing_ok: bool = True) -> None:
        """Delete an object.

        Parameters
        ----------
        key : str
            Object key to delete.
        missing_ok : bool, optional
            If ``True`` (default), succeed even if the object is missing. If
            ``False``, raise an error when the object does not exist.

        Raises
        ------
        FileNotFoundError
            If ``missing_ok`` is ``False`` and the object does not exist.
        """
        if not missing_ok and not self._exists(key):
            raise FileNotFoundError(
                f"Object '{key}' does not exist in bucket '{self.bucket_name}'"
            )

        self._s3.delete_object(Bucket=self.bucket_name, Key=key)

    # -----------------------------
    # Internal helpers
    # -----------------------------
    def _exists(self, key: str) -> bool:
        """Return whether an object exists using a HEAD request.

        Parameters
        ----------
        key : str
            Object key to check.

        Returns
        -------
        bool
            ``True`` if the object exists, ``False`` otherwise.
        """
        try:
            self._s3.head_object(Bucket=self.bucket_name, Key=key)
            return True
        except ClientError as e:
            code = e.response.get("Error", {}).get("Code")
            if code in ("404", "NoSuchKey", "NotFound"):
                return False
            # Something else went wrong; bubble it up
            raise


# -----------------------------
# Internal progress + transfer utilities
# -----------------------------
class _ProgressPrinter:
    """Simple thread-safe byte counter with periodic stderr output.

    Used as a boto3 callback for uploads. Instantiate with ``enabled=True`` to
    produce output; otherwise it is a no-op.
    """

    def __init__(self, total: int, enabled: bool = False, label: str = "Uploading") -> None:
        self.total = int(total)
        self.enabled = bool(enabled) and self.total > 0
        self.label = label
        self._lock = threading.Lock()
        self._seen = 0
        self._last = 0.0

    def __call__(self, bytes_amount: int) -> None:
        if not self.enabled:
            return
        with self._lock:
            self._seen += int(bytes_amount)
            now = time.time()
            if (now - self._last) >= 0.2 or self._seen >= self.total:
                self._last = now
                pct = (self._seen / self.total) * 100.0 if self.total else 100.0
                msg = f"\r{self.label}: {self._seen}/{self.total} bytes ({pct:.1f}%)"
                print(msg, end="", file=sys.stderr, flush=True)
                if self._seen >= self.total:
                    print(file=sys.stderr)  # newline


def _make_transfer_config(max_concurrency: int):  # pragma: no cover - thin wrapper
    """Create a boto3 TransferConfig if available; otherwise return ``None``.

    This keeps the module importable without boto3 installed (tests provide stubs).
    """
    try:
        from boto3.s3.transfer import TransferConfig  # type: ignore

        return TransferConfig(max_concurrency=int(max_concurrency or 4))
    except Exception:
        return None
