import pytest, json, tempfile, shutil
import asyncio
import pydantic
from mamba_daemon import daemon, models, devicewatcher
import io
from aiohttp import test_utils
from contextlib import contextmanager


@contextmanager
def with_tempdir():
    tmpdir = tempfile.mkdtemp(prefix='mamba-', suffix='-test')
    yield tmpdir
    shutil.rmtree(tmpdir, ignore_errors=True)


class FakeRequest(object):
    def __init__(self, match_info=None, body=None):
        self.match_info = match_info or {}
        self.body = body or ''

    async def json(self):
        return json.loads(self.body)


class FakeSession(object):
    def __init__(self, responses: dict[str, list]):
        self.gets: list = []
        self.posts: list = []
        self.responses: dict[str, list] = responses

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc, tb):
        return

    async def post(self, url, *args, **named):
        self.posts.append(
            {
                'url': url,
                'args': args,
                'named': named,
            }
        )
        return self.request(url, *args, **named)

    async def get(self, url, *args, **named):
        self.gets.append(
            {
                'url': url,
                'args': args,
                'named': named,
            }
        )
        return self.request(url, *args, **named)

    def request(self, url, *args, **named):
        try:
            response = self.responses[url].pop(0)
        except Exception as err:
            raise RuntimeError("Test case had no response for %s", url)
        else:
            if isinstance(response, Exception):
                raise response
            elif isinstance(response, FakeResponse):
                return response
            else:
                return FakeResponse(200, response)


class FakeResponse(object):
    def __init__(self, status, content):
        self.status = status
        if not isinstance(content, bytes):
            content = json.dumps(content)
        self.content = content

    async def read(self):
        if self.content is not None:
            return self.content
        else:
            raise IOError("No body")


def FakeWatcherClass(session):
    """Create a faked-out watcher class with the given session"""
    assert session, "session cannot be null"

    class FakeWatcher(devicewatcher.DeviceWatcher):
        def url_for_path(self, path):
            return path

        async def get_session(self):
            """Test entry point, customise the returned session for test cases"""
            return session

    return FakeWatcher


def sample_config():
    return {
        'site_id': 235,
        'devices': [
            {
                'unique_key': 'example',
                'ip_address': '192.168.18.2',
                'port': 5000,
                'settings': {
                    'power': True,
                    'mute': False,
                    'volume': 50,
                    'channel': '1',
                    'station': '23605',
                    'subtitles': 'eng',
                },
            }
        ],
    }


