#!/usr/bin/env python3
"""
Generate a standalone Plotly HTML page with an interactive t-SNE scatter plot.
Click any point to see the 4-channel waveform of that event.

Usage:
  python generate_tsne_viewer.py                  # use saved model_weights.pt
  python generate_tsne_viewer.py --train           # retrain from scratch
  python generate_tsne_viewer.py --output viewer.html
"""

import argparse
import base64
import json
import os
import sys
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchmetrics
import lightning as L
from sklearn.manifold import TSNE


# ---------------------------------------------------------------------------
# Reproduce the exact architecture from the notebook
# ---------------------------------------------------------------------------

seed = 17052026
L.seed_everything(seed, verbose=False)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class ScintDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.from_numpy(x).to(torch.float)
        self.y = torch.from_numpy(y).to(torch.float)
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


class NNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim_1, hidden_dim_2, output_dim):
        super().__init__()
        self.linear_1 = nn.Linear(input_dim, hidden_dim_1)
        self.linear_2 = nn.Linear(hidden_dim_1, hidden_dim_2)
        self.linear_3 = nn.Linear(hidden_dim_2, output_dim)
        self.act = nn.ReLU()
    def forward(self, x):
        x = self.act(self.linear_1(x))
        x = self.act(self.linear_2(x))
        return self.linear_3(x)
    def getlatent(self, x):
        x = self.act(self.linear_1(x))
        return self.act(self.linear_2(x))


class Classifier(L.LightningModule):
    def __init__(self, model, model_params, data_config):
        super().__init__()
        self.model = model(**model_params)
        self.data_config = data_config
        self.accuracy = torchmetrics.Accuracy(task="binary", num_classes=2)
    def forward(self, x):
        return self.model(x)
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), self.data_config["learning_rate"])
    def _calculate_loss(self, batch, mode):
        X, y = batch
        logits = self(X)
        target = y.float()
        loss = F.binary_cross_entropy_with_logits(logits, target)
        acc = self.accuracy(logits, target)
        return loss, acc
    def _step(self, batch, mode):
        loss, acc = self._calculate_loss(batch, mode)
        self.log(f"{mode}_loss", loss, batch_size=self.data_config["batch_size"],
                 prog_bar=True, on_epoch=True, on_step=(mode == "train"), sync_dist=True)
        self.log(f"{mode}_acc", acc, batch_size=self.data_config["batch_size"],
                 prog_bar=True, on_epoch=True, on_step=(mode == "train"), sync_dist=True)
        return loss
    def training_step(self, batch):
        return self._step(batch, mode="train")
    def validation_step(self, batch):
        return self._step(batch, mode="val")
    def test_step(self, batch):
        return self._step(batch, mode="test")


# ---------------------------------------------------------------------------
# Data pipeline
# ---------------------------------------------------------------------------

def load_data(path="data.npz"):
    data = np.load(path)
    Li6 = data["Li6"]
    Po = data["Po"]

    Li6_labels = np.zeros((len(Li6), 2))
    Li6_labels[:, 0] = 1
    Po_labels = np.zeros((len(Po), 2))
    Po_labels[:, 1] = 1

    y = np.concatenate([Li6_labels, Po_labels])
    x = np.concatenate([Li6.sum(axis=1), Po.sum(axis=1)])
    x_orig = np.concatenate([Li6, Po])
    x = x / x.sum(axis=1, dtype=np.float32)[:, np.newaxis]

    return x, x_orig, y


