#
#  Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an
# express license agreement from NVIDIA CORPORATION is strictly
# prohibited.
#
import os
import tempfile
from pathlib import Path
from typing import Any

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tzlocal import get_localzone

from training_telemetry.config import BackendType
from training_telemetry.config_loader import load_config
from training_telemetry.context import checkpoint_save, get_recorder, running, timed_span, training
from training_telemetry.events import Event, EventName
from training_telemetry.metrics import (
    ApplicationMetrics,
    CheckpointMetrics,
    CheckPointType,
    DeviceMetrics,
    IterationMetrics,
    NVLinkMetrics,
)
from training_telemetry.provider import Provider
from training_telemetry.spans import SpanColor, SpanName
from training_telemetry.utils import get_rank_count, get_rank_index
from training_telemetry.verbosity import Verbosity

# Generate some random data for this example
torch.manual_seed(42)
# Generate random input features and binary labels
X = torch.randn(1000, 10, dtype=torch.float32)
y = (X.sum(dim=1) > 0).float()
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
num_epochs = 100

# Initialize the telemetry provider with a default configuration
config = load_config(
    config_file=Path(__file__).parent / "example_config.yaml",
    defaults={"application": {"job_name": "torch_example", "job_id": "1234567890", "environment": "test"}},
    override_from_env=False,
)
Provider.set_provider(config)


# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())

    def forward(self, x: torch.Tensor) -> Any:
        return self.layers(x)


def get_application_metrics() -> ApplicationMetrics:
    return ApplicationMetrics.create(
        rank=get_rank_index(),
        world_size=get_rank_count(),
        node_name="localhost",
        timezone=str(get_localzone()),
        total_iterations=num_epochs * len(dataloader),
        checkpoint_enabled=True,
        checkpoint_strategy="sync",
    )


@running(metrics=get_application_metrics())
def main() -> None:
    # Initialize model, loss function and optimizer
    model = SimpleModel()
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

    events = []
    with training() as training_span:
        current_iteration = 0
        accuracy = torch.tensor(float("nan"))
        loss = torch.tensor(float("nan"))

        for epoch in range(num_epochs):
            for batch_idx, (inputs, targets) in enumerate(dataloader):
                with timed_span(SpanName.ITERATION, color=SpanColor.RED, verbosity=Verbosity.PROFILING):
                    # Forward pass
                    with timed_span(SpanName.MODEL_FORWARD, color=SpanColor.RED, verbosity=Verbosity.PROFILING):
                        outputs = model(inputs)
                        loss = criterion(outputs.squeeze(), targets)

                    # Backward pass and optimize
                    with timed_span(SpanName.ZERO_GRAD, color=SpanColor.GREEN, verbosity=Verbosity.PROFILING):
                        optimizer.zero_grad()
                    with timed_span(SpanName.MODEL_BACKWARD, color=SpanColor.BLUE, verbosity=Verbosity.PROFILING):
                        loss.backward()
                    with timed_span(SpanName.OPTIMIZER_UPDATE, color=SpanColor.YELLOW, verbosity=Verbosity.PROFILING):
                        optimizer.step()

                    # Calculate accuracy
                    predictions = (outputs.squeeze() > 0.5).float()
                    accuracy = (predictions == targets).float().mean()

                    current_iteration += 1

                # Log iteration metrics to TRACING backends every iteration
                metrics = IterationMetrics.create(
                    current_iteration=current_iteration,
                    num_iterations=1,
                    average_iteration_time=0.5,  # dummy value, in real life the elapsed time for the iteration
                    average_forward_time=0.1,  # dummy value, in real life the elapsed time for the forward pass
                    average_backward_time=0.2,  # dummy value, in real life the elapsed time for the backward pass
                    average_dataloader_time=0.3,  # dummy value, in real life the elapsed time for the dataloader
                    tflops=123.4,  # dummy value, in real life the estimated tflops for the iteration
                    loss=loss.item(),
                )
                events.append(Event.create(EventName.TRAINING_ITERATIONS, metrics=metrics))
                get_recorder().event(
                    Event.create(EventName.TRAINING_ITERATIONS, metrics=metrics),
                    training_span,
                    verbosity=Verbosity.TRACING,
                )

                # Every 50 iterations, log the averages to the LOGGER backend only, and reset the events. Also log device metrics to all backends.
                if current_iteration % 50 == 0:
                    metrics = IterationMetrics.create(
                        current_iteration=current_iteration,
                        num_iterations=sum([e.metrics["num_iterations"].value for e in events]),
                        average_iteration_time=sum([e.metrics["average_iteration_time"].value for e in events])
                        / len(events),
                        average_forward_time=sum([e.metrics["average_forward_time"].value for e in events])
                        / len(events),
                        average_backward_time=sum([e.metrics["average_backward_time"].value for e in events])
                        / len(events),
                        average_dataloader_time=sum([e.metrics["average_dataloader_time"].value for e in events])
                        / len(events),
                        tflops=sum([e.metrics["tflops"].value for e in events]) / len(events),
                        loss=loss.item(),
                    )
                    get_recorder().event(
                        Event.create(EventName.TRAINING_ITERATIONS, metrics=metrics),
                        training_span,
                        verbosity=Verbosity.INFO,
                        backend_types={BackendType.LOGGER},
                    )
                    events = []

                    device_metrics = DeviceMetrics.create(
                        current_iteration=current_iteration,
                        gpu_utilization=0.5,  # dummy value
                        memory_usage=0.5,  # dummy value
                        power_draw=0.5,  # dummy value
                        temperature=0.5,  # dummy value
                        clock_rate=0.5,  # dummy value
                        nvlink_metrics=NVLinkMetrics.create(
                            tx_data_payload=1000,  # dummy value
                            rx_data_payload=1000,
                            crc_errors=0,  # dummy value
                            replay_errors=0,  # dummy value
                            recovery_errors=0,  # dummy value
                        ),
                    )
                    get_recorder().event(
                        Event.create(EventName.DEVICE_PROPERTIES, metrics=device_metrics),
                        training_span,
                        verbosity=Verbosity.INFO,
                    )

            # Save checkpoint every 5 epochs
            if epoch % 5 == 0:
                print(
                    f"Epoch [{epoch+1}/{num_epochs}], "
                    f"Batch [{batch_idx+1}/{len(dataloader)}], "
                    f"Loss: {loss.item():.4f}, "
                    f"Accuracy: {accuracy.item():.4f}"
                )
                checkpoint = {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "epoch": epoch,
                    "iteration": current_iteration,
                    "loss": loss.item(),
                    "accuracy": accuracy.item(),
                }
                with checkpoint_save() as checkpoint_save_span:
                    with tempfile.TemporaryDirectory() as temp_dir:
                        checkpoint_file_name = os.path.join(temp_dir, f"checkpoint_iter_{current_iteration}.pt")
                        torch.save(checkpoint, checkpoint_file_name)
                        checkpoint_save_span.add_metrics(
                            CheckpointMetrics.create(
                                checkpoint_type=CheckPointType.LOCAL,
                                current_iteration=current_iteration,
                                num_iterations=len(dataloader),
                                checkpoint_directory=temp_dir,
                            )
                        )


if __name__ == "__main__":
    main()
