#!/usr/bin/env python3
from contextlib import contextmanager
import sys
import re


def main():
    with io() as (input, output):
        Processor(input, output).process_lines()


@contextmanager
def io():
    argc = len(sys.argv)
    input = sys.stdin
    close_input = False
    close_output = False
    if argc > 1:
        input = open(sys.argv[1])
        close_input = True
    output = sys.stdout
    if argc > 2:
        output = open(sys.argv[2], 'w')
        close_output = True
    yield input, output
    if close_input:
        input.close()
    if close_output:
        output.close()


class Processor:

    def __init__(self, input, output):
        self.input = input
        self.output = output
        self.processor_class = StartLineProcessor
        self.sec_l = None
        self.sec_p = None
        self.partition_number = None

    def process_lines(self):
        for line in self.input:
            self.line_processor(line)()
        PartitionEmitter.dispatch(
            PartitionLineProcessor(self, ''),
            name='sgpt',
            number=self.partition_number + 1,
            start_sector=8388480,
            end_sector=8388543
        )()

    def line_processor(self, line):
        return self.processor_class(self, line)


class LineProcessor:

    def __init__(self, parent, line):
        self.parent = parent
        self.line = line
        self.input = self.parent.input
        self.output = self.parent.output

    def __call__(self):
        self.match()

    def print(self, obj):
        print(obj, file=self.output)


class StartLineProcessor(LineProcessor):
    def match(self):
        sector_re = re.compile(
            r'^Sector size \(logical/physical\): (\d+)/(\d+) bytes$')
        sec_match = sector_re.match(self.line)
        if sec_match:
            self.parent.sec_l = int(sec_match.group(1))
            self.parent.sec_p = int(sec_match.group(2))
            if self.parent.sec_l != self.parent.sec_p:
                raise Exception(
                    "logical/physical sector size "
                    f"({self.parent.sec_l}/{self.parent.sec_p}) "
                    "are not equal. Don't know which to pick! "
                    "Aborting!")
            self.parent.processor_class = HeaderLineProcessor


FIXED_HEADER = '''\
##### Fixed Header, MAYBE NOT CORRECT for all Planet Computers Cosmo Communicator devices ####
##### CERTAINLY NOT CORRECT FOR A DIFFERENT DEVICE, USE YOUR BRAIN! ####
#### USE AT YOUR OWN RISK ####
#########################################__gdisk2scatter.py__###################################################
#
#  General Setting
#
#########################################__gdisk2scatter.py__###################################################
- general: MTK_PLATFORM_CFG
  info:
    - config_version: V1.1.2
      platform: MT6771
      project: k71v1_64_bsp
      storage: EMMC
      boot_channel: MSDC_0
      block_size: 0x20000
############################################################################################################
#
#  Layout Setting
#
############################################################################################################
##### END OF FIXED HEADER ####\
'''


class HeaderLineProcessor(LineProcessor):
    def match(self):
        number_re = re.compile(
            r'^Number +Start \(sector\) +End \(sector\) +Size +Code +Name'
        )
        number_match = number_re.match(self.line)
        if number_match:
            self.emit_header()
            PartitionEmitter.dispatch(
                self,
                name='preloader',
                number=0,
                start_sector=0,
                end_sector=511
            )()
            PartitionEmitter.dispatch(
                self,
                name='pgpt',
                number=1,
                start_sector=0,
                end_sector=63
            )()
            self.parent.processor_class = PartitionLineProcessor

    def emit_header(self):
        self.print(FIXED_HEADER)


class PartitionLineProcessor(LineProcessor):
    def match(self):

        part_re = re.compile(
            r'^ +(?P<number>\d+) +(?P<start>\d+) +(?P<end>\d+) '
            r'+([.0-9]+ \w+) +(\d{4}) '
            r'+(?P<name>[A-z_0-9]+)$'
        )

        part_match = part_re.match(self.line)
        if part_match:
            self.emit_partition(part_match)

    def emit_partition(self, part_match):
        number = int(part_match.group('number')) + 1
        self.parent.partition_number = number
        PartitionEmitter.dispatch(
            self,
            name=part_match.group('name'),
            number=number,
            start_sector=int(part_match.group('start')),
            end_sector=int(part_match.group('end')))()


FILE_NAMES = {
    'preloader': 'preloader_k71v1_64_bsp.bin',
    'md1img': 'md1rom.img',
    'scp1': 'tinysys-scp1.bin',
    'scp2': 'tinysys-scp2.bin',
    'sspm_1': 'sspm.img',
    'lk': 'lk.bin',
    'lk2': 'lk2.bin',
    'logo': 'logo.bin',
    'tee1': 'trustzone1.bin',
    'tee2': 'trustzone2.bin'
}


