#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import absolute_import

from datetime import datetime
import socket
from select import select
from xml.sax.handler import ContentHandler, feature_namespaces

from defusedxml.sax import make_parser
import nagiosplugin


NS_XMPP_SASL = 'urn:ietf:params:xml:ns:xmpp-sasl'
NS_XMPP_TLS  = 'urn:ietf:params:xml:ns:xmpp-tls'
NS_XMPP_CAPS = 'http://jabber.org/protocol/caps'


class XmppException(Exception):
    """
    Custom exception class.

    """


class XmppClientServerResponseHandler(ContentHandler):
    seen_elements = set()
    mechanisms = []
    starttls = False
    capabilities = {}

    inelem = []
    level = 0

    def startElementNS(self, name, qname, attrs):
        self.inelem.append(name)
        self.seen_elements.add(name)
        if name == (NS_XMPP_TLS, 'starttls'):
            self.starttls = True
        elif name == (NS_XMPP_CAPS, 'c'):
            for qname in attrs.getQNames():
                self.capabilities[qname] = attrs.getValueByQName(qname)

    def endElementNS(self, name, qname):
        del self.inelem[-1]

    def characters(self, content):
        if self.inelem[-1] == (NS_XMPP_SASL, 'mechanism'):
            self.mechanisms.append(content)
        else:
            print(self.inelem, content)

    def is_valid_start(self):
        return True  # TODO: some real implementation


class Xmpp(nagiosplugin.Resource):
    state = nagiosplugin.Unknown
    cause = None

    def __init__(self, host_address, port, ipv6, is_server, starttls,
                 servername):
        self.address = host_address
        self.port = port
        self.ipv6 = ipv6
        self.is_server = is_server
        self.starttls = starttls
        self.servername = servername
        self.parser = make_parser()
        self.parser.setFeature(feature_namespaces, True)
        if self.is_server:
            pass  # TODO: make server parser
        else:
            self.contenthandler = XmppClientServerResponseHandler()
            self.parser.setContentHandler(self.contenthandler)

    def get_addrinfo(self):
        if self.ipv6 is None:
            addrfamily = 0
        elif self.ipv6 is True:
            addrfamily = socket.AF_INET6
        else:
            addrfamily = socket.AF_INET
        return socket.getaddrinfo(
            self.address, self.port, addrfamily, socket.SOCK_STREAM,
            socket.IPPROTO_TCP)
        self.result = nagiosplugin.Critical

    def open_socket(self, addrinfo):
        for res in addrinfo:
            af, socktype, proto, canonname, sa = res
            try:
                s = socket.socket(af, socktype, proto)
            except socket.error:
                s = None
                continue
            try:
                s.connect(sa)
            except socket.error:
                s.close()
                s = None
                continue
            break
            if s is None:
                raise XmppException("could not open socket")
        return s

    def handle_server(self, xmppsocket):
        pass

    def handle_xmpp_stanza(self, xmppsocket, message_str, timeout=0.1):
        xmppsocket.sendall(message_str.encode('utf-8'))
        while True:
            rready, wready, xready = select([xmppsocket], [], [], timeout)
            if xmppsocket in rready:
                data = xmppsocket.recv(4096)
                if not data: break
                self.parser.feed(data.decode('utf-8'))
            else:
                break

    def handle_client(self, xmppsocket):
        self.handle_xmpp_stanza(xmppsocket, (
            "<?xml version='1.0' ?><stream:stream to='{servername}' "
            "xmlns='jabber:client' "
            "xmlns:stream='http://etherx.jabber.org/streams' "
            "version='1.0'>"
        ).format(servername=self.servername))
        if not self.contenthandler.is_valid_start():
            raise XmppException("no valid response to XMPP client request")
        self.handle_xmpp_stanza(xmppsocket, "</stream:stream>")

    def probe(self):
        start = datetime.now()
        try:
            addrinfo = self.get_addrinfo()
            xmppsocket = self.open_socket(addrinfo)
            try:
                if self.is_server:
                    self.handle_server(xmppsocket)
                else:
                    self.handle_client(xmppsocket)
            finally:
                xmppsocket.close()
            self.parser.close()
        except socket.gaierror as e:
            self.state = nagiosplugin.Critical
            self.cause = str(e)
            return nagiosplugin.Metric("time", "unknown")
        except XmppException as e:
            self.state = nagiosplugin.Critical
            self.cause = e.message
            return nagiosplugin.Metric("time", "unknown")
        end = datetime.now()
        return nagiosplugin.Metric(
            'time', (end - start).total_seconds(), 's', min=0)


class XmppContext(nagiosplugin.ScalarContext):

    def evaluate(self, metric, resource):
        if resource.cause:
            return nagiosplugin.Result(resource.state, resource.cause, metric)
        return super(XmppContext, self).evaluate(metric, resource)


@nagiosplugin.guarded
def main():
    import argparse
    parser = argparse.ArgumentParser(description="Check XMPP services")
    parser.add_argument(
        "-H", "--host-address", help="host address", required=True)
    parser.add_argument(
        "-p", "--port", help="port", type=int)
    is_server = parser.add_mutually_exclusive_group()
    is_server.add_argument(
        "--s2s", dest="is_server", action='store_true',
        help="server to server (s2s)")
    is_server.add_argument(
        "--c2s", dest="is_server", action='store_false',
        help="client to server (c2s)")
    ipv6 = parser.add_mutually_exclusive_group()
    ipv6.add_argument(
        "-4", "--ipv4", dest="ipv6", action='store_false',
        help="enforce IPv4")
    ipv6.add_argument(
        "-6", "--ipv6", dest="ipv6", action='store_true',
        help="enforce IPv6")
    parser.add_argument(
        "--servername", help="server name to be used")
    parser.add_argument(
        "--starttls",
        action='store_true', help="check whether the service allows starttls")
    parser.set_defaults(is_server=False, ipv6=None)
    parser.add_argument(
        "-w", "--warning", metavar="SECONDS", default='',
        help="return warning if connection setup takes longer than SECONDS")
    parser.add_argument(
        "-c", "--critical", metavar="SECONDS", default='',
        help="return critical if connection setup takes longer than SECONDS")
    args = parser.parse_args()
    if args.port is None:
        if args.is_server:
            args.port = 5269
        else:
            args.port = 5222
    if args.servername is None:
        args.servername = args.host_address
    kwargs = vars(args)
    warning = kwargs.pop('warning')
    critical = kwargs.pop('critical')
    check = nagiosplugin.Check(
        Xmpp(**kwargs),
        XmppContext('time', warning, critical)
    )
    check.main(timeout=0)


if __name__ == "__main__":
    main()