#!/usr/bin/python3

import os
import numpy as np
import plotly.graph_objects as go
from argparse import ArgumentParser

def from_time (s):
    hms = s.split (':')
    h = m = 0
    if len (hms) == 3:
        h = int (hms.pop (0))
    if len (hms) == 2:
        m = int (hms.pop (0))
    assert len (hms) == 1
    s = float (hms [0]) + m * 60 + h * 3600
    return s
# end def from_time

class Measurement:

    def __init__ \
        (self, variant, popsize, n_obj, idx, user, system, elapsed, is_overall):
        self.variant    = variant
        self.popsize    = popsize
        self.n_obj      = n_obj
        self.idx        = idx
        self.t_user     = from_time (user)
        self.t_system   = from_time (system)
        self.t_elapsed  = from_time (elapsed)
        self.is_overall = is_overall
        assert self.variant in ('Jensen', 'NSGAII')
    # end def __init__

    def __str__ (self):
        d = self.__dict__
        return '%(variant)s: %(n_obj)s %(popsize)s %(idx)s: %(t_user)s' % d
    # end def __str__
    __repr__ = __str__

    @classmethod
    def parse_file (cls, fn):
        with open (fn, 'r') as f:
            for line in f:
                if line.startswith ('Timing'):
                    user = line.split (':', 1)[1]
                    system = elapsed = '0'
                    is_overall = False
                else:
                    user, system, elapsed, _ = line.split (None, 3)
                    assert user.endswith ('user')
                    user    = user [:-4]
                    assert system.endswith ('system')
                    system  = system [:-6]
                    assert elapsed.endswith ('elapsed')
                    elapsed = elapsed [:-7]
                    is_overall = True
                break
            else:
                return
        t, variant, popsize, objectives, i = os.path.basename (fn).split ('_')
        popsize, objectives, i = (int (x) for x in (popsize, objectives, i))
        assert t.endswith ('t')
        if variant:
            assert variant == '-n'
            variant = 'NSGAII'
        else:
            variant = 'Jensen'
        return cls \
            ( variant    = variant
            , popsize    = popsize
            , n_obj      = objectives
            , idx        = i
            , user       = user
            , system     = system
            , elapsed    = elapsed
            , is_overall = is_overall
            )
    # end def parse_file

    def key (self):
        return (self.variant, self.n_obj, self.popsize, self.idx)
    # end def key

# end class Measurement

