#!/usr/bin/env python3
import csv
import itertools
import numpy
import math

def PlotArrays(title,plot_labels,dict_of_arrays,time,share_y_with=None):
    import matplotlib.pyplot as plt
    #import numpy
    import matplotlib
    #import seaborn
    fig = plt.figure(figsize=(11,8.5),dpi=80)
    assert len(plot_labels) == dict_of_arrays[list(dict_of_arrays.keys())[0]].shape[1]
    numPlots = len(plot_labels)
    if share_y_with is None:
        share_y_with = [None,]*numPlots
    outer_grid = matplotlib.gridspec.GridSpec(numPlots, 1, wspace=0.05, hspace=0.2)

    axIndex = 0
    axs=[]
    for axIndex in range(numPlots):
        if axIndex == 0:
            ax=fig.add_subplot(outer_grid[axIndex,0])
        else:
            if share_y_with[axIndex] is None:
                ax=fig.add_subplot(outer_grid[axIndex,0], sharex=axs[0], sharey=None)
            else:
                ax=fig.add_subplot(outer_grid[axIndex,0], sharex=axs[0], sharey=axs[share_y_with[axIndex]])
                pass
            pass
        #set_ticks_both(ax.yaxis)
        axs.append(ax)
        pass
    axId = -1
    for iArray,array_name in enumerate(dict_of_arrays.keys()):
        for axId,label in enumerate(plot_labels):
            array = dict_of_arrays[array_name]
            assert len(array.shape) == 2
            assert array.shape[1] == numPlots

            yy = array[:,axId]
            assert len(yy.shape) == 1
            axs[axId].plot(time,yy,label=array_name)
            pass
        pass
    for axId,label in enumerate(plot_labels):
        leg = axs[axId].legend(loc='best', fancybox=True)
        if leg:
            leg.get_frame().set_alpha(0.5)
            pass
        axs[axId].set_ylabel(label)
        pass
    fig.suptitle(title)
    #fig.subplots_adjust(top=0.99, bottom=0.03) 
    plt.show()


class OptitrackFile:
    def __init__(self,filename):
        with open(filename) as csvfile:
            reader = csv.reader(csvfile, delimiter=',')
            #Read the general header first:
            row0 = next(reader)
            self.general_header = dict()
            for i in range(0,len(row0),2):
                if row0[i]:
                    self.general_header[row0[i]] = row0[i+1]
                    pass
                pass
            # print(self.general_header)
            empty_row = next(reader)
            assert not any(empty_row)
            col_names = None
            for iRow,row in enumerate(itertools.islice(reader,5)):
                if iRow == 2: # Skip the hash info for now
                    continue
                if col_names is None:
                    col_names = [[] for foo in row]
                    pass
                for iCol,col_info in enumerate(row):
                    if col_info:
                        col_names[iCol].append(col_info)
                        pass
                    pass
                pass
            data = [[] for foo in col_names]
            for row in reader:
                for iCol,colData in enumerate(row):
                    if colData:
                        data[iCol].append(float(colData))
                    else:
                        data[iCol].append(float('NaN'))
                        pass
                    pass
                pass
            self.raw_data = {}
            for iCol, col_name in enumerate(col_names):
                self.raw_data[tuple(col_name)] = numpy.array(data[iCol])
                pass
            pass
        pass

    def GroupData(self,rigid_body_name, marker_names):
        self.time = self.raw_data[('Time (Seconds)',)]
        self.rigid_body_XYZ = numpy.stack([self.raw_data[('Rigid Body', rigid_body_name, 'Position', axis)] for axis in 'XYZ'],axis=1)
        assert len(self.rigid_body_XYZ.shape) == 2
        assert self.rigid_body_XYZ.shape[1] == 3
        assert self.rigid_body_XYZ.shape[0] == len(self.time)
        self.rigid_body_q_WXYZ = numpy.stack([self.raw_data[('Rigid Body', rigid_body_name, 'Rotation', axis)] for axis in 'WXYZ'],axis=1)
        assert len(self.rigid_body_q_WXYZ.shape) == 2
        assert self.rigid_body_q_WXYZ.shape[1] == 4
        assert self.rigid_body_q_WXYZ.shape[0] == len(self.time)
        self.rigid_body_markers_XYZ = []
        for marker_name in marker_names:
            temp_XYZ = numpy.stack([self.raw_data[('Rigid Body Marker', marker_name, 'Position', axis)] for axis in 'XYZ'],axis=1)
            assert len(temp_XYZ.shape) == 2
            assert temp_XYZ.shape[1] == 3
            assert temp_XYZ.shape[0] == len(self.time)
            self.rigid_body_markers_XYZ.append(temp_XYZ)
            pass
        pass

    def FindMarkerTriangles(self):
        # This function is just useful to verify that the area of just triangle is constant.
        possible_triangles = list(itertools.combinations(range(len(self.rigid_body_markers_XYZ)),3))
        for iTriangle,(ia,ib,ic) in enumerate(possible_triangles):
            a,b,c = [self.rigid_body_markers_XYZ[i] for i in [ia,ib,ic]]
            area = 0.5*numpy.linalg.norm(numpy.cross(b-a,a-c,axis=1),axis=1)  # from http://mathworld.wolfram.com/TriangleArea.html
            if 0: # check area calc
                length_a_b = numpy.linalg.norm(a-b,axis=1)
                length_a_c = numpy.linalg.norm(a-c,axis=1)
                length_b_c = numpy.linalg.norm(b-c,axis=1)
                s = 0.5*(length_a_b+length_a_c+length_b_c)
                area_heron = numpy.sqrt(s*(s-length_a_b)*(s-length_a_c)*(s-length_b_c)) # Heron's formula
                assert max(abs(area-area_heron)) < 1e-15
                #print("heron's max diff=",max(abs(area-area_heron)))
                pass
            area_mean = numpy.mean(area)
            area_maxDiff = max(abs(area-area_mean))
            #Delta	=	1/2|(x_2-x_1)x(x_1-x_3)|	
            print((ia,ib,ic),area_mean,area_maxDiff)
            pass
        pass

    def PlotRigidData(self):
        dict_of_arrays = {"optitrack_rigid":self.rigid_body_XYZ}
        PlotArrays("rigid body position","XYZ",dict_of_arrays,time=self.time)
        dict_of_arrays = {"optitrack_quat":self.rigid_body_q_WXYZ}
        PlotArrays("rigid body quaternion","WXYZ",dict_of_arrays,time=self.time)
        pass

    def PlotMarkerData(self):
        dict_of_arrays = {}
        for iMarker in range(len(self.rigid_body_markers_XYZ)):
            dict_of_arrays["Marker {}".format(iMarker)]=self.rigid_body_markers_XYZ[iMarker]
            pass
        PlotArrays("Marker positions","XYZ",dict_of_arrays,time=self.time)
        pass



def main():
    optitrack = OptitrackFile("me133a_6nov_wand_250_270m.csv")
    optitrack.GroupData("ME133_Wand",["ME133_Wand:Marker{}".format(i) for i in [1,2,3,4,5]])
    optitrack.PlotRigidData()
    optitrack.PlotMarkerData()
    #optitrack.FindMarkerTriangles()

if __name__ == "__main__":
    main()
