import openpyxl
import sys
import statistics
import collections
import pprint
import xlsxwriter
import itertools
import matplotlib
import matplotlib.pyplot as plt
import numpy

# key = (frontrtt, frontrate) = {
#    "Type": "http/dash",
#    "Seq": 0,
#    "Back RTT": "10ms",
#    "Back Rate": "8mbit",
#    "Stall Rate": 0.0,
#    "Average Quality": 9.6
#}

entries_raw = collections.defaultdict(lambda: [])

def quality_convert(x):
    if x < 2:
        return 1500 * x - 500
    else:
        return 2500 * x - 2500

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

def draw7_200ms_mbit_caddy(titles, saveas, unroll, unroll2):
    print("draw7_200ms_mbit_caddy")
    barwidth = 0.20


    telescope   = [x[12] for x in listkeep(unroll2, lambda y: y[5] == "UNIFORM")]
    telescoperr = [x[13] for x in listkeep(unroll2, lambda y: y[5] == "UNIFORM")]
    baseline    = [x[12] for x in listkeep(unroll2, lambda y: y[5] == "UNCHANGE")]
    baselinerr  = [x[13] for x in listkeep(unroll2, lambda y: y[5] == "UNCHANGE")]


#    print(unroll)
#    telescope = [listkeep(unroll, lambda x: x[5] == "ABRTHU|UNIFORM" and x[0] == "BUF:0/11")[0][12],
#                 listkeep(unroll, lambda x: x[5] == "ABRTHU|UNIFORM" and x[0] == "BUF:3/11")[0][12],
#                 listkeep(unroll, lambda x: x[5] == "ABRTHU|UNIFORM" and x[0] == "BUF:6/11")[0][12],
#                 listkeep(unroll, lambda x: x[5] == "ABRTHU|UNIFORM" and x[0] == "BUF:9/11")[0][12]]
#    telescoperr = [listkeep(unroll, lambda x: x[5] == "ABRTHU|UNIFORM" and x[0] == "BUF:0/11")[0][13],
#                   listkeep(unroll, lambda x: x[5] == "ABRTHU|UNIFORM" and x[0] == "BUF:3/11")[0][13],
#                   listkeep(unroll, lambda x: x[5] == "ABRTHU|UNIFORM" and x[0] == "BUF:6/11")[0][13],
#                   listkeep(unroll, lambda x: x[5] == "ABRTHU|UNIFORM" and x[0] == "BUF:9/11")[0][13]]
#
#    baseline = [listkeep(unroll, lambda x: x[5] == "ABRTHU|BASELINE" and x[0] == "BUF:0/11")[0][12],
#                listkeep(unroll, lambda x: x[5] == "ABRTHU|BASELINE" and x[0] == "BUF:3/11")[0][12],
#                listkeep(unroll, lambda x: x[5] == "ABRTHU|BASELINE" and x[0] == "BUF:6/11")[0][12],
#                listkeep(unroll, lambda x: x[5] == "ABRTHU|BASELINE" and x[0] == "BUF:9/11")[0][12]]
#    baselinerr = [listkeep(unroll, lambda x: x[5] == "ABRTHU|BASELINE" and x[0] == "BUF:0/11")[0][13],
#                  listkeep(unroll, lambda x: x[5] == "ABRTHU|BASELINE" and x[0] == "BUF:3/11")[0][13],
#                  listkeep(unroll, lambda x: x[5] == "ABRTHU|BASELINE" and x[0] == "BUF:6/11")[0][13],
#                  listkeep(unroll, lambda x: x[5] == "ABRTHU|BASELINE" and x[0] == "BUF:9/11")[0][13]]

    d = listkeep(unroll, lambda x: x[5] == "DTUBE" and x[0] == "BUF:0/11")[0]
    dtubeo = [d[12], 0, 0, 0]
    dtubeorr = [d[13], 0, 0, 0]

    c = listkeep(unroll, lambda x: x[5] == "CADDY" and x[0] == "BUF:0/11")[0]
    caddy = [c[12], 0, 0, 0]
    caddyorr = [c[13], 0, 0, 0]
    print(dtubeo, caddy)


    #index = [(y[3] + "\n" + y[4]) for y in listkeep(unroll, lambda x: x[5] == "ABRTHU|BASELINE")]
    index = ["Level 0", "Level 3", "Level 6", "Level 9"]

    br1 = numpy.arange(len(telescope)) - barwidth - barwidth/2
    br2 = [x + barwidth for x in br1]
    br3 = [x + barwidth for x in br2]
    br4 = [x + barwidth for x in br3]

    x1, y1 = [0.55, 0.55], [0, 12]
    plt.plot(x1, y1, "--")
    br1new = [br1[0]]
    br1new.extend([x + barwidth for x in br1[1:]])
    plt.bar(br1new, telescope, width = barwidth, yerr = telescoperr,
            error_kw=dict(capsize=5),
            label ="Telescope")
    br2new = [br2[0]]
    br2new.extend([x + barwidth for x in br2[1:]])
    plt.bar(br2new, baseline, width = barwidth, yerr = baselinerr,
            error_kw=dict(capsize=5),
            label ="ABR")
    plt.bar(br3, dtubeo, width = barwidth, yerr = dtubeorr,
            error_kw=dict(capsize=5),
            edgecolor ="green", color = "white", label ="Dtube", hatch = '/')
    plt.bar(br4, caddy, width = barwidth, yerr = caddyorr,
            error_kw=dict(capsize=5),
            edgecolor ="pink", color = "white", label ="HTTP", hatch = 'x')

    plt.xlabel("Cache level", fontsize=14)
    plt.ylabel("QoE", fontsize=14)
    plt.ylim(bottom=0, top=12)

    plt.title(titles, fontsize=16)
    plt.xticks(numpy.arange(len(index)), index)#, rotation="vertical")

    handles, labels = plt.gca().get_legend_handles_labels()
    plt.legend(handles[:], labels[:])
    plt.tight_layout()
    plt.savefig(saveas)
    plt.close("all")

