#!/usr/bin/env python
# -*- coding: utf-8 -*-
import threading
import stm32bl
import Queue
from serial import SerialException
import typ708
import wx


wxFLASH_CONNECTION_STATE_EVENT = wx.NewEventType()
wxFLASH_ERASE_BLOCKS_EVENT = wx.NewEventType()
wxFLASH_WRITE_BLOCKS_EVENT = wx.NewEventType()
wxFLASH_VERIFY_BLOCKS_EVENT = wx.NewEventType()
wxFLASH_REPAIR_BLOCKS_EVENT = wx.NewEventType()
wxFLASH_VERIFY_FINISHED_EVENT = wx.NewEventType()
wxFLASH_REPAIR_FINISHED_EVENT = wx.NewEventType()
FLASH_BASE_ADDRESS = 0x08000000
FLASH_PAGE_SIZE = 1024


class FlashThreadCommands(object):
    def __call__(self):
        pass


class FlashConnectionStateEvent(wx.PyEvent):
    """Event informing the GUI about a new connection state"""
    def __init__(self, connected, message=None):
        wx.PyEvent.__init__(self)
        self.SetEventType(wxFLASH_CONNECTION_STATE_EVENT)
        self.connected = connected
        self.message = message

    def Clone(self):
        self.__class__(self.GetId())


class FlashEraseBlocksEvent(wx.PyEvent):
    """Event informing the GUI about erased flash blocks"""
    def __init__(self, nErased=0, nTotal=0, finished=False, message=None):
        wx.PyEvent.__init__(self)
        self.SetEventType(wxFLASH_ERASE_BLOCKS_EVENT)
        self.message = message
        self.nErased = nErased
        self.nTotal = nTotal
        self.finished = finished

    def Clone(self):
        self.__class__(self.GetId())


class FlashWriteBlocksEvent(wx.PyEvent):
    """Event informing the GUI about written flash blocks"""
    def __init__(self, nWritten=0, nTotal=0, finished=False, message=None):
        wx.PyEvent.__init__(self)
        self.SetEventType(wxFLASH_WRITE_BLOCKS_EVENT)
        self.message = message
        self.nWritten = nWritten
        self.nTotal = nTotal
        self.finished = finished

    def Clone(self):
        self.__class__(self.GetId())


class FlashRepairBlocksEvent(wx.PyEvent):
    """Event informing the GUI about rewritten flash blocks"""
    def __init__(self, nWritten=0, nTotal=0, finished=False, message=None):
        wx.PyEvent.__init__(self)
        self.SetEventType(wxFLASH_REPAIR_BLOCKS_EVENT)
        self.message = message
        self.nWritten = nWritten
        self.nTotal = nTotal
        self.finished = finished

    def Clone(self):
        self.__class__(self.GetId())


class FlashVerifyBlocksEvent(wx.PyEvent):
    """Event informing the GUI about verified flash blocks"""
    def __init__(self, nVerified=0, nTotal=0, finished=False, message=None,
                 ok=None):
        wx.PyEvent.__init__(self)
        self.SetEventType(wxFLASH_VERIFY_BLOCKS_EVENT)
        self.message = message
        self.nVerified = nVerified
        self.nTotal = nTotal
        self.finished = finished
        self.ok = ok

    def Clone(self):
        self.__class__(self.GetId())


class FlashVerifyFinishedEvent(wx.PyEvent):
    """Event informing the GUI about finished verification"""
    def __init__(self, message=None, ok=None, rewrite=False):
        wx.PyEvent.__init__(self)
        self.SetEventType(wxFLASH_VERIFY_FINISHED_EVENT)
        self.message = message
        self.ok = ok
        self.rewrite = rewrite

    def Clone(self):
        self.__class__(self.GetId())


