from typing import Iterator

import CellFrame
from CellFrame.Chain import ChainAtomPtr
from CellFrame.Network import Net
from CellFrame.Common import Datum, DatumTx, DatumToken, DatumEmission, DatumAnchor, DatumDecree
from DAP import Crypto
from DAP.Crypto import Cert, Sign
from CellFrame.Chain import Mempool, Wallet, Chain
from CellFrame.Consensus import DAG, Block
from CellFrame.Common import TxOut, TxIn, TxToken, TxSig, TxOutCondSubtypeSrvStakeLock, TxInCond, \
    TxOutExt
from DAP.Crypto import HashFast
from DAP.Core import logIt
from datetime import datetime
from CellFrame.Chain import ChainAddr
import hashlib
import sys
import json
from pycfhelpers.helpers import json_dump, find_tx_out, get_tx_items


class TSD:
    TYPE_UNKNOWN = 0x0000
    TYPE_TIMESTAMP = 0x0001
    TYPE_ADDRESS = 0x0002
    TYPE_VALUE = 0x0003
    TYPE_CONTRACT = 0x0004
    TYPE_NET_ID = 0x0005
    TYPE_BLOCK_NUM = 0x0006
    TYPE_TOKEN_SYM = 0x0007
    TYPE_OUTER_TX_HASH = 0x0008
    TYPE_SOURCE = 0x0009
    TYPE_SOURCE_SUBTYPE = 0x000A
    TYPE_DATA = 0x000B
    TYPE_SENDER = 0x000C
    TYPE_TOKEN_ADDRESS = 0x000D
    TYPE_SIGNATURS = 0x000E
    TYPE_UNIQUE_ID = 0x000F
    TYPE_BASE_TX_HASH = 0x0010
    TYPE_EMISSION_CENTER_UID = 0x0011
    TYPE_EMISSION_CENTER_VER = 0x0012


class CellframeEmission:
    def __init__(self, datum, event=None):
        self.datum = datum
        self.hash = str(self.datum.hash)
        self.event = event

        m = hashlib.sha256()
        addr = datum.getTSD(TSD.TYPE_ADDRESS)
        btx = datum.getTSD(TSD.TYPE_BASE_TX_HASH)
        otx = datum.getTSD(TSD.TYPE_OUTER_TX_HASH)
        src = datum.getTSD(TSD.TYPE_SOURCE)
        stp = datum.getTSD(TSD.TYPE_SOURCE_SUBTYPE)
        ts = datum.getTSD(TSD.TYPE_TIMESTAMP)
        data = datum.getTSD(TSD.TYPE_DATA)
        uid = datum.getTSD(TSD.TYPE_UNIQUE_ID)

        m.update(str(addr).encode("utf-8"))
        m.update(str(btx).encode("utf-8"))
        m.update(str(otx).encode("utf-8"))
        m.update(str(src).encode("utf-8"))
        m.update(str(stp).encode("utf-8"))
        m.update(str(data).encode("utf-8"))
        m.update(str(ts).encode("utf-8"))
        m.update(str(uid).encode("utf-8"))
        m.update(str(data).encode("utf-8"))

        self.uid = m.hexdigest()

    def getTSD(self, type):
        tsd = self.datum.getTSD(type)
        if tsd:
            try:
                return tsd.decode("utf-8")
            except:
                pass
        return None

    def setTSD(self, type, data):
        self.datum.addTSD(type, data)