def sample_anneal_responses():
    return {
        devicewatcher.DeviceWatcher.SLEEP_PATH: [
            {
                'success': True,
                'result': False,
            },
        ],
        devicewatcher.DeviceWatcher.STATUS_PATH: [
            {
                'success': True,
                'result': {
                    'uptime': 0.0,
                    'sleep': False,
                    'station': 23605,
                    'play_status': '',
                    'volume': 50,
                    'mute': False,
                },
            }
        ],
        devicewatcher.DeviceWatcher.CHANNEL_PLAN_PATH: [
            {
                'success': True,
                'channels': [
                    {
                        'tmsid': '176.stations.xmltv.tvmedia.ca',
                        'channel': '20',
                        'callsign': 'TCM',
                        'label': 'Turner Classic Movies USA',
                    },
                    {
                        'tmsid': '1202.stations.xmltv.tvmedia.ca',
                        'channel': '21',
                        'callsign': 'A&E',
                        'label': 'A&E US - Pacific Feed',
                    },
                    {
                        'tmsid': '1206.stations.xmltv.tvmedia.ca',
                        'channel': '22',
                        'callsign': 'DISC',
                        'label': 'Discovery Channel (US) - Pacific Feed',
                    },
                    {
                        'tmsid': '1209.stations.xmltv.tvmedia.ca',
                        'channel': '23',
                        'callsign': 'TOON',
                        'label': 'Cartoon Network USA - Pacific Feed',
                    },
                    {
                        'tmsid': '1262.stations.xmltv.tvmedia.ca',
                        'channel': '24',
                        'callsign': 'AMC-P',
                        'label': 'AMC - Pacific Feed',
                    },
                    {
                        'tmsid': '1949.stations.xmltv.tvmedia.ca',
                        'channel': '25',
                        'callsign': 'KTTV',
                        'label': 'FOX (KTTV) Los Angeles, CA',
                    },
                    {
                        'tmsid': '2013.stations.xmltv.tvmedia.ca',
                        'channel': '26',
                        'callsign': 'MUNDO',
                        'label': 'Telemundo  - Pacific Feed',
                    },
                    {
                        'tmsid': '2172.stations.xmltv.tvmedia.ca',
                        'channel': '27',
                        'callsign': 'KCBS-TV',
                        'label': 'CBS (KCBS) Los Angeles, CA',
                    },
                    {
                        'tmsid': '2589.stations.xmltv.tvmedia.ca',
                        'channel': '28',
                        'callsign': 'KABC-TV',
                        'label': 'ABC (KABC) Los Angeles, CA',
                    },
                    {
                        'tmsid': '2592.stations.xmltv.tvmedia.ca',
                        'channel': '29',
                        'callsign': 'BSSC',
                        'label': 'Bally Sports SoCal',
                    },
                ],
            }
        ],
        devicewatcher.DeviceWatcher.CHANNEL_PATH: [
            {'success': True, 'result': 23605},
        ],
        devicewatcher.DeviceWatcher.STATION_PATH: [
            {'success': True, 'result': 23605},
        ],
        devicewatcher.DeviceWatcher.VOLUME_PATH: [
            {'success': True, 'result': 50},
        ],
        devicewatcher.DeviceWatcher.CC_PATH: [
            {'success': True, 'result': False},
        ],
        devicewatcher.DeviceWatcher.MUTE_PATH: [
            {'success': True, 'result': False},
        ],
    }


@pytest.mark.asyncio
async def test_set_config():
    bar = daemon.MambaDaemon()
    sample = models.BarConfig(**sample_config())
    response = await bar.set_config(sample)
    assert response == True, response
    response = await bar.get_config()
    assert response == sample, response
    assert response.devices[0].settings.volume == 50, response.devices[0].settings


def generic_success_responses(overrides):
    base = sample_anneal_responses()
    base.update(overrides)
    return base


@pytest.mark.asyncio
async def test_identify():
    bar = daemon.MambaDaemon()
    session = FakeSession(
        generic_success_responses(
            {
                devicewatcher.DeviceWatcher.IDENTIFY_PATH: [
                    FakeResponse(
                        500, {"error": True, "messages": ["Internal failure"]}
                    ),
                    TimeoutError(),
                    {'success': True, 'result': 'example'},
                ],
            }
        )
    )
    bar.watcherClass = FakeWatcherClass(session)

    sample = models.BarConfig(**sample_config())
    await bar.ensure_watchers(sample)
    assert 'example' in bar.devices, bar.devices
    watcher = bar.devices['example']
    await asyncio.sleep(0.02)  # bleh
    assert watcher.wanted, "did not start the watcher"
    successes = await bar.identify(['moo'], 'zoo', retries=1, delay=0.01)
    assert not successes, successes
    successes = await bar.identify(['example'], 'zoo')
    assert successes, successes
    assert successes[0].unique_key == 'example'
    assert successes[0].success == True, successes[0]


@pytest.mark.asyncio
async def test_watcher_start():
    """Test that watcher start and stop setup tasks"""
    sample = models.BarConfig(**sample_config())
    config = sample.devices[0]
    watcher = devicewatcher.DeviceWatcher(config, None)
    await watcher.start()
    assert watcher.tasks
    await watcher.stop()
    assert not watcher.tasks


@pytest.mark.asyncio
async def test_get_url():
    sample = models.BarConfig(**sample_config())
    config = sample.devices[0]
    watcher = devicewatcher.DeviceWatcher(config, None)
    url = watcher.url_for_path(watcher.IDENTIFY_PATH)
    assert url.startswith('http://192.168.18.2:5000')
    assert url.endswith(watcher.IDENTIFY_PATH)