class FlashRepairFinishedEvent(wx.PyEvent):
    """Event informing the GUI about finished repair"""
    def __init__(self, message=None):
        wx.PyEvent.__init__(self)
        self.SetEventType(wxFLASH_REPAIR_FINISHED_EVENT)
        self.message = message

    def Clone(self):
        self.__class__(self.GetId())


class FlashThreadConnectCommand(FlashThreadCommands):
    def __init__(self, connect=True):
        FlashThreadCommands.__init__(self)
        self.connect = connect

    def _connect(self, flashThread):
        message = "Connection to boot-loader established."
        flashThread.connection.connected = True
        try:
            flashThread.bl = \
                stm32bl.Stm32bl(flashThread.connection.port, 9600)
            flashThread.bl.register_progress_subscriber(
                flashThread.connection.flash_progress)
        except (SerialException, stm32bl.Stm32BLException), e:
            flashThread.connection.connected = False
            message = "Could not connect to the bootloader: " \
                + repr(e)
        event = FlashConnectionStateEvent(flashThread.connection.connected,
            message)
        flashThread.publish_event(event)

    def _disconnect(self, flashThread):
        message = "Connection closed."
        flashThread.connection.connected = False
        flashThread.bl.disconnect()
        event = FlashConnectionStateEvent(flashThread.connection.connected,
            message)
        flashThread.publish_event(event)

    def __call__(self, flashThread):
        if self.connect:
            self._connect(flashThread)
        else:
            self._disconnect(flashThread)


class FlashThreadEraseBlocksCommand(FlashThreadCommands):
    def __init__(self, pages):
        FlashThreadCommands.__init__(self)
        self.pages = pages

    def publish_event(self, flashThread, erased):
        message = u"Erasing %.0f%%..." % round(100.0 * erased / self.pages)
        finished = False
        if erased >= self.pages:
            message = "Erasing finished."
            finished = True
        event = FlashEraseBlocksEvent(erased, self.pages, finished, message)
        flashThread.publish_event(event)

    def __call__(self, flashThread):
        self.publish_event(flashThread, 0)
        for page in range(self.pages):
            try:
                flashThread.bl.erase_blocks([page])
                self.publish_event(flashThread, page + 1)
            except (SerialException, stm32bl.Stm32BLException), e:
                message = "Error: " + repr(e)
                event = FlashConnectionStateEvent(False, message)
                flashThread.publish_event(event)
                break


class FlashThreadRepairBlocksCommand(FlashThreadCommands):
    def __init__(self, data, pages):
        FlashThreadCommands.__init__(self)
        self.data = data
        self.pages = pages

    def __call__(self, flashThread):
        try:
            flashThread.bl.update_memory(
                FLASH_BASE_ADDRESS, self.data, self.pages, FLASH_PAGE_SIZE)
        except (SerialException, stm32bl.Stm32BLException), e:
            pass
        message = "Repairing firmware errors finished."
        event = FlashRepairFinishedEvent(message=message)
        flashThread.publish_event(event)


class FlashThreadWriteBlocksCommand(FlashThreadCommands):
    def __init__(self, data):
        FlashThreadCommands.__init__(self)
        self.data = data

    def __call__(self, flashThread):
        try:
            flashThread.bl.write_memory(FLASH_BASE_ADDRESS, self.data)
        except (SerialException, stm32bl.Stm32BLException), e:
            message = "Error: " + repr(e)
            event = FlashConnectionStateEvent(False, message)
            flashThread.publish_event(event)