class PartitionEmitter:
    def __init__(self, parent, name, number, start_sector, end_sector):
        self.parent = parent
        self.part_no = number
        self.part_name = name
        self.start_sector = start_sector
        self.end_sector = end_sector

    def __call__(self):
        self.emit_start()
        for k in self.keys:
            self.emit_key(k)
        self.emit_end()

    keys = (
        'partition_index',
        'partition_name',
        'file_name',
        'is_download',
        'type',
        'linear_start_addr',
        'physical_start_addr',
        'partition_size',
        'region',
        'storage',
        'boundary_check',
        'is_reserved',
        'operation_type',
        'is_upgradable',
        'empty_boot_needed',
        'reserve'
    )

    def emit_key(self, k):
        func = getattr(self, f'emit_{k}', None)
        if func is not None:
            func()
        else:
            self.print(f'  {k}: {self.defaults[k]}')

    @property
    def defaults(self):
        return {
            'is_download': 'true',
            'type': 'NORMAL_ROM',
            'region': 'EMMC_USER',
            'storage': 'HW_STORAGE_EMMC',
            'boundary_check': 'true',
            'is_reserved': 'false',
            'operation_type': 'UPDATE',
            'is_upgradable': 'false',
            'empty_boot_needed': 'false',
            'reserve': '0x00'
        }

    @classmethod
    def dispatch(cls, parent, name, *args, **kwargs):
        return cls.make_class(name)(parent, name, *args, **kwargs)

    @classmethod
    def make_class(cls, part_name):
        mixins = (
            ({'pgpt', 'boot_para', 'para', 'expdb', 'frp', 'nvdata',
              'metadata', 'seccfg', 'sec1', 'gz1', 'gz2'},
             InvisibleMixin),
            ({'preloader', 'recovery', 'md1img', 'md1dsp', 'spmfw',
              'scp1', 'scp2', 'sspm_1', 'sspm_2',
              'cam_vpu1', 'cam_vpu2', 'cam_vpu3',
              'lk', 'lk2', 'boot', 'dtbo', 'tee1', 'tee2', 'vendor', 'system'},
             UpgradableMixin),
            ({'nvcfg', 'protect1', 'protect2', 'persist', 'proinfo'},
             ProtectedMixin),
            ({'nvram'},
             BinRegionMixin),
            ({'lk', 'logo', 'tee1'},
             EmptyBootMixin),
            ({'otp', 'flashinfo', 'sgpt'},
             ReservedMixin),
            ({'preloader'},
              PreloaderMixin),
            (set(FILE_NAMES),
             FileNameMixin)
        )
        base_classes = []
        for types, mixin in mixins:
            if part_name in types:
                base_classes.append(mixin)
        base_classes.append(cls)
        return type(f'{part_name.capitalize()}PartitionEmitter',
                    tuple(base_classes), {})

    @staticmethod
    def part_name(part_match):
        return part_match.group("name")

    def print(self, obj):
        self.parent.print(obj)

    def emit_start(self):
        self.print('')
        print('-', file=self.parent.output, end=' ')

    def emit_partition_index(self):
        self.print(f'partition_index: SYS{self.part_no}')

    def emit_partition_name(self):
        self.print(f'  partition_name: {self.part_name}')

    def emit_file_name(self):
        self.print(f'  file_name: {self.part_name}.img')

    def emit_linear_start_addr(self):
        self.print('  linear_start_addr: '
                   f'{self.to_start_addr(self.start_sector)}')

    def to_start_addr(self, sector):
        return f'0x{self.sector_byte(sector):X}'

    def sector_byte(self, sector):
        return sector * self.parent.parent.sec_l

    def emit_physical_start_addr(self):
        self.print('  physical_start_addr: '
                   f'{self.to_start_addr(self.start_sector)}')

    def emit_partition_size(self):
        start = self.sector_byte(self.start_sector)
        end = self.sector_byte(self.end_sector + 1)
        length = end - start
        self.print('  partition_size: '
                   f'0x{length:X}')

    def emit_end(self):
        pass


class PreloaderMixin:
    @property
    def defaults(self):
        return dict(super().defaults,
                    type='SV5_BL_BIN',
                    region='EMMC_BOOT_1',
                    operation_type='BOOTLOADERS')


class InvisibleMixin:
    @property
    def defaults(self):
        return dict(super().defaults, operation_type='INVISIBLE')


class UpgradableMixin:
    @property
    def defaults(self):
        return dict(super().defaults, is_upgradable='true')


class ProtectedMixin:
    @property
    def defaults(self):
        return dict(super().defaults, operation_type='PROTECTED')


class BinRegionMixin:
    @property
    def defaults(self):
        return dict(super().defaults, operation_type='BINREGION')


class EmptyBootMixin:
    @property
    def defaults(self):
        return dict(super().defaults, empty_boot_needed='true')


class ReservedMixin:
    @property
    def defaults(self):
        return dict(super().defaults,
                    boundary_check='false',
                    is_reserved='true',
                    operation_type='RESERVED')

class FileNameMixin:

    def emit_file_name(self):
        self.print(f'  file_name: {FILE_NAMES[self.part_name]}')


if __name__ == "__main__":
    # execute only if run as a script
    main()