@pytest.mark.asyncio
async def test_handle_no_body():
    session = FakeSession(
        generic_success_responses(
            {
                devicewatcher.DeviceWatcher.IDENTIFY_PATH: [
                    FakeResponse(500, None),
                ],
            }
        )
    )
    sample = models.BarConfig(**sample_config())
    config = sample.devices[0]
    watcher = FakeWatcherClass(session)(config, None)
    try:
        await watcher.client_post(watcher.IDENTIFY_PATH, None, session)
    except RuntimeError:
        pass
    else:
        assert False, "Should have raised an error with a 500"


@pytest.mark.asyncio
async def test_update_config():
    sample = models.BarConfig(**sample_config())
    config = sample.devices[0]
    watcher = devicewatcher.DeviceWatcher(config, None)
    second: models.DeviceConfig = config.model_copy()
    second.settings.power = False
    await watcher.update_config(second)
    updated = watcher.updated_queue.get_nowait()
    assert updated == second


@pytest.mark.asyncio
async def test_anneal_full():
    session = FakeSession(sample_anneal_responses())
    sample = models.BarConfig(**sample_config())
    config = sample.devices[0]
    watcher = FakeWatcherClass(session)(config, None)
    assert await watcher.anneal_config()


@pytest.mark.asyncio
async def test_anneal_call_failure():
    responses = sample_anneal_responses()
    responses[devicewatcher.DeviceWatcher.MUTE_PATH][0]['success'] = False
    session = FakeSession(responses)
    sample = models.BarConfig(**sample_config())
    config = sample.devices[0]
    watcher = FakeWatcherClass(session)(config, None)
    assert not await watcher.anneal_config()
    assert 'mute' in watcher.errors, watcher.errors
    assert watcher.errors['mute'], watcher.errors


@pytest.mark.asyncio
async def test_anneal_full():
    sample = models.BarConfig(**sample_config())
    settings: models.DeviceConfig = sample.devices[0].settings
    assert settings.power == True
    assert settings.sleep == False
    settings.sleep = True
    assert settings.power == False


@pytest.mark.asyncio
async def test_settings_validate():
    pytest.raises(pydantic.ValidationError, models.DeviceIdentity, ip_address='22')
    pytest.raises(pydantic.ValidationError, models.DeviceIdentity, port=1023)
    pytest.raises(pydantic.ValidationError, models.DeviceIdentity, port=65536)
    pytest.raises(pydantic.ValidationError, models.DeviceIdentity, port=4320)
    pytest.raises(pydantic.ValidationError, models.DeviceIdentity, port=4321)
    pytest.raises(pydantic.ValidationError, models.DeviceIdentity, port=9222)
    pytest.raises(pydantic.ValidationError, models.DeviceIdentity, port=11800)


@pytest.mark.asyncio
async def test_epg_gets_channels():
    with with_tempdir() as tmp:
        d = daemon.MambaDaemon(storage=tmp)
        device = devicewatcher.DeviceWatcher(
            models.DeviceConfig(
                unique_key='TESTDEVICE',
                ip_address='192.168.18.2',
                port=5000,
                name='Top TV',
                settings=models.DeviceSettings(
                    power=True,
                    channel='43',
                ),
            ),
            service=d,
        )
        d.devices['TESTDEVICE'] = device
        device.channel_plan = models.ChannelPlan(
            channels=[
                models.ChannelSummary(
                    tmsid='1262.stations.xmltv.tvmedia.ca',
                    channel='43',
                    label='AMC - Pacific Feed',
                    callsign='AMC-P',
                )
            ]
        )
        # TODO: is the epg.json currently shared or not?
        await d.on_channel_plan(device.channel_plan)
        data = await d.current_epg_data()
        channels = data.get('channels')
        assert channels
        assert len(channels) == 1
        chan = channels[0]
        assert chan == {
            'channel': '43',
            'station': [
                "1262.stations.xmltv.tvmedia.ca",
                "AMC - Pacific Feed",
                "AMC-P",
                "",
                "",
                "",
                "http://cdn.tvpassport.com/image/station/256x144/v2/s10021_h15_ab.png",
            ],
        }