class FlashThreadVerifyBlocksCommand(FlashThreadCommands):
    def __init__(self, data):
        FlashThreadCommands.__init__(self)
        self.data = data

    def determine_rewrite(self, mem):
        pages = []
        myMem = mem[:]
        myData = self.data[:]
        page = 0
        while len(myData) > 0:
            memPage = myMem[:FLASH_PAGE_SIZE]
            dataPage = myData[:FLASH_PAGE_SIZE]
            myMem = myMem[FLASH_PAGE_SIZE:]
            myData = myData[FLASH_PAGE_SIZE:]
            if memPage != dataPage:
                pages.append(page)
            page += 1
        print repr(pages), len(pages)
        if len(pages) == 0:
            return False
        elif len(pages) > (len(self.data) / FLASH_PAGE_SIZE) / 3:
            return True
        return pages


    def __call__(self, flashThread):
        try:
            mem = flashThread.bl.read_memory(FLASH_BASE_ADDRESS, len(self.data))
            ok = mem == self.data
            rewrite = False
            if ok:
                message = "Verification OK."
                flashThread.bl.cmd_go(FLASH_BASE_ADDRESS)
            else:
                message = "Verification ERROR."
                rewrite = self.determine_rewrite(mem)
            event = FlashVerifyFinishedEvent(message=message, ok=ok,
                                             rewrite=rewrite)
            flashThread.publish_event(event)
        except (SerialException, stm32bl.Stm32BLException), e:
            message = "Error: " + repr(e)
            event = FlashConnectionStateEvent(False, message)
            flashThread.publish_event(event)


class FlashThread(threading.Thread):
    def __init__(self, connection):
        threading.Thread.__init__(self)
        self.daemon = True
        self._endEvent = threading.Event()
        self.connection = connection
        self._commandQ = Queue.Queue()
        self.bl = None

    def publish_event(self, event):
        self.connection.publish_event(event)

    def enqueue_command(self, command):
        self._commandQ.put_nowait(command)

    def run(self):
        while not self._endEvent.isSet():
            try:
                cmd = self._commandQ.get(timeout=1)
            except Queue.Empty:
                continue
            try:
                cmd(self)
            except (SerialException) as e:
                print e

    def stop(self):
        self._endEvent.set()


class FlashConnection(object):
    def __init__(self, parent, port):
        self.parent = parent
        self.port = port
        self.flashThread = FlashThread(self)
        self.flashThread.start()
        self.connected = False
        self.connect()

    def publish_event(self, event):
        if self.parent is not None:
            wx.PostEvent(self.parent, event)

    def connect(self):
        cmd = FlashThreadConnectCommand(connect=True)
        self.flashThread.enqueue_command(cmd)

    def disconnect(self):
        cmd = FlashThreadConnectCommand(connect=False)
        self.flashThread.enqueue_command(cmd)

    def erase_blocks(self, pages):
        cmd = FlashThreadEraseBlocksCommand(pages)
        self.flashThread.enqueue_command(cmd)

    def write(self):
        data = bytearray(typ708.bin)
        cmd = FlashThreadWriteBlocksCommand(data)
        self.flashThread.enqueue_command(cmd)

    def rewrite(self, pages):
        cmd = FlashThreadRepairBlocksCommand(typ708.bin, pages)
        self.flashThread.enqueue_command(cmd)

    def verify(self):
        cmd = FlashThreadVerifyBlocksCommand(typ708.bin)
        self.flashThread.enqueue_command(cmd)

    def flash_progress(self, kind, done, total, finished):
        if kind == stm32bl.Stm32bl.PROGRESS_WRITE:
            message = u"Writing %.0f%%..." % round(100.0 * done / total)
            if finished:
                message = "Writing finished."
            event = FlashWriteBlocksEvent(done, total, finished, message)
            self.publish_event(event)
        elif kind == stm32bl.Stm32bl.PROGRESS_VERIFY:
            message = u"Verifying %.0f%%..." % round(100.0 * done / total)
            if finished:
                message = "Verifying finished."
            event = FlashVerifyBlocksEvent(done, total, finished, message)
            self.publish_event(event)
        elif kind == stm32bl.Stm32bl.PROGRESS_UPDATE:
            message = u"Repairing errors %.0f%%..." % round(100.0 * done / total)
            if finished:
                message = "Repairing errors finished."
            event = FlashRepairBlocksEvent(done, total, finished, message)
            self.publish_event(event)
