From 9edc6e0c3ecdf26d1c150843de731365f27308f8 Mon Sep 17 00:00:00 2001
From: Jan Dittberner <jan@dittberner.info>
Date: Tue, 10 Feb 2015 03:20:31 +0100
Subject: [PATCH] add some initial XML processing for C2S

---
 check_xmpp | 154 +++++++++++++++++++++++++++++++++++++++++------------
 1 file changed, 121 insertions(+), 33 deletions(-)

diff --git a/check_xmpp b/check_xmpp
index af9f456..4d43765 100755
--- a/check_xmpp
+++ b/check_xmpp
@@ -1,9 +1,56 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+from datetime import datetime
+import socket
+from xml.sax import make_parser
+from xml.sax.handler import ContentHandler, feature_namespaces
 
 import nagiosplugin
-import socket
-from datetime import datetime
+
+
+NS_XMPP_SASL = 'urn:ietf:params:xml:ns:xmpp-sasl'
+NS_XMPP_TLS  = 'urn:ietf:params:xml:ns:xmpp-tls'
+NS_XMPP_CAPS = 'http://jabber.org/protocol/caps'
+
+
+class XmppException(Exception):
+    """
+    Custom exception class.
+
+    """
+
+
+class XmppClientServerResponseHandler(ContentHandler):
+    seen_elements = set()
+    mechanisms = []
+    starttls = False
+    capabilities = {}
+
+    inelem = []
+    level = 0
+
+    def startElementNS(self, name, qname, attrs):
+        self.inelem.append(name)
+        self.seen_elements.add(name)
+        if name == (NS_XMPP_TLS, 'starttls'):
+            self.starttls = True
+        elif name == (NS_XMPP_CAPS, 'c'):
+            for qname in attrs.getQNames():
+                self.capabilities[qname] = attrs.getValueByQName(qname)
+
+    def endElementNS(self, name, qname):
+        del self.inelem[-1]
+
+    def characters(self, content):
+        if self.inelem[-1] == (NS_XMPP_SASL, 'mechanism'):
+            self.mechanisms.append(content)
+        else:
+            print self.inelem, content
+
+    def is_valid_start(self):
+        return True  # TODO: some real implementation
 
 
 class Xmpp(nagiosplugin.Resource):
@@ -18,36 +65,13 @@ class Xmpp(nagiosplugin.Resource):
         self.is_server = is_server
         self.starttls = starttls
         self.servername = servername
-
-    def probe(self):
-        start = datetime.now()
-        try:
-            for res in self.get_addrinfo():
-                af, socktype, proto, canonname, sa = res
-                try:
-                    s = socket.socket(af, socktype, proto)
-                except socket.error:
-                    s = None
-                    continue
-                try:
-                    s.connect(sa)
-                except socket.error:
-                    s.close()
-                    s = None
-                    continue
-                break
-            if s is None:
-                self.state = nagiosplugin.Critical
-                self.cause = 'could not open socket'
-                return nagiosplugin.Metric("time", "unknown")
-            s.close()
-        except socket.gaierror as e:
-            self.state = nagiosplugin.Critical
-            self.cause = str(e)
-            return nagiosplugin.Metric("time", "unknown")
-        end = datetime.now()
-        return nagiosplugin.Metric(
-            'time', (end - start).total_seconds(), 's', min=0)
+        self.parser = make_parser()
+        self.parser.setFeature(feature_namespaces, True)
+        if self.is_server:
+            pass  # TODO: make server parser
+        else:
+            self.contenthandler = XmppClientServerResponseHandler()
+            self.parser.setContentHandler(self.contenthandler)
 
     def get_addrinfo(self):
         if self.ipv6 is None:
@@ -61,6 +85,70 @@ class Xmpp(nagiosplugin.Resource):
             socket.IPPROTO_TCP)
         self.result = nagiosplugin.Critical
 
+    def open_socket(self, addrinfo):
+        for res in addrinfo:
+            af, socktype, proto, canonname, sa = res
+            try:
+                s = socket.socket(af, socktype, proto)
+            except socket.error:
+                s = None
+                continue
+            try:
+                s.connect(sa)
+            except socket.error:
+                s.close()
+                s = None
+                continue
+            break
+            if s is None:
+                raise XmppException("could not open socket")
+        return s
+
+    def handle_server(self, xmppsocket):
+        pass
+
+    def is_valid_client_response(self, xmldata):
+        self.parser.feed(xmldata)
+        return self.contenthandler.is_valid_start()
+
+    def handle_client(self, xmppsocket):
+        xmppsocket.sendall((
+            "<?xml version='1.0' ?><stream:stream to='{servername}' "
+            "xmlns='jabber:client' "
+            "xmlns:stream='http://etherx.jabber.org/streams' "
+            "version='1.0'>"
+        ).format(servername=self.servername))
+        if not self.is_valid_client_response(xmppsocket.recv(4096)):
+            raise XmppException("no valid response to XMPP client request")
+        xmppsocket.sendall("</stream:stream>")
+        self.parser.feed(xmppsocket.recv(4096))
+
+
+    def probe(self):
+        start = datetime.now()
+        try:
+            addrinfo = self.get_addrinfo()
+            xmppsocket = self.open_socket(addrinfo)
+            try:
+                if self.is_server:
+                    self.handle_server(xmppsocket)
+                else:
+                    self.handle_client(xmppsocket)
+            finally:
+                xmppsocket.close()
+            self.parser.close()
+        except socket.gaierror as e:
+            self.state = nagiosplugin.Critical
+            self.cause = str(e)
+            return nagiosplugin.Metric("time", "unknown")
+        except XmppException as e:
+            self.state = nagiosplugin.Critical
+            self.cause = e.message
+            return nagiosplugin.Metric("time", "unknown")
+        end = datetime.now()
+        return nagiosplugin.Metric(
+            'time', (end - start).total_seconds(), 's', min=0)
+
 
 class XmppContext(nagiosplugin.ScalarContext):
 
@@ -119,7 +207,7 @@ def main():
         Xmpp(**kwargs),
         XmppContext('time', warning, critical)
     )
-    check.main()
+    check.main(timeout=0)
 
 
 if __name__ == "__main__":