#!/usr/bin/env python3
"""
TabICL Fine-Tuning Benchmark

Adapt the official TabICL fine-tuning tutorial to real datasets.
Compares zero-shot TabICL vs fine-tuned TabICL vs TabPFN3.
Uses the same OpenML datasets as the rest of the benchmark suite.
"""

import os
import sys
import json
import traceback
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score, log_loss, accuracy_score

os.environ.setdefault("TABPFN_TOKEN", os.environ.get("TABPFN_TOKEN", ""))
from tabicl import FinetunedTabICLClassifier, TabICLClassifier
from tabpfn import TabPFNClassifier

RESULTS_FILE = Path.home() / "tabpfn-playground" / "tabicl_finetune.json"
SEEDS = [42, 43, 44, 45, 46]


def load_openml_dataset(name):
    import openml
    if name == "credit-g":
        ds = openml.datasets.get_dataset(31)
    elif name == "telco-churn":
        ds = openml.datasets.get_dataset(42178)
    elif name == "default-credit":
        ds = openml.datasets.get_dataset(42477)
    elif name == "bank-marketing":
        ds = openml.datasets.get_dataset(1461)
    elif name == "cc-fraud":
        ds = openml.datasets.get_dataset(46455)
    else:
        raise ValueError(f"Unknown dataset: {name}")
    X, y, _, _ = ds.get_data(target=ds.default_target_attribute)
    y = pd.Series(y).astype("category").cat.codes.values
    return X, y


def encode_categorical(X_train, X_val, X_test):
    """Label-encode categorical columns so FinetunedTabICLClassifier accepts them."""
    from sklearn.preprocessing import OrdinalEncoder
    import pandas as pd
    df_train = pd.DataFrame(X_train).copy()
    df_val = pd.DataFrame(X_val).copy() if X_val is not None else None
    df_test = pd.DataFrame(X_test).copy()

    cat_cols = df_train.select_dtypes(include=["object", "category"]).columns.tolist()
    if not cat_cols:
        return df_train.values, df_val.values if df_val is not None else None, df_test.values

    enc = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)
    df_train[cat_cols] = enc.fit_transform(df_train[cat_cols])
    if df_val is not None:
        df_val[cat_cols] = enc.transform(df_val[cat_cols])
    df_test[cat_cols] = enc.transform(df_test[cat_cols])

    return df_train.values, df_val.values if df_val is not None else None, df_test.values


