#!/usr/bin/env python3
"""Distill TabPFN's feature priorities into cheap models via engineered interactions."""
import warnings, argparse, json, time
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd

import openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, PolynomialFeatures
from sklearn.metrics import roc_auc_score
from sklearn.inspection import permutation_importance

from tabpfn import TabPFNClassifier
import xgboost as xgb
from catboost import CatBoostClassifier
import lightgbm as lgb

DEVICE = 'cuda'

OPENML_MAP = {
    "credit-g": 31,
    "telco-churn": 42178,
    "default-of-credit-card-clients": 42477,
    "bank-marketing": 1461,
    "Credit_Card_Fraud_Classification": 46455,
}


def load_dataset(name):
    ds = openml.datasets.get_dataset(OPENML_MAP[name])
    X, y, _, _ = ds.get_data(target=ds.default_target_attribute)
    return X, y


def preprocess(X_train, X_test, y_train):
    le = LabelEncoder()
    y_train_enc = le.fit_transform(y_train.astype(str))
    if not hasattr(X_train, "dtypes"):
        X_train, X_test = pd.DataFrame(X_train), pd.DataFrame(X_test)
    num_cols = X_train.select_dtypes(include=[np.number]).columns.tolist()
    cat_cols = [c for c in X_train.columns if c not in num_cols]
    for col in num_cols:
        median = X_train[col].median()
        X_train[col] = X_train[col].fillna(median)
        X_test[col] = X_test[col].fillna(median)
        mean, std = X_train[col].mean(), X_train[col].std() + 1e-8
        X_train[col] = (X_train[col] - mean) / std
        X_test[col] = (X_test[col] - mean) / std
    for col in cat_cols:
        X_train[col] = X_train[col].fillna("missing").astype(str)
        X_test[col] = X_test[col].fillna("missing").astype(str)
        all_vals = pd.concat([X_train[col], X_test[col]], ignore_index=True)
        codes, uniques = pd.factorize(all_vals)
        mapping = {v: k for k, v in enumerate(uniques)}
        X_train[col] = X_train[col].map(mapping).fillna(-1).astype(int)
        X_test[col] = X_test[col].map(mapping).fillna(-1).astype(int)
    cols = num_cols + cat_cols
    return X_train[cols].to_numpy(np.float32), X_test[cols].to_numpy(np.float32), y_train_enc, le, cols


def engineer_features(X, top_indices, degree=2):
    """Generate interaction and polynomial features from top-k important indices."""
    n = X.shape[0]
    k = len(top_indices)
    if k < 2:
        return X.copy()
    
    feats = [X.copy()]
    
    # Pairwise products (interactions)
    for i in range(k):
        for j in range(i + 1, k):
            fi, fj = top_indices[i], top_indices[j]
            feats.append((X[:, fi] * X[:, fj]).reshape(-1, 1))
    
    # Ratios (with epsilon to avoid div by zero)
    for i in range(min(k, 5)):
        for j in range(min(k, 5)):
            if i != j:
                fi, fj = top_indices[i], top_indices[j]
                feats.append((X[:, fi] / (np.abs(X[:, fj]) + 1e-6)).reshape(-1, 1))
    
    # Squared terms
    for i in range(k):
        fi = top_indices[i]
        feats.append((X[:, fi] ** 2).reshape(-1, 1))
    
    return np.hstack(feats)


def train_xgb(X, y, depth, seed=42, extra_feats=None):
    if extra_feats is not None:
        X = np.hstack([X, extra_feats])
    clf = xgb.XGBClassifier(
        n_estimators=200, max_depth=depth, learning_rate=0.1,
        subsample=0.8, colsample_bytree=0.8, objective='binary:logistic',
        eval_metric='logloss', random_state=seed, n_jobs=4)
    t0 = time.perf_counter()
    clf.fit(X, y)
    fit_t = time.perf_counter() - t0
    return clf, fit_t, X.shape[1]


def train_catboost(X, y, seed=42, extra_feats=None):
    if extra_feats is not None:
        X = np.hstack([X, extra_feats])
    clf = CatBoostClassifier(
        iterations=200, depth=6, learning_rate=0.1,
        verbose=False, random_state=seed, loss_function='Logloss')
    t0 = time.perf_counter()
    clf.fit(X, y)
    fit_t = time.perf_counter() - t0
    return clf, fit_t, X.shape[1]


def train_lgbm(X, y, seed=42, extra_feats=None):
    if extra_feats is not None:
        X = np.hstack([X, extra_feats])
    clf = lgb.LGBMClassifier(
        n_estimators=200, max_depth=6, learning_rate=0.1,
        subsample=0.8, colsample_bytree=0.8,
        random_state=seed, n_jobs=4, verbosity=-1)
    t0 = time.perf_counter()
    clf.fit(X, y)
    fit_t = time.perf_counter() - t0
    return clf, fit_t, X.shape[1]


