import typing, logging, os, time, json
import asyncio
from .devicewatcher import DeviceWatcher
from .models import BarConfig
from . import daemonapi, models
import json, gzip, time

log = logging.getLogger(__name__)

DEFAULT_EPG_CACHE = "/var/firmware/run/epg.json.gz"
DEFAULT_CHANNEL_CACHE = "/var/firmware/run/channels.json"
DEFAULT_STORAGE = "/var/firmware/run/mamba-config.json"
DEFAULT_EPG_RENEWAL_INTERVAL_HOUR = 1 * 24
CHANNEL_SYNC_INTERVAL = int(os.environ.get("CHANNEL_SYNC_INTERVAL", 30 * 60))


async def run_process(command, input=None):
    """Run a process, wait for it to return"""
    if isinstance(command, list):
        # log.debug('> %s', command if isinstance(command, str) else ' '.join(command))
        proc = await asyncio.create_subprocess_exec(
            *command,
            stdin=asyncio.subprocess.PIPE if input else None,
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
        )
    else:
        # log.debug('S> %s', command if isinstance(command, str) else ' '.join(command))
        proc = await asyncio.create_subprocess_shell(
            command,
            stdin=asyncio.subprocess.PIPE if input else None,
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
        )
    try:
        output, err = await proc.communicate(input)
    except asyncio.CancelledError as err:
        if proc.returncode is None:
            proc.kill()
        raise
    if proc.returncode:
        log.warning(
            "Error on %s:\n%s",
            " ".join(command) if isinstance(command, list) else command,
            err.decode("utf-8"),
        )
        raise RuntimeError(proc.returncode, command, output, err)
    else:
        # if somehow communicate exited without finishing...
        if proc.returncode is None:
            proc.kill()
    return output, err


ChannelList = typing.Optional[typing.List[models.ChannelSummary]]