# if datum not base-tx - exception
class CellframeBaseTransactionDatum:
    def __init__(self, datum, net, block=None):
        self.block = block
        self.datum = datum
        self.hash = str(datum.hash)
        self.created = datum.dateCreated
        self.net = net
        # base tx : has txToken item
        if not self.tx_token():
            raise RuntimeError("Datum {} not base tx".format(self.datum))

        self.to_address = str(self.tx_out().addr)
        self.amount = self.tx_out().value

        self.emission_hash = str(self.tx_token().tokenEmissionHash)

        self.emission_ = None

    def tx_out(self):
        try:
            return next(filter(lambda x: isinstance(x, (TxOut,)), self.datum.getItems()))
        except:
            return None

    def tx_in(self):
        try:
            return next(filter(lambda x: isinstance(x, (TxIn,)), self.datum.getItems()))
        except:
            return None

    def tx_token(self):
        try:
            return next(filter(lambda x: isinstance(x, (TxToken,)), self.datum.getItems()))
        except:
            return None

    def tx_sig(self):
        try:
            return next(filter(lambda x: isinstance(x, (TxSig,)), self.datum.getItems()))
        except:
            return None

    def emission(self):
        if not self.emission_:
            #tiker = str(self.tx_token().ticker)
            hf = HashFast.fromString(str(self.tx_token().tokenEmissionHash))
            ledger = self.net.getLedger()
            ems = ledger.tokenEmissionFind(tiker, hf)
            if not ems:
                return None
            self.emission_ = CellframeEmission(ems)

        return self.emission_