def evaluate(name, clf, X, y, extra_feats=None, is_proba=True):
    if extra_feats is not None:
        X = np.hstack([X, extra_feats])
    t0 = time.perf_counter()
    if is_proba:
        proba = clf.predict_proba(X)[:, 1]
    else:
        proba = clf.predict(X)
    pred_t = time.perf_counter() - t0
    auc = roc_auc_score(y, proba)
    return float(auc), pred_t


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', required=True)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--max_train', type=int, default=None)
    parser.add_argument('--n_repeats', type=int, default=5)
    parser.add_argument('--top_k', type=int, default=10)
    parser.add_argument('--out', required=True)
    args = parser.parse_args()

    print(f'Dataset: {args.dataset}, seed: {args.seed}, top_k: {args.top_k}')
    X_raw, y_raw = load_dataset(args.dataset)
    X_train_raw, X_test_raw, y_train_raw, y_test = train_test_split(
        X_raw, y_raw, test_size=0.2, random_state=args.seed, stratify=y_raw)

    if args.max_train and len(X_train_raw) > args.max_train:
        idx, _ = train_test_split(
            np.arange(len(X_train_raw)), train_size=args.max_train,
            random_state=args.seed, stratify=y_train_raw)
        if hasattr(X_train_raw, 'iloc'):
            X_train_raw = X_train_raw.iloc[idx]
            y_train_raw = y_train_raw.iloc[idx]
        else:
            X_train_raw = X_train_raw[idx]
            y_train_raw = y_train_raw[idx]

    X_tr, X_te, y_tr, le, feature_names = preprocess(X_train_raw, X_test_raw, y_train_raw)
    y_te = le.transform(y_test.astype(str))
    if len(np.unique(y_te)) < 2:
        print('Test set has only one class, aborting.')
        return

    # Train TabPFN for baseline + importance
    print('Training TabPFN...')
    pfn = TabPFNClassifier(device=DEVICE, random_state=args.seed, n_estimators=8)
    t0 = time.perf_counter()
    pfn.fit(X_tr, y_tr)
    pfn_fit = time.perf_counter() - t0
    pfn_auc = roc_auc_score(y_te, pfn.predict_proba(X_te)[:, 1])
    print(f'  TabPFN baseline AUC: {pfn_auc:.4f}')

    # Permutation importance from TabPFN
    print(f'Computing permutation importance (n_repeats={args.n_repeats})...')
    r = permutation_importance(
        pfn, X_te, y_te, n_repeats=args.n_repeats,
        random_state=args.seed, scoring='roc_auc', n_jobs=1)
    imp = r.importances_mean
    top_k_idx = np.argsort(imp)[-args.top_k:][::-1]
    print(f'  Top-{args.top_k} features: {[feature_names[i] for i in top_k_idx]}')

    # Engineer features from top-k
    print('Engineering interaction features from TabPFN top features...')
    eng_tr = engineer_features(X_tr, top_k_idx)
    eng_te = engineer_features(X_te, top_k_idx)
    n_eng = eng_tr.shape[1] - X_tr.shape[1]
    print(f'  Added {n_eng} engineered features ({eng_tr.shape[1]} total)')

    results = {
        'dataset': args.dataset,
        'seed': args.seed,
        'top_k': args.top_k,
        'top_features': [feature_names[i] for i in top_k_idx],
        'n_base_features': X_tr.shape[1],
        'n_engineered': n_eng,
        'TabPFN_baseline': float(pfn_auc),
        'variants': {},
    }

    # ---- cheap models: raw features ----
    print('\n--- Raw features only ---')
    variants_raw = [
        ('XGB_d6', lambda: train_xgb(X_tr, y_tr, 6, args.seed)),
        ('XGB_d12', lambda: train_xgb(X_tr, y_tr, 12, args.seed)),
        ('CatBoost', lambda: train_catboost(X_tr, y_tr, args.seed)),
        ('LightGBM', lambda: train_lgbm(X_tr, y_tr, args.seed)),
    ]
    for name, train_fn in variants_raw:
        clf, fit_t, n_feats = train_fn()
        auc, pred_t = evaluate(name, clf, X_te, y_te)
        results['variants'][name] = {'auc': auc, 'fit_time': fit_t, 'pred_time': pred_t, 'n_features': n_feats}
        print(f'  {name}: AUC={auc:.4f}, fit={fit_t:.1f}s')

    # ---- cheap models: raw + engineered ----
    print('\n--- Raw + engineered features ---')
    variants_eng = [
        ('XGB_d6_eng', lambda: train_xgb(X_tr, y_tr, 6, args.seed, eng_tr[:, X_tr.shape[1]:])),
        ('XGB_d12_eng', lambda: train_xgb(X_tr, y_tr, 12, args.seed, eng_tr[:, X_tr.shape[1]:])),
        ('CatBoost_eng', lambda: train_catboost(X_tr, y_tr, args.seed, eng_tr[:, X_tr.shape[1]:])),
        ('LightGBM_eng', lambda: train_lgbm(X_tr, y_tr, args.seed, eng_tr[:, X_tr.shape[1]:])),
    ]
    for name, train_fn in variants_eng:
        clf, fit_t, n_feats = train_fn()
        auc, pred_t = evaluate(name, clf, X_te, y_te, eng_te[:, X_tr.shape[1]:])
        results['variants'][name] = {'auc': auc, 'fit_time': fit_t, 'pred_time': pred_t, 'n_features': n_feats}
        print(f'  {name}: AUC={auc:.4f}, fit={fit_t:.1f}s')

    # Summary
    print('\n=== Summary ===')
    print(f"TabPFN baseline: {results['TabPFN_baseline']:.4f}")
    for name in sorted(results['variants'].keys()):
        v = results['variants'][name]
        delta = v['auc'] - results['TabPFN_baseline']
        print(f"  {name:20s}: {v['auc']:.4f}  (Δ {delta:+.4f}, {v['n_features']} feats)")

    with open(args.out, 'w') as f:
        json.dump(results, f, indent=2)
    print(f'\nSaved: {args.out}')


if __name__ == '__main__':
    main()
