from datetime import datetime
from typing import Any
from atxstyle.sixishdj import gettext_lazy as _

from atxstyle.sixish import unicode
import logging, subprocess, os, re, json
from django import forms
from django.core import validators
from django.core.exceptions import PermissionDenied
from atxstyle.djmigrate import get_app
from django.forms import utils as forms_util
from django.forms import widgets
from media import bigfilewidget
from django.contrib import messages
import datetime
from . import utctime

from fussy import nbio

from django.conf import settings

log = logging.getLogger(__name__)

__all__ = [
    'AlwaysValidate',
    'ArrayOfSize',
    'DateForm',
    'FirmwareForm',
    'LocationForm',
    'RebootForm',
    'exclude_range',
    'password_setting_form',
    'validate_fraction',
    'validate_mac',
    'validate_password',
    'validate_range',
    'colour_validator',
    'time_validator',
    'validate_ip_port_range',
    'validate_pid',
    'validate_pid_range',
    'validate_program_number_range',
]


def listify_string(value):
    if isinstance(value, str):
        if value.startswith('['):
            value = json.loads(value)
        elif not value:
            return []
        else:
            value = value.split(',')
    return value


def SelectMultiple_value_from_datadict(
    self, data, files, name, base=widgets.SelectMultiple.value_from_datadict
):
    """Override multiselect to allow for modern-style FormData submissions

    FormData allows for embedding files from the Ajax Upload,
    but it *also* requires a modern data-format, which isn't
    supported by vanila Django.
    """
    value = base(self, data, files, name)
    value = listify_string(value)
    if isinstance(value, (tuple, list)):
        result = []
        for item in value:
            if isinstance(item, str):
                result.extend(listify_string(item))
        value = result
    return value


widgets.SelectMultiple.value_from_datadict = SelectMultiple_value_from_datadict


def CustomLabelModelChoiceField(label_function):
    class CustomLabelModelChoiceField(forms.ModelChoiceField):
        """Subclass of model choice field that uses label_function to label instances"""

        label_from_instance = label_function

    return CustomLabelModelChoiceField


class AlwaysValidateMixin(object):
    def validate_unique(self):
        try:
            self.instance.validate_unique()
        except forms.ValidationError as e:
            self._update_errors(e.message_dict)


class AlwaysValidate(forms.ModelForm):
    pass


def validate_password(passwd):
    """Trivial password test"""
    conditions_met = 0
    if not (len(passwd) >= 6):
        return False
    if len([x for x in passwd if not x.isalnum()]) > 0:
        conditions_met += 1
    if len([x for x in passwd if x.isalpha()]) > 0:
        conditions_met += 1
    result = conditions_met >= 2
    return result


UNIQUE_KEY_VALIDATOR = validators.RegexValidator(
    r'^[a-zA-Z0-9_.-]+$',
    message=_("Need characters in a-z A-Z 0-9 _.- (no spaces or other punctuation)"),
)


time_validator = validators.RegexValidator(
    r'^\W*(([01]?\d|2[0-3]):[0-5]\d)\W*$',
    _('Enter a time in HH:MM form (24-hour clock)'),
)
reboot_offset_validator = validators.RegexValidator(
    r'^\W*([012]?\d[:]\d{2})|([+]?\d+)\W*$',
    _('Enter a time in HH:MM form (24-hour clock) or +M minutes in the future'),
)
colour_validator = validators.RegexValidator(
    r'#([0-9a-fA-F]{6}|[0-9a-fA-F]{3})',
    _('Enter a colour in CSS style #FFFFFF notation'),
)


class RebootForm(forms.Form):
    """Form to provide a semi-immediate reboot"""

    permissions = [
        'config.shutdown',
    ]
    reboot_time = forms.CharField(
        validators=[reboot_offset_validator],
        required=False,
        help_text=_(
            "If specified, the reboot will occur at this time (specified in server local time as HH:MM (24-hour) or +MM (minutes in the future), default is 1 minute in the future)"
        ),
        label=_("Reboot Time/Delay"),
        widget=forms.widgets.TextInput(
            attrs={'placeholder': '+1'},
        ),
    )

    def __init__(self, *args, **named):
        if 'instance' in named:
            named.pop('instance')
        super(RebootForm, self).__init__(*args, **named)