def split_data(x, x_orig, y, train_frac=0.8):
    rng = np.random.default_rng(seed)
    indices = rng.permutation(len(x))
    k = int(len(x) * train_frac)
    return {
        "xtrain": x[indices[:k]],
        "ytrain": y[indices[:k]],
        "xtest": x[indices[k:]],
        "ytest": y[indices[k:]],
        "xtest_orig": x_orig[indices[k:]],
    }


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(description="Generate interactive t-SNE + waveform viewer (Plotly HTML)")
    parser.add_argument("--train", action="store_true", help="Retrain model (default: load model_weights.pt)")
    parser.add_argument("--output", default="tsne_viewer.html", help="Output HTML file")
    parser.add_argument("--data", default="data.npz", help="Path to data.npz")
    parser.add_argument("--n-samples", type=int, default=1000, help="Number of t-SNE samples")
    parser.add_argument("--downsample", type=int, default=8,
                        help="Downsample factor for waveforms (4000 / factor = timepoints)")
    args = parser.parse_args()

    print("[1/6] Loading data …")
    x, x_orig, y = load_data(args.data)
    split = split_data(x, x_orig, y)

    print("[2/6] Setting up model …")
    cl = Classifier(
        model=NNModel,
        model_params={"input_dim": 4000, "hidden_dim_1": 250, "hidden_dim_2": 50, "output_dim": 2},
        data_config={"batch_size": 32, "learning_rate": 1e-3},
    )

    if args.train:
        print("[3/6] Training …")
        dset = ScintDataset(split["xtrain"], split["ytrain"])
        train_size = int(np.around(0.8 * len(dset)))
        val_size = len(dset) - train_size
        dset_tr, dset_val = torch.utils.data.random_split(dset, [train_size, val_size])

        t_dl = DataLoader(dset_tr, batch_size=32, shuffle=True, num_workers=0)
        v_dl = DataLoader(dset_val, batch_size=32, shuffle=False, num_workers=0)

        trainer = L.Trainer(max_epochs=10, accelerator="auto", devices=1, strategy="auto")
        trainer.fit(model=cl, train_dataloaders=t_dl, val_dataloaders=v_dl)
        torch.save(cl.model.state_dict(), "model_weights.pt")
        print("       Weights saved to model_weights.pt")
    else:
        if not os.path.exists("model_weights.pt"):
            print("ERROR: model_weights.pt not found. Run with --train first.", file=sys.stderr)
            sys.exit(1)
        cl.model.load_state_dict(torch.load("model_weights.pt", weights_only=True))
        print("       Weights loaded from model_weights.pt")

    print("[4/6] Extracting latent features and probabilities …")
    cl.eval()
    rng = np.random.default_rng(seed)
    n = args.n_samples
    sample_idx = rng.choice(len(split["xtest"]), n, replace=False)

    x_sampled = torch.from_numpy(split["xtest"][sample_idx]).to(torch.float)
    with torch.no_grad():
        latents = cl.model.getlatent(x_sampled).numpy()
        logits = cl.model(x_sampled).numpy()

    probs = F.softmax(torch.from_numpy(logits), dim=1).numpy()
    probs_li6 = probs[:, 0]
    labels = (split["ytest"][sample_idx][:, 0] == 0).astype(int)  # 1 = Li6, 0 = Po
    sample_labels = np.where(labels == 1, "Li6", "Po")
    waveforms_orig = split["xtest_orig"][sample_idx]

    print(f"       Latents: {latents.shape}, Waveforms: {waveforms_orig.shape}")

    print("[5/6] Running t-SNE …")
    tsne = TSNE(n_components=2, perplexity=30, random_state=seed,
                max_iter=1000, learning_rate="auto", init="pca")
    tsne_results = tsne.fit_transform(latents)

    print("[6/6] Generating HTML …")
    html = build_html(
        tsne_results=tsne_results,
        probs_li6=probs_li6,
        labels=labels,
        sample_labels=sample_labels,
        sample_idx=sample_idx,
        waveforms_orig=waveforms_orig,
        downsample=args.downsample,
    )

    with open(args.output, "w") as f:
        f.write(html)
    size_mb = os.path.getsize(args.output) / 1e6
    print(f"       Wrote {args.output} ({size_mb:.1f} MB)")


# ---------------------------------------------------------------------------
# HTML generation
# ---------------------------------------------------------------------------

