#!/usr/bin/env python3
"""Compare feature importance: XGBoost vs TabPFN vs TabICL."""
import warnings, argparse, json, os
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.inspection import permutation_importance
from sklearn.metrics import roc_auc_score
from scipy.stats import spearmanr, kendalltau

from tabpfn import TabPFNClassifier
from tabicl import TabICLClassifier
import xgboost as xgb

DEVICE = 'cuda'

OPENML_MAP = {
    "credit-g": 31,
    "telco-churn": 42178,
    "default-of-credit-card-clients": 42477,
    "bank-marketing": 1461,
    "Credit_Card_Fraud_Classification": 46455,
}


def load_dataset(name):
    ds = openml.datasets.get_dataset(OPENML_MAP[name])
    X, y, _, _ = ds.get_data(target=ds.default_target_attribute)
    return X, y


def preprocess(X_train, X_test, y_train):
    le = LabelEncoder()
    y_train_enc = le.fit_transform(y_train.astype(str))
    if not hasattr(X_train, "dtypes"):
        X_train, X_test = pd.DataFrame(X_train), pd.DataFrame(X_test)
    num_cols = X_train.select_dtypes(include=[np.number]).columns.tolist()
    cat_cols = [c for c in X_train.columns if c not in num_cols]
    for col in num_cols:
        median = X_train[col].median()
        X_train[col] = X_train[col].fillna(median)
        X_test[col] = X_test[col].fillna(median)
        mean, std = X_train[col].mean(), X_train[col].std() + 1e-8
        X_train[col] = (X_train[col] - mean) / std
        X_test[col] = (X_test[col] - mean) / std
    for col in cat_cols:
        X_train[col] = X_train[col].fillna("missing").astype(str)
        X_test[col] = X_test[col].fillna("missing").astype(str)
        all_vals = pd.concat([X_train[col], X_test[col]], ignore_index=True)
        codes, uniques = pd.factorize(all_vals)
        mapping = {v: k for k, v in enumerate(uniques)}
        X_train[col] = X_train[col].map(mapping).fillna(-1).astype(int)
        X_test[col] = X_test[col].map(mapping).fillna(-1).astype(int)
    X_train_np = X_train[num_cols + cat_cols].to_numpy(np.float32)
    X_test_np = X_test[num_cols + cat_cols].to_numpy(np.float32)
    return X_train_np, X_test_np, y_train_enc, le, list(X_train.columns)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', required=True)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--max_train', type=int, default=None)
    parser.add_argument('--n_repeats', type=int, default=10)
    parser.add_argument('--out', required=True)
    args = parser.parse_args()

    print(f'Dataset: {args.dataset}, seed: {args.seed}')
    X_raw, y_raw = load_dataset(args.dataset)
    X_train_raw, X_test_raw, y_train_raw, y_test = train_test_split(
        X_raw, y_raw, test_size=0.2, random_state=args.seed, stratify=y_raw)

    if args.max_train and len(X_train_raw) > args.max_train:
        idx, _ = train_test_split(
            np.arange(len(X_train_raw)), train_size=args.max_train,
            random_state=args.seed, stratify=y_train_raw)
        if hasattr(X_train_raw, 'iloc'):
            X_train_raw = X_train_raw.iloc[idx]
            y_train_raw = y_train_raw.iloc[idx]
        else:
            X_train_raw = X_train_raw[idx]
            y_train_raw = y_train_raw[idx]

    X_tr, X_te, y_tr, le, feature_names = preprocess(X_train_raw, X_test_raw, y_train_raw)
    y_te = le.transform(y_test.astype(str))

    # Train models
    print('Training XGBoost...')
    xgc = xgb.XGBClassifier(
        n_estimators=200, max_depth=6, learning_rate=0.1,
        subsample=0.8, colsample_bytree=0.8, objective='binary:logistic',
        eval_metric='logloss', random_state=args.seed, n_jobs=4)
    xgc.fit(X_tr, y_tr)

    print('Training TabPFN...')
    pfn = TabPFNClassifier(device=DEVICE, random_state=args.seed, n_estimators=8)
    pfn.fit(X_tr, y_tr)

    print('Training TabICL...')
    icl = TabICLClassifier(device=DEVICE, random_state=args.seed, n_estimators=8)
    icl.fit(X_tr, y_tr)

    # Baseline AUC
    xgb_auc = roc_auc_score(y_te, xgc.predict_proba(X_te)[:, 1])
    pfn_auc = roc_auc_score(y_te, pfn.predict_proba(X_te)[:, 1])
    icl_auc = roc_auc_score(y_te, icl.predict_proba(X_te)[:, 1])
    print(f'Baseline AUCs: XGB={xgb_auc:.4f}, PFN={pfn_auc:.4f}, ICL={icl_auc:.4f}')

    # Permutation importance
    print(f'Computing permutation importance (n_repeats={args.n_repeats})...')
    results = {}
    for name, clf in [('XGBoost', xgc), ('TabPFN', pfn), ('TabICL', icl)]:
        print(f'  {name}...')
        r = permutation_importance(
            clf, X_te, y_te, n_repeats=args.n_repeats,
            random_state=args.seed, scoring='roc_auc', n_jobs=1)
        results[name] = {
            'importances_mean': r.importances_mean.tolist(),
            'importances_std': r.importances_std.tolist(),
            'rank': (len(r.importances_mean) - np.argsort(np.argsort(r.importances_mean))).tolist(),
        }

    # Rank correlations
    models = ['XGBoost', 'TabPFN', 'TabICL']
    corr_matrix = np.zeros((len(models), len(models)))
    for i, m1 in enumerate(models):
        for j, m2 in enumerate(models):
            rho, _ = spearmanr(results[m1]['importances_mean'], results[m2]['importances_mean'])
            corr_matrix[i, j] = rho

    print('\nSpearman rank correlations:')
    print('          XGBoost  TabPFN   TabICL')
    for i, m1 in enumerate(models):
        row = ' '.join([f'{corr_matrix[i,j]:>8.3f}' for j in range(len(models))])
        print(f'{m1:<9} {row}')

    # Top-k overlap
    print('\nTop-5 feature overlap:')
    for k in [3, 5, 10]:
        top_sets = {}
        for m in models:
            top_k = set(np.argsort(results[m]['importances_mean'])[-k:])
            top_sets[m] = top_k
        for i, m1 in enumerate(models):
            for j, m2 in enumerate(models):
                if i < j:
                    overlap = len(top_sets[m1] & top_sets[m2])
                    print(f'  {m1} ∩ {m2} (top-{k}): {overlap}/{k}')

    # Save
    output = {
        'dataset': args.dataset,
        'seed': args.seed,
        'n_features': len(feature_names),
        'feature_names': feature_names,
        'n_test': len(X_te),
        'baseline_auc': {'XGBoost': xgb_auc, 'TabPFN': pfn_auc, 'TabICL': icl_auc},
        'importance': results,
        'correlations': {
            'spearman': {
                f'{m1}_vs_{m2}': float(corr_matrix[i,j])
                for i, m1 in enumerate(models)
                for j, m2 in enumerate(models)
            }
        },
    }
    with open(args.out, 'w') as f:
        json.dump(output, f, indent=2)
    print(f'\nSaved: {args.out}')

    # Plot
    n_features = len(feature_names)
    fig, axes = plt.subplots(1, 3, figsize=(18, max(4, n_features * 0.25)), sharey=True)
    colors = {'XGBoost': '#1f77b4', 'TabPFN': '#2ca02c', 'TabICL': '#d62728'}

    for ax, model in zip(axes, models):
        imp = np.array(results[model]['importances_mean'])
        std = np.array(results[model]['importances_std'])
        order = np.argsort(imp)
        ax.barh(range(n_features), imp[order], xerr=std[order],
                color=colors[model], alpha=0.8, capsize=2)
        ax.set_yticks(range(n_features))
        ax.set_yticklabels([feature_names[i] for i in order], fontsize=8)
        ax.set_xlabel('Permutation importance (Δ ROC-AUC)')
        ax.set_title(f'{model}\nAUC={output["baseline_auc"][model]:.4f}')
        ax.grid(axis='x', alpha=0.3)

    fig.suptitle(f'Feature importance: {args.dataset}', fontsize=14)
    fig.tight_layout()
    plot_path = args.out.replace('.json', '.png')
    fig.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f'Saved plot: {plot_path}')


if __name__ == '__main__':
    main()
