#!/usr/bin/env python3

import json
import numpy as np
import matplotlib.pyplot as plt
import argparse
import requests
from urllib.parse import urlparse, urljoin
import xml.etree.ElementTree as ET

plt.rcParams["text.usetex"] = True
plt.rcParams["font.size"] = 14


# ---------------------------------------------------------
def parse_arguments():
    parser = argparse.ArgumentParser(
        description="Empilement cumulatif des histogrammes de masse D0 depuis des fichiers JSON CERNBOX avec navigation clavier"
    )
    parser.add_argument(
        "--cernbox",
        help="Lien public vers un dossier CERNBOX",
        default="https://cernbox.cern.ch/files/link/public/J7AVIaolTO4L2ej?openWithDefaultApp=true&\
fileId=eoshome-c%24F5SW64ZPOVZWK4RPMMXWG33HMFXA%3D%3D%3D%3D%21206836114&items-per-page=100&files-pu\
blic-link-view-mode=resource-table&tiles-size=24"
    )
    return parser.parse_args()


# ---------------------------------------------------------
def get_webdav_base_url(public_link):
    parsed = urlparse(public_link)
    token = parsed.path.split("/public/")[-1]
    if not token:
        raise RuntimeError("Impossible d'extraire le token WebDAV depuis l'URL")
    return f"https://cernbox.cern.ch/remote.php/dav/public-files/{token}/"


# ---------------------------------------------------------
def list_json_files_webdav(base_url):
    headers = {"Depth": "1"}
    response = requests.request("PROPFIND", base_url, headers=headers)
    response.raise_for_status()

    root = ET.fromstring(response.text)
    namespace = {"d": "DAV:"}
    files = []

    for elem in root.findall(".//d:href", namespace):
        href = elem.text
        if href.lower().endswith(".json"):
            files.append(urljoin(base_url, href.split("/")[-1]))

    if not files:
        raise RuntimeError("Aucun fichier JSON trouvé dans le dossier WebDAV CERNBOX")
    return sorted(files)


# ---------------------------------------------------------
def load_masses_from_url(url):
    resp = requests.get(url)
    resp.raise_for_status()
    masses = np.asarray(json.loads(resp.text), dtype=float)
    if masses.size == 0:
        raise RuntimeError(f"Fichier JSON vide : {url}")
    return masses


# ---------------------------------------------------------
class CumulativeHistogramNavigator:
    def __init__(self, json_urls, nbins=25, mass_min=1810, mass_max=1920):
        self.json_urls = json_urls
        self.nbins = nbins
        self.mass_min = mass_min
        self.mass_max = mass_max
        self.index = 0

        self.fig, (self.ax_single, self.ax_cum) = plt.subplots(
            ncols=2, figsize=(13, 6), sharex=True
        )
        plt.ion()
        self.fig.canvas.mpl_connect("key_press_event", self.on_key)

        self.bin_edges = np.linspace(self.mass_min, self.mass_max, self.nbins + 1)
        self.bin_centers = 0.5 * (self.bin_edges[:-1] + self.bin_edges[1:])
        self.bin_widths = np.diff(self.bin_edges)

        self.counts_cache = [None] * len(self.json_urls)

        first_counts = self._get_counts(self.index)
        self.cumulative_counts = first_counts.copy()

        self._redraw(single_counts=first_counts)
        plt.show()
        plt.ioff()
        plt.show()

    def _get_counts(self, i: int) -> np.ndarray:
        if self.counts_cache[i] is None:
            url = self.json_urls[i]
            masses = load_masses_from_url(url)
            counts, _ = np.histogram(masses, bins=self.bin_edges)
            self.counts_cache[i] = counts.astype(int, copy=False)
        return self.counts_cache[i]

    def _redraw(self, single_counts: np.ndarray):
        url = self.json_urls[self.index]
        fname = url.split("/")[-1]

        # --- Gauche : distribution simple ---
        self.ax_single.clear()
        self.ax_single.bar(
            self.bin_centers,
            single_counts,
            width=self.bin_widths,
            align="center",
            color="tab:orange",
            edgecolor="black",
        )
        self.ax_single.set_xlabel(r"$m(D^0)\;[\mathrm{MeV}/c^2]$")
        self.ax_single.set_ylabel("Candidats")
        self.ax_single.set_title(fname)
        self.ax_single.grid(alpha=0.3)

        # --- Droite : distribution cumulative ---
        self.ax_cum.clear()
        self.ax_cum.bar(
            self.bin_centers,
            self.cumulative_counts,
            width=self.bin_widths,
            align="center",
            color="tab:blue",
            edgecolor="black",
        )
        self.ax_cum.set_xlabel(r"$m(D^0)\;[\mathrm{MeV}/c^2]$")
        self.ax_cum.set_ylabel("Candidats cumulés")
        self.ax_cum.set_title("Distribution cumulative")
        self.ax_cum.grid(alpha=0.3)

        # Titre global : uniquement X/Y
        self.fig.suptitle(
            f"{self.index + 1}/{len(self.json_urls)} fichiers",
            y=0.98
        )

        self.fig.canvas.draw()
        self.fig.canvas.flush_events()

    def on_key(self, event):
        if event.key == "right":
            if self.index < len(self.json_urls) - 1:
                self.index += 1
                counts = self._get_counts(self.index)
                self.cumulative_counts += counts
                self._redraw(single_counts=counts)

        elif event.key == "left":
            if self.index > 0:
                counts_current = self._get_counts(self.index)
                self.cumulative_counts -= counts_current

                self.index -= 1
                counts_new = self._get_counts(self.index)
                self._redraw(single_counts=counts_new)


# ---------------------------------------------------------
def main():
    args = parse_arguments()
    base_webdav_url = get_webdav_base_url(args.cernbox)
    json_urls = list_json_files_webdav(base_webdav_url)
    print(f"{len(json_urls)} fichiers JSON trouvés.")

    CumulativeHistogramNavigator(json_urls)


# ---------------------------------------------------------
if __name__ == "__main__":
    main()
