#!/usr/bin/env python3
"""
Train a small LLM to learn integer addition, one digit per token.
Used for GPU power scaling benchmark on RTX 5090.

Usage:
    python train_addition_llm.py --epochs 3 --power-limit 575
"""

import argparse
import json
import subprocess
import time
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


class DigitAdditionDataset(Dataset):
    """Dataset of random integer additions, encoded as digit tokens."""

    def __init__(self, n_samples: int = 50000, num_digits: int = 8):
        self.num_digits = num_digits
        self.samples = []
        torch.manual_seed(42)
        for _ in range(n_samples):
            a = torch.randint(0, 10 ** num_digits, (1,)).item()
            b = torch.randint(0, 10 ** num_digits, (1,)).item()
            result = a + b
            self.samples.append((a, b, result))

    def __len__(self):
        return len(self.samples)

    def _to_digits(self, n: int, length: int) -> list[int]:
        s = str(n).zfill(length)
        return [int(ch) for ch in s]

    def __getitem__(self, idx):
        a, b, result = self.samples[idx]
        a_digits = self._to_digits(a, self.num_digits)
        b_digits = self._to_digits(b, self.num_digits)
        # Simple representation: a digits + b digits + result digits
        # Result may need num_digits+1 digits
        result_digits = self._to_digits(result, self.num_digits + 1)
        tokens = a_digits + b_digits + result_digits
        return torch.tensor(tokens, dtype=torch.long)


class AdditionTransformer(nn.Module):
    def __init__(
        self,
        vocab_size: int = 20,
        d_model: int = 512,
        n_layers: int = 6,
        n_heads: int = 8,
        max_len: int = 32,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_len, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4 * d_model,
            batch_first=True,
            dtype=torch.float32,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        pos = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        x = self.embedding(x) + self.pos_embedding(pos)
        x = self.transformer(x)
        return self.head(x)


def set_power_limit(watts: int):
    """Set GPU power limit via nvidia-smi (requires sudo)."""
    subprocess.run(
        ["sudo", "nvidia-smi", "-pl", str(watts)],
        check=True,
        capture_output=True,
    )
    print(f"Power limit set to {watts}W")


def train(args: argparse.Namespace):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")
    if device.type == "cuda":
        print(f"GPU: {torch.cuda.get_device_name(0)}")

    if args.power_limit:
        set_power_limit(args.power_limit)

    dataset = DigitAdditionDataset(n_samples=50000, num_digits=8)
    # Full-batch training as used in the benchmark
    loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)

    model = AdditionTransformer(
        vocab_size=20,
        d_model=512,
        n_layers=6,
        n_heads=8,
        max_len=32,
    ).to(device)

    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    criterion = nn.CrossEntropyLoss()

    model.train()
    start = time.perf_counter()

    for epoch in range(args.epochs):
        for batch in loader:
            batch = batch.to(device)
            # Teacher forcing: predict next token
            input_ids = batch[:, :-1]
            targets = batch[:, 1:]

            optimizer.zero_grad()
            logits = model(input_ids)
            loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{args.epochs}, loss: {loss.item():.4f}")

    elapsed = time.perf_counter() - start
    print(f"Training complete in {elapsed:.1f}s")

    # Save results
    results = {
        "power_limit_w": args.power_limit,
        "epochs": args.epochs,
        "wall_time_s": round(elapsed, 1),
        "device": str(device),
        "gpu_name": torch.cuda.get_device_name(0) if device.type == "cuda" else None,
    }
    out_path = Path(args.output)
    out_path.write_text(json.dumps(results, indent=2))
    print(f"Results written to {out_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train addition LLM for power scaling benchmark")
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--power-limit", type=int, default=None)
    parser.add_argument("--output", type=str, default="bench_result.json")
    args = parser.parse_args()
    train(args)