class Plot_Measurement:

    line_default = dict \
        ( layout = dict
            ( xaxis = dict
                ( gridcolor = '#B0B0B0'
                , showgrid  = True
                , zeroline  = False
                )
            , yaxis = dict
                ( gridcolor = '#B0B0B0'
                , showgrid  = True
                , zeroline  = False
                )
            , paper_bgcolor = 'white'
            , plot_bgcolor  = 'white'
            , title         = {}
            )
        )

    def __init__ (self, args):
        self.dir    = args.directory
        self.ofile  = args.output_file
        self.by_obj = {}
        for fn in os.listdir (self.dir):
            m = Measurement.parse_file (os.path.join (self.dir, fn))
            if m is None:
                continue
            if m.n_obj not in self.by_obj:
                self.by_obj [m.n_obj] = {}
            if m.variant not in self.by_obj [m.n_obj]:
                self.by_obj [m.n_obj][m.variant] = []
            self.by_obj [m.n_obj][m.variant].append (m)
        self.compute ()
    # end def __init__

    def compute (self):
        self.is_overall = None
        popsiz = set ()
        self.jensen = {}
        self.jenstd = {}
        self.nsgaii = {}
        self.nsgstd = {}
        self.ratio  = {}
        for obkey in self.by_obj:
            if obkey not in self.jensen:
                self.jensen [obkey] = {}
                self.jenstd [obkey] = {}
            if obkey not in self.nsgaii:
                self.nsgaii [obkey] = {}
                self.nsgstd [obkey] = {}
            if obkey not in self.ratio:
                self.ratio [obkey] = {}
            n_obj = self.by_obj [obkey]
            jen = {}
            nsg = {}
            l_jen = n_obj ['Jensen']
            l_nsg = n_obj ['NSGAII']
            for m in l_jen:
                if self.is_overall is None:
                    self.is_overall = m.is_overall
                ps = m.popsize
                popsiz.add (ps)
                if ps not in jen:
                    jen [ps] = []
                jen [ps].append (m.t_user)
                assert self.is_overall == m.is_overall
            for m in l_nsg:
                ps = m.popsize
                popsiz.add (ps)
                if ps not in nsg:
                    nsg [ps] = []
                nsg [ps].append (m.t_user)
                assert self.is_overall == m.is_overall
            for ps in popsiz:
                if ps not in jen or ps not in nsg:
                    continue
                ajen = np.array (jen [ps])
                ansg = np.array (nsg [ps])
                self.jensen [obkey][ps] = np.average (ajen)
                self.jenstd [obkey][ps] = np.std (ajen)
                self.nsgaii [obkey][ps] = np.average (ansg)
                self.nsgstd [obkey][ps] = np.std (ansg)
                self.ratio  [obkey][ps] = \
                    self.nsgaii [obkey][ps] / self.jensen [obkey][ps]
        self.popsize = np.array (list (sorted (popsiz)))
        for obkey in list (self.jensen):
            l1 = list (sorted (self.jensen [obkey].items ()))
            l2 = list (sorted (self.jenstd [obkey].items ()))
            a = np.array ([list (aa + bb) for aa, bb in zip (l1, l2)])
            self.jensen [obkey] = a.T
            assert (self.jensen [obkey][0] == self.jensen [obkey][2]).all ()
        for obkey in list (self.nsgaii):
            l1 = list (sorted (self.nsgaii [obkey].items ()))
            l2 = list (sorted (self.nsgstd [obkey].items ()))
            a = np.array ([list (aa + bb) for aa, bb in zip (l1, l2)])
            self.nsgaii [obkey] = a.T
            assert (self.nsgaii [obkey][0] == self.nsgaii [obkey][2]).all ()
        for obkey in list (self.ratio):
            a = np.array (list (sorted (self.ratio [obkey].items ())))
            self.ratio [obkey] = a.T
    # end def compute

    def plot (self):
        cfg = dict (displaylogo = False)
        fig = go.Figure ()
        for obkey in sorted (self.ratio):
            x, y = self.ratio [obkey]
            d = dict (x = x, y = y, name = str (obkey))
            fig.add_trace (go.Scatter (**d))
        layout = self.line_default ['layout']
        xaxis  = layout ['xaxis']
        yaxis  = layout ['yaxis']
        if self.is_overall:
            t = 'Overall speedup Jensen vs. NSGA-II'
        else:
            t = 'Non-dominated sort speedup Jensen vs. NSGA-II'
        layout ['title']['text'] = t
        xaxis.update (title = 'Population size')
        yaxis.update (title = 'Speedup factor')
        fig.update (self.line_default)
        if self.ofile:
            d = dict (include_plotlyjs = 'directory', config = cfg)
            fig.write_html (self.ofile + '_speedup.html', **d)
        else:
            fig.show (config = cfg)
    # end def plot

    def plot_absdiff (self):
        cfg = dict (displaylogo = False)
        fig = go.Figure ()
        for obkey in sorted (self.nsgaii):
            nx, ny, _, _ = self.nsgaii [obkey]
            jx, jy, _, _ = self.jensen [obkey]
            assert (jx == nx).all ()
            d = dict (x = nx, y = ny - jy, name = 'Difference: %s' % obkey)
            fig.add_trace (go.Scatter (**d))
        layout = self.line_default ['layout']
        xaxis  = layout ['xaxis']
        yaxis  = layout ['yaxis']
        if self.is_overall:
            t = 'Overall absolute difference Jensen vs. NSGA-II'
        else:
            t = 'Non-dominated sort absolute difference Jensen vs. NSGA-II'
        layout ['title']['text'] = t
        xaxis.update (title = 'Population size')
        yaxis.update (title = 'Difference (seconds)')
        fig.update (self.line_default)
        if self.ofile:
            d = dict (include_plotlyjs = 'directory', config = cfg)
            fig.write_html (self.ofile + '_absolute_diff.html', **d)
        else:
            fig.show (config = cfg)
    # end def plot_absdiff

    def plot_n_obj (self, n_obj):
        cfg = dict (displaylogo = False)
        fig = go.Figure ()
        x, y, _, e = self.nsgaii [n_obj]
        e = go.scatter.ErrorY (type = 'data', array = e)
        name = '%s: %s' % ('nsgaii', n_obj)
        d = dict (x = x, y = y, error_y = e, name = name)
        fig.add_trace (go.Scatter (**d))
        x, y, _, e = self.jensen [n_obj]
        e = go.scatter.ErrorY (type = 'data', array = e)
        name = '%s: %s' % ('jensen', n_obj)
        d = dict (x = x, y = y, error_y = e, name = name)
        fig.add_trace (go.Scatter (**d))
        layout = self.line_default ['layout']
        xaxis  = layout ['xaxis']
        yaxis  = layout ['yaxis']
        if self.is_overall:
            t = 'Overall time %d objectives' % n_obj
        else:
            t = 'Non-dominated sort time %d objectives' % n_obj
        layout ['title']['text'] = t
        xaxis.update (title = 'Population size')
        yaxis.update (title = 'CPU Time (seconds)')
        fig.update (self.line_default)
        if self.ofile:
            d = dict (include_plotlyjs = 'directory', config = cfg)
            fig.write_html (self.ofile + '_time_%s.html' % n_obj, **d)
        else:
            fig.show (config = cfg)
    # end def plot_n_obj

# end class Plot_Measurement

if __name__ == '__main__':
    cmd = ArgumentParser ()
    cmd.add_argument \
        ( '-d', '--directory'
        , help    = 'Directory containing measurements, default=$(default)s'
        , default = 'm'
        )
    cmd.add_argument \
        ( '-o', '--output-file'
        , help = 'Output file prefix for html graphics'
        )
    args = cmd.parse_args ()
    m = Plot_Measurement (args)
    m.plot_n_obj (2)
    m.plot_n_obj (3)
    m.plot_n_obj (5)
    m.plot_n_obj (10)
    m.plot_n_obj (20)
    m.plot_absdiff ()
    m.plot ()
