import json, logging, os
from fussy import twrite
from atxstyle.sixish import as_unicode
from atxstyle import encryption, utctime, uniquekey, standardlog, sysinfo

LICENSE_BUNDLE = 'licenses'
LICENSE_SERVER = 'licenseserver'
log = logging.getLogger(__name__)

CURRENT_CERTIFICATE_FILE = '/etc/atxlicense/current.gpg'
CURRENT_LICENSES_FILE = '/etc/atxlicense/current.json'
LICENSE_SERVER_URL = '/var/firmware/atxlicense/config.url'
LICENSE_SERVER_METADATA = '/var/firmware/atxlicense/metadata.json'
CURRENT_SERVER_CONFIG = '/etc/atxlicense/config.json'


def set_license_server_url(value):
    if not '%s' in value:
        raise ValueError('Need a URL with a %s substitution')
    twrite.twrite(LICENSE_SERVER_URL, value + '\n')
    return LICENSE_SERVER_URL


def current_license_server_url():
    if os.path.exists(LICENSE_SERVER_URL):
        base = open(LICENSE_SERVER_URL).read().strip()
        if not base.endswith('v2/'):
            base = base + 'v2/'
        return base
    return None


def get_current_metadata():
    base = {}
    if os.path.exists(LICENSE_SERVER_METADATA):
        base.update(json.loads(as_unicode(open(LICENSE_SERVER_METADATA).read())))
    FIRMWARE_LINK = '/opt/firmware/current'
    if os.path.exists(FIRMWARE_LINK) and os.path.islink(FIRMWARE_LINK):
        firmware = os.readlink(FIRMWARE_LINK)
    else:
        firmware = 'SOURCE'
    base.update(
        {
            'timestamp': utctime.format(utctime.current_utc()),
            'firmware': firmware,
            'system': sysinfo.os_status(),
            'server_id': uniquekey.get_keys(),
            'server_pk': encryption.get_our_public_key(),
        }
    )
    try:
        from netconfig import status

        base['network'] = status.get_status()
    except ImportError:
        pass
    except Exception:
        log.warning('Unable to retrieve network status for configuration metadata')
    return base


def get_license_server_url():
    return (
        current_license_server_url()
        or 'https://digistreamepgdata.atxnetworks.com/certs/client/%s/certificates/v2/'
    )


def current():
    if not os.path.exists(CURRENT_LICENSES_FILE):
        return {}
    return json.loads(as_unicode(open(CURRENT_LICENSES_FILE).read()))


def pull_config(server=None):
    """Main function to pull configuration"""
    return pull(server, config=True)


def pull_options():
    """Construct the options for the pull operation"""
    import argparse

    parser = argparse.ArgumentParser(
        description='Pulls licenses/configurations from upstream server',
    )
    parser.add_argument(
        '-s',
        '--server',
        help='Override the server specified in the GUI',
        default=None,
    )
    parser.add_argument(
        '-c',
        '--config',
        help='Download configuration, rather than licenses',
        action='store_true',
        default=False,
    )
    return parser


def pull_main():
    options = pull_options().parse_args()
    return pull(options.server, options.config)


def pull(server=None, config=None):
    """Pull our licenses from a licensing server"""
    import requests

    standardlog.debug('pull-licenses', 'firmware')
    server = server or get_license_server_url()
    urls = [server]
    if server.endswith('v2/'):
        urls.append(server[:-3])
    if config:
        urls = [u + 'config/' for u in urls]
        server = urls[0]
        metadata = {}
    else:
        metadata = get_current_metadata()
    if os.path.exists(uniquekey.KEY_STORE):
        for url in urls:
            final_url = url % uniquekey.get_base_key()
            if url != server and not config:
                # second iteration, to an original v1 server...
                metadata['server_id'] = metadata['server_id']['original']
                log.info('Pulling from v1 licensing server')
            else:
                log.info('Attempting to pull from v2 licensing server')
            log.info('Downloading from %s', url)
            log.info('Metadata: %s', metadata)
            try:
                response = requests.post(
                    final_url, verify=False, data=json.dumps(metadata)
                )
            except requests.exceptions.ConnectionError as err:
                log.info('Unable to connect to licensing server %s', final_url)
                continue
            else:
                if response.ok:
                    try:
                        parsed = json.loads(as_unicode(response.content))
                    except Exception as err:
                        log.error("Unable to parse license API response: %s", err)
                        return 1
                    if parsed.get('success'):
                        if config:
                            if verify_and_write_config(parsed['config']):
                                return 0
                            else:
                                return 1
                        else:
                            if verify_and_write_licenses(parsed['license']):
                                return 0
                            else:
                                return 1
                    else:
                        log.warning(
                            'Failure downloading licenses: %s', parsed.get('message')
                        )
                        return 1
                else:
                    log.info('Response code: %s', response.status_code)
    log.warning('Unable to download licenses')
    return 1


MANUAL_FILE = '/var/firmware/atxlicense/certificate.gpg'


def manual_options():
    import argparse

    parser = argparse.ArgumentParser(
        description='Load a license bundle from MANUAL_FILE (must be run as root)'
    )
    return parser