def draw():
    labels = ["Level 0", "Level 3", "Level 6", "Level 9"]

    x = np.arange(len(labels))  # the label locations
    width = 0.35  # the width of the bars

    fig, ax = plt.subplots()
    rects1 = ax.bar(x - width/2, telescope, width, yerr=telescope_e, label="Telescope", capsize=3)
    rects2 = ax.bar(x + width/2, abr, width, yerr=abr_e, label="ABR", capsize=3)

    ax.set_ylabel("QoE")
    ax.set_title("QoE Comparison over WAN")
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.legend()

    fig.tight_layout()

    plt.savefig("WAN.png")
    #plt.show()

def parse_ipfs_baseline(xlsx, read_row_start, read_row_end):
    wb_obj = openpyxl.load_workbook(xlsx)
    sheet = wb_obj.active

    filenameparse = xlsx.replace("mbit", "Mbps").split("_")
    #DTUBE__SEQ_1_BUF__CAP_10ms_25mbit_BACKRATE_10ms_25mbit.10min_4k_dash_25000.html.xlsx
    testtype = filenameparse[0]
    seq = int(filenameparse[3])
    frontrtt = filenameparse[7]
    frontrate = filenameparse[8]
    backrtt = filenameparse[10]
    backrate = filenameparse[11].split(".")[0]

    backratelist = []
    stallratelist = []
    averagequalitylist = []
    startupdelaylist = []
    for row in sheet.iter_rows(min_row=read_row_start, max_row=read_row_end):
        l = [cell.value for cell in row]
        if l[0] == None or l[0] == "t":
            continue
        backratelist.append(l[0])
        srate = (l[1] - l[2])/l[2]
        stallratelist.append(srate)
        startupdelaylist.append(l[4] + l[5])
        averagequalitylist.append(l[6])

    entries_raw[(frontrtt, frontrate, "BUF:0/11")].append({
        "Type": testtype,
        "Seq": seq,
        "Back RTT": backrtt,
        "Back Rate": backrate,
        "Stall Rate": statistics.mean(stallratelist),
        "Average Quality": 11,
    })

def parse_ipfs_internet(xlsx, read_row_start, read_row_end):
    wb_obj = openpyxl.load_workbook(xlsx)
    sheet = wb_obj.active

    filenameparse = xlsx.replace("mbit", "Mbps").split("_")
    # INTERNET2_IPFS_UNIFORM__SEQ_3_BUF_9_CAP_1ms_20mbit_BACKRATE_1ms_20mbit.second_10min_4k_dash.html.xlsx
    testtype = filenameparse[2]
    buf = filenameparse[7] + "/11"
    seq = int(filenameparse[5])
    frontrtt = filenameparse[9]
    frontrate = filenameparse[10]
    backrtt = filenameparse[12]
    backrate = filenameparse[13].split(".")[0]

    backratelist = []
    stallratelist = []
    averagequalitylist = []
    startupdelaylist = []
    for row in sheet.iter_rows(min_row=read_row_start, max_row=read_row_end):
        l = [cell.value for cell in row]
        if l[0] == None or l[0] == "t":
            continue
        backratelist.append(l[0])
        srate = (l[1] - l[2])/l[2]
        stallratelist.append(srate)
        startupdelaylist.append(l[4] + l[5])
        averagequalitylist.append(l[6])


    entries_raw[(frontrtt, frontrate, buf)].append({
        "Type": testtype,
        "Seq": seq,
        "Back RTT": backrtt,
        "Back Rate": backrate,
        "Stall Rate": statistics.mean(stallratelist),
        "Average Quality": statistics.mean(averagequalitylist)
    })

