#
#  Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an
# express license agreement from NVIDIA CORPORATION is strictly
# prohibited.
#
import json
from dataclasses import dataclass
from enum import Enum
from typing import Any


@dataclass
class EventMetric:
    """
    A metric attached to an event. It can have a single value, but this value can be an instance of EventMetrics,
    so recursively containing other EventMetrics as values. Otherwise primitive values are supporter: number and strings.
    It is understood that the timestamp of the event is the timestamp of the metric. In OpenTelemetry,
    these metrics are reported as event attributes. For other backends they are serialized
    as a JSON object in the event. For spans they are normally attached to the span end event.
    """

    name: str
    value: Any


class EventMetrics(dict[str, EventMetric]):
    """
    A set of metrics for an event. This is a dictionary of metric name to metric value,
    where the value is an instance of EventMetric. An EventMetric can recursively contain
    other EventMetrics as values.
    """

    def add_metric(self, name: str, value: Any) -> EventMetric:
        """Add a new metric to the set of metrics."""
        self[name] = EventMetric(name, value)
        return self[name]

    def to_json(self) -> dict[str, Any]:
        """Convert all metrics to a JSON-compatible dictionary.

        Returns:
            dict: JSON-serializable representation of all metrics
        """

        def _encode_value(value: Any) -> Any:
            if isinstance(value, str):
                return value
            elif isinstance(value, EventMetrics):
                return value.to_json()
            else:
                return value

        return {metric.name: _encode_value(metric.value) for metric in self.values()}

    @classmethod
    def from_json(cls, data: dict) -> "EventMetrics":
        """Create a EventMetrics instance from a JSON-compatible dictionary.

        Args:
            data: Dictionary containing the metrics data

        Returns:
            EventMetrics: New metrics collection created from the data
        """
        metrics = cls()

        def _decode_value(value: Any) -> Any:
            try:
                if isinstance(value, dict):
                    return EventMetrics.from_json(value)
                elif isinstance(value, str) and value.startswith('"'):
                    return json.loads(value)
                else:
                    return value
            except (json.JSONDecodeError, TypeError):
                return value

        for name, value in data.items():
            metrics.add_metric(name, _decode_value(value))

        return metrics

    @classmethod
    def merge(cls, metrics: "EventMetrics", other: "EventMetrics") -> "EventMetrics":
        """Merge two sets of metrics."""
        merged = cls()
        for metric in metrics.values():
            merged.add_metric(metric.name, metric.value)
        for metric in other.values():
            merged.add_metric(metric.name, metric.value)
        return merged


class ApplicationMetrics(EventMetrics):
    """A set of metrics for an application."""

    @classmethod
    def create(
        cls,
        rank: int | None = None,
        world_size: int | None = None,
        node_name: str | None = None,
        timezone: str | None = None,
        total_iterations: int | None = None,
        checkpoint_enabled: bool | None = None,
        checkpoint_strategy: str | None = None,
    ) -> "ApplicationMetrics":
        metrics = cls()
        if rank is not None:
            metrics.add_metric("rank", rank)
        if world_size is not None:
            metrics.add_metric("world_size", world_size)
        if node_name is not None:
            metrics.add_metric("node_name", node_name)
        if timezone is not None:
            metrics.add_metric("timezone", timezone)
        if total_iterations is not None:
            metrics.add_metric("total_iterations", total_iterations)
        if checkpoint_enabled is not None:
            metrics.add_metric("checkpoint_enabled", checkpoint_enabled)
        if checkpoint_strategy is not None:
            metrics.add_metric("checkpoint_strategy", checkpoint_strategy)
        return metrics


class IterationMetrics(EventMetrics):
    """
    A set of metrics for one or more iterations, either for model training, validation, or testing.
    For testing and validation, these are normally used to report the metrics at the end of the loop, or span.
    For training, these are normally used to report the metrics using the TRAINING_ITERATIONS event.
    """

    @classmethod
    def create(
        cls,
        current_iteration: int | None = None,
        num_iterations: int | None = None,
        interval: int | None = None,
        average_iteration_time: float | None = None,
        average_forward_time: float | None = None,
        average_backward_time: float | None = None,
        average_dataloader_time: float | None = None,
        average_optimizer_update_time: float | None = None,
        tflops: float | None = None,
        tokens_per_second: float | None = None,
        loss: float | None = None,
        batch_size: int | None = None,
    ) -> "IterationMetrics":
        """
        Create a IterationMetrics instance.

        Args:
            current_iteration: The current iteration number of the loop, the last one if reported at the end of the loop
            num_iterations: The total number of iterations in the loop or since last reporting in the case of TRAINING_ITERATIONS
            interval: The interval between the current and previous iteration, where a similar event or span was reported. Note
            that for TRAINING_ITERATIONS, this is normally the same as num_iterations.
            average_iteration_time: The average time per iteration
            average_forward_time: The average model forward time per iteration
            average_backward_time: The average model backward time per iteration
            average_dataloader_time: The average dataloader time per iteration
            average_optimizer_update_time: The average optimizer update time per iteration
            tflops: The number of tera-floating point operations per second for each iteration
            tokens_per_second: The number of tokens processed per second for each iteration, needed if tflops cannot be calculated
            loss: The current loss value
            batch_size: The number of samples or tokens processed per iteration, also known as the batch size
        """
        metrics = cls()
        if current_iteration is not None:
            metrics.add_metric("current_iteration", current_iteration)
        if num_iterations is not None:
            metrics.add_metric("num_iterations", num_iterations)
        if interval is not None:
            metrics.add_metric("interval", interval)
        if average_iteration_time is not None:
            metrics.add_metric("average_iteration_time", average_iteration_time)
        if average_forward_time is not None:
            metrics.add_metric("average_forward_time", average_forward_time)
        if average_backward_time is not None:
            metrics.add_metric("average_backward_time", average_backward_time)
        if average_dataloader_time is not None:
            metrics.add_metric("average_dataloader_time", average_dataloader_time)
        if average_optimizer_update_time is not None:
            metrics.add_metric("average_optimizer_update_time", average_optimizer_update_time)
        if tflops is not None:
            metrics.add_metric("tflops", tflops)
        if tokens_per_second is not None:
            metrics.add_metric("tokens_per_second", tokens_per_second)
        if loss is not None:
            metrics.add_metric("loss", loss)
        if batch_size is not None:
            metrics.add_metric("batch_size", batch_size)
        return metrics


