import os
import re
import xlrd
import pandas as pd
import numpy as np
import datetime as dt
import multiprocessing as mp
from pathlib import Path


class AvoidedCostOutputs:
    program_administrator = ''
    generation_costs_source = ''
    generation_costs = None
    transmission_and_distribution_costs_source = ''
    transmission_and_distribution_costs = None
    shift_hours = None
    shiftable = None

    def __init__(self):
        self.generation_costs = pd.DataFrame(
            columns=(
                'ProgramAdministrator',
                'ClimateZone',
                'Timestamp',
                'DataType',
                'Units',
                'Value'
            )
        )
        self.transmission_and_distribution_costs = pd.DataFrame(
            columns=(
                'ProgramAdministrator',
                'ClimateZone',
                'Timestamp',
                'DataType',
                'Units',
                'Value'
            )
        )
        self.shift_hours = 0
        self.shiftable = False

    def read_xlsb(self, file_path, sheet=''):
        # loads the data in the binary excel spreadsheet located at file_path
        # into the applicable attribute

        filename = Path(file_path).name
        program_administrator_match = re.search(
            r'(PG(\&)?E|SCE|SDG(\&)?E|S(o)?C(al)?G(as)?)', file_path)
        cost_type_match = re.search(r'(Gen|TD)', file_path)

        sheet_names_check = {
            'Gen': (
                    'Gen',
                    'EmissionRate'
                ),
            'TD': (
                    'CZ1', 'CZ2', 'CZ3A', 'CZ3B', 'CZ4',
                    'CZ5', 'CZ6', 'CZ7', 'CZ8',
                    'CZ9', 'CZ10', 'CZ11', 'CZ12',
                    'CZ13', 'CZ14', 'CZ15', 'CZ16',
                    'System'
                )
        }

        # initialize output dataframe:
        data_out = pd.DataFrame(
            columns=(
                'ProgramAdministrator',
                'ClimateZone',
                'Timestamp',
                'DataType',
                'Units',
                'Value'
            )
        )

        if program_administrator_match and cost_type_match:
            self.program_administrator = program_administrator_match.group(0)
            cost_type = cost_type_match.group(0)

            # get a list of the sheets in the excel file:
            sheet_names = pd.ExcelFile(filename, engine='pyxlsb').sheet_names

            if sheet in sheet_names_check[cost_type] and sheet in sheet_names:
                sheet_names = [sheet]
            elif sheet != '':
                print('WARNING: The selected sheet was not found--reading full workbook')

            # retrieve data from each sheet with a valid name:
            for sheet_name in filter(lambda sn: sn in sheet_names_check[cost_type], sheet_names):
                print('< Retrieving Sheet \'{}\' from file \'{}\' ... >'.format(
                    sheet_name, filename))
                data_in = pd.read_excel(
                    filename, sheet_name=sheet_name, header=1, engine='pyxlsb').loc[range(8760), :]
                data_in.rename(
                    columns={data_in.columns[0]: 'timestamp'}, inplace=True)
                data_in.loc[:, 'timestamp'] = data_in.loc[:, 'timestamp'].map(
                    lambda d: xlrd.xldate.xldate_as_datetime(d, 0))
                data_in.dropna(axis='columns', how='all', inplace=True)

                # set values for indexing fields:
                if cost_type == 'Gen':
                    climate_zone = 'N/A'
                    if sheet_name == 'EmissionRate':
                        data_type = 'CO2'
                        units = 'tons/kWh'
                    else:
                        data_type = 'Gen'
                        units = '$/MWh'
                else:
                    climate_zone = sheet_name
                    data_type = 'TD'
                    units = '$/MWh'

                # use only year-valued column labels:
                year_columns = filter(lambda c: isinstance(
                    c, int) and c >= 2000, data_in.columns)

                # convert column-wise annual data into long-formatted table:
                for year in year_columns:
                    row_count = 8760
                    timestamp = data_in.loc[:, 'timestamp']
                    values = data_in.loc[:, year]

                    data_out = data_out.append(
                        pd.DataFrame(
                            {
                                'ProgramAdministrator': [self.program_administrator] * row_count,
                                'ClimateZone': [climate_zone] * row_count,
                                'Timestamp': timestamp.map(lambda d: dt.datetime(int(year), d.month, d.day) + dt.timedelta(hours=int(round(d.hour + d.minute/60 + d.second/3600)))),
                                'DataType': [data_type] * row_count,
                                'Units': [units] * row_count,
                                'Value': values
                            }
                        )
                    )

                # convert NOx and PM10 columns:
                if sheet_name == 'EmissionRate':
                    f = lambda s: re.search('((no|NO)(x|X)|(pm|PM)10)', str(s))
                    pollutant_columns = filter(f, data_in.columns)
                    first_year = int(next((x for x in filter(lambda c: isinstance(
                        c, int) and c >= 2000, data_in.columns)), 1900))
                    for pollutant in pollutant_columns:
                        if re.match('((no|NO)(x|X))', pollutant):
                            pollutant_label = 'NOx'
                        elif re.match('((pm|PM)10)', pollutant):
                            pollutant_label = 'PM10'
                        values = data_in.loc[:, pollutant]
                        data_out = data_out.append(
                            pd.DataFrame(
                                {
                                    'ProgramAdministrator': [self.program_administrator] * row_count,
                                    'ClimateZone': [climate_zone] * row_count,
                                    'Timestamp': timestamp.map(lambda d: dt.datetime(int(first_year), d.month, d.day) + dt.timedelta(hours=int(round(d.hour + d.minute/60 + d.second/3600)))),
                                    'DataType': [pollutant_label] * row_count,
                                    'Units': ['lbs/MWh'] * row_count,
                                    'Value': values
                                }
                            )
                        )
            data_out = data_out.set_index(
                ['ProgramAdministrator', 'ClimateZone', 'DataType']).sort_index()
            self.shift_hours = 0
            if cost_type == 'Gen':
                self.generation_costs = data_out
                self.generation_costs_source = filename
            elif cost_type == 'TD':
                self.transmission_and_distribution_costs = data_out
                self.transmission_and_distribution_costs_source = filename
        else:
            try:
                raise Exception(filename)
            except:
                print('ERROR: Unable to parse avoided cost output filename--must contain utility (PG&E, SCE, SCG, or SDG&E) and either \'Gen\' or \'TD\'')


    def read_directory(self, directory_path, program_administrator):
        # retrieves avoided cost tables in a given directory_path for the
        # requested program_administrator

        self.program_administrator = program_administrator
        f = lambda s: re.search('^('+self.program_administrator+'|' +
                                self.program_administrator.replace('&', '')+')'+'(\s|_)*(Gen|TD)\.xlsb$', s)
        filenames = filter(f, os.listdir(directory_path))
        count = 0
        for filename in filenames:
            self.read_xlsb(filename)
            count += 1
        if count == 0:
            try:
                raise Exception(directory_path, program_administrator)
            except:
                print('ERROR: No files found in \'{}\'for PA=\'{}\''.format(
                    directory_path, program_administrator))


    def activate_shifting(self):
        self.shiftable = True
        print('NOTE: Avoided Cost Output Shifting is ON')

    def deactivate_shifting(self):
        self.shiftable = True
        self.shift(0)
        self.shiftable = False
        print('NOTE: Avoided Cost Output Shifting is OFF')

    def shift(self, shift_hours):
        # shifts the timestamps for the avoided cost outputs by the number of
        # hours specified in shift_hours; a positive value of shift_hours will
        # cause the data to shift earlier, i.e., if shift_hours = +24, the
        # avoided Gen and TD costs for Jan 2 will become the costs for Jan 1,
        # and the costs for Jan 1 shift to Dec 31. This method overwrites the
        # object's data and keeps track of the shift so invoking _.shift(0)
        # restores the original timestamps.
        if self.shiftable:
            if not isinstance(shift_hours,int):
                try:
                    raise Exception(shift_hours)
                except Exception as ex:
                    print('ERROR: Failed to shift timestamps by {} hours--only integers allowed'.format(ex))
            elif shift_hours != self.shift_hours:
                f = lambda t, y: dt.datetime(y,t.month,t.day,t.hour,0,0)
                g = lambda r: f(r.loc['Timestamp'] + dt.timedelta(hours = shift_hours - self.shift_hours),r.loc['Timestamp'].year)

                if len(self.generation_costs.index)>0:
                    gen_out = self.generation_costs
                    gen_out['Timestamp'] = gen_out.apply(g,axis='columns')
                    self.generation_costs = gen_out.sort_values('Timestamp').sort_index()
                if len(self.transmission_and_distribution_costs.index)>0:
                    td_out = self.transmission_and_distribution_costs
                    td_out['Timestamp'] = td_out.apply(g,axis='columns')
                    self.transmission_and_distribution_costs = td_out.sort_values('Timestamp').sort_index()
                self.shift_hours = shift_hours
        else:
            print('NOTE: Avoided Cost Output Shifting is OFF')

    def chunkify(self):
        # breaks the combined 'Gen' and 'TD' tables into chunks based on
        # unique indices, returns as iterator.

        # stack the generation and t&d tables:
        combined_table = pd.concat([self.generation_costs,self.transmission_and_distribution_costs])

        # get list of unique tuples indices:
        indices = dict.fromkeys(combined_table.index).keys()

        # return an iterator for the chunks of the table:
        f = lambda k: k[2].lower()!='nox' and k[2].lower()!='pm10'
        for index in filter(f,indices):
            yield combined_table.loc[index,:]



    def first_timestamp(self):
        # returns the earliest timestamp in either the generation or the t&d tables.

        timestamps = []
        if len(self.generation_costs.index) > 0:
            timestamps.append(self.generation_costs.loc[:,'Timestamp'].min())
        if len(self.transmission_and_distribution_costs.index) > 0:
            timestamps.append(self.transmission_and_distribution_costs.loc[:,'Timestamp'].min())
        return min(timestamps)

