# Create your views here.
import logging, time
from collections import deque
from atxstyle import control

log = logging.getLogger(__name__)

STAT_SUMMARY_FILE = '/dev/shm/status/system-stats-brief.json'
STAT_FULL_FILE = '/dev/shm/status/system-stats.json'


def warn_on_load(accumulator_instance):
    """Produce a warning if we see high load over hourly period"""


#    for period in accumulator_instance.accumulator.periods:
#        if period > 1800:
#            current_avg = accumulator_instance.get(


class Accumulator(object):
    """Defines a thing that collects status for a particular key

    key -- tuple(str,...) key mapping items in original records to final
           data-values
    periods -- set of periods over which to summarize data (tuple of float seconds)
    count -- number of data-points to retain
    deltas -- bool store delta-to-last-value rather than raw value...
    rate -- bool store delta-over-time rather than raw value...
    """

    def __init__(
        self,
        key,
        periods=(60, 60 * 60),
        count=100,
        deltas=False,
        rate=False,
        thresholds=None,
    ):
        self.key = key
        self.periods = periods
        self.count = count
        self.deltas = deltas
        self.rate = rate
        self.instances = {}  # key: AccumulatorInstance
        self.thresholds = thresholds or []

    def __call__(self, key):
        if key not in self.instances:
            self.instances[key] = AccumulatorInstance(key, self)
        return self.instances[key]

    def __json__(self):
        return {
            'periods': self.periods,
            'rate': self.rate,
            'delta': self.deltas,
            'key': self.key,
        }


class AccumulatorInstance(object):
    def __init__(self, key, accumulator):
        self.key = key
        self.accumulator = accumulator
        self.last = None
        self.last_ts = None
        # period: [(start,(min,max,total,count)),...]
        self.historic = dict(
            [
                (period, deque(maxlen=self.accumulator.count))
                for period in self.accumulator.periods
            ]
        )

    def append(self, value, ts):
        if self.accumulator.deltas:
            if self.last is None:
                self.last = value
                self.last_ts = ts
                return
            value, self.last = value - self.last, value
        self.current = value
        for period in self.accumulator.periods:
            self.add_to_period(period, value, ts)
        for threshold in self.accumulator.thresholds:
            try:
                threshold(self)
            except Exception:
                log.exception("Error on threshold %s", threshold)

    def add_to_period(self, period, value, ts):
        """Add a particular (processed) value to the given period"""
        history = self.history_for(period)
        if not history:
            history.append((ts, [0, 0, 0, 0]))
        else:
            while ts - history[-1][0] > period:
                history.append((history[-1][0] + period, [0, 0, 0, 0]))
        start, cur = history[-1]
        assert isinstance(cur, list), cur
        assert isinstance(start, float), cur
        assert len(cur) == 4, cur
        # log.info('%ss %s %s',period,self.key,cur)
        if cur[-1]:  # count of adds...
            cur[:] = [
                min((cur[0], value)),
                max((cur[1], value)),
                cur[2] + value,
                cur[3] + 1,
            ]
        else:
            cur[:] = [value, value, value, 1]
        # log.info('    %s %s',self.key,cur)

    def history_for(self, period):
        """Get our current record for the given period"""
        return self.historic.get(period)

    def friendly_records(self, records, brief=False):
        result = []
        for period, history in records:
            subset = list(history)
            if brief:
                subset = subset[-2:]
            period_record = {
                'period': period,
                'values': [
                    {
                        'start': ts,
                        'min': r[0],
                        'max': r[1],
                        'total': r[2],
                        'count': r[3],
                    }
                    for (ts, r) in subset
                ],
            }
            if period_record['values']:
                last = period_record['values'][-1]
                fraction = (time.time() - last['start']) / float(period)
                if fraction < 1:
                    last['fractional_period'] = fraction
            result.append(period_record)
        return result

    def __json__(self):
        base = self.accumulator.__json__()
        base.update(
            {
                'key': self.key,
                'historic': self.friendly_records(sorted(self.historic.items())),
            }
        )
        return base

    def brief_json(self):
        base = self.accumulator.__json__()
        base.update(
            {
                'key': self.key,
                'historic': self.friendly_records(
                    sorted(self.historic.items()), brief=True
                ),
            }
        )
        return base


class Threshold(object):
    def __init__(
        self,
        ok_range,
        trigger,
        period=3600,  # measurement per
        min_items=2,
    ):
        """Create a threshold calling trigger if outside ok_range

        ok_range -- (min,max) if either min or max are None, they are not checked
        trigger -- function to call if the range is exceeded
        period -- data-collection period whose data-values we check
        min_items -- minimum number of items before we will check the value
        """
        self.ok_range = ok_range
        self.trigger = trigger
        self.period = period
        self.min_items = min_items

    def __call__(self, instance):
        """Check if the instance is out-of-spec"""
        accumulator = instance.accumulator
        if not self.period in accumulator.periods:
            log.warning(
                "No %s period defined, so threshold %s will never run",
                self.period,
                self.trigger,
            )
            return None
        values = instance.history_for(self.period)
        if not values:
            return None
        ts, (low, high, total, count) = values[-1]
        if (not count) or count < self.min_items:
            # log.debug("Not yet enough data-points to test %s trigger (%s/%s)",self.trigger.__name__,count,self.min_items)
            return None
        value = total / float(count)
        min_test, max_test = self.ok_range
        if (min_test is not None) and (value < min_test):
            log.debug("Value %r lower than threshold %r", value, min_test)
            return self.trigger(value, threshold=self, instance=instance)
        elif (max_test is not None) and (value > max_test):
            log.debug("Value %r higher than threshold %r", value, max_test)
            return self.trigger(value, threshold=self, instance=instance)
        return None