class CheckPointType(str, Enum):
    """The type of checkpoint, global or local."""

    # Global checkpoint, saved to remote storage, and persistent.
    GLOBAL = "global"
    # Local checkpoint, saved to local storage, and non-persistent.
    LOCAL = "local"


class CheckpointMetrics(EventMetrics):
    """
    A set of metrics for a checkpoint save or load events.

    The checkpoint type is either global or local.
    The current iteration is the iteration number at the checkpoint save or load.
    The number of iterations since the previous checkpoint save, not needed for checkpoint load.
    The interval is the number of iterations between the current and previous checkpoint save,
    normally the same as num_iterations unless some iterations were excluded from checkpointing.
    The checkpoint size is the size of the checkpoint in bytes.
    The checkpoint directory is the directory where the checkpoint is saved or loaded from.
    """

    @classmethod
    def create(
        cls,
        checkpoint_type: CheckPointType | None = None,
        current_iteration: int | None = None,
        num_iterations: int | None = None,
        interval: int | None = None,
        checkpoint_size: int | None = None,
        checkpoint_directory: str | None = None,
    ) -> "CheckpointMetrics":
        metrics = cls()
        if checkpoint_type is not None:
            metrics.add_metric("checkpoint_type", checkpoint_type)
        if current_iteration is not None:
            metrics.add_metric("current_iteration", current_iteration)
        if num_iterations is not None:
            metrics.add_metric("num_iterations", num_iterations)
        if interval is not None:
            metrics.add_metric("interval", interval)
        if checkpoint_size is not None:
            metrics.add_metric("checkpoint_size", checkpoint_size)
        if checkpoint_directory is not None:
            metrics.add_metric("checkpoint_directory", checkpoint_directory)
        return metrics


class NVLinkMetrics(EventMetrics):
    """
    A set of metrics for the NVLink communication of the GPU device, normally sent as part of the device metrics below.
    """

    @classmethod
    def create(
        cls,
        tx_data_payload: float | None = None,
        rx_data_payload: float | None = None,
        crc_errors: int | None = None,
        replay_errors: int | None = None,
        recovery_errors: int | None = None,
    ) -> "NVLinkMetrics":
        """
        Create a NVLinkMetrics instance.

        Args:
            tx_data_payload: Tx data payload in KiB
            rx_data_payload: Rx data payload in KiB
            crc_errors: The number of CRC errors on all links over the past sample period.
            replay_errors: The number of replay errors on all links over the past sample period.
            recovery_errors: The number of recovery errors on all links over the past sample period.
        """
        metrics = cls()
        if tx_data_payload is not None:
            metrics.add_metric("tx_data_payload", tx_data_payload)
        if rx_data_payload is not None:
            metrics.add_metric("rx_data_payload", rx_data_payload)
        if crc_errors is not None:
            metrics.add_metric("crc_errors", crc_errors)
        if replay_errors is not None:
            metrics.add_metric("replay_errors", replay_errors)
        if recovery_errors is not None:
            metrics.add_metric("recovery_errors", recovery_errors)
        return metrics


class DeviceMetrics(EventMetrics):
    """
    A set of metrics for the properties of the GPU device, normally sent as part of the training loop to report device metrics
    using the DEVICE_PROPERTIES event.
    """

    @classmethod
    def create(
        cls,
        current_iteration: int | None = None,
        gpu_utilization: float | None = None,
        memory_usage: float | None = None,
        power_draw: float | None = None,
        temperature: float | None = None,
        clock_rate: float | None = None,
        nvlink_metrics: NVLinkMetrics | None = None,
    ) -> "DeviceMetrics":
        """
        Create a DeviceMetrics instance.

        Args:
            current_iteration: The current iteration number of the loop,
                this allows correlating the device metrics with the iteration metrics.
            gpu_utilization: The percent of time over the past sample period
                during which one or more kernels was executing on the GPU
                as reported by the NVIDIA NVML library.
            memory_usage: Return the percent of time over the past sample period
                during which global (device) memory was being read or written
                as reported by the NVIDIA NVML library.
            power_draw: Return the average power draw of the GPU sensor in MilliWatts.
            temperature: Return the average temperature of the GPU sensor
                in Degrees C (Centigrades).
            clock_rate: Return the clock speed of the GPU SM in MHz (megahertz)
                over the past sample period as reported by the NVIDIA NVML library.
            nvlink_metrics: The NVLink metrics for the GPU device, if available.

        """
        metrics = cls()
        if current_iteration is not None:
            metrics.add_metric("current_iteration", current_iteration)
        if gpu_utilization is not None:
            metrics.add_metric("gpu_utilization", gpu_utilization)
        if memory_usage is not None:
            metrics.add_metric("memory_usage", memory_usage)
        if power_draw is not None:
            metrics.add_metric("power_draw", power_draw)
        if temperature is not None:
            metrics.add_metric("temperature", temperature)
        if clock_rate is not None:
            metrics.add_metric("clock_rate", clock_rate)
        if nvlink_metrics is not None:
            metrics.add_metric("nvlink_metrics", nvlink_metrics)
        return metrics