class ImpactProfile:
    # class to represent an annual impact profile, consisting of 8760 data points
    # of normalized electric usage or savings, indexed by timestamp, in a pandas
    # dataframe.

    electric_target_sector = ''
    name = ''
    description = ''
    data = None
    shift_hours = 0
    shiftable = True

    def __init__(self):
        self.data = pd.DataFrame(columns=['Timestamp','Impact'])

    def read_csv(self,filename,column_name):
        # read a csv containing one or more hourly impact profiles in columns:
        data_in = pd.read_csv(filename)

        # check for 8760 rows:
        if data_in.index.size != 8760:
            try:
                raise Exception(data_in.index.size)
            except Exception as ex:
                print('ERROR: Impact profiles must have 8760 rows--the specified file has {} rows'.format(ex))
        elif column_name not in data_in.columns:
            try:
                raise Exception(column_name)
            except Exception as ex:
                print('ERROR: No column labelled \'{}\' found'.format(ex))
        else:
            # convert impact profile hour numbers to timestamps:
            def f(row):
                timestamp = dt.datetime(1900,1,1,0,0,0) + dt.timedelta(hours=row.loc['hour_of_year'])
                return {
                    'Timestamp' : timestamp,
                    'Impact' : row.loc[column_name],
                }

            # extract requested impact profile and sort by timestamp:
            data_out = data_in.apply(f, axis='columns',result_type='expand')
            self.name = column_name
            self.shift_hours = 0
            self.data = data_out.set_index(['Timestamp']).sort_index()

    def activate_shifting(self):
        self.shiftable = True
        print('NOTE: Impact Profile Shifting is ON')

    def deactivate_shifting(self):
        self.shift(0)
        self.shiftable = False
        print('NOTE: Impact Profile Shifting is OFF')

    def shift(self,shift_hours):
        # shifts the timestamps for the impact profile by the number of hours
        # specified in shift_hours; a positive value of shift_hours will cause
        # the profile to shift earlier, i.e., if shift_hours = +24, the profile
        # for Jan 2 will become the profile for Jan 1 and the profile for Jan 1
        # will become Dec 31. This method overwrites the object's data and
        # keeps track of the shift so invoking _.shift(0) restores the original
        # timestamps.

        # if the input variable matches the current shift, there is no need to
        # run the calculations:
        if self.shiftable:
            if not isinstance(shift_hours,int):
                try:
                    raise Exception(shift_hours)
                except Exception as ex:
                    print('WARNING: Failed to shift timestamps by {} hours, only integers allowed.'.format(ex))
            elif shift_hours != self.shift_hours:
                f = lambda t: dt.datetime(1900,t.month,t.day,t.hour,0,0)
                g = lambda r: f(r.loc['Timestamp'] + dt.timedelta(hours = shift_hours - self.shift_hours))

                data_out = self.data.reset_index()
                data_out['Timestamp'] = data_out.apply(g,axis='columns')
                data_out = data_out.set_index(['Timestamp']).sort_index()

                self.shift_hours = shift_hours
                self.data = data_out
        else:
            print('NOTE: Impact Profile Shifting is OFF')

    def chunkify(self):
        # breaks the impact profile into four chunks, one for each quarter of the year.

        # create list of bounding dates for quarters in arbitrary year of 1900:
        quarters = quarter_bounds(dt.datetime(1900,1,1,0,0),dt.datetime(1900,12,31,23,0,0))

        # get bounds for each quarter for filtering impact profile:
        chunks = []
        for quarter in quarters:
            chunks.append(self.data.loc[(self.data.index>=quarter['start']) & (self.data.index<=quarter['end']),:])

        return chunks

    def first_timestamp(self):
        # returns the earliest timestamp in the load profile. Should be the
        # hour of midnight on January 1, 1900.
        return self.data.index.min()