class CellframeNetwork:
    main: Chain
    zerochain: Chain

    def __init__(self, name, chains, group_alias=None, commision_wallet=None):

        self.name = name
        self.net = Net.byName(name)
        self.group_alias = group_alias or name
        self.commision_wallet = commision_wallet

        if not self.net:
            raise RuntimeError("No such net: {}".format(name))

        for chain in chains:
            setattr(self, chain, self.net.getChainByName(chain))

    @staticmethod
    def wallet_from_signature(sigbytes):
        sign = Sign.fromBytes(sigbytes)
        return sign.getAddr()

    @staticmethod
    def netid_from_wallet(wallet):
        return ChainAddr.fromStr(str(wallet)).getNetId().long()

    def tx_sender_wallet(self, tx):
        sigitem = get_tx_items(tx, TxSig)
        if not sigitem:
            return None

        # first signature is a sender signature
        return sigitem[0].sign.getAddr(self.net)

    def ledger_tx_by_hash(self, txh):
        hf = HashFast.fromString(txh)

        # only ledger-accepted
        ledger = self.net.getLedger()
        tx = ledger.txFindByHash(hf)

        return tx

    def netid(self):
        return self.net.id.long()

    def set_mempool_notification_callback(self, chain, callback):
        callback_name = "{}".format(self.name)
        logIt.notice("New mempool notifier for {}".format(callback_name))

        def callback_wraper(op_code, group, key, value, net_name):
            # op_code a | d
            # group - table name
            # key - hash записи
            #
            callback(op_code, group, key, value, net_name, self, chain)

        chain.addMempoolNotify(callback_wraper, callback_name)

    def set_gdbsync_notification_callback(self, callback):
        # Любая таблица в GDB
        callback_name = "{}".format(self.name)

        logIt.notice("New gdb notifier for {}".format(callback_name))

        def callback_wraper(op_code, group, key, value, net_name):
            callback(self, op_code, group, key, value, net_name)

        self.net.addNotify(callback_wraper, self.name)

    def set_atom_notification_callback(self, chain, callback):
        # New atom Block | Event
        #
        callback_name = "{}".format(self.name)
        logIt.notice("New atom notifier for {}".format(callback_name))

        def callback_wraper(atom, size, callback_name):
            callback(atom, size, callback_name, self, chain)

        chain.addAtomNotify(callback_wraper, callback_name)

    def set_ledger_tx_notification_callback(self, callback):
        ledger = self.net.getLedger()

        def callback_wrapper(ledger, tx, argv):
            callback(ledger, tx, argv, self)

        ledger.txAddNotify(callback_wrapper, self.net)

    def set_ledger_bridge_tx_notification_callback(self, callback):
        ledger = self.net.getLedger()

        def callback_wrapper(ledger, tx, argv):
            callback(ledger, tx, argv, self)

        logIt.notice("New bridgedTxNotify for {}".format(self.net))
        ledger.bridgedTxNotifyAdd(callback_wrapper, self.net)

    def load_cert(certname):
        return Crypto.Cert.load(certname)

    def extract_emission_from_mempool_nofitication(self, chain, value):
        ems = Mempool.emissionExtract(chain, value)
        if ems:
            return CellframeEmission(ems)
        else:
            return None

    def create_base_transaction(self, emission, certs, fee, native_tw=None):

        if native_tw:
            w = Wallet.openFile(native_tw)
            return Mempool.baseTxCreate(self.main, emission.datum.hash, self.zerochain, emission.datum.value,
                                        emission.datum.ticker,
                                        emission.datum.addr, fee, w)
        else:
            return Mempool.baseTxCreate(self.main, emission.datum.hash, self.zerochain, emission.datum.value,
                                        emission.datum.ticker,
                                        emission.datum.addr, fee, certs)

    def get_emission_by_tsd(self, tsd_dict):

        atom_count = self.zerochain.countAtom()
        atoms = self.zerochain.getAtoms(atom_count, 1, True)

        emissions = {}

        for atom in atoms:

            # event = DAG.fromAtom(atom[0], atom[1])
            event = DAG.fromAtom(atom)
            if not event.datum.isDatumTokenEmission():
                continue

            token_emission = event.datum.getDatumTokenEmission()
            if not token_emission:
                continue

            results = []
            for key, value in tsd_dict.items():

                tsd = token_emission.getTSD(key)

                if not tsd and value == None:
                    results.append(True)
                    continue
                try:
                    if tsd and tsd.decode("utf-8") == value:
                        results.append(True)
                        continue
                except:
                    pass

                results.append(False)

            if all(results):
                emissions[str(event.datum.hash)] = CellframeEmission(token_emission, event)

        return emissions

    def get_emission_from_mempool_by_tsd(self, tsd_dict):

        datums = Mempool.list(self.net, self.zerochain).values()

        emissions = {}

        for datum in datums:

            if not datum.isDatumTokenEmission():
                continue

            token_emission = datum.getDatumTokenEmission()
            if not token_emission:
                continue

            results = []
            for key, value in tsd_dict.items():

                tsd = token_emission.getTSD(key)

                if not tsd and value == None:
                    results.append(True)
                    continue
                try:
                    if tsd and tsd.decode("utf-8") == value:
                        results.append(True)
                        continue
                except:
                    pass

                results.append(False)

            if all(results):
                emissions[str(datum.hash)] = CellframeEmission(token_emission)

        return emissions

    def base_transactions_from_blocks(self, emission_hash=None):

        iterator = self.main.createAtomItem(False)

        ptr = self.main.atomIterGetFirst(iterator)

        if not ptr:
            logIt.error("Can't iterate over blocks in {}!".format(self.name))
            return []

        aptr, size = ptr
        # iterate over blocks: atom-pointer should not be none, and size shoud be >0
        while aptr:

            if size <= 0:  # skip such blocks
                aptr, size = self.main.atomIterGetNext(iter)
                continue

            block = Block.fromAtom(aptr, size)

            if not block.datums:
                aptr, size = self.main.atomIterGetNext(iterator)
                continue

            for datum in block.datums:
                if datum.isDatumTX():
                    try:
                        # if emshash provided - filter items
                        basedatum = CellframeBaseTransactionDatum(datum.getDatumTX(), block)
                        if emission_hash:
                            if basedatum.emission_hash == emission_hash:
                                yield basedatum
                            else:
                                continue
                        else:
                            yield basedatum

                    except:
                        continue

            aptr, size = self.main.atomIterGetNext(iterator)

    def get_transactions_to_wallet_from_blocks(self, address):

        def nextBlockDatums():
            iterator = self.createAtomItem(False)
            aptr, size = self.main.atomIterGetFirst(iterator)
            # iterate over blocks: atom-pointer should not be none, and size shoud be >0
            while aptr:

                if size <= 0:  # skip such blocks
                    aptr, size = self.main.atomIterGetNext(iter)
                    continue

                block = Block.fromAtom(aptr, size)

                if block.datums:
                    yield block, block.datums

                aptr, size = self.main.atomIterGetNext(iterator)

        def isDatumToAddress(datum_with_block):
            try:
                txn_out = next(filter(lambda x: isinstance(x, (TxOut,)), datum_with_block.datum.getItems()))
                return str(txn_out.addr) == address
            except Exception as e:
                return False

        transactions_to_wallet = []

        class DatumWithBlock:
            def __init__(self, datum, block):
                self.datum = datum
                self.block = block

        for block, datums in nextBlockDatums():
            tx_datums = [DatumWithBlock(datum.getDatumTX(), block) for datum in
                         filter(lambda datum: datum.isDatumTX(), datums)]
            transactions_to_wallet.extend(list(filter(isDatumToAddress, tx_datums)))

        return transactions_to_wallet

    def create_emission(self, wallet, token_symbol, value, tsd):

        addr = CellFrame.Chain.ChainAddr.fromStr(wallet)
        ems = DatumEmission(str(value), token_symbol, addr)

        for key, value in tsd.items():
            if isinstance(value, dict):
                ems.addTSD(key, json_dump(value).encode("utf-8"))
            elif isinstance(value, list):
                ems.addTSD(key, json_dump(value).encode("utf-8"))
            else:
                ems.addTSD(key, str(value).encode("utf-8"))

        return CellframeEmission(ems)

    def place_emission(self, ems, chain):
        return Mempool.emissionPlace(chain, ems.datum)

    def place_datum(self, datum, chain):
        return Mempool.addDatum(chain, datum)

    def remove_key_from_mempool(self, key, chain):
        Mempool.remove(chain, key)

    def mempool_list(self, chain):
        return Mempool.list(self.net, chain)

    def mempool_get_emission(self, key):
        return Mempool.emissionGet(self.zerochain, key)

    def mempool_proc(self, hash, chain):
        Mempool.proc(hash, chain)

    def all_tx_from_ledger(self):
        res = []
        legder = self.net.getLedger()
        count = legder.count()

        txs = legder.getTransactions(count, 1, False)

        if not txs:
            return [], legder

        return txs, legder

    @staticmethod
    def get_datums_from_atom(chain: Chain, atom: ChainAtomPtr) -> list[Datum]:
        # logIt.message(f"{chain.getCSName()=}")
        if chain.getCSName() == "esbocs":
            block = Block.fromAtom(atom)
            return block.datums

        if chain.getCSName() == "dag_poa":
            event = DAG.fromAtom(atom)
            return [event.datum]

    def get_datums_from_chains(self, chains: tuple[str] = ("main", "zerochain")) -> Iterator[Datum]:
        logIt.warning("get_datums_from_chains()")
        for chain_name in chains:
            chain: Chain = getattr(self, chain_name)

            iterator = chain.createAtomIter(False)
            ptr = chain.atomIterGetFirst(iterator)

            if not ptr:
                logIt.message("not ptr")
                return []

            atom, size = ptr
            logIt.message("...")
            while atom:
                if size <= 0:
                    atom, size = chain.atomIterGetNext(iterator)
                    continue
                datums = self.get_datums_from_atom(chain, atom)

                if not datums:
                    atom, size = chain.atomIterGetNext(iterator)
                    continue

                for datum in datums:
                    yield datum

                atom, size = chain.atomIterGetNext(iterator)

    def all_tx_from_blocks(self) -> Iterator[DatumTx]:
        logIt.warning("all_tx_from_blocks()")
        for datum in self.get_datums_from_chains(chains=("main",)):
            if datum.isDatumTX():
                yield datum.getDatumTX()
            else:
                continue