#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import absolute_import

from datetime import datetime
from select import select
from xml.sax.handler import ContentHandler, feature_namespaces
import logging
import os.path
import socket
import ssl

from defusedxml.sax import make_parser
import nagiosplugin


NS_IETF_XMPP_SASL = 'urn:ietf:params:xml:ns:xmpp-sasl'
NS_IETF_XMPP_TLS  = 'urn:ietf:params:xml:ns:xmpp-tls'
NS_IETF_XMPP_STREAMS = 'urn:ietf:params:xml:ns:xmpp-streams'
NS_JABBER_CAPS = 'http://jabber.org/protocol/caps'
NS_ETHERX_STREAMS = 'http://etherx.jabber.org/streams'

XMPP_STATE_NEW = 'new'
XMPP_STATE_STREAM_START = 'started'
XMPP_STATE_RECEIVED_FEATURES = 'received features'
XMPP_STATE_FINISHED = 'finished'
XMPP_STATE_ERROR = 'error'
XMPP_STATE_PROCEED_STARTTLS = 'proceed with starttls'

_LOG = logging.getLogger('nagiosplugin')


class XmppException(Exception):
    """
    Custom exception class.

    """
    def __init__(self, message):
        self.message = message
        super(XmppException, self).__init__()


class XmppStreamError(object):
    condition = None
    text = None
    other_elements = {}

    def str(self):
        if self.text:
            return "{condition}: {text}".format(
                condition=self.condition, text=self.text)
        return self.condition


class XmppClientServerResponseHandler(ContentHandler):
    seen_elements = set()
    mechanisms = []
    starttls = False
    tlsrequired = False
    capabilities = {}
    state = XMPP_STATE_NEW
    streaminfo = None

    inelem = []
    level = 0

    def __init__(self, expect_starttls):
        self.expect_starttls = expect_starttls
        super(XmppClientServerResponseHandler, self).__init__()

    def startElementNS(self, name, qname, attrs):
        self.inelem.append(name)
        self.seen_elements.add(name)
        if name == (NS_ETHERX_STREAMS, 'stream'):
            self.state = XMPP_STATE_STREAM_START
            self.streaminfo = dict([
                (qname, attrs.getValueByQName(qname)) for
                qname in attrs.getQNames()])
        elif name == (NS_IETF_XMPP_TLS, 'starttls'):
            self.starttls = True
        elif (
            self.inelem[-2] == (NS_IETF_XMPP_TLS, 'starttls') and
            name == (NS_IETF_XMPP_TLS, 'required')
        ):
            self.tlsrequired = True
            _LOG.info("info other side requires TLS")
        elif name == (NS_JABBER_CAPS, 'c'):
            for qname in attrs.getQNames():
                self.capabilities[qname] = attrs.getValueByQName(qname)
        elif name == (NS_ETHERX_STREAMS, 'error'):
            self.state = XMPP_STATE_ERROR
            self.errorinstance = XmppStreamError()
        elif (
            self.state == XMPP_STATE_ERROR and
            name != (NS_IETF_XMPP_STREAMS, 'text')
        ):
            if name[0] == NS_IETF_XMPP_STREAMS:
                self.errorinstance.condition = name[1]
            else:
                self.errorinstance.other_elements[name] = {'attrs': dict([
                    (qname, attrs.getValueByQName(qname)) for
                    qname in attrs.getQNames()
                ])}
        _LOG.debug('start %s (%s)', name, attrs.getQNames())

    def endElementNS(self, name, qname):
        if name == (NS_ETHERX_STREAMS, 'features'):
            self.state = XMPP_STATE_RECEIVED_FEATURES
        elif name == (NS_ETHERX_STREAMS, 'stream'):
            self.state = XMPP_STATE_FINISHED
        elif name == (NS_ETHERX_STREAMS, 'error'):
            raise XmppException("XMPP stream error: {error}".format(
                error=self.errorinstance))
        elif name == (NS_IETF_XMPP_TLS, 'proceed'):
            self.state = XMPP_STATE_PROCEED_STARTTLS
        elif name == (NS_IETF_XMPP_TLS, 'failure'):
            raise XmppException("starttls initiation failed")
        _LOG.debug('end %s', name)
        del self.inelem[-1]

    def characters(self, content):
        elem = self.inelem[-1]
        if elem == (NS_IETF_XMPP_SASL, 'mechanism'):
            self.mechanisms.append(content)
        elif self.state == XMPP_STATE_ERROR:
            if elem == (NS_IETF_XMPP_STREAMS, 'text'):
                self.errorinstance.text = content
            else:
                self.errorinstance.other_elements[elem]['text'] = content
        else:
            _LOG.warning('ignored content in %s: %s', self.inelem, content)

    def is_valid_start(self):
        if not self.state == XMPP_STATE_RECEIVED_FEATURES:
            raise XmppException('XMPP feature list is not finished')
        if self.expect_starttls is True and self.starttls is False:
            raise XmppException('expected STARTTLS capable service')
        if (
            'version' not in self.streaminfo or
            self.streaminfo['version'] != '1.0'
        ):
            _LOG.warning(
                'unknown stream version %s', self.streaminfo['version'])
        return True