class AvoidedCostElectric:
    # represents a section of the avoided cost electric table in the cost
    # effectiveness tool, corresponding to a single combination of program
    # administrator, target sector, and end use

    impact_profile = None
    avoided_cost_outputs = None
    shift_table = ''
    data = None
    emissions = None

    def __init__(self):
        self.impact_profile = ImpactProfile()
        self.avoided_cost_outputs = AvoidedCostOutputs()
        self.set_shift_table('impact profile')
        self.proc_threads = 4
        self.data = pd.DataFrame(
            columns=[
                'ProgramAdministrator',
                'EndUse',
                'ElectricTargetSector',
                'ClimateZone',
                'DataType',
                'YearQuarter',
                'Cost'
            ]
        )
        self.emissions = pd.DataFrame(
            columns=[
                'ProgramAdministrator',
                'EndUse',
                'ElectricTargetSector',
                'ClimateZone',
                'DataType',
                'EmissionsRate'
            ]
        )

    def set_shift_table(self,shift_table_name):
        if shift_table_name.lower() == 'impact profile':
            self.impact_profile.activate_shifting()
            self.avoided_cost_outputs.deactivate_shifting()
            self.shift_table = 'impact profile'
        elif shift_table_name.lower() == 'avoided cost outputs':
            self.impact_profile.deactivate_shifting()
            self.avoided_cost_outputs.activate_shifting()
            self.shift_table = 'avoided cost outputs'
        else:
            try:
                raise Exception(shift_table_name)
            except Exception as ex:
                print('WARNING: shift table \'{}\' not found. Please use either \'impact profile\' or \'avoided cost outputs\''.format(ex))

    def set_impact_profile(self,impact_profile):
        self.impact_profile = impact_profile

    def set_avoided_cost_outputs(self,avoided_cost_outputs):
        self.avoided_cost_outputs = avoided_cost_outputs

    def generate_avoided_cost_electric_table(self):
        # calculates the sums of the products of the impact profile values and the
        # energy costs for each hour in each quarter and returns quarterly avoided
        # costs.

        # align weekdays for impact profile and avoided cost outputs:
        shift_hours = ( self.avoided_cost_outputs.first_timestamp().weekday() - self.impact_profile.first_timestamp().weekday() ) * 24
        self.impact_profile.shift(shift_hours)

        impact_profile_chunks = self.impact_profile.chunkify()

        # setup final avoided cost table:
        self.data = pd.DataFrame({
            'ProgramAdministrator' : [],
            'ElectricEndUse'       : [],
            'ElectricTargetSector' : [],
            'ClimateZone'          : [],
            'DataType'             : [],
            'YearQuarter'          : [],
            'Cost'                 : [],
        })

        # iterate through avoided_cost_outputs_chunks:
        for avoided_cost_outputs_chunk in self.avoided_cost_outputs.chunkify():
            # get earliest and latest dates in chunk of avoided cost outputs:
            start_date = avoided_cost_outputs_chunk.loc[:,'Timestamp'].min()
            end_date = avoided_cost_outputs_chunk.loc[:,'Timestamp'].max()

            # generate a list of quarters within the earliest and latest avoided cost dates:
            quarters = quarter_bounds(start_date,end_date)

            # setup list of data chunks for multiprocessing:
            index_mask = lambda quarter: (avoided_cost_outputs_chunk.loc[:,'Timestamp']>=quarter['start']) & (avoided_cost_outputs_chunk.loc[:,'Timestamp']<=quarter['end'])
            mp_chunks = [
                {
                    'avoided_cost_outputs_chunk' : avoided_cost_outputs_chunk.loc[index_mask(quarter)],
                    'impact_profile_chunk' : impact_profile_chunks[quarter['quarter']]
                } for quarter in quarters
            ]

            # calculate avoided cost table using multiprocessing:
            with mp.Pool(self.proc_threads) as mp_pool:
                self.data = self.data.append(pd.DataFrame(mp_pool.map(self.mp_mapper,mp_chunks)),ignore_index=True)

    def generate_emissions_table(self):
        if 'NonRes' in self.impact_profile.name or 'Non_Res' in self.impact_profile.name:
            electric_target_sector = 'NonRes'
            electric_end_use = self.impact_profile.name[self.impact_profile.name.find('Res_')+4:]
        elif 'Res' in self.impact_profile.name:
            electric_target_sector = 'Res'
            electric_end_use = self.impact_profile.name[self.impact_profile.name.find('Res_')+4:]
        else:
            electric_target_sector = 'N/A'
            electric_end_use = self.impact_profile.name

        program_administrator = self.avoided_cost_outputs.program_administrator
        impact_profile = self.impact_profile.data.loc[:,'Impact'].values
        nox_rates = self.avoided_cost_outputs.generation_costs.loc[(program_administrator,'N/A','NOx'),'Value'].values
        pm10_rates = self.avoided_cost_outputs.generation_costs.loc[(program_administrator,'N/A','PM10'),'Value'].values
        coeff = 0.001

        self.emissions = pd.DataFrame({
            'ProgramAdministrator' : program_administrator,
            'ElectricEndUse' : [electric_end_use] * 2,
            'ElectricTargetSector' : [electric_target_sector] * 2,
            'DataType' : ['NOx','PM10'],
            'EmissionsRate' : [
                coeff * np.matrix.dot(nox_rates,impact_profile),
                coeff * np.matrix.dot(pm10_rates,impact_profile),
            ]
        })

    def save_avoided_cost_electric_table(self,filename):
        # formats the avoided cost table for the Cost Effectiveness Tool and
        # saves as a csv file to the specified filename.

        # split the dataframe into td, gen, and co2 tables:
        data_td = self.data.loc[
            (self.data.loc[:,'DataType']=='TD'),:
        ].rename(columns={
            'ProgramAdministrator':'PA',
            'ElectricEndUse':'EU',
            'ElectricTargetSector':'TS',
            'ClimateZone':'CZ',
            'YearQuarter':'Qtr',
            'Cost':'TD',
        }).drop(
            columns='DataType'
        ).set_index([
            'PA',
            'EU',
            'TS',
            'CZ',
            'Qtr',
        ])
        data_gen = self.data.loc[
            (self.data.loc[:,'DataType']=='Gen'),:
        ].rename(columns={
            'ProgramAdministrator':'PA',
            'ElectricEndUse':'EU',
            'ElectricTargetSector':'TS',
            'YearQuarter':'Qtr',
            'Cost':'Gen',
        }).drop(
            columns=['DataType','ClimateZone']
        ).set_index([
            'PA',
            'EU',
            'TS',
            'Qtr',
        ])
        data_co2 = self.data.loc[
            (self.data.loc[:,'DataType']=='CO2'),:
        ].rename(columns={
            'ProgramAdministrator':'PA',
            'ElectricEndUse':'EU',
            'ElectricTargetSector':'TS',
            'YearQuarter':'Qtr',
            'Cost':'CO2',
        }).drop(
            columns=['DataType','ClimateZone']
        ).set_index([
            'PA',
            'EU',
            'TS',
            'Qtr',
        ])

        # join tables horizontally on applicable indices:
        data_out = data_td.join(data_gen,on=[
            'PA','EU','TS','Qtr'
        ]).join(data_co2,on=[
            'PA','EU','TS','Qtr'
        ])

        # save output table to csv:
        data_out.to_csv(filename)

    def save_emissions_table(self,filename):
        # formats the emissions table for the Cost Effectiveness Tool and saves
        # as a csv file to the specified filename. 
        data_out = pd.DataFrame({
            'PA' : self.emissions.loc[0,'ProgramAdministrator'],
            'TS' : self.emissions.loc[0,'ElectricTargetSector'],
            'EU' : self.emissions.loc[0,'ElectricEndUse'],
            'NOx' : self.emissions.loc[(self.emissions.loc[
                :,'DataType']=='NOx'),'EmissionsRate'].values,
            'PM10' : self.emissions.loc[(self.emissions.loc[
                :,'DataType']=='PM10'),'EmissionsRate'].values
        }).set_index(['PA','TS','EU'])
        data_out.to_csv(filename)

    def save(self):
        # saves the avoided cost table to a filename defined by the table
        # attributes:
        filename_ace = 'AvoidedCostElecCO2SeqCO2_{}_{}.csv'.format(
            self.avoided_cost_outputs.program_administrator,
            self.impact_profile.name
        )
        filename_em = 'Emissions_{}_{}.csv'.format(
            self.avoided_cost_outputs.program_administrator,
            self.impact_profile.name
        )
        self.save_avoided_cost_electric_table(filename_ace)
        self.save_emissions_table(filename_em)

    def mp_mapper(self,chunk):
        # this is a helper function mapped across chunks of data and
        # distributed across multiple processor threads.

        # extract avoided cost  and impact profile data from data chunk:
        avoided_cost_chunk = chunk['avoided_cost_outputs_chunk']
        impact_profile_chunk = chunk['impact_profile_chunk']

        index = avoided_cost_chunk.index[0]
        first_timestamp = avoided_cost_chunk.loc[:,'Timestamp'].min()
        quarter = int(first_timestamp.month / 3)
        if 'NonRes' in self.impact_profile.name or 'Non_Res' in self.impact_profile.name:
            electric_target_sector = 'NonRes'
            electric_end_use = self.impact_profile.name[
                self.impact_profile.name.find('Res_')+4:]
        elif 'Res' in self.impact_profile.name:
            electric_target_sector = 'Res'
            electric_end_use = self.impact_profile.name[
                self.impact_profile.name.find('Res_')+4:]
        else:
            electric_target_sector = 'N/A'
            electric_end_use = self.impact_profile.name

        if index[2] == 'Gen':
            coeff = 0.001
        elif index[2] == 'TD':
            coeff = 0.001
        else:
            coeff = 1.000
        return pd.Series({
            'ProgramAdministrator' : index[0],
            'ElectricEndUse'       : electric_end_use,
            'ElectricTargetSector' : electric_target_sector,
            'ClimateZone'          : index[1],
            'DataType'             : index[2],
            'YearQuarter'          : '{}Q{}'.format(first_timestamp.year,quarter+1),
            'Cost'                 : coeff * np.matrix.dot(
                avoided_cost_chunk.loc[
                    :, 'Value'].values,impact_profile_chunk.loc[
                        :, 'Impact'].values),
        })

def quarter_bounds(start_date,end_date):
    # this is a helper function to generate a list of quarters; each quarter
    # is bounded by the first and last hour contained therein, and is indexed
    # relative to the input start_date as well as the quarter of the year
    # (Jan-Mar : 0, Apr-Jun : 1, Jul-Sep : 2, Oct-Dec : 3).

    # calculate the number of quarters within the input bounding dates:
    number_of_quarters = round((end_date - start_date).days / 365.25 * 4)

    # functions to calculate a quarter's first and last hours:
    quarter_start = lambda x : dt.datetime(start_date.year + int(x / 4), (x % 4) * 3 + 1, 1, 0, 0)
    quarter_end = lambda x : dt.datetime(start_date.year+int((x+1)/4),((x+1)%4)*3+1,1,0,0)+dt.timedelta(hours=-1)

    # return list of dictionaries with quarter index, quarter number (0-3), and bounding hours:
    return [{
        'index' : x,
        'quarter' : x%4,
        'start' : quarter_start(x),
        'end' : quarter_end(x)
    } for x in range(number_of_quarters)]
