#!/usr/bin/python3

# Copyright 2014..2025, Martin <debacle@debian.org>
# License: AGPL-3+

# Python standard modules
import argparse
import asyncio
import collections
import configparser
import email.mime.text
import email.utils
import hashlib
import html
import logging
import os
import pathlib
import smtplib
import socket
import subprocess
import sys
import textwrap

# additional modules
import apt
import prettytable
import slixmpp

longname = "Pain in the APT"
shortname = "painintheapt"
version = "0.20251121"

columns = ["Name", "Installed", "Candidate"]
Package = collections.namedtuple("Package", " ".join(columns).lower())


def getargs():
    ap = argparse.ArgumentParser(
        description="Pester people about available package updates by email or jabber.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    ap.add_argument(
        "-c",
        "--configfile",
        default="/etc/%s.conf" % shortname,
        help="configuration file",
    )
    ap.add_argument(
        "-d",
        "--debug",
        default=False,
        action="store_true",
        help="print debug output to stderr",
    )
    ap.add_argument(
        "-f",
        "--force",
        default=False,
        action="store_true",
        help="send message, even if updates did not change",
    )
    ap.add_argument(
        "-s",
        "--stampfile",
        help="stamp file",
        default="/var/lib/%s/stamp" % shortname,
    )
    ap.add_argument(
        "-t",
        "--testmessage",
        default=False,
        action="store_true",
        help="send a test message only",
    )
    ap.add_argument("-v", "--version", action="version", version="%(prog)s " + version)
    return ap.parse_args()


def update():
    """Create the APT cache and update it.

    Return the cache and a list of updates.
    """
    updates = []
    cache = apt.Cache()
    cache.update()
    cache.open()
    cache.upgrade(dist_upgrade=True)
    changes = cache.get_changes()
    for c in changes:
        name = c._pkg.name
        pkg = cache[name]
        installed = pkg.installed.version if pkg.installed else "-"
        candidate = pkg.candidate.version if pkg.candidate else "-"
        updates.append(Package(name, installed, candidate))
    return cache, updates


def wrap(text, maxwid):
    """Fill paragraph."""
    return "\n".join(textwrap.wrap(text, maxwid))


_changes = None


def get_changelogs(cache, send_changes):
    """Download changelogs. Beware: This is very slow.

    Identical changelogs for different binary packages are combined.
    """
    global _changes
    if cache is None or send_changes is not True:
        return ""
    if _changes:
        return _changes
    changelogs = collections.defaultdict(list)
    changes = cache.get_changes()
    for c in changes:
        name = c._pkg.name
        changelog = cache[name].get_changelog().strip()
        changelogs[changelog].append(name)
    # now do some very fancy formatting
    maxwid = 79
    _changes = ("\n" + "-" * maxwid + "\n").join(
        sorted(
            [
                wrap(", ".join(sorted(names)), maxwid) + ":\n\n" + changelog
                for changelog, names in changelogs.items()
            ]
        )
    )
    return _changes


def maketable(lst):
    """Create a pretty table of package updates."""
    table = prettytable.PrettyTable(columns)
    table.sortby = columns[0]
    table.align = "l"
    maxwid = 23
    for element in lst:
        table.add_row(
            [
                wrap(element.name, maxwid),
                wrap(element.installed, maxwid),
                wrap(element.candidate, maxwid),
            ]
        )
    return table.get_string()