def parse_ipfs_mpd(xlsx, read_row_start, read_row_end):
    wb_obj = openpyxl.load_workbook(xlsx)
    sheet = wb_obj.active

    filenameparse = xlsx.replace("mbit", "Mbps").split("_")
    # IPFSMPD4_SAN_PLCY_UNCACHEBASED_NPF_SEQ_1_BUF_6_CAP_10ms_25mbit_BACKRATE_150ms_8mbit.second_10min_4k_dash.html.xlsx
    # IPFSMPD7_SANTHUONLY-CONT_PLCY_UNIFORM_NPF_SEQ_3_BUF_0_CAP_10ms_25mbit_BACKRATE_400ms_8mbit.second_10min_4k_dash.html.xlsx
    extra = ""
    if filenameparse[1] == "SANTHUONLY-CONT":
        extra = "ABRTHU|"
    else:
        return
        extra = "DYN|"

    buflevel = "BUF:" + filenameparse[8] + "/11"
    testtype = extra + filenameparse[3]
    seq = int(filenameparse[6])
    frontrtt = filenameparse[10]
    frontrate = filenameparse[11]
    backrtt = filenameparse[13]
    backrate = filenameparse[14].split(".")[0]

    backratelist = []
    stallratelist = []
    averagequalitylist = []
    for row in sheet.iter_rows(min_row=read_row_start, max_row=read_row_end):
        l = [cell.value for cell in row]
        if l[0] == None or l[0] == "t":
            stallratelist.append("nan")
            averagequalitylist.append("nan")
        else:
            backratelist.append(l[0])
            srate = (l[1] - l[2])/l[2]
            stallratelist.append(srate)
            averagequalitylist.append(l[6])

    entries_raw[(frontrtt, frontrate, buflevel)].append({
        "Type": testtype,
        "Seq": seq,
        "Back RTT": backrtt,
        "Back Rate": backrate,
        "Stall Rate": stallratelist,
        "Average Quality": averagequalitylist,
    })


def rounder(x):
    return round(float(x), 2);

def filterz(l):
    return [rounder(x) for x in l]

def nonzero(l):
    for x in l:
        if x != 0:
            return True
    return False

start = 2
end = 4

for xlsx in sys.argv:
    if ".xlsx" not in xlsx:
        continue
    if ".json" in xlsx:
        continue

    print(xlsx)
    if xlsx.startswith("IPFSMPD"):
        parse_ipfs_mpd(xlsx, start, end)
    elif xlsx.startswith("DTUBE") or xlsx.startswith("CADDY"):
        parse_ipfs_baseline(xlsx, start, end)
    elif xlsx.startswith("INTERNET2"):
        parse_ipfs_internet(xlsx, start, end)


# key = (frontrtt, frontrate) = {
#    "Type": "http/dash",
#    "Seq": 0,
#    "Back RTT": "10ms",
#    "Back Rate": "8mbit",
#    "Stall Rate": []
#    "Average Quality": []
#}
# to
# key = (frontrtt, frontrate) = {
#    "Type": "http/dash",
#    "Back RTT": "10ms",
#    "Back Rate": "8mbit",
#    "Stall Rate": 0.0,
#    "Average Quality": 9.6,
#}

entries = collections.defaultdict(lambda: [])
for k, dlist in entries_raw.items():
    seqgather = collections.defaultdict(lambda: [])
    for d in dlist:
        seqgather[(d["Type"], d["Back RTT"], d["Back Rate"])].append(d)

    for dk, seqlist in seqgather.items():
        ssrate = []
        if type(seqlist[0]["Stall Rate"]) is float:
            ssrate = [x["Stall Rate"] for x in seqlist]

        else:
            for i in range(len(seqlist[0]["Stall Rate"])):
                for s in range(len(seqlist)):
                    if seqlist[s]["Stall Rate"][i] != "nan":
                        ssrate.append(seqlist[s]["Stall Rate"][i])

        avgrate = []
        if type(seqlist[0]["Average Quality"]) is float or type(seqlist[0]["Average Quality"]) is int:
            avgrate = [x["Average Quality"] for x in seqlist]

        else:
            print(seqlist)
            for i in range(len(seqlist[0]["Average Quality"])):
                for s in range(len(seqlist)):
                    if seqlist[s]["Average Quality"][i] != "nan":
                        avgrate.append(seqlist[s]["Average Quality"][i])

        entries[k].append({
            "Type": dk[0],
            "Back RTT": dk[1],
            "Back Rate": dk[2],
            "Stall Rate": ssrate,
            "Average Quality": avgrate,
        })


