import typing, json, logging, hashlib, time, os
from functools import wraps
from . import dataserver, zap2it

log = logging.getLogger(__name__)

DEMO_JITTER = bool(os.getenv('DEMO_DATA_JITTER'))


def async_log_failures(function):
    @wraps(function)
    async def with_log_on_failure(*args, **named):
        try:
            return await function(*args, **named)
        except Exception as err:
            log.exception('Unhandled failure in %s', function.__name__)
            raise

    return with_log_on_failure


def durations_for_fetch(fetch):
    """Get expire,refresh durations for given epgfetch configuration record"""
    # return 5 * 60, 60
    expire = fetch.get('cache_duration', 72 * 3600)
    refresh = fetch.get('refresh_duration', 12 * 3600)
    if DEMO_JITTER:
        refresh = 60
    return expire, refresh


async def mark_fresh(cache, key, duration):
    """Mark the given key as fresh for given duration"""
    sub_key = '%s:fresh' % (key,)
    await cache.set(sub_key, 'true', expire=duration)
    return key


async def is_fresh(cache, key):
    sub_key = '%s:fresh' % (key,)
    return bool(await cache.get(sub_key))


@async_log_failures
async def cache_schedules(
    cache, fetch, tmsids, timestamp=time.time
) -> typing.Dict[str, typing.Dict]:
    """Pull the given schedules from configured fetch json structure for tmsids

    cache -- redis pool from aioredis
    fetch -- used to provide the default durations based on the data-types and determine
             which data callback to use for fetching the data
    tmsids -- list of string tmsids to pull from the server


    return {'tmsid': EPGDataStruct }
    """
    if fetch['format'] == 'epgdata':
        retrieve = dataserver.cache_callback
    elif fetch['format'] == 'dsi':
        retrieve = dataserver.cache_callback
        log.warning("DSI fetch structure: ")
        final_fetch = fetch.copy()
        final_fetch['url'] = 'http://epguploads:8024/feeds/download/%d/' % (
            fetch['__pk__'],
        )
        fetch = final_fetch
    elif fetch['format'] == 'zap2it':
        retrieve = zap2it.cache_callback
    else:
        raise RuntimeError("%s epg data feed not yet supported" % (fetch['format']))
    expire, refresh = durations_for_fetch(fetch)
    return await cache_epgdata(
        fetch,
        tmsids,
        cache=cache,
        retrieve=retrieve,
        expire=expire,
        refresh=refresh,
        timestamp=timestamp,
    )


@async_log_failures
async def clear_schedules(cache, pattern='cache.epg.*'):
    """Test cleanup code to remove cached schedules"""
    for key in await cache.keys(pattern):
        log.info("Deleting: %s", key)
        await cache.delete(key)


async def should_refresh(cache, key):
    """Check cache and see if we are due to refresh

    returns (should_refresh, current) for the key
    """
    refresh = True
    existing = await cache.get(key)
    if existing:
        fresh = await is_fresh(cache, key)
        if fresh:
            return False, existing
        else:
            return True, existing
    return refresh, existing


def with_cache(function):
    """Context manager that checks for freshness, existence, etc and sets/marks-fresh if updated"""

    @wraps(function)
    async def wrapped(cache, fetch, key, *args, **named):
        refresh_duration, expire_duration = durations_for_fetch(fetch)
        refresh_required, existing = await should_refresh(cache, key)
        if refresh_required or not existing:
            try:
                existing = await function(cache, fetch, key, *args, **named)
                encoded = json.dumps(existing)
            except Exception as err:
                log.exception("Failed updating cache with %s", function)
                if existing:
                    existing = json.loads(existing)
            else:
                log.info("Refreshed key %s, storing for %s", key, expire_duration)
                await cache.set(key, encoded, expire=expire_duration)
                await mark_fresh(cache, key, refresh_duration)
        else:
            log.info("Loaded key %s from cache", key)
            existing = json.loads(existing)
        return existing

    return wrapped


async def cache_epgdata(
    fetch,
    tmsids: typing.List[str],
    cache,
    retrieve,
    expire=3600 * 24,
    refresh=3600 * 6,
    timestamp=time.time,
) -> typing.Dict[str, typing.Dict]:
    """Cache EPG data in redis

    fetch -- epgfetch datasource record describing how to fetch the data
    tmsids -- set of tmsids to cache
    cache -- aioredis.create_pool() or aioredis.create_connection() result
    retrieve -- callback function given (tmsid,fetch) that does a live refresh attempt
    expire -- duration for which to cache the resulting content
    refresh -- duration after which to make an attempt to refresh the content

    returns {'tmsid': {}} of epgdataserver style flat records
    """
    cached = {}
    hash = hashlib.md5(fetch['url'].encode('utf-8')).hexdigest()
    not_provided = []

    @with_cache
    async def pull_tmsid(cache, fetch, key, tmsid):
        return await retrieve(tmsid, fetch, cache, expire, timestamp=timestamp)

    for tmsid in tmsids:
        key = 'cache.epg.%s.sched.%s' % (hash, tmsid)
        content = await pull_tmsid(cache, fetch, key, tmsid)

        if content:
            cached[tmsid] = content
        else:
            not_provided.append(tmsid)
        # either we got new content, or we didn't but had a cache, or we don't set cached[tmsid]
    # if not_provided:
    #     log.warning(
    #         'Datasource %s does not provide: %s',
    #         fetch.get('__pk__'),
    #         ", ".join(not_provided),
    #         extra=dict(
    #             service_id=fetch.get('__pk__'),
    #         ),
    #     )
    return cached