class JabberBot(slixmpp.ClientXMPP):
    def __init__(
        self,
        jid,
        password,
        to,
        room,
        pubsub_service,
        pubsub_node,
        nick,
        subject,
        table,
        changes,
        avatar,
        host,
        port,
    ):
        slixmpp.ClientXMPP.__init__(self, jid, password)
        self.to = to
        self.room = room
        self.pubsub_service = pubsub_service
        self.pubsub_node = pubsub_node
        self.nick = nick
        self.subject = subject
        self.table = table
        self.changes = changes
        self.avatar = avatar
        self.host = host or None
        self.port = int(port) if port else None

        self.register_plugin("xep_0030")  # Service Discovery
        if room:
            self.register_plugin("xep_0045")  # Multi-User Chat
        if pubsub_service and pubsub_node:
            self.register_plugin("xep_0060")  # Publish-Subscribe
        self.register_plugin("xep_0084")  # User Avatar
        self.register_plugin("xep_0153")  # vCard-Based Avatars
        self.register_plugin("xep_0199")  # XMPP ping
        self.add_event_handler("session_start", self.start)
        self.connect(host=self.host, port=self.port)

    async def start(self, event):
        self.send_presence()
        if self.avatar:
            await self.set_avatar()
        pre = "```"
        table = pre + self.table + pre + "\n" if len(self.table) else ""
        for to in self.to:
            self.send_message(
                mto=to,
                msubject=self.subject,
                # subject is not shown by all clients, better add it to body
                mbody="\n".join([self.subject, table, self.changes]),
                mtype="chat",
            )
        if self.room:
            self.plugin["xep_0045"].join_muc(self.room, self.nick)
            self.send_message(
                mto=self.room,
                # no per message subject in groupchats, add it to message body
                mbody="\n".join([self.subject, table, self.changes]),
                mtype="groupchat",
            )
        if self.pubsub_service and self.pubsub_node:
            table = (
                '<pre xmlns="http://www.w3.org/1999/xhtml">'
                + html.escape(self.table)
                + "</pre>"
                if len(self.table)
                else ""
            )
            payload = (
                '<entry xmlns="http://www.w3.org/2005/Atom"><title>'
                + html.escape(self.subject)
                + '</title><content type="xhtml"><div>'
                + table
                + "<p>"
                + html.escape(self.changes).replace("\n", "</p>\n<p>").replace(" ", "&#160;")
                + "</p></div></content></entry>"
            )
            self["xep_0060"].publish(
                self.pubsub_service,
                self.pubsub_node,
                payload=slixmpp.xmlstream.ET.fromstring(payload),
            )
        self.disconnect(wait=True)

    async def set_avatar(self) -> None:
        avatar: bytes = b""
        try:
            with open(self.avatar, "rb") as f:
                avatar = f.read()
        except IOError:
            logging.error(f"Could not find or open {self.avatar}.")
            return

        avatar_type: str = f"image/{pathlib.Path(self.avatar).suffix.lstrip('.')}"

        result: slixmpp.exceptions.XMPPError = await self["xep_0153"].set_avatar(
            avatar=avatar, mtype=avatar_type
        )
        if isinstance(result, slixmpp.exceptions.XMPPError):
            logging.error("Could not set vCard avatar.")

        result = await self["xep_0084"].publish_avatar(avatar)
        if isinstance(result, slixmpp.exceptions.XMPPError):
            logging.error("Could not publish User Avatar.")
        else:
            result = await self["xep_0084"].publish_avatar_metadata(
                [
                    {
                        "id": self["xep_0084"].generate_id(avatar),
                        "type": avatar_type,
                        "bytes": len(avatar),
                    }
                ]
            )
            if isinstance(result, slixmpp.exceptions.XMPPError):
                logging.error("Could not publish User Avatar metadata.")


def read_password(config, config_dir):
    password_file = config.get("password_file", "").strip()
    if len(password_file):
        filename = os.path.join(config_dir, password_file)
        with open(filename) as f:
            return f.read().strip()

    print("password deprecated, use password_file instead", file=sys.stderr)
    return config.get("password", "")


def sendxmpp(config, config_dir, table, count, host, debug, changes):
    """Send message to a jabber conference room."""
    jid = config.get("jid", "")
    password = read_password(config, config_dir)
    to = config.get("to", "").split(",")
    room = config.get("room")
    pubsub_service = config.get("pubsub_service", "").strip()
    pubsub_node = config.get("pubsub_node", "").strip()
    subject = "%d package update(s) for %s" % (count, host)
    avatar = config.get("avatar", "")
    host = config.get("host", None)
    port = config.get("port", None)
    xmpp = JabberBot(
        jid,
        password,
        to,
        room,
        pubsub_service,
        pubsub_node,
        longname,
        subject,
        table,
        changes,
        avatar,
        host,
        port,
    )

    xmpp.loop.run_until_complete(xmpp.disconnected)
    for task in asyncio.all_tasks(loop=xmpp.loop):
        task.cancel()