class Xmpp(nagiosplugin.Resource):
    state = nagiosplugin.Unknown
    cause = None
    socket = None
    daysleft = None

    def __init__(
        self, host_address, port, ipv6, is_server, starttls,
        servername, checkcerts, caroots
    ):
        self.address = host_address
        self.port = port
        self.ipv6 = ipv6
        self.is_server = is_server
        self.starttls = starttls
        self.servername = servername
        self.checkcerts = checkcerts
        self.caroots = caroots
        self.make_parser()
        self.set_content_handler()

    def make_parser(self):
        self.parser = make_parser()
        self.parser.setFeature(feature_namespaces, True)

    def set_content_handler(self):
        if self.is_server:
            pass  # TODO: make server parser
        else:
            self.contenthandler = XmppClientServerResponseHandler(
                expect_starttls=self.starttls)
            self.parser.setContentHandler(self.contenthandler)

    def get_addrinfo(self):
        if self.ipv6 is None:
            addrfamily = 0
        elif self.ipv6 is True:
            addrfamily = socket.AF_INET6
        else:
            addrfamily = socket.AF_INET
        return socket.getaddrinfo(
            self.address, self.port, addrfamily, socket.SOCK_STREAM,
            socket.IPPROTO_TCP)
        self.result = nagiosplugin.Critical

    def open_socket(self, addrinfo):
        for res in addrinfo:
            af, socktype, proto, canonname, sa = res
            try:
                s = socket.socket(af, socktype, proto)
            except socket.error:
                s = None
                continue
            try:
                s.connect(sa)
            except socket.error:
                s.close()
                s = None
                continue
            break
        if s is None:
            raise XmppException("could not open socket")
        return s

    def handle_xmpp_stanza(
        self, message_str, timeout=0.1, expected_state=None
    ):
        self.socket.sendall(message_str.encode('utf-8'))
        chunks = []
        while True:
            rready, wready, xready = select([self.socket], [], [], timeout)
            if self.socket in rready:
                data = self.socket.recv(4096)
                if not data: break
                chunks.append(data)
            else:
                break
        self.parser.feed(b''.join(chunks).decode('utf-8'))
        if (
            expected_state is not None and
            self.contenthandler.state != expected_state
        ):
            raise XmppException(
                "unexpected state %s" % self.contenthandler.state)

    def handle_server(self):
        pass

    def start_stream(self):
        self.handle_xmpp_stanza((
            "<?xml version='1.0' ?><stream:stream to='{servername}' "
            "xmlns='jabber:client' "
            "xmlns:stream='http://etherx.jabber.org/streams' "
            "version='1.0'>"
        ).format(servername=self.servername),
            expected_state=XMPP_STATE_RECEIVED_FEATURES)

    def setup_ssl_context(self):
        context = ssl.create_default_context()
        context.options = ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3
        if not self.checkcerts:
            context.check_hostname = False
            context.verify_mode = ssl.CERT_NONE
        else:
            context.verify_mode = ssl.CERT_REQUIRED
            if self.caroots:
                if os.path.isfile(self.caroots):
                    kwargs = {'cafile': self.caroots}
                else:
                    kwargs = {'capath': self.caroots}
                context.load_verify_locations(**kwargs)
            else:
                context.load_default_certs()
            stats = context.cert_store_stats()
            if stats['x509_ca'] == 0:
                _LOG.info(
                    "tried to load CA certificates from default locations, but"
                    " could not find any CA certificates.")
                raise XmppException('no CA certificates found')
            else:
                _LOG.debug('certificate store statistics: %s', stats)
        return context

    def initiate_tls(self):
        self.handle_xmpp_stanza(
            "<starttls xmlns='{xmlns}'/>".format(xmlns=NS_IETF_XMPP_TLS),
            expected_state=XMPP_STATE_PROCEED_STARTTLS)
        sslcontext = self.setup_ssl_context()
        try:
            self.socket = sslcontext.wrap_socket(
                self.socket, server_hostname=self.servername)
        except ssl.SSLError as ssle:
            raise XmppException("SSL error %s" % ssle.strerror)
        except ssl.CertificateError as certerr:
            raise XmppException("Certificate error %s" % certerr)
        self.starttls = False
        # reset infos retrieved previously as written in RFC 3920 sec. 5.
        self.parser.reset()
        if self.checkcerts:
            certinfo = self.socket.getpeercert()
            _LOG.debug("got the following certificate info: %s", certinfo)
            _LOG.info(
                "certificate is valid from %s until %s",
                certinfo['notBefore'], certinfo['notAfter'])
            enddate = ssl.cert_time_to_seconds(certinfo['notAfter'])
            remaining = datetime.fromtimestamp(enddate) - datetime.now()
            self.daysleft = remaining.days
        # start new parsing
        self.make_parser()
        self.set_content_handler()
        self.start_stream()
        if not self.contenthandler.is_valid_start():
            raise XmppException("no valid response to XMPP client request")

    def handle_client(self):
        self.start_stream()
        if not self.contenthandler.is_valid_start():
            raise XmppException("no valid response to XMPP client request")
        if self.starttls is True or self.contenthandler.tlsrequired:
            self.initiate_tls()
        self.handle_xmpp_stanza("</stream:stream>")

    def probe(self):
        start = datetime.now()
        try:
            addrinfo = self.get_addrinfo()
            self.socket = self.open_socket(addrinfo)
            try:
                if self.is_server:
                    self.handle_server()
                else:
                    self.handle_client()
            finally:
                self.socket.close()
            self.parser.close()
        except socket.gaierror as e:
            self.state = nagiosplugin.Critical
            self.cause = str(e)
            return nagiosplugin.Metric("time", "unknown")
        except XmppException as e:
            self.state = nagiosplugin.Critical
            self.cause = e.message
            return nagiosplugin.Metric("time", "unknown")
        end = datetime.now()
        yield nagiosplugin.Metric(
            'time', (end - start).total_seconds(), 's', min=0)
        yield nagiosplugin.Metric('daysleft', self.daysleft, 'd')