def PeriodicRebootForm(*args, **named):
    System = get_app('config').System

    class PeriodicRebootForm(AlwaysValidate):
        """Form to schedule/update a periodic system reboot"""

        permissions = [
            'config.shutdown',
        ]
        dependencies = {
            ('periodic_reboot', (1, 2)): [
                'periodic_reboot_time',
            ]
        }

        class Meta:
            model = System
            fields = (
                'periodic_reboot',
                'periodic_reboot_time',
            )

        def __init__(self, *args, **named):
            super(PeriodicRebootForm, self).__init__(*args, **named)
            self.fields['periodic_reboot_time'].required = False

        def clean(self):
            if not self.cleaned_data.get('periodic_reboot_time'):
                self.cleaned_data['periodic_reboot_time'] = '00:00'
            return self.cleaned_data

        def save(self, *args, **named):
            result = super(PeriodicRebootForm, self).save(*args, **named)
            result.write_periodic_reboot()
            return result

    if 'instance' not in named:
        instance = System.objects.first()
        named['instance'] = instance
    return PeriodicRebootForm(*args, **named)


def LocationForm(*args, **named):
    System = get_app('config').System

    class FinalLocationForm(AlwaysValidate):
        permissions = [
            'config.location',
        ]

        class Meta:
            model = System
            fields = (
                'timezone',
                'location',
            )

        def save(self, *args, **named):
            base = super(FinalLocationForm, self).save(*args, **named)
            from . import system

            self.written = system.write_timezone(self.instance.timezone)
            base.write_snmp_config()
            return base

    if 'instance' not in named:
        instance = System.objects.first()
        named['instance'] = instance
    return FinalLocationForm(*args, **named)


def SerialForm(*args, **named):
    System = get_app('config').System

    class SerialForm(forms.ModelForm):
        permissions = ['factory']

        class Meta:
            model = System
            fields = System.FACTORY_SERIAL_FORM_FIELDS

        def save(self, *args, **named):
            result = super(SerialForm, self).save(*args, **named)
            result.write_snmp_config()
            return result

    if 'instance' not in named:
        instance = System.objects.first()
        named['instance'] = instance
    return SerialForm(*args, **named)


def handle_firmware_update(firmware_path):
    try:
        log.warning('Attempting installation from: %s', firmware_path)
        directory = os.path.dirname(settings.FIRMWARE_PROMOTION)
        if not os.path.exists(directory):
            os.makedirs(directory)
        os.rename(firmware_path, settings.FIRMWARE_PROMOTION)
    except Exception:
        subprocess.check_call(
            'mv %s %s' % (firmware_path, settings.FIRMWARE_PROMOTION), shell=True
        )
    if settings.MODIFY_SYSTEM:
        if os.path.exists('/usr/bin/systemd-run'):
            try:
                subprocess.check_call(
                    [
                        'sudo',
                        '-n',
                        os.path.join(settings.SUDO_BIN_DIRECTORY, 'promote-install'),
                    ]
                )
            except subprocess.CalledProcessError as err:
                log.exception('Unable to trigger upgrade of firmware')
                raise RuntimeError('Could not install firmware')
            else:
                return True
        else:
            subprocess.call(
                ' '.join(
                    [
                        'sudo',
                        '-n',
                        os.path.join(settings.SUDO_BIN_DIRECTORY, 'promote-install'),
                        '&',
                    ]
                ),
                shell=True,
            )
    else:
        log.warning("Skipping installation, running from source")
    return True


class LegacyFirmwareForm(forms.Form):
    firmware = forms.FileField(
        label=_("Firmware Image"),
        required=True,
        help_text=_("The firmware image to install"),
    )

    def __init__(self, *args, **named):
        for key in ('user', 'instance'):
            if key in named:
                named.pop(key)
        super(LegacyFirmwareForm, self).__init__(*args, **named)

    def save(self):
        firmware = self.cleaned_data.get('firmware')
        if firmware:
            result = handle_firmware_update(firmware.temporary_file_path())
            self.cleaned_data['firmware'] = None
            return result
        else:
            return False


