import hashlib, uuid
import os, tempfile, logging, time, json
import asyncio
import websockets
import aiofiles

log = logging.getLogger('wsuploads')

UPLOAD_OWNER = (os.getenv('UPLOAD_OWNER') or '').strip()


class Uploader(object):
    def __init__(self, uploads_path, development=False):
        self.uploads_path = uploads_path
        self.development = development
        if not os.path.exists(self.uploads_path):
            os.makedirs(self.uploads_path)

    async def current_ids(self, user_id):
        """Report current upload ids"""
        import glob

        result = []
        user_path = os.path.join(self.uploads_path, str(user_id))
        for path in glob.glob('%s/*.json' % (user_path,)):
            async with aiofiles.open(path, 'r') as fh:
                result.append(json.loads(await fh.read()))
        return result

    async def control_channel(
        self, websocket: websockets.WebSocketServerProtocol, path: str, user_id: int
    ):
        """Control channel providing cancel and list of uploads"""
        await websocket.send(
            json.dumps({'current_files': await self.current_ids(user_id)})
        )
        while True:
            msg = await websocket.recv()
            content = json.loads(msg)
            log.info("Command channel: %s", content)
            command = content.get('command')
            if command == 'cancel':
                id = content.get('id')
                log.info('User cancel of %s %s', user_id, id)
                known_ids = await self.current_ids(user_id)
                matching = [x for x in known_ids if x.get('upload_id') == id]
                if matching:
                    for match in matching:
                        for path in [match['path'], match['path'] + '.json']:
                            try:
                                os.remove(match['path'])
                            except (IOError, OSError) as err:
                                pass
                else:
                    log.warning('Cancel of unknown upload %s %s', user_id, id)
                await websocket.send(
                    json.dumps(
                        {
                            "success": True,
                            "command": "cancel",
                            'id': id,
                        }
                    )
                )
                await websocket.close()
                return

    async def upload_file(self, websocket: websockets.WebSocketServerProtocol, path):
        path = websocket.request_headers.get('X-Original-URI') or path
        log.info("upload request: %s", path)

        if self.development:
            user_id = self.development
        else:
            user_id = int(websocket.request_headers.get('X-Authentication-Id', '0'), 10)
            if not user_id:
                await websocket.send(
                    json.dumps({'error': True, 'message': 'Not logged in'})
                )
                return

        filename = os.path.basename(path)
        if filename == '.control':
            return await self.control_channel(websocket, path, user_id)

        id = str(uuid.uuid4())
        final_file = os.path.join(self.uploads_path, str(user_id), id)
        final_dir = os.path.dirname(final_file)

        if not os.path.exists(final_dir):
            os.makedirs(final_dir, exist_ok=True)

        written = 0
        md5 = hashlib.md5()
        sha512 = hashlib.sha512()
        async with aiofiles.open(final_file, 'wb') as fh:
            while True:
                await websocket.send(json.dumps({"id": id, "user_id": user_id}))
                content = await websocket.recv()
                written += len(content)
                if content:

                    smallWrite = await fh.write(content)
                    if smallWrite != len(content):
                        await websocket.send(
                            json.dumps(
                                {
                                    'error': True,
                                    'message': 'Unable to write (likely disk quota issue)',
                                }
                            )
                        )
                        os.remove(final_file)
                        break
                    md5.update(content)
                    sha512.update(content)

                    await websocket.send(json.dumps({"written": written}))
                    log.info(
                        'Wrote: %0.1fMiB to %s', written / (1024 * 1024), final_file
                    )
                else:
                    break
        log.info("Finished write of %s %s", user_id, id)
        metadata = {
            'filename': filename,
            'path': final_file,
            'upload_id': id,
            'user_id': user_id,
            'size': written,
            'md5': md5.hexdigest(),
            'sha512': sha512.hexdigest(),
            'ts': time.time(),
        }
        async with aiofiles.open(final_file + '.json', 'w') as fh:
            await fh.write(json.dumps(metadata))
        if UPLOAD_OWNER:
            for fn in [final_file, final_file + '.json']:
                try:
                    os.chown(fn, int(UPLOAD_OWNER), -1)
                except (ValueError, OSError, IOError, RuntimeError) as err:
                    pass
        await websocket.send(json.dumps({"done": written}))
        await websocket.close()

    async def cleanup(self):
        while True:
            current = await self.current_ids('*')
            for item in current:
                if item.get('ts', 0) < time.time() - (3600 * 24 * 7):
                    for path in [item['path'], item['path'] + '.json']:
                        try:
                            os.remove(path)
                        except (OSError, IOError):
                            pass
            await asyncio.sleep(3600)


def get_options():
    import argparse

    parser = argparse.ArgumentParser(
        description="Provides streaming upload of large files over websockets"
    )
    parser.add_argument(
        '-p', '--port', default=8766, help="Port on which to listen, default 8766"
    )
    parser.add_argument(
        '--development',
        type=int,
        default=0,
        help='If specified, run such that all uploads are done for this user (normally 1 will use the factory user)',
    )
    parser.add_argument('-d', '--directory', default='/var/firmware/protected/uploads')
    return parser


def main():
    logging.basicConfig(level=logging.INFO)
    options = get_options().parse_args()
    log.info(
        "Using upload directory: %s listening on %s", options.directory, options.port
    )
    u = Uploader(options.directory, development=options.development)
    asyncio.ensure_future(u.cleanup())
    start_server = websockets.serve(u.upload_file, '0.0.0.0', options.port)

    asyncio.get_event_loop().run_until_complete(start_server)
    asyncio.get_event_loop().run_forever()


if __name__ == "__main__":
    main()