class MambaDaemon(object):
    """Models the entire service that controls the Bar application"""

    watcherClass: type = DeviceWatcher

    def __init__(
        self,
        storage: str = DEFAULT_STORAGE,
        epg_cache: str = DEFAULT_EPG_CACHE,
        offline_mode: bool = False,
        epg_renewal_interval: float = DEFAULT_EPG_RENEWAL_INTERVAL_HOUR * 3600,
        log_level: str = "INFO",
        channel_cache: str = DEFAULT_CHANNEL_CACHE,
    ):
        self.config_queue: asyncio.Queue = asyncio.Queue()
        self.devices: typing.Dict[str, DeviceWatcher] = {}
        self.tasks: typing.List[asyncio.Task] = []
        self.config: typing.Optional[typing.Dict] = None
        self.wanted: bool = False
        self.epg_ts: float = 0
        self.epg_data: typing.Optional[dict] = None
        self.epg_cache: str = epg_cache
        self.channel_cache: str = channel_cache
        self.offline_mode: bool = offline_mode
        self.epg_renewal_interval: int = epg_renewal_interval
        # our current (most recent) channel plan
        self.current_channel_plan: typing.Optional[models.ChannelPlan] = None
        # channel plan as of rendering of current epg data
        self.epg_channel_plan: typing.Optional[models.ChannelPlan] = None
        self.storage: str = storage
        self.log_level: str = log_level

    async def start(self):
        """Run the service starting the API and waiting for a configuration"""
        try:
            self.wanted = True
            log.info("Starting the bar app")
            config = None
            if os.path.exists(self.storage):
                import json

                try:
                    log.info("Loading previous configuration from %s", self.storage)
                    config = BarConfig(**json.loads(open(self.storage).read()))
                except Exception as err:
                    log.exception("Error loading previously set configuration")
                else:
                    log.info("Loaded configuration: %s", config)
                    await self.set_config(config)
            if os.path.exists(self.channel_cache):
                self.current_channel_plan = await self.load_channel_plan()
            self.tasks.append(
                asyncio.create_task(
                    self.process_config_updates(), name="process_config_updates"
                )
            )
            self.tasks.append(
                daemonapi.create_web_task(
                    await self.create_app(),
                    port=8028,
                    logging_level=self.log_level,
                )
            )
            if os.environ.get("DEMO_EPGDATA"):  # wrong flag here...
                self.tasks.append(
                    asyncio.create_task(
                        self.periodic_channel_sync(), name="periodic-channel-sync"
                    )
                )
                log.info("Creating epg puller")
                self.tasks.append(
                    asyncio.create_task(
                        self.periodic_epg_data(), name="periodic_epg_data"
                    )
                )
        except Exception as err:
            log.exception("Crashed during startup")
            await asyncio.sleep(30)
            asyncio.get_event_loop().stop()

    async def stop(self):
        """Stop the service, shutting down all background tasks and device watchers"""
        self.wanted = False
        for task in self.tasks:
            task.cancel()

    async def periodic_channel_sync(self):
        """Periodically sync active VSBB devices by performing a channel up and down"""
        while self.wanted:
            for key, watcher in self.devices.items():
                if (
                    watcher.last_status
                    and watcher.last_status.play_status != "INACTIVE"
                ):
                    try:
                        await watcher.channel_step("up")
                        log.info(f"Channel up executed for {watcher.config.unique_key}")
                        await watcher.channel_step("down")
                        log.info(
                            f"Channel down executed for {watcher.config.unique_key}"
                        )
                    except Exception as err:
                        log.error(
                            f"Error in periodic channel sync for {watcher.config.unique_key}: {err}"
                        )

            await asyncio.sleep(CHANNEL_SYNC_INTERVAL)

    async def process_config_updates(self):
        """Given updates to the overall queue, start/stop individual device watchers"""
        last_applied: typing.Optional[BarConfig] = None
        while self.wanted:
            try:
                update: typing.Optional[BarConfig] = await self.config_queue.get()
                if not update:
                    continue
                if update == last_applied:
                    log.info("No change in the update")
                    continue
                log.info("Configuration updated")
                await self.ensure_watchers(update)
                last_applied = update
            except Exception as err:
                log.exception("Failure during process whole-daemon config updates")
                await asyncio.sleep(5)

    async def ensure_watchers(self, update: BarConfig):
        """Ensure that all configured devices have watchers"""
        seenWatchers = {}
        oldWatchers = self.devices
        for device in update.devices:
            key = device.unique_key
            if not key:
                log.error("No unique key in device %s", device)
                continue
            watcher = self.devices.get(key)
            if watcher is None:
                watcher = self.watcherClass(device, self)
                self.tasks.append(
                    asyncio.create_task(watcher.start(), name=f"start-watcher-{key}")
                )
            else:
                await watcher.update_config(device)
            seenWatchers[key] = watcher
        self.devices = seenWatchers
        for key, device in oldWatchers.items():
            if key not in seenWatchers:
                log.info("Shutting down client for %s", device)
                await device.stop()

    async def create_app(self):
        """Create the web API

        Note that all this really does is tell the web api that
        this particular daemon is *the* daemon to which to forward
        requests...
        """
        daemonapi.app.daemon = self
        return daemonapi.app

    async def set_config(self, config: BarConfig) -> bool:
        """Set the config on the daemon and arrange to apply it to the devices"""
        self.config = config

        encoded = config.model_dump_json(indent=2)
        tmp = self.storage + "~"
        with open(tmp, "w") as fh:
            fh.write(encoded)
        os.rename(tmp, self.storage)

        await self.config_queue.put(config)
        return True

    async def get_config(self) -> typing.Optional[BarConfig]:
        """Get the config on the daemon"""
        return self.config

    async def identify(
        self,
        keys: typing.Optional[typing.List[str]],
        message: typing.Optional[str],
        duration: typing.Optional[float] = 10.0,
        style: typing.Optional[str] = "",
        background: typing.Optional[str] = "",
        retries: int = 5,
        delay: float = 1.0,
    ) -> typing.List[models.IdentifyResponse]:
        """Request that devices with unique_key == key (or all devices if omitted) identify

        Note: this is a long lived request that tries each device, returning
        the set of devices that succeeded, as this operation is intended to
        provide an immediate on-screen display, not an "eventual consistency"
        operation.

        returns list of unique_keys which were successfully set
        """
        children = []
        log.info("Identify Keys: %s Message: %s", keys, message)
        for key, watcher in self.devices.items():
            if keys is None or watcher.config.unique_key in keys:
                children.append(watcher)
            elif keys:
                log.info(
                    "Not sending identify to watcher %s", watcher.config.unique_key
                )

        responses = await asyncio.gather(
            *[
                watcher.identify(
                    message or watcher.config.name,
                    duration=duration,
                    retries=retries,
                    delay=delay,
                    style=style,
                    background=background,
                )
                for watcher in children
            ]
        )
        return list(responses)

    async def get_epg(self):
        """Get a (cached) subset of the epg data"""
        if (
            self.epg_ts > time.time() - 3600
            and self.epg_data is not None
            and self.epg_channel_plan == self.current_channel_plan
        ):
            return self.epg_data
        self.epg_data = await self.current_epg_data(
            channel_plan=self.current_channel_plan
        )

        self.epg_ts = time.time()
        return self.epg_data

    async def set_epg(self, data: dict):
        """Save the epg data received in a cache file"""
        log.info("set_epg on the daemon (live dataserver)")
        try:
            epg_data = models.EPGData(**data)
        except Exception as err:
            log.error("Failed to parse the passed epgdata: %s", err)
            return {"error": True, "messages": [str(err)]}

        try:
            with gzip.open(self.epg_cache, "wt", encoding="utf-8") as f:
                data = epg_data.model_dump()
                self.epg_data = data
                json.dump(data, f)
                self.epg_ts = time.time()
                # self.tasks.append(
                #     asyncio.create_task(
                #         self.push_current_epg_data(data),
                #         name="push-epgdata-on-set",
                #     )
                # )
                self.config_queue.put_nowait(self.config)
        except Exception as err:
            log.exception("Failure in saving epg data")
            return {"error": True, "messages": [str(err)]}

        return {"success": True, "error": False}

    async def channel_plan(self) -> typing.Optional[models.ChannelPlan]:
        """Get a channel plan (assumes there's just one across all devices)"""
        return self.current_channel_plan

    async def set_channels(self, data: dict):
        """Save the channel plan data received in a cache file"""
        log.info("set_channels on the daemon (live dataserver)")
        try:
            channelplan_data = models.ChannelPlan(**data)
            await self.on_channel_plan(channelplan_data)
        except Exception as err:
            log.exception("Failed to parse the passed channel plan")
            return {"error": True, "messages": [str(err)]}

        return {"success": True, "error": False}

    async def on_channel_plan(self, channel_plan: models.ChannelPlan):
        """Handle update from watcher where a new channel plan was download"""
        log.info("Daemon got a new channel plan")
        if channel_plan != self.current_channel_plan:
            await self.save_channel_plan(channel_plan)
            self.current_channel_plan = channel_plan
        else:
            log.info("No changes to the channel plan")

    async def load_channel_plan(self) -> typing.Optional[models.ChannelPlan]:
        """Load channel plan from disk if available"""
        if os.path.exists(self.channel_cache):
            log.warning("Loading channel plan from %s", self.channel_cache)
            try:
                with open(self.channel_cache) as fh:
                    content = fh.read()
                return models.ChannelPlan(**json.loads(content))
            except Exception as err:
                log.exception("Failure reading channel cache")
                return None
        else:
            log.warning("No existing channel plan in %s", self.channel_cache)

    async def save_channel_plan(self, channel_plan: models.ChannelPlan):
        """Save channel plan to disk on update from VSBB"""
        with open(self.channel_cache, "w") as fh:
            log.info("Caching channel plan in %s", self.channel_cache)
            fh.write(channel_plan.model_dump_json(indent=2))

    async def periodic_epg_data(self):
        while self.wanted:
            try:
                if self.config and self.config.epgfetch == "":
                    period = await self.renew_cached_epg_data(self.epg_renewal_interval)
                else:
                    log.info(
                        "No config, or config isn't using demo epgdata (empty epgfetch)"
                    )
                    period = 60
            except Exception as err:
                log.exception("Failure during pull")
                period = 1800
            await asyncio.sleep(period)

    async def pull_demo_epg_data(self):
        log.warning("Attempting to pull the demo epg data")
        import tempfile, shutil

        log.warning("DEMO EPGDATA PULLING RUNNING")

        datafile = self.epg_cache
        if os.path.exists(datafile):
            age = time.time() - os.stat(datafile).st_mtime
            if age < (3600 * 23):
                return (3600 * 24) - age

        tmp = tempfile.mkdtemp(dir="/var/firmware/run")
        failed = False

        for command in [
            [
                "epgfetch-ftp-client",
                "--target",
                tmp,
                "--user",
                "atx_test",
                "--password",
                "E2F39EC2a7",
                "ftp://ftp.tvmedia.ca/",
            ],
            [
                "epgfetch-convert-dsi",
                "-o",
                os.path.join(tmp, "epg.json"),
                os.path.join(tmp, "xmltv.xml"),
            ],
            ["gzip", os.path.join(tmp, "epg.json")],
        ]:
            try:
                start = time.time()
                log.warning("demo-epg-pull: %s", " ".join(command))
                await run_process(
                    command,
                )
                stop = time.time()
                log.debug("%0.1fs on %s", stop - start, " ".join(command))
            except RuntimeError as err:
                failed = True
                break
            except Exception as err:
                failed = True
                break
        if not failed:
            log.info("demo-epg-data pulled")
            os.rename(os.path.join(tmp, "epg.json.gz"), datafile)
            self.epg_data = None
            await self.get_epg()
        shutil.rmtree(tmp, ignore_errors=True)
        retry = 3600 * 24 if not failed else 1800
        return retry

    async def renew_cached_epg_data(self, offset_sec):
        log.warning(
            "renew_cached_epg: Attempting to renew the cached demo epg data timestamps"
        )
        import tempfile, shutil

        # If the epg cache file doesn't exist, then this means this is the first time we're running
        # the demo and we'll have to pull the epg data from the remote upstream server
        # NOTE: You have to be connected to the internet to do this. Once the epg has been downloaded,
        # the daemon can run in an offline mode forever with the "--offline-mode" option passed in
        datafile = self.epg_cache
        if not os.path.exists(datafile):
            return 60  # Waits till the EPG data is pushed from the MambaManager

        def to_datetime(epoch_time):
            return time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(epoch_time))

        datafile_time = os.stat(datafile).st_mtime
        now_time = time.time()
        diff_time = now_time - datafile_time
        # Only renew if the epg cache is older that the desired offset
        log.debug(
            "%s last update time: %s, time now: %s, diff: %s hours, desired offset: %s hours",
            datafile,
            to_datetime(datafile_time),
            to_datetime(now_time),
            diff_time / 3600,
            offset_sec / 3600,
        )
        retry_period = min(offset_sec, 1800)
        if diff_time < offset_sec:
            log.info(
                "new_cached_epg: EPG data valid, skipping renewal. Will retry in %ss",
                retry_period,
            )
            return retry_period

        # This ensures that the offset is properly calculated when the box has been powered off
        # longer than the defaul offset time after the epg data was pulled
        offset_sec = int(diff_time / 3600)
        offset_sec = offset_sec * 3600

        # No need to update if the offset is zero
        if offset_sec <= 0:
            return retry_period

        data = json.loads(gzip.open(datafile).read())

        # Iterate over schedules and renew the timestamps by adding the passed in offset
        new_schedules = []
        oldest_start = time.time()
        for station, program_id, start, end in data["schedules"]:
            # log.debug("new_cached_epg: EPG cache - station: %s program_id: %s start: %s end: %s", station, program_id, to_datetime(start), to_datetime(end))
            new_schedules.append(
                [station, program_id, start + offset_sec, end + offset_sec]
            )

        data["schedules"] = new_schedules
        tmp = tempfile.mkdtemp(dir="/var/firmware/run")
        epg_file = os.path.join(tmp, "epg.json")

        # Write back to json
        with open(epg_file, "w") as f:
            json.dump(data, f)

        failed = False

        for command in [
            ["gzip", epg_file],
        ]:
            try:
                start = time.time()
                log.info("new_cached_epg: %s", " ".join(command))
                await run_process(
                    command,
                )
                stop = time.time()
                log.debug("%0.1fs on %s", stop - start, " ".join(command))
            except RuntimeError as err:
                failed = True
                break
            except Exception as err:
                failed = True
                break
        if not failed:
            log.info("epg info renewed")
            os.rename(os.path.join(tmp, "epg.json.gz"), self.epg_cache)
            self.epg_data = None
            await self.epg()
        shutil.rmtree(tmp, ignore_errors=True)
        retry = offset_sec if not failed else 1800
        return retry

    async def current_epg_data(
        self, channel_plan: typing.Optional[models.ChannelPlan] = None
    ):
        """Get a current subset of the static epg data-set"""

        import json, gzip, time

        channel_plan = channel_plan or await self.channel_plan()
        filter: typing.Dict[str, str] = {}
        if channel_plan is not None:
            for channel in channel_plan.channels:
                filter[channel.tmsid] = channel.channel
        self.epg_channel_plan = channel_plan

        if not os.path.exists(self.epg_cache):
            log.info("No cache file is present, returning empty epg data...")
            content = {
                "stations": [],
                "schedules": [],
                "programs": [],
                "channels": [],
            }
        else:
            content = json.loads(gzip.open(self.epg_cache).read())
        result = content.copy()

        filtered_channels = []
        if filter:
            filtered_stations = [
                station for station in result["stations"] if station[0] in filter
            ]
            for station in filtered_stations:
                filtered_channels.append(
                    {
                        "channel": filter[station[0]],
                        "station": station,
                    }
                )
            filtered_channels.sort(
                key=lambda x: tuple(
                    [int(x) for x in x["channel"].replace("-.", ".").split(".")]
                ),
            )
        else:
            filtered_stations = result["stations"]

        programs = set()
        schedules = []
        min_end = time.time() - 3600
        max_start = min_end + 3600 * 25
        for schedule in result["schedules"]:
            station, program, start, end = schedule
            if filter and station not in filter:
                continue
            if end > min_end and start < max_start:
                programs.add(program)
                schedules.append(schedule)
        result["schedules"] = schedules
        result["programs"] = [
            program for program in result["programs"] if program[0] in programs
        ]
        try:
            await self.cache_icons(filtered_stations)
            result["stations"] = filtered_stations
            result["channels"] = filtered_channels
        except Exception as err:
            log.exception("Failed to cache all icons")
            raise
        return result

    # async def push_current_epg_data(self, epg_data):
    #     """Push out epg data for each client"""
    #     import json

    #     converted = await self.convert_epg_data(epg_data)
    #     for key, watcher in list(self.devices.items()):
    #         try:
    #             await watcher.push_epg(converted)
    #         except Exception as err:
    #             log.warning("Unable to update the EPG on %s", watcher)

    async def convert_epg_data(self, epg_data):
        """Convert the epg data into a format that can be sent directly to the devices

        It is *crazy* how much easier it was to have the Python do this than to
        have the golang process the json as-is... just astonishly lopsided ease of use.
        """
        return epg_data

    ICON_URL = "/icon-cache/"
    ICON_PATH = "/var/firmware/run/icon-cache"

    async def cache_icons(self, stations):
        """Cache the icons from the station set to local directory"""
        from aiohttp import client

        async with client.ClientSession() as session:
            rewritten_stations = []
            for station in stations:
                rewritten_stations.append(await self.pull_icon(station, session))
            return rewritten_stations

    async def pull_icon(self, station, session):
        """Pull a single station's icon into our cache"""
        if not os.path.exists(self.ICON_PATH):
            os.makedirs(self.ICON_PATH)

        try:
            # icon_url = station[6] -->> station[6] is not available, repalced with stattion[5].
            # Hardcode for testing
            icon_url = (
                "http://cdn.tvpassport.com/image/station/256x144/v2/s12852_h15_ab.png"
            )
        except KeyError:
            if not icon_url:
                log.warning("No icon url index[6] for %s", station)
                return station
            return station
        else:
            if not icon_url:
                log.warning("No icon url for %s", station)
                return station
            # okay, so this *should* be a real icon at this point...
            from urllib import parse

            base = parse.urlparse(icon_url)
            filename = os.path.basename(base.path)
            if not filename:
                log.warning("No filename found in url %s", icon_url)
                return station
            tmsid = station[0]
            expected_file = "%s.%s" % (tmsid, filename)
            expected_url = os.path.join(self.ICON_URL, expected_file)
            expected_path = os.path.join(self.ICON_PATH, expected_file)
            if not os.path.exists(expected_path):
                log.info("cache %s => %s", icon_url, expected_path)
                response = await session.get(icon_url)
                if response.status == 200:
                    content = await response.read()
                    with open(expected_path, "wb") as fh:
                        fh.write(content)
                else:
                    log.error("Unable to cache the icon url %s", icon_url)
                    return station
            else:
                log.info("Already cached: %s", expected_file)
            station = station[:]
            # station[6] = expected_url -->> station[6] is not available, repalced with stattion[5].
            station[5] = expected_url
            return station

    async def get_status(self):
        """Get our current operational status"""
        device_statuses: models.DeviceStatusReport = []
        for key, watcher in list(self.devices.items()):
            device_statuses.append(await watcher.status_report())
        return models.BarStatusReport(
            errored=any([d.errors for d in device_statuses]),
            devices=device_statuses,
        )