class FirmwareForm(
    bigfilewidget.BigFileUploadMixIn, bigfilewidget.IgnoreInstance, forms.Form
):
    firmware = forms.FileField(
        label=_("Firmware Image"),
        required=True,
        help_text=_("The firmware image to install"),
    )
    big_file_fields = ['firmware']

    def save(self):
        try:
            firmware = self.cleaned_data.get('firmware')
            if firmware:
                result = handle_firmware_update(firmware.temporary_file_path())
                self.cleaned_data['firmware'] = None
                return result
            else:
                return False
        except Exception as err:
            log.exception('Failure during firmware upload save')
            raise


class DateTimeField(forms.DateTimeField):
    def to_python(self, value):
        import datetime

        if value and isinstance(value, (bytes, unicode)) and value.isdigit():
            from . import utctime

            return utctime.from_timestamp(int(value))
        return super(DateTimeField, self).to_python(value)

    def prepare_value(self, value):
        """Convert python to js/html compatible value"""
        if value:
            if isinstance(value, datetime.datetime):
                if value.tzinfo:
                    return utctime.as_timestamp(value)
                else:
                    return utctime.as_timestamp(value.as_timezone(utctime.UTC))
            return super(DateTimeField, self).prepare_value(value)
        return None


class DateForm(forms.Form):
    date = DateTimeField(
        label=_("Override System Date"),
        required=True,
        help_text=_(
            "Set the system's internal clock to the given date and time in ISO format YYYY-mm-dd HH:MM:SS"
        ),
        input_formats=[
            '%Y-%m-%d %H:%M',
            '%Y-%m-%d %H:%M:%S',
            '%Y-%m-%dT%H:%M',
            '%Y-%m-%dT%H:%M:%S',
        ],
    )

    def __init__(self, *args, **named):
        if 'instance' in named:
            self.instance = named.pop('instance')
        super(DateForm, self).__init__(*args, **named)

    def save(self, *args, **named):
        date = self.cleaned_data['date']
        formatted = date.strftime('%m%d%H%M%Y.%S')
        if settings.MODIFY_SYSTEM:
            output = nbio.Process('sudo -n date %(formatted)s' % locals())()
        return date


def password_setting_form(request):
    """Create a password-setting form class for the given request"""

    def can_change(user):
        return request.user.has_perm('config.change_%s_password' % (user.username,))

    from django.contrib.auth import models as auth_models

    users = [
        user
        for user in auth_models.User.objects.order_by('username').all()
        if can_change(user)
    ]

    class PasswordSettingForm(forms.Form):
        """Custom password setting form, following the UCrypt model"""

        user = forms.ChoiceField(
            choices=[(user.id, unicode(user)) for user in users],
            required=True,
        )
        password = forms.CharField(
            widget=forms.PasswordInput, label="Password", required=True
        )
        confirm = forms.CharField(
            widget=forms.PasswordInput, label="Confirm", required=True
        )

        def clean_user(self):
            value = int(self.cleaned_data.get('user'))
            user = None
            for possible in users:
                if value == possible.id:
                    user = possible
                    break
            if not user:
                raise forms.ValidationError(
                    _("Unrecognized user id: %(id)s") % {'id': value}
                )
            return user

        def clean(self):
            values = super(PasswordSettingForm, self).clean()
            password = values.get('password')
            confirm = values.get('confirm')
            if password != confirm:
                raise forms.ValidationError(
                    _("Your Password and Confirmation are not the same")
                )
            if not validate_password(password):
                raise forms.ValidationError(
                    _("Your Password is not sufficiently complex to prevent intrusion")
                )
            return values

    return PasswordSettingForm


class RangeValidator(object):
    def __init__(self, start, stop):
        self.start = start
        self.stop = stop

    def __call__(self, value):
        if self.start is not None:
            if self.start > value:
                raise validators.ValidationError(
                    'Require a value greater than or equal to %s, received %s'
                    % (self.start, value)
                )
        if self.stop is not None:
            if self.stop < value:
                raise validators.ValidationError(
                    'Require a value less than or equal to %s, received %s'
                    % (self.stop, value)
                )
        return value

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.start, self.stop == other.start, other.stop
        return False

    def deconstruct(self):
        return (
            '%s.%s' % (__name__, self.__class__.__name__),
            (self.start, self.stop),
            {},
        )

    def schedule_json(self):
        return {
            'type': self.__class__.__name__,
            'start': self.start,
            'stop': self.stop,
        }