class XmppContext(nagiosplugin.ScalarContext):

    def evaluate(self, metric, resource):
        if resource.cause:
            return nagiosplugin.Result(resource.state, resource.cause, metric)
        return super(XmppContext, self).evaluate(metric, resource)


class DaysValidContext(nagiosplugin.Context):

    def __init__(
        self, name, warndays=0, critdays=0,
        fmt_metric='certificate expires in {value} days'):
        super(DaysValidContext, self).__init__(name, fmt_metric=fmt_metric)
        self.warning = nagiosplugin.Range('@%d:' % warndays)
        self.critical = nagiosplugin.Range('@%d:' % critdays)

    def evaluate(self, metric, resource):
        if resource.checkcerts and metric.value is not None:
            hint = self.describe(metric)
            if self.critical.match(metric.value):
                return nagiosplugin.Result(nagiosplugin.Critical, hint, metric)
            if self.warning.match(metric.value):
                return nagiosplugin.Result(nagiosplugin.Warn, hint, metric)
            return nagiosplugin.Result(nagiosplugin.Ok, hint, metric)
        return nagiosplugin.Result(nagiosplugin.Ok)

    def performance(self, metric, resource):
        if resource.checkcerts and metric.value is not None:
            return nagiosplugin.Performance('daysvalid', metric.value, 'd')
        return None


@nagiosplugin.guarded
def main():
    import argparse
    parser = argparse.ArgumentParser(description="Check XMPP services")
    parser.add_argument(
        "-H", "--host-address", help="host address", required=True)
    parser.add_argument(
        "-p", "--port", help="port", type=int)
    is_server = parser.add_mutually_exclusive_group()
    is_server.add_argument(
        "--s2s", dest="is_server", action='store_true',
        help="server to server (s2s)")
    is_server.add_argument(
        "--c2s", dest="is_server", action='store_false',
        help="client to server (c2s)")
    ipv6 = parser.add_mutually_exclusive_group()
    ipv6.add_argument(
        "-4", "--ipv4", dest="ipv6", action='store_false',
        help="enforce IPv4")
    ipv6.add_argument(
        "-6", "--ipv6", dest="ipv6", action='store_true',
        help="enforce IPv6")
    parser.add_argument(
        "--servername", help="server name to be used")
    parser.add_argument(
        "--starttls",
        action='store_true', help="check whether the service allows starttls")
    parser.add_argument(
        "-w", "--warning", metavar="SECONDS", default='',
        help="return warning if connection setup takes longer than SECONDS")
    parser.add_argument(
        "-c", "--critical", metavar="SECONDS", default='',
        help="return critical if connection setup takes longer than SECONDS")
    parser.add_argument(
        "--no-check-certificates", dest='checkcerts', action='store_false',
        help="do not check whether the XMPP server's certificate is valid")
    parser.add_argument(
        "-r", "--ca-roots", dest="caroots",
        help="path to a file or directory where CA certifcates can be found")
    parser.add_argument(
        "--warn-days", dest="warndays", type=int, default=0,
        help=(
            "set state to WARN if the certificate is valid for not more than "
            "the given number of days"))
    parser.add_argument(
        "--crit-days", dest="critdays", type=int, default=0,
        help=(
            "set state to CRITICAL if the certificate is valid for not more "
            "than the given number of days"))
    parser.add_argument(
        '-v', '--verbose', action='count', default=0)
    parser.set_defaults(is_server=False, ipv6=None, checkcerts=True)
    args = parser.parse_args()
    if args.port is None:
        if args.is_server:
            args.port = 5269
        else:
            args.port = 5222
    if args.servername is None:
        args.servername = args.host_address
    kwargs = vars(args)
    warning, critical, warndays, critdays, verbose = [
        kwargs.pop(arg) for arg in
        ('warning', 'critical', 'warndays', 'critdays', 'verbose')
    ]
    check = nagiosplugin.Check(
        Xmpp(**kwargs),
        XmppContext('time', warning, critical),
        DaysValidContext('daysleft', warndays, critdays)
    )
    check.main(verbose=verbose, timeout=0)


if __name__ == "__main__":
    main()