#
#  Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an
# express license agreement from NVIDIA CORPORATION is strictly
# prohibited.
#
import importlib
import os
from types import ModuleType
from typing import Any, Optional

__torch_import_attempted: bool = False
__torch_distributed_module: Optional[ModuleType] = None


def import_torch_distributed() -> Optional[ModuleType]:
    """Lazy import of the torch.distributed package to avoid a hard dependency for non-torch applications."""
    global __torch_distributed_module, __torch_import_attempted
    if __torch_import_attempted:
        return __torch_distributed_module

    try:
        __torch_distributed_module = importlib.import_module("torch.distributed")
        return __torch_distributed_module
    except ImportError:
        return None
    finally:
        __torch_import_attempted = True


def get_rank(group: Optional[Any] = None) -> int:
    """Get the rank (GPU device) of the worker.

    Returns:
        rank (int): The rank of the worker.
    """
    rank = int(os.getenv("RANK", "0"))
    dist = import_torch_distributed()
    if dist is not None and dist.is_available() and dist.is_initialized():
        rank = dist.get_rank(group)
    return rank


def get_world_size(group: Optional[Any] = None) -> int:
    """Get world size. How many GPUs are available in this job.

    Returns:
        world_size (int): The total number of GPUs available in this job.
    """
    world_size = int(os.getenv("WORLD_SIZE", "1"))
    dist = import_torch_distributed()
    if dist is not None and dist.is_available() and dist.is_initialized():
        world_size = dist.get_world_size(group)
    return world_size


def is_rank0() -> bool:
    """Check if current process is the master GPU.

    Returns:
        (bool): True if this function is called from the master GPU, else False.
    """
    return get_rank() == 0


def start_monitoring_flops() -> Any:
    """
    If possible, enable pytorch flops counter mode, to track floating point operations that are performed,
    until end_monitoring_tflops is called.
    Returns:
        flop_counter (Any): The flop counter object if successful, else None.
    """
    try:
        from torch.utils.flop_counter import FlopCounterMode

        flop_counter = FlopCounterMode(display=False)
        flop_counter.__enter__()
        return flop_counter
    except Exception:
        return None


def end_monitoring_flops(flop_counter: Any) -> float:
    if flop_counter is not None:
        flop_counter.__exit__(None, None, None)
        return float(flop_counter.get_total_flops())
    else:
        return 0