class RangeExcluder(object):
    def __init__(self, start, stop, why='are excluded'):
        self.start = start
        self.stop = stop
        self.why = why

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.start, self.stop, self.why == other.start, other.stop, self.why
        return False

    def __call__(self, value):
        if self.start <= value <= self.stop:
            raise validators.ValidationError(
                'Values in the range %s to %s %s, received %s'
                % (self.start, self.stop, self.why, value)
            )
        return value

    def deconstruct(self):
        return (
            '%s.%s' % (__name__, self.__class__.__name__),
            (self.start, self.stop),
            {},
        )

    def schedule_json(self):
        return {
            'type': self.__class__.__name__,
            'start': self.start,
            'stop': self.stop,
            'why': self.why,
        }


class ArrayOfSize(object):
    def __init__(self, min_length=None, max_length=None):
        self.min_length = min_length
        self.max_length = max_length

    def deconstruct(self):
        return (
            '%s.%s' % (__name__, self.__class__.__name__),
            (self.min_length, self.max_length),
            {},
        )

    def __call__(self, value):
        if isinstance(value, None):
            return
        length = len(value)
        if self.min_length is not None:
            if length < self.min_length:
                raise validators.ValidationError(
                    "Array must be at least %d elements, was %s"
                    % (self.min_length, value)
                )
        if self.max_length is not None:
            if length > self.max_length:
                raise validators.ValidationError(
                    "Array must be at most %d elements, was %s"
                    % (self.max_length, value)
                )


class EventLogClearForm(forms.Form):
    confirm = forms.BooleanField(
        label=_("Clear Events"),
        help_text=_("Confirm that you wish to clear the event log"),
        required=False,
    )

    def __init__(self, *args, **named):
        self.instance = named.pop('instance')
        super(EventLogClearForm, self).__init__(*args, **named)

    def clean_confirm(self):
        confirm = self.cleaned_data.get('confirm')
        if not confirm:
            raise forms.ValidationError(_("You must confirm the clearing"))
        return confirm

    def save(self, *args, **named):
        log.warning("Clearing event log")
        from fussy import nbio

        nbio.Process(['fussy-log-server-rotate'])()
        return self.instance


def validate_range(start=None, stop=None):
    return RangeValidator(start, stop)


class PlaceHolder(object):
    def __init__(self, function):
        self.function = function
        self.__name__ = function.__name__

    def __repr__(self):
        return unicode(self.function())

    def __gt__(self, other):
        return self.function() > other

    def __lt__(self, other):
        return self.function() < other

    def __eq__(self, other):
        return self.function() == other

    def __cmp__(self, other):
        return cmp(self.function(), other)


def is_json_list(text):
    if text:
        if not text.startswith('[') and text.endswith(']'):
            raise validators.ValidationError("Expected a json list")


def is_json_dict(text):
    if text:
        if not text.startswith('{') and text.endswith('}'):
            raise validators.ValidationError("Expected a json object")


colour_validator = validators.RegexValidator(
    r'#([0-9a-fA-F]{6}|[0-9a-fA-F]{3})',
    _('Enter a colour in CSS style #FFFFFF notation'),
)


def lazy_value(function):
    return PlaceHolder(function)


def exclude_range(start, stop, why='are not allowed'):
    """Exclude a range of values from being accepted"""
    return RangeExcluder(start, stop, why)


MAC_RE = re.compile(r'^([0-9a-fA-F]{2}([:-]?|$)){6}$')


def validate_mac(value):
    if not MAC_RE.match(value):
        raise validators.ValidationError(
            'Require a 12-hexadecimal digit MAC address (with : or - separators optional)'
        )


def validate_fraction(fraction):
    if fraction < 0:
        raise validators.ValidationError(
            _("Require a value greater than 0.0"),
        )
    elif fraction > 1:
        raise validators.ValidationError(
            _("Require a value less than 1.0"),
        )
    return fraction