def sendsmtp(config, config_dir, table, count, host, debug, changes):
    """Send email by SMTP to whomsoever it may concern."""
    server = config.get("server", "localhost")
    port = config.getint("port", 25)
    username = config.get("username", "")
    password = read_password(config, config_dir)
    from_ = config.get("from", username)
    to = config.get("to", username)
    cc = config.get("cc", "")

    msg = email.mime.text.MIMEText("\n\n".join([table, changes]).strip(), "plain", "utf-8")
    msg["From"] = from_
    msg["To"] = to
    msg["Subject"] = "%d package update(s) for %s" % (count, host)
    msg["X-Mailer"] = longname

    if cc:
        msg["Cc"] = cc

    s = smtplib.SMTP(host=server, port=port)
    if debug:
        s.set_debuglevel(True)
    s.starttls()
    s.ehlo_or_helo_if_needed()
    if username or password:
        s.login(username, password)
    recipients = [r[1] for r in email.utils.getaddresses([to + "," + cc])]
    s.sendmail(from_, list(set(recipients)), msg.as_string())
    s.quit()


def sendmailx(config, config_dir, table, count, host, debug, changes):
    """Send email by mailx to whomsoever it may concern."""
    cmd = [
        "/usr/bin/mailx",
        "-r",
        config.get("from", "root"),
        "-s",
        "%d package update(s) for %s" % (count, host),
        "-a",
        "X-Mailer: " + longname,
    ]
    cc = config.get("cc", "")
    if cc:
        cmd += ["-c", cc]
    # this is taken from apticron
    if os.path.realpath("/usr/bin/mailx") == "/usr/bin/heirloom-mailx":
        cmd += ["-S", "ttycharset=utf-8"]
    else:
        cmd += [
            "-a",
            "MIME-Version: 1.0",
            "-a",
            "Content-type: text/plain; charset=UTF-8",
            "-a",
            "Content-transfer-encoding: 8bit",
        ]
    to = config.get("to", "root")
    mailx = subprocess.Popen(cmd + [to], stdin=subprocess.PIPE)
    mailx.stdin.write("\n\n".join([table, changes]).strip())
    mailx.stdin.close()
    mailx.wait()


def has_changed(configfile, table, stampfile):
    change = False
    hashsum = hashlib.sha1()
    for line in open(configfile):
        hashsum.update(line.encode("utf-8"))
    hashsum.update(table.encode("utf-8"))
    newhash = hashsum.hexdigest()
    try:
        with open(stampfile) as f:
            oldhash = f.readline().strip()
    except Exception as err:
        oldhash = "invalid"
    if oldhash != newhash:
        change = True
    return change, newhash


class AcquireProgress(apt.progress.text.AcquireProgress):
    def __init__(self, debug):
        super(AcquireProgress, self).__init__(
            outfile=sys.stderr if debug else open("/dev/null", "w")
        )


if __name__ == "__main__":
    args = getargs()
    logging.basicConfig(
        level=logging.DEBUG if args.debug else logging.WARNING,
        format="%(levelname)-8s %(message)s",
    )
    config = configparser.ConfigParser()
    config.read(args.configfile)
    config_dir = os.path.dirname(args.configfile)

    fqdn = socket.getfqdn()
    # workaround for dodgy /etc/hosts
    if fqdn in ["localhost", "localhost.localdomain"]:
        fqdn = socket.gethostname() or fqdn

    if args.testmessage:
        cache = None
        count = 0
        table = "this is a test message from painintheapt"
        change = True
    else:
        cache, updates = update()
        count = len(updates)
        table = maketable(updates) if count else ""
        change, newhash = has_changed(args.configfile, table, args.stampfile)

    ret = 0
    for section, function in [
        ("XMPP", sendxmpp),
        ("SMTP", sendsmtp),
        ("MAILX", sendmailx),
    ]:
        try:
            if section in config.sections() and (change or args.force):
                send_changes = config[section].getboolean("send_changes", True)
                function(
                    config[section],
                    config_dir,
                    table,
                    count,
                    fqdn,
                    args.debug,
                    get_changelogs(cache, send_changes),
                )
        except Exception as err:
            print(str(err), file=sys.stderr)
            ret = 1

    if args.testmessage:
        sys.exit(ret)

    if change or args.force:
        with open(args.stampfile, "wb") as f:
            f.write(newhash.encode("utf-8"))

    cache.fetch_archives(progress=AcquireProgress(args.debug))

    sys.exit(ret)