def run_single_seed(dataset_name, seed):
    print(f"\n=== {dataset_name} | seed={seed} ===")
    X, y = load_openml_dataset(dataset_name)

    X_tmp, X_test, y_tmp, y_test = train_test_split(
        X, y, test_size=0.20, random_state=seed, stratify=y
    )
    X_train, X_val, y_train, y_val = train_test_split(
        X_tmp, y_tmp, test_size=0.15 / 0.80,
        random_state=seed, stratify=y_tmp
    )
    print(f"  train={len(y_train)} val={len(y_val)} test={len(y_test)} pos={y_test.mean():.4f}")

    # Pre-encode categoricals for fine-tuning (FinetunedTabICLClassifier needs numeric input)
    X_train_ft, X_val_ft, X_test_ft = encode_categorical(X_train, X_val, X_test)

    # --- Zero-shot TabICL ---
    print("  Fitting zero-shot TabICL...")
    icel_zs = TabICLClassifier(device="cuda", random_state=seed, n_estimators=4, verbose=False)
    icel_zs.fit(X_train, y_train)
    proba_zs = icel_zs.predict_proba(X_test)
    auc_zs = roc_auc_score(y_test, proba_zs[:, 1])
    ap_zs = average_precision_score(y_test, proba_zs[:, 1])
    ll_zs = log_loss(y_test, proba_zs)
    acc_zs = accuracy_score(y_test, np.argmax(proba_zs, axis=1))
    print(f"    ZS: AUC={auc_zs:.4f} AP={ap_zs:.4f} acc={acc_zs:.4f}")

    # --- Fine-tuned TabICL ---
    print("  Fine-tuning TabICL...")

    history = {
        "epoch": [], "val_roc_auc": [], "val_log_loss": [],
        "val_accuracy": [], "train_loss": [],
    }

    class _HistoryLogger:
        def setup(self, config):
            del config
        def log_step(self, metrics, step):
            del metrics, step
        def log_epoch(self, metrics, step):
            del step
            history["epoch"].append(int(metrics.get("train/epoch", len(history["epoch"]))) + 1)
            history["val_roc_auc"].append(float(metrics.get("val/roc_auc", np.nan)))
            history["val_log_loss"].append(float(metrics.get("val/log_loss", np.nan)))
            history["val_accuracy"].append(float(metrics.get("val/accuracy", np.nan)))
            history["train_loss"].append(float(metrics.get("train/mean_loss", np.nan)))
        def finish(self):
            pass

    icel_ft = FinetunedTabICLClassifier(
        epochs=60,
        learning_rate=1e-5,
        n_estimators_finetune=2,
        n_estimators_validation=2,
        n_estimators_inference=4,
        early_stopping=True,
        patience=20,
        eval_metric="roc_auc",
        device="cuda",
        random_state=seed,
        verbose=False,
    )
    icel_ft._make_experiment_logger = lambda: _HistoryLogger()
    icel_ft.fit(X_train_ft, y_train, X_val=X_val_ft, y_val=y_val)
    proba_ft = icel_ft.predict_proba(X_test_ft)
    auc_ft = roc_auc_score(y_test, proba_ft[:, 1])
    ap_ft = average_precision_score(y_test, proba_ft[:, 1])
    ll_ft = log_loss(y_test, proba_ft)
    acc_ft = accuracy_score(y_test, np.argmax(proba_ft, axis=1))
    print(f"    FT: AUC={auc_ft:.4f} AP={ap_ft:.4f} acc={acc_ft:.4f}  ΔAUC={auc_ft-auc_zs:+.4f}")

    # Best epoch
    best_epoch_info = {}
    if history["epoch"]:
        best_idx = int(np.nanargmax(history["val_roc_auc"]))
        best_epoch_info = {
            "best_epoch": history["epoch"][best_idx],
            "best_val_auc": history["val_roc_auc"][best_idx],
            "total_epochs": len(history["epoch"]),
            "history": history,
        }

    # --- TabPFN3 for reference ---
    print("  Fitting TabPFN3...")
    pfn = TabPFNClassifier(device="cuda", random_state=seed, n_estimators=4)
    pfn.fit(X_train, y_train)
    proba_pfn = pfn.predict_proba(X_test)
    auc_pfn = roc_auc_score(y_test, proba_pfn[:, 1])
    ap_pfn = average_precision_score(y_test, proba_pfn[:, 1])
    ll_pfn = log_loss(y_test, proba_pfn)
    acc_pfn = accuracy_score(y_test, np.argmax(proba_pfn, axis=1))
    print(f"    PFN: AUC={auc_pfn:.4f} AP={ap_pfn:.4f} acc={acc_pfn:.4f}")

    return {
        "dataset": dataset_name,
        "seed": seed,
        "n_train": len(y_train),
        "n_val": len(y_val),
        "n_test": len(y_test),
        "pos_rate": float(y_test.mean()),
        "zero_shot": {
            "roc_auc": float(auc_zs), "ap": float(ap_zs),
            "log_loss": float(ll_zs), "accuracy": float(acc_zs),
        },
        "finetuned": {
            "roc_auc": float(auc_ft), "ap": float(ap_ft),
            "log_loss": float(ll_ft), "accuracy": float(acc_ft),
        },
        "tabpfn": {
            "roc_auc": float(auc_pfn), "ap": float(ap_pfn),
            "log_loss": float(ll_pfn), "accuracy": float(acc_pfn),
        },
        "best_epoch_info": best_epoch_info,
    }


def main():
    RESULTS_FILE.parent.mkdir(parents=True, exist_ok=True)
    all_results = []
    if RESULTS_FILE.exists():
        with open(RESULTS_FILE) as f:
            all_results = json.load(f)

    datasets = ["credit-g", "telco-churn", "default-credit", "bank-marketing", "cc-fraud"]

    for ds in datasets:
        for seed in SEEDS:
            exists = any(r.get("dataset") == ds and r.get("seed") == seed for r in all_results)
            if exists:
                print(f"Skip {ds} seed={seed} (exists)")
                continue
            try:
                res = run_single_seed(ds, seed)
                all_results.append(res)
                with open(RESULTS_FILE, "w") as f:
                    json.dump(all_results, f, indent=2)
            except Exception as e:
                traceback.print_exc()
                print(f"ERROR {ds} seed={seed}: {e}")
                all_results.append({"dataset": ds, "seed": seed, "error": str(e), "traceback": traceback.format_exc()})
                with open(RESULTS_FILE, "w") as f:
                    json.dump(all_results, f, indent=2)

    print("\n=== SUMMARY ===")
    valid = [r for r in all_results if "error" not in r]
    for ds in datasets:
        vals = [r for r in valid if r["dataset"] == ds]
        if not vals:
            continue
        ft_deltas = [r["finetuned"]["roc_auc"] - r["zero_shot"]["roc_auc"] for r in vals]
        pfn_vs_zs = [r["tabpfn"]["roc_auc"] - r["zero_shot"]["roc_auc"] for r in vals]
        print(f"{ds:15s}: FT-ZS={np.mean(ft_deltas):+.4f}±{np.std(ft_deltas):.4f}  PFN-ZS={np.mean(pfn_vs_zs):+.4f}±{np.std(pfn_vs_zs):.4f}  (n={len(vals)})")


if __name__ == "__main__":
    main()