validate_ip_port_range = validate_range(start=1, stop=2**16 - 1)
validate_unprivileged_port = validate_range(start=1025, stop=2**16 - 1)
validate_program_number_range = validate_range(start=1, stop=2**16 - 1)
validate_source_id_range = validate_range(start=1, stop=2**16 - 1)

# TODO: Versative/VMS 2.2 requires PIDs to be >= 32, not sure what that's the lower limit
# validate_pid_range = validate_range( start=1, stop=2**13-1)
validate_pid_range = validate_range(start=32, stop=2**13 - 1)
validate_pid_no8177 = exclude_range(8177, 8177, 'is reserved')
validate_pid = [validate_pid_range, validate_pid_no8177]


def validate_pid_noproidium(value):
    if getattr(settings, 'DQAM', False):
        return exclude_range(4176, 4191, 'are reserved for ProIdiom')(value)

    return True


def pid_selection(component):
    """Convert/validate a single value pid selection"""
    component = component.strip()
    if '|' in component:
        elements = [
            str(x) for x in [pid_selection(x) for x in component.split('|')] if x
        ]
        if not elements:
            return None
        else:
            return '|'.join(elements)
    if component.isdigit():
        pid = int(component)
        for val in validate_pid:
            val(pid)
        return pid
    elif not component:
        return None
    else:
        if not len(component) == 3:
            raise validators.ValidationError(
                '%r is not a numeric PID or a 3-letter language code' % (component,)
            )
        # note: we purposely do *not* validate as a real language code, as it is
        # common to use non-standard language codes
        return component


def pid_selections(values):
    """Convert "selection of pids" into a list of int or strings defining pid or language to select"""
    components = [x.strip() for x in values.split(',') if x]
    if not components:
        return []
    result = []
    for component in components:
        rule = pid_selection(component)
        if rule is not None:
            result.append(rule)
    return result


def validate_pid_selections(values):
    return pid_selections(values)


def LicenseClientForm(*args, **named):
    get_app('config')

    class LicenseClientForm(AlwaysValidate):
        upload = forms.FileField(
            label=_("Upload License"),
            help_text=_(
                "Upload a licensing bundle to the server when it does not have access to a licensing server"
            ),
            required=False,
        )
        permissions = [
            'config.licenses',
        ]

        class Meta:
            model = get_app('config').System
            fields = ('atx_license_server',)

        def clean_upload(self):
            """Currently we actually import if the validation succeeds

            While the certificate import is relatively fast, it can
            take a very long time for servers with thousands of certificates.
            """
            base = self.cleaned_data.get('upload')
            if base:
                content = self.cleaned_data.get('upload')
                path = content.temporary_file_path()
                logfile = '/var/firmware/log/license-load.log'
                try:
                    nbio.Process(
                        [
                            os.path.join(settings.BIN_DIRECTORY, 'license-client'),
                            '-f',
                            path,
                        ]
                    )()
                except nbio.ProcessError:
                    raise forms.ValidationError(
                        _("Unable to import certificate:\n%(log)s")
                        % {
                            'log': open(logfile).read().decode('utf-8')
                            if os.path.exists(logfile)
                            else 'No log generated',
                        }
                    )
            return None

        def save(self, *args, **named):
            """On saving the source, update the licenses"""
            base = super(LicenseClientForm, self).save(*args, **named)
            if not self.cleaned_data.get('upload'):
                try:
                    nbio.Process(
                        [
                            os.path.join(settings.BIN_DIRECTORY, 'license-client'),
                        ]
                    )()
                except nbio.ProcessError:
                    return base
            base.write_snmp_config()
            return base

    return LicenseClientForm(*args, **named)


def SNMPForm(*args, **named):
    class SNMPFormCls(AlwaysValidate):
        permissions = [
            'config.snmp',
        ]
        send_test = forms.BooleanField(
            widget=forms.HiddenInput,
            required=False,
            initial=False,
        )

        class Meta:
            model = get_app('config').System
            fields = (
                'snmp_community',
                #'snmp_rw_community',
                'snmp_trap_sink',
                'snmp_trap_sink_port',
                'snmp_trap_community',
            )

        def save(self, *args):
            result = super(SNMPFormCls, self).save(*args)
            result.write_snmp_config()
            if self.cleaned_data.get('send_test'):
                log.info("Sending test traps")
                from snmpagents import traps

                traps.test(settings.PRODUCT)
            return result

    return SNMPFormCls(*args, **named)


