from atxstyle.sixish import unicode
import logging
from django.core.exceptions import FieldDoesNotExist
from django.conf import settings

DEBUG = getattr(settings, 'DEBUG', False)
log = logging.getLogger(__name__)


class QueryFilter(object):
    """Mix-in for filtering by query"""

    debug = False
    model = None
    default_filters = None

    @classmethod
    def field_filter(cls, filter):
        from django.db.models import Q

        field = filter['field']
        filt = filter['f']
        value = filter['v']

        if filt in ('eq', 'filter', ''):
            return Q(**{field: value})
        elif filt in ('ne', 'exclude'):
            return ~Q(**{field: value})
        elif filt in ('gte', 'after'):
            return Q(**{field + "__gte": value})
        elif filt in ('lte', 'before'):
            return Q(**{field + "__lte": value})
        elif filt in ('gt', 'more'):
            return Q(**{field + "__gt": value})
        elif filt in ('lt', 'less'):
            return Q(**{field + "__lt": value})
        elif filt in ('contains',):
            return Q(**{field + "__icontains": value})
        elif filt in ('does_not_contain',):
            return ~Q(**{field + "__icontains": value})
        elif filt in ('in', 'not_in'):
            if isinstance(value, (bytes, unicode)):
                value = value.split(',')
            if filt == 'in':
                return Q(**{field + "__in": value})
            else:
                return ~Q(**{field + "__in": value})
        else:
            raise ValueError("Unrecognized filter type %r", filt)

    DEFAULT_COMPARISONS = [
        'gt',
        'lt',
        'gte',
        'lte',
        'eq',
        'ne',
        'contains',
        'does_not_contain',
        'isnull',
        'in',
        'not_in',
    ]
    MPTT_COMPARISONS = ['family', 'descendants', 'ancestors', 'children']

    def filter_def(
        self, key, value, comparisons=DEFAULT_COMPARISONS + MPTT_COMPARISONS, query=None
    ):
        """Are we allowed to use the given key?"""
        if self.debug and DEBUG:
            log.info("Check key %r for filter", key)
        if key in self.MPTT_COMPARISONS:
            return {'field': key, 'v': value, 'f': key}
        fragments = key.split('__')
        base, rest = fragments[0], fragments[1:]
        model = query.model if query is not None else self.model
        try:
            field = model._meta.get_field(base)
        except FieldDoesNotExist:
            if self.debug and DEBUG:
                log.info("  %r field not present on %s", base, model.__name__)
            return None
        if self.debug and DEBUG:
            log.info("  %r field %s", key, field)
        if rest:
            if len(rest) == 1 and rest[0] in comparisons:
                # local field and just a comparison...
                if self.debug and DEBUG:
                    log.info("  %r field %s %s", key, rest[0], value)
                return {'field': base, 'f': rest[0], 'v': value}
            comparison = 'eq'
            if rest[-1] in comparisons:
                comparison = rest[-1]
                rest = rest[:-1]
            trav = '__'.join([base] + rest)
            traversals = getattr(self, 'allowed_traversals', {})
            if trav in traversals:
                if comparison in ('in', 'not_in'):
                    if isinstance(value, (bytes, unicode)):
                        value = value.split(',')
                return {'field': trav, 'f': comparison, 'v': value}
            else:
                log.info('Disallowed traversal: %s', key)
                return None
        return {'field': base, 'f': 'eq', 'v': value}

    def filtered_traversals(self, filters, strict=True, query=None):
        return [
            f
            for f in [
                self.filter_def(key, value, query=query)
                for (key, value) in filters.items()
            ]
            if f
        ]

    def get_query_filters(self, request, type_key, query=None):
        filters = []
        for source in (request.GET, request.POST):
            # log.info("Query: %s",source)
            filters += self.filtered_traversals(source, query=query)
        if self.debug and DEBUG:
            log.info("Filters: %s", filters)
        if not filters:
            filters = self.default_filters or []
            if callable(filters):
                filters = filters(self, request, type_key)
            filters = filters[:]
        if hasattr(self, 'base_filters'):
            filters += self.filtered_traversals(
                self.base_filters(self, request, type_key),
                query=query,
            )
        if self.debug and DEBUG:
            log.info("Final filters: %s", filters)
        return filters

    def filter_query(self, query, request, type_key):
        filters = self.get_query_filters(request, type_key)
        log.info("Final filters on %s: %s", type_key, filters)
        final_filter = None
        for filter in filters:
            if filter['f'] in self.MPTT_COMPARISONS:
                # log.info("Filtering by hierarchy, will discard root query")
                # special case for hierarchic members of the given family...
                # TODO: review why this was done originally, potentially go back to
                # re-starting the query on hier-check
                root = query.filter(id=filter['v']).first()
                if root:
                    query = getattr(root, 'get_%s' % (filter['f']))()
                else:
                    raise KeyError("Unknown %s" % (query.model.__name__), filter['v'])
        for filter in filters:
            if filter['f'] in self.MPTT_COMPARISONS:
                continue
            if filter['f'] == 'isnull':
                query = query.filter(
                    **{
                        filter['field'] + '__isnull': filter['v']
                        in ('true', 'True', '1')
                    }
                )
            else:
                q = self.field_filter(filter)
                if final_filter is None:
                    final_filter = q
                else:
                    final_filter &= q
        if final_filter:
            query = query.filter(final_filter)
        return query


class DRFFilter(QueryFilter):
    def __init__(self, request, queryset, view):
        self.view = view
        self.query = queryset
        self.model = queryset.model
        self.request = request
        self.allowed_traversals = getattr(view, 'allowed_traversals', [])
        if hasattr(view, 'base_filters'):
            self.base_filters = view.base_filters
