#
#  Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an
# express license agreement from NVIDIA CORPORATION is strictly
# prohibited.
#
import os
import tempfile
from pathlib import Path
from typing import Any

import lightning as L
import torch
import torch.nn as nn
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, TensorDataset

from training_telemetry.config_loader import load_config
from training_telemetry.torch.lightning.callback import TelemetryCallback


# Define a Lightning model
class SimpleModel(L.LightningModule):
    def __init__(self, learning_rate: float = 0.001, weight_decay: float = 0.0001) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

        # Model architecture
        self.layers = nn.Sequential(nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())

        # Loss function
        self.criterion = nn.BCELoss()

    def forward(self, x: torch.Tensor) -> Any:
        return self.layers(x)

    def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Any:
        inputs, targets = batch

        # Forward pass
        outputs = self(inputs)
        loss = self.criterion(outputs.squeeze(), targets)

        # Calculate accuracy
        predictions = (outputs.squeeze() > 0.5).float()
        accuracy = (predictions == targets).float().mean()

        # Log metrics
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_accuracy", accuracy, prog_bar=True)

        return loss

    def configure_optimizers(self) -> Any:
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        return optimizer


def main() -> None:
    # Set random seed for reproducibility
    L.seed_everything(42)

    # Generate some random data for this example
    X = torch.randn(1000, 10, dtype=torch.float32)
    y = (X.sum(dim=1) > 0).float()
    dataset = TensorDataset(X, y)

    log_interval = 100  # log every 100 steps
    batch_size = 32

    # Create DataLoader
    train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    # Initialize model
    model = SimpleModel(learning_rate=0.001, weight_decay=0.0001)

    # Initialize telemetry callback
    config = load_config(
        config_file=Path(__file__).parent / "../example_config.yaml",
        defaults={
            "application": {
                "job_name": "lightning_example",
                "job_id": "1234567890",
                "environment": "test",
                "log_interval": log_interval,
            }
        },
        override_from_env=False,
    )
    telemetry_callback = TelemetryCallback(config=config)
    telemetry_callback.on_app_start()

    temp_dir = tempfile.TemporaryDirectory()
    checkpoint_dir = os.path.join(temp_dir.name, "checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)

    try:
        # Configure checkpoint callback
        checkpoint_callback = ModelCheckpoint(
            dirpath=checkpoint_dir,
            filename="lightning-{epoch:02d}-{train_loss:.2f}",
            every_n_epochs=20,
            save_top_k=3,
            monitor="train_loss",
            mode="min",
        )

        # Configure trainer with callbacks
        trainer = L.Trainer(
            max_epochs=100,
            enable_checkpointing=True,
            enable_progress_bar=False,
            enable_model_summary=True,
            log_every_n_steps=log_interval,
            callbacks=[telemetry_callback, checkpoint_callback],
            logger=False,
        )

        # Train the model
        trainer.fit(model, train_dataloader)
    except Exception as e:
        telemetry_callback.on_exception(trainer, model, e)
        raise e
    finally:
        telemetry_callback.on_app_end()


if __name__ == "__main__":
    main()