class CompoundForm(object):
    """Helper class that mimics a single form from multiple forms..."""

    def __init__(self, *forms):
        self.instances = forms
        self._errors = forms_util.ErrorDict()

    def hidden_fields(self):
        for form in self.instances:
            for field in form.hidden_fields():
                yield field

    def visible_fields(self):
        for form in self.instances:
            for field in form.visible_fields():
                yield field

    @property
    def errors(self):
        _errors = self._errors
        if not _errors:
            for form in self.instances:
                _errors.update(form.errors)
            if not _errors:
                try:
                    self.clean()
                except validators.ValidationError as err:
                    #                    from django.forms.utils import ErrorList
                    self.instances[0].add_error('__all__', err)
                    _errors.update(self.instances[0].errors)
        return _errors

    def has_error(self, key, code=None):
        #        import ipdb;ipdb.set_trace()
        for form in self.instances:
            if form.has_error(key, code=code):
                return True
        if key in self.errors:
            for err in self.errors[key]:
                if any(error.code == code for error in self._errors.as_data()[key]):
                    return True
        return False

    def add_error(self, *args, **named):
        self.instances[0].add_error(*args, **named)

    def non_field_errors(self):
        _errors = forms_util.ErrorList()
        for form in self.instances:
            _errors.extend(form.non_field_errors())
        _errors.extend(self._errors.get('__all_', []))
        return _errors

    @property
    def dependencies(self):
        base = {}
        for form in self.instances:
            base.update(getattr(form, 'dependencies', {}))
        if base:
            return base
        raise AttributeError('dependencies')

    @property
    def field_sets(self):
        base = []
        for form in self.instances:
            base.extend(getattr(form, 'field_sets', []))
        if base:
            return base
        raise AttributeError('field_sets')

    def __iter__(self):
        for form in self.instances:
            for field in form:
                yield field

    def is_valid(self, *args, **named):
        return not self.errors

    def clean(self, *args, **named):
        """Just a point to do cross-subform validation"""

    def save(self, *args, **named):
        result = []
        for form in self.instances[::-1]:
            result.append(form.save(*args, **named))
        return result


class EULAForm(forms.Form):
    needs_request = True

    def save(self, *args, **named):
        if not self.request.user.has_perm('config.accept_eula'):
            raise PermissionDenied("User does not have permission to accept the EULA.")
        from atxstyle import models, utctime
        from fussy import twrite

        system = models.System.get_current()
        system.eula_accepted = utctime.current_utc()
        system.save()
        system.write_snmp_config()
        twrite.twrite(models.EULA_ACCEPTED, str(system.eula_accepted))
        return system

    def __init__(self, *args, **named):
        if 'instance' in named:
            named.pop('instance')
        self.request = named.pop('request')
        super(EULAForm, self).__init__(*args, **named)


class DirectSmallFileWidget(widgets.Input):
    """Direct upload via AJAX for small data-files"""

    input_type = 'file'
    needs_multipart_form = False

    def render(self, name, value, attrs=None):
        # rendering in a django template...
        return super(DirectSmallFileWidget, self).render(name, None, attrs=attrs)

    def value_omitted_from_data(self, data, files, name):
        return name not in files

    def value_from_datadict(self, data, files, name):
        "File widgets take data from FILES, not POST"
        return files.get(name)


class MultiJSONSelectWidget(forms.widgets.SelectMultiple):
    """Multi-select that allows for JSON formats"""

    def value_from_datadict(self, data, files, name):
        base = data.get(name)
        # log.warning("Raw data: %s", base)
        import json

        if isinstance(base, bytes):
            if base.startswith(b'['):
                return json.loads(base)
        elif isinstance(base, unicode):
            if base.startswith(u'['):
                return json.loads(base)
        return super(MultiJSONSelectWidget, self).value_from_datadict(
            data,
            files,
            name,
        )


from atxstyle.sixishdj import gettext as _