class SummaryStats(object):
    """In-memory RRD style accumulation of statistics"""

    # Log missing keys a maximum of once per key...
    logged_missing = set()

    def __init__(
        self,
        accumulators,
    ):
        self.accumulators = dict([(a.key, a) for a in accumulators])
        self.accumulator_instances = {}
        self.historic = []
        self.current = None

    def add(self, record, ts):
        self.current = record
        flat = self.flatten(record)
        for key, value in flat:
            accumulator = self.accumulator_instances.get(key)
            if accumulator is None:
                definition = self.accumulators.get(key[0])
                if definition:
                    self.accumulator_instances[key] = accumulator = definition(key)
                else:
                    if key not in self.logged_missing:
                        log.debug("No accumulator for: %s", key)
                        self.logged_missing.add(key)
                    continue
            accumulator.append(value, ts)

    @classmethod
    def flatten(cls, record, parent_key=None):
        if isinstance(record, dict):
            sub_elements = sorted(record.items())
        elif hasattr(record, '_fields'):
            sub_elements = sorted(zip(record._fields, record))
        else:
            return
        for (key, value) in sub_elements:
            final_key = parent_key + (key,) if parent_key else (key,)
            if isinstance(value, (int, float)):
                yield final_key, value
            elif isinstance(value, (list, tuple)):
                for i, item in enumerate(value):
                    yield final_key + (i,), item
            else:
                for item in cls.flatten(value, final_key):
                    yield item

    def brief_json(self):
        records = {}
        for key, accumulator in sorted(self.accumulator_instances.items()):
            record = records
            for i, fragment in enumerate(key[:-1]):
                try:
                    child = record[fragment]
                except (IndexError, KeyError):
                    if len(key) >= i:
                        if isinstance(key[i + 1], int):
                            record[fragment] = child = []
                        else:
                            record[fragment] = child = {}
                    else:
                        child = record[fragment] = {}
                record = child
            if isinstance(record, list):
                record.append(accumulator.brief_json())
            else:
                record[key[-1]] = accumulator.brief_json()
        return records

    def __json__(self):
        records = {}
        for key, accumulator in sorted(self.accumulator_instances.items()):
            record = records
            for i, fragment in enumerate(key[:-1]):
                try:
                    child = record[fragment]
                except (IndexError, KeyError):
                    if len(key) >= i:
                        if isinstance(key[i + 1], int):
                            record[fragment] = child = []
                        else:
                            record[fragment] = child = {}
                    else:
                        child = record[fragment] = {}
                record = child
            if isinstance(record, list):
                record.append(accumulator.__json__())
            else:
                record[key[-1]] = accumulator.__json__()
        return records


@control.with_frequency_filter('high-cpu-percentage', 20 * 60.0)
def log_load_high(value, threshold, instance):
    if value:
        log.error('Average CPU Percentage at %0.1f%%', value)


@control.with_frequency_filter('high-cpu-percentage', 20 * 60.0)
def log_memory_high(value, threshold, instance):
    if value:
        log.error('Average Memory Percentage at %0.1f%%', value)


DEFAULT_ACCUMULATORS = [
    Accumulator(
        key='bandwidth',
        periods=(10, 60, 3600),
        count=20,
        deltas=True,
        rate=True,
    ),
    Accumulator(
        key='load',
        periods=(60, 3600),
        count=20,
        thresholds=[
            Threshold(
                ok_range=(None, 90),
                period=3600,
                min_items=100,
                trigger=log_load_high,
            ),
        ],
    ),
    Accumulator(
        key='memory',
        periods=(60, 3600),
        count=20,
        thresholds=[
            Threshold(
                ok_range=(None, 90),
                period=3600,
                min_items=100,
                trigger=log_memory_high,
            ),
        ],
    ),
]


def accumulator_iterator(function, accumulators=DEFAULT_ACCUMULATORS):
    summary = SummaryStats(accumulators)
    while True:
        try:
            record = function()
            ts = record.pop('ts')
            summary.add(record, ts)
            yield summary.__json__(), summary.brief_json()
        except Exception:
            log.exception("Unable to compile stats")
            yield None, None


def accumulator(
    function, period=10, filename=STAT_FULL_FILE, brief_filename=STAT_SUMMARY_FILE
):
    from fussy import twrite
    import json

    delay = period - (time.time() % period)
    time.sleep(delay)
    start = time.time()
    for stats, brief in accumulator_iterator(function):
        try:
            content = json.dumps(stats, separators=(',', ':'))
            if content:
                twrite.twrite(filename, content)
            brief = json.dumps(brief, separators=(',', ':'))
            if brief:
                twrite.twrite(brief_filename, brief)
        except Exception:
            log.exception("Unable to write stats")
        start += period
        time.sleep(max((0.0000001, start - time.time())))


def main():
    from atxstyle import standardlog
    from atxstyle import sysinfo

    standardlog.debug('sysinfo-accumulate', do_console=True)
    logging.getLogger('fussy.twrite').setLevel(logging.WARN)
    accumulator(sysinfo.quick_summary)