def get_options():
    import argparse

    parser = argparse.ArgumentParser(
        description="Daemon providing backend operations for the mamba bar application (define DEMO_EPGDATA for demo data support)"
    )
    parser.add_argument(
        "-r",
        "--redis",
        default=os.environ.get("REDIS_URL", "redis://localhost"),
        help="Override the default redis db",
    )
    parser.add_argument(
        "-l",
        "--listen",
        default="127.0.0.1",
        help="Host on which to listen, default lo, use 0.0.0.0 for a public host",
    )
    parser.add_argument(
        "--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
    )
    parser.add_argument(
        "--config",
        default=DEFAULT_STORAGE,
        help="Configuration storage location (BarConfig) %s" % (DEFAULT_STORAGE,),
    )
    parser.add_argument(
        "--epg-cache",
        default=DEFAULT_EPG_CACHE,
        help="EPG data cache-file to load for data sub-setting %s"
        % (DEFAULT_EPG_CACHE,),
    )
    parser.add_argument(
        "--channel-cache",
        default=DEFAULT_CHANNEL_CACHE,
        help="VSBB Channel play cache-file to store/load %s" % (DEFAULT_CHANNEL_CACHE,),
    )
    parser.add_argument(
        "-o",
        "--offline-mode",
        action="store_true",
        default=False,
        help="In offline mode the cached epg data will be renewed and recycled",
    )
    parser.add_argument(
        "-i",
        "--renewal-interval",
        type=int,
        default=DEFAULT_EPG_RENEWAL_INTERVAL_HOUR * 3600,
        help="Number of seconds after which the epg data must be renewed in offline mode, default (24hrs in seconds)",
    )
    return parser


def main():
    options = get_options().parse_args()
    print(options)
    log_level = getattr(logging, options.log_level)
    # these get discarded by the fastapi setup...
    logging.basicConfig(level=log_level)
    # log.setLevel(logging.INFO)

    service = MambaDaemon(
        epg_cache=options.epg_cache,
        storage=options.config,
        offline_mode=options.offline_mode,
        epg_renewal_interval=options.renewal_interval,
        log_level=options.log_level,
        channel_cache=options.channel_cache,
    )
    loop = asyncio.get_event_loop()
    try:
        asyncio.ensure_future(service.start())
        try:
            loop.run_forever()
        except (KeyboardInterrupt, SystemExit) as err:
            if log_level > logging.INFO:
                log.warning("Interrupt during mainloop: %s", err)
            else:
                log.exception("Interrupt during mainloop: %s", err)
    finally:
        log.info("Shutting down asyncio")
        logging.getLogger("asyncio").setLevel(logging.CRITICAL)
        loop.run_until_complete(loop.shutdown_asyncgens())
        loop.close()


if __name__ == "__main__":
    main()
