"""EPG Live-Query cache server"""
import asyncio, logging, time, typing, os
from atxstyle.standardlog import LoggerAdapter
from epgfetch import cache as epgfetch_cache, zap2it, dataserver
from shogunservice.shogunserver import async_log_failures
from shogunservice import cepraconverters

log = logging.getLogger(__name__)


class EPGLiveServer(object):
    """Live cache for epg data (run-time pulling and processing)"""

    SCHEDULE_DURATION = 12 * 3600

    def __init__(self, cache, timestamp=time.time):
        self.cache = cache
        self.cache_lock = asyncio.Lock()
        self.timestamp = timestamp

    @async_log_failures
    async def refresh_epg_data(
        self,
        ipg_config,
        log: typing.Union[logging.Logger, LoggerAdapter] = log,
    ):
        """Pull epg data from configured server for each station in ipg_config

        This can be thought of as:

            - given the IPGClient configuration
            - see if we have a fresh configuration for it, if so, apply that
            - if not, use the configured epgfetch to get any epg data updates
            - then apply the updated configuration
        """
        fetch = ipg_config.get('epgfetch')
        demo_jitter = bool(os.environ.get('DEMO_DATA_JITTER'))

        if fetch:
            expire_period, refresh_period = epgfetch_cache.durations_for_fetch(fetch)
            if demo_jitter:
                refresh_period = 60
            # short circuit all checking based on identity and time-published
            key = 'cache.epg.ipgclient.%s.%s' % (
                ipg_config['__key__'],
                ipg_config.get('ts', 0),
            )
            channel_mapping = ipg_config.get('tuning')
            if channel_mapping:
                channels = channel_mapping.get('channels', ())
                tmsids = []
                for channel in channels:
                    station = channel.get('station')
                    if station and not str(station[0]).startswith('_'):
                        tmsids.append(station[0])
                log.info("Calculating epg hash for: %s", sorted(tmsids))
                (
                    epg_config_hash,
                    missing,
                ) = await self.populate_epg_cache(fetch, tmsids)
                if missing:
                    log.warning(
                        'Some TMSIDs could not be cached with Data Source #%s: %s',
                        fetch.get('__pk__'),
                        ", ".join(missing),
                        extra={
                            'service_id': fetch.get('__pk__'),
                        },
                    )
                    if os.environ.get('DEMO_DATA_JITTER'):
                        refresh_period = 60
                    delay = min((2 * 3600, refresh_period))
                else:
                    delay = refresh_period
                log.info('Storing ipg config cache at %s for %ss', key, expire_period)

                await self.cache.set(key, epg_config_hash, expire=expire_period)
                return epg_config_hash, delay
            else:
                log.info('No channel_mapping in the ipg config')
                return None, 20
        else:
            log.info('No epgfetch in the ipg_config, sending null epgconfig-hash')
            epg_config_hash = None
            return epg_config_hash, 20

    @async_log_failures
    async def populate_epg_cache(
        self, fetch, tmsids, force=False
    ) -> typing.Tuple[bytes, list]:
        """Pull cache schedules, process them for the client, initiate update

        returns the resource hash of the configuration, i.e. the
        SetConfigRequest.EpgKey value to set, which then will trigger
        a dozen other requests for data...

        returns sha(bytes)
        """
        await self.cache_lock.acquire()
        # uniqueness for the tmsids so we're not sending N copies of the same schedule...
        tmsids = sorted(set(tmsids))
        try:
            key = 'cache.epg.config.%s' % (
                cepraconverters.hash_to_hex(
                    cepraconverters.do_sha(
                        fetch['url'] + '#' + ','.join(sorted([str(x) for x in tmsids]))
                    )
                )
            )
            refresh_required, existing = await epgfetch_cache.should_refresh(
                self.cache, key
            )
            log.info(
                "Should we refresh? %s Existing: %s",
                refresh_required,
                bool(existing),
            )
            missing = []
            if (refresh_required or force) or not existing:
                log.info('Caculating cache of epg data for %s', ", ".join(tmsids))
                hash, last, missing = await self.recalculate_schedules(fetch, tmsids)
                log.info(
                    'Last schedule is at %s',
                    cepraconverters.isoformat(last),
                )
                expire_period, refresh_period = epgfetch_cache.durations_for_fetch(
                    fetch
                )
                await self.cache.set(
                    key,
                    hash,
                    expire=expire_period,
                )
                await epgfetch_cache.mark_fresh(self.cache, key, refresh_period)
                existing = hash
            else:
                log.info('Already have tmsids %s for %s cached', tmsids, fetch['url'])
            return existing, missing
        finally:
            self.cache_lock.release()

    async def recalculate_schedules(self, fetch, tmsids):
        """Pull and process schedules into a config-set for use among all clients using these tmsids"""
        # epgfetch pull-into-standard format...
        log.info("Recalculate schedules for %s", tmsids)
        sources = await epgfetch_cache.cache_schedules(
            self.cache,
            fetch,
            tmsids,
            timestamp=self.timestamp,
        )
        # now process the results into cepra-specific formats...
        missing = [t for t in tmsids if t not in sources]
        expire, refresh = epgfetch_cache.durations_for_fetch(fetch)
        config, resources, initial, final = cepraconverters.generate_config(
            sources.values(),
            duration=max((refresh, 6 * 3600)),
            total_duration=expire,
            single_tranche=fetch.get('format') == 'zap2it',
        )
        serial = config.SerializeToString()
        hashed = cepraconverters.do_sha(serial)
        resources[hashed] = serial
        for hash, resource in resources.items():
            # Note: note using a hash-map because we want individual keys to time out...
            key = 'cache.epg.resources.%s' % (cepraconverters.hash_to_hex(hash),)
            await self.cache.set(
                key,
                resource,
                expire=min(
                    (max((3600, int(final - self.timestamp()))), self.SCHEDULE_DURATION)
                ),
            )

        return hashed, final, missing

    @classmethod
    async def clear_schedules(cls, cache):
        """Test cleanup code to remove cached schedules"""
        return await epgfetch_cache.clear_schedules(cache)