def manual_load():
    """Command-line to manually load a certificate as the license certificate bundle

    Loads from MANUAL_FILE

    sudo $(which license-load)
    """
    standardlog.debug('license-load', 'firmware', clear=True)
    log.info("Starting certificate load")
    manual_options().parse_args()  # just to allow for --help description
    content = open(MANUAL_FILE).read()
    if verify_and_write_licenses(content):
        log.info("Certificate loaded")
        return 0
    else:
        log.error("Certificate load failure")
        return 1


def verify_and_write_licenses(certificate):
    try:
        certificates = verify_bundle(certificate)
    except encryption.EncryptionError as err:
        log.error("Unable to verify downloaded bundle: %s", err)
        return False
    else:
        twrite.twrite(CURRENT_CERTIFICATE_FILE, certificate)
        twrite.twrite(CURRENT_LICENSES_FILE, json.dumps(certificates, indent=2))
        return True


def verify_and_write_config(certified):
    """Given a certified server configuration, write to the CURRENT_SERVER_CONFIG file"""
    try:
        config = verify_license(certified)
    except encryption.EncryptionError as err:
        log.error("Unable to verify downloaded config: %s", err)
        return False
    else:
        twrite.twrite(CURRENT_SERVER_CONFIG, json.dumps(config, indent=2))
        return True


def filter_expired(certificates):
    """Filter any expired certificates out of the set"""
    d = utctime.current_utc()
    return [c for c in certificates if utctime.parse(c['valid_until']) > d]


def verify_bundle(certificate, **named):
    """Load the certificate on the client, verify, unpack and report"""
    try:
        for_server = named.pop('for_server')
    except KeyError:
        for_server = None
    if for_server:
        allowed = [for_server]
    else:
        allowed = [uniquekey.get_base_key(), uniquekey.original_key()]
    central_bundle = False
    d = utctime.current_utc()
    try:
        decoded = verify_license(certificate, **named)
        central_bundle = True
    except encryption.EncryptionError as err:
        # log.exception( 'Did not decode central server certificate: %s', err )
        content = encryption.decrypt_from_client(certificate, **named)
        try:
            decoded = json.loads(as_unicode(content))
        except Exception as err:
            raise ValueError(
                "Unable to decode signature, properly signed, but not a license file"
            )
    try:
        assigned = decoded['assigned_server']
    except KeyError as err:
        assigned = None
        log.warning(
            "No assigned_server key in %s",
            decoded,
        )
    was_assigned = False
    for key in allowed:
        if assigned == key:
            was_assigned = True
            break
    if not was_assigned:
        raise encryption.EncryptionError(
            "This bundle does not appear to be for this server, signed for server %s, we expected: %s"
            % (
                assigned,
                " or ".join(allowed),
            )
        )
    bundle_valid_until = utctime.parse(decoded['valid_until'])
    if bundle_valid_until < d:
        raise encryption.EncryptionError(
            "License Bundle is stale: %s (current %s)" % (bundle_valid_until, d)
        )
    if decoded['type'] != LICENSE_BUNDLE:
        raise RuntimeError(
            "Expected a license bundle! Got a %s" % (decoded['type_name'])
        )
    server = None
    certificates = []
    for certificate in decoded['certificates']:
        try:
            parsed = verify_license(certificate, **named)
        except Exception as err:
            err.args += (certificate,)
            raise
        # TODO: there are more validity checks we could add here...
        if (not central_bundle) and parsed['license_server'] != decoded[
            'license_server'
        ]:
            log.error(
                'Certificate from wrong Licensing Server: %s %s vs %s',
                parsed['uuid'],
                parsed['license_server'],
                decoded['license_server'],
            )
            raise encryption.EncryptionError(
                "Received License Bundle with Improper Certificate"
            )
        else:
            certificates.append(parsed)
    for certificate in certificates:
        if certificate['type'] == LICENSE_SERVER:
            server = certificate
            break
    if not central_bundle:
        if not server:
            raise encryption.EncryptionError(
                'No license-server certificate found in the bundle'
            )
        if decoded['server_pk'] != server['license_server_pk']:
            log.error('Bundle PK: %s', decoded['server_pk'])
            log.error('Server PK: %s', server['server_pk'])
            raise encryption.EncryptionError(
                'License bundle was not signed by the license server'
            )
    certificate_map = dict([(x['uuid'], x) for x in certificates])
    local = uniquekey.get_base_key()
    verified = []
    if central_bundle:
        decoded['verified'] = certificates[:]
    else:
        for uuid, assignment in decoded.get('allotment_logs', []):
            log.info('Checking assignments for certificate: %s', uuid)
            if assignment:
                last = assignment.splitlines()[-1]
                command = last.split('\t')
                if command[1] == 'assigned':
                    if command[2] == local:
                        verified.append(certificate_map.get(uuid))
                else:
                    log.info('Released %s', uuid)
            else:
                log.info('No assignment for %s', uuid)
        decoded['verified'] = verified
    return decoded


def verify_license(certificate, **named):
    """Verify a central-licensor-issued license"""
    content = encryption.check_signature(certificate)
    try:
        return json.loads(as_unicode(content))
    except Exception:
        raise encryption.EncryptionError(
            "Unable to decode certificate (signed, but not structured properly)"
        )