def build_html(tsne_results, probs_li6, labels, sample_labels, sample_idx,
               waveforms_orig, downsample):
    n = len(tsne_results)
    t = waveforms_orig.shape[2] // downsample

    wf_down = waveforms_orig[:, :, ::downsample].astype(np.float32)
    wf_bytes = wf_down.tobytes()
    wf_b64 = base64.b64encode(wf_bytes).decode("ascii")

    meta = {
        "tsne_x": tsne_results[:, 0].round(4).tolist(),
        "tsne_y": tsne_results[:, 1].round(4).tolist(),
        "probs": probs_li6.round(4).tolist(),
        "labels": labels.tolist(),
        "indices": sample_idx.tolist(),
        "n": n, "ch": 4, "t": t,
        "downsample": downsample,
    }

    meta_json = json.dumps(meta)

    # Use a plain string template (not f-string) to avoid JS/Python brace conflicts
    html = """<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width,initial-scale=1"/>
<title>t-SNE visualisation of NN latent space and waveform signals - </title>
<script src="https://cdn.plot.ly/plotly-2.35.2.min.js"></script>
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
       background: #f8f9fa; color: #222; }
#container { max-width: 1400px; margin: 0 auto; padding: 20px; }
h1 { font-size: 1.4rem; text-align: center; margin-bottom: 4px; }
.subtitle { text-align: center; color: #666; font-size: 0.9rem; margin-bottom: 16px; }
#main { display: flex; gap: 20px; flex-wrap: wrap; }
#tsne-col { flex: 1 1 620px; min-width: 400px; }
#info-col { flex: 1 1 480px; min-width: 360px; display: flex; flex-direction: column; gap: 12px; }
#sample-card { background: #fff; border-radius: 8px; padding: 14px 18px;
                box-shadow: 0 1px 3px rgba(0,0,0,.08); }
#sample-card h3 { font-size: 1.05rem; margin-bottom: 8px; }
.stat { display: flex; justify-content: space-between; padding: 2px 0; font-size: 0.9rem; }
.stat .lbl { color: #777; }
.stat .val { font-weight: 600; }
#waveform-card { background: #fff; border-radius: 8px; padding: 10px;
                  box-shadow: 0 1px 3px rgba(0,0,0,.08); flex: 1; min-height: 320px; }
#waveform-card h3 { font-size: 1rem; margin-bottom: 4px; padding-left: 6px; }
#click-prompt { color: #999; text-align: center; padding: 60px 20px; font-size: 0.95rem; }
.footer { text-align: center; color: #aaa; font-size: 0.8rem; margin-top: 16px; }
</style>
</head>
<body>
<div id="container">
  <h1>t-SNE of Last Hidden Layer</h1>
  <p class="subtitle">Colored by P(Li6) — click any point to see its 4-channel waveform</p>
  <div id="main">
    <div id="tsne-col"><div id="tsne-plot"></div></div>
    <div id="info-col">
      <div id="sample-card">
        <h3 id="card-title">Click a point</h3>
        <div class="stat"><span class="lbl">Sample index</span><span class="val" id="v-idx">—</span></div>
        <div class="stat"><span class="lbl">True label</span><span class="val" id="v-label">—</span></div>
        <div class="stat"><span class="lbl">P(Li6)</span><span class="val" id="v-prob">—</span></div>
      </div>
      <div id="waveform-card">
        <h3>Waveform (4 channels)</h3>
        <div id="waveform-plot"><p id="click-prompt">Click a t-SNE point to view its waveform</p></div>
      </div>
    </div>
  </div>
  <p class="footer">PhyNuBe4 notebook — Kaciel Béraud, LPC Caen</p>
  <p class="footer">interface vibe-coded by A. Vacheret</p>
</div>
<script>
const CH_COLORS = ['#75bbfd','#380282','#fe2c54','#b40b2c'];
const META = PLACEHOLDER_META;
const WF = (() => {
    const bin = atob('PLACEHOLDER_WF');
    const buf = new ArrayBuffer(bin.length);
    const view = new Uint8Array(buf);
    for (let i = 0; i < bin.length; i++) view[i] = bin.charCodeAt(i);
    return new Float32Array(buf);
})();
function getWf(s, ch) {
    const off = (s * META.ch + ch) * META.t;
    return WF.subarray(off, off + META.t);
}
const hoverText = META.labels.map((l, i) =>
    'Sample ' + META.indices[i] + '<br/>Label: ' + (l ? 'Li6' : 'Po') +
    '<br/>P(Li6): ' + META.probs[i].toFixed(4)
);
Plotly.newPlot('tsne-plot', [{
    x: META.tsne_x, y: META.tsne_y,
    mode: 'markers', type: 'scattergl',
    marker: {
        color: META.probs, colorscale: 'RdYlBu',
        colorbar: {title:'P(Li6)'}, size: 6, showscale: true,
        line: {width:0.5, color:'#888'}
    },
    text: hoverText, hoverinfo: 'text',
    selected: {marker: {color:'#000', size:10, opacity:1}},
    unselected: {marker: {opacity:0.4}}
}], {
    xaxis: {title:'t-SNE 1', zeroline:false},
    yaxis: {title:'t-SNE 2', zeroline:false},
    hovermode: 'closest', dragmode: 'zoom',
    margin: {l:50, r:20, t:10, b:50}, height: 520,
    paper_bgcolor: '#f8f9fa', plot_bgcolor: '#fff'
}, {responsive: true}).then(gd => {
    gd.on('plotly_click', data => {
        const i = data.points[0].pointIndex;
        document.getElementById('v-idx').textContent = META.indices[i];
        document.getElementById('v-label').textContent = META.labels[i] ? 'Li6' : 'Po';
        document.getElementById('v-prob').textContent = META.probs[i].toFixed(4);
        document.getElementById('card-title').textContent =
            'Sample ' + META.indices[i] + ' — ' + (META.labels[i] ? 'Li6' : 'Po');
        const time = Array.from({length: META.t}, (_,k) => k * META.downsample);
        const traces = [];
        for (let ch = 0; ch < META.ch; ch++) {
            traces.push({
                x: time, y: Array.from(getWf(i, ch)),
                type: 'scatter', mode: 'lines',
                name: 'Ch ' + ch,
                line: {color: CH_COLORS[ch], width: 1.2}
            });
        }
        Plotly.react('waveform-plot', traces, {
            xaxis: {title:'Time sample', range:[0, Math.min(1500, META.t * META.downsample)]},
            yaxis: {title:'Amplitude'},
            showlegend: true,
            legend: {orientation:'h', y:1.02, x:0.5, xanchor:'center'},
            margin: {l:50, r:10, t:10, b:40}, height: 300,
            paper_bgcolor: '#fff', plot_bgcolor: '#fafafa'
        });
    });
});
</script>
</body>
</html>"""

    html = html.replace("'PLACEHOLDER_WF'", "'" + wf_b64 + "'")
    html = html.replace("PLACEHOLDER_META", meta_json)
    return html


if __name__ == "__main__":
    main()