#pprint.pprint(entries)

workbook = xlsxwriter.Workbook("score.xlsx")
worksheet = workbook.add_worksheet()

curser = 0

def devl(l):
    try:
        return statistics.stdev(l)
    except Exception as e:
        return 0

unroll = []
for k, dlist in entries.items():
    dlist.sort(key=lambda x: x["Type"])

    typelist = set()
    backendlist = set()
    for entry in dlist:
        typelist.add(entry["Type"])
        backendlist.add((entry["Back RTT"], entry["Back Rate"]))

    typelist = sorted(list(typelist))
    backendlist = sorted(list(backendlist))

#    print(k, typelist, backendlist)
    worksheet.write_row(curser, 1, ["BUF", "Front RTT", "Front Rate", "Back RTT", "Back Rate", "Type", "Quality Average", "Quality Stdev", "StallRate Average", "StallRate Stdev", "QoE(pnty=1)", "QoE(pnty=6)", "QoE(pnty=11)"])
    curser += 1

    grouped_dlist = []
    for e in dlist:
        QoE_lin1 = []
        QoE_lin6 = []
        QoE_lin11 = []

        for i in range(len(e["Average Quality"])):
#            QoE_lin.append((quality_convert(e["Average Quality"][i] * 127) - 25000 * (e["Stall Rate"][i]) * 634.567) / 1024)
            QoE_lin1.append(e["Average Quality"][i] - 1 * (e["Stall Rate"][i]))
            QoE_lin6.append(e["Average Quality"][i] - 6 * (e["Stall Rate"][i]))
            QoE_lin11.append(e["Average Quality"][i] - 11 * (e["Stall Rate"][i]))

        try:
            grouped_dlist.append([k[2], k[0], k[1],
                                  e["Back RTT"], e["Back Rate"], e["Type"],
                                  rounder(statistics.mean(e["Average Quality"])), rounder(devl(e["Average Quality"])),
                                  rounder(statistics.mean(e["Stall Rate"])), rounder(devl(e["Stall Rate"])),
                                  rounder(statistics.mean(QoE_lin1)),
                                  rounder(statistics.mean(QoE_lin6)),
                                  rounder(statistics.mean(QoE_lin11))])#,
#                                  rounder(statistics.mean(QoE_log))])
            unroll.append([k[2], k[0], k[1],
                           e["Back RTT"], e["Back Rate"], e["Type"],
                           statistics.mean(e["Average Quality"]), devl(e["Average Quality"]),
                           statistics.mean(e["Stall Rate"]), devl(e["Stall Rate"]),
                           statistics.mean(QoE_lin1), devl(QoE_lin1),
                           statistics.mean(QoE_lin6), devl(QoE_lin6),
                           statistics.mean(QoE_lin11), devl(QoE_lin11)])

        except Exception as ex:
            print(k, e)
            raise ex
    grouped_dlist = sorted(grouped_dlist, key=lambda e: (e[0], e[1], e[2], e[3], e[4]))

    e = grouped_dlist[0]
    nowtoken = (e[0], e[1], e[2], e[3], e[4])
    for x in grouped_dlist:
        xtoken = (x[0], x[1], x[2], x[3], x[4])
        if xtoken != nowtoken:
            curser += 1
            nowtoken = xtoken
        worksheet.write_row(curser, 1, x)
        curser += 1

workbook.close()

def listkeep(l, func):
    return [x for x in l if func(x)]


# 0 = buf,
# 1 = front rtt,
# 2 = front rate,
# 3 = back rtt,
# 4 = back rate,
# 5 = type,
# 6 = Average Quality,
# 7 = dev AQ,
# 8 = Stall Rate,
# 9 = dev SR,
# 10 = r1,
# 11 = r1err,
# 12 = r6,
# 12 = r6err,
# 13 = r11err,
# 13 = r11err,

alldata = sorted(unroll)
thudata = listkeep(alldata, lambda x: "DYN" not in x[5])
draw7_200ms_mbit_caddy("QoE Comparison in IPFS Network",
                       "QoEallbufCaddy.png", #400ms_25mbit
                       listkeep(thudata,
                                lambda x: (x[3] == "400ms" and x[4] == "25Mbps") or
                                (x[3] == "1ms" and x[4] == "25mbit")),
                       unroll)
