Statistics
| Branch: | Tag: | Revision:

root / vncauthproxy / proxy.py @ 512c571e

History | View | Annotate | Download (20.4 kB)

1
#!/usr/bin/env python
2
"""
3
vncauthproxy - a VNC authentication proxy
4
"""
5
#
6
# Copyright (c) 2010-2011 Greek Research and Technology Network S.A.
7
#
8
# This program is free software; you can redistribute it and/or modify
9
# it under the terms of the GNU General Public License as published by
10
# the Free Software Foundation; either version 2 of the License, or
11
# (at your option) any later version.
12
#
13
# This program is distributed in the hope that it will be useful, but
14
# WITHOUT ANY WARRANTY; without even the implied warranty of
15
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16
# General Public License for more details.
17
#
18
# You should have received a copy of the GNU General Public License
19
# along with this program; if not, write to the Free Software
20
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
21
# 02110-1301, USA.
22

    
23
DEFAULT_CTRL_SOCKET = "/var/run/vncauthproxy/ctrl.sock"
24
DEFAULT_LOG_FILE = "/var/log/vncauthproxy/vncauthproxy.log"
25
DEFAULT_PID_FILE = "/var/run/vncauthproxy/vncauthproxy.pid"
26
DEFAULT_CONNECT_TIMEOUT = 30
27
# Default values per http://www.iana.org/assignments/port-numbers
28
DEFAULT_MIN_PORT = 49152 
29
DEFAULT_MAX_PORT = 65535
30

    
31
import os
32
import sys
33
import logging
34
import gevent
35
import daemon
36
import random
37
import daemon.pidlockfile
38

    
39
import rfb
40
 
41
try:
42
    import simplejson as json
43
except ImportError:
44
    import json
45

    
46
from gevent import socket
47
from signal import SIGINT, SIGTERM
48
from gevent import signal
49
from gevent.select import select
50
from time import sleep
51

    
52
logger = None
53

    
54
# Currently, gevent uses libevent-dns for asynchornous DNS resolution,
55
# which opens a socket upon initialization time. Since we can't get the fd
56
# reliably, We have to maintain all file descriptors open (which won't harm
57
# anyway)
58

    
59
class AllFilesDaemonContext(daemon.DaemonContext):
60
    """DaemonContext class keeping all file descriptors open"""
61
    def _get_exclude_file_descriptors(self):
62
        class All:
63
            def __contains__(self, value):
64
                return True
65
        return All()
66

    
67

    
68
class VncAuthProxy(gevent.Greenlet):
69
    """
70
    Simple class implementing a VNC Forwarder with MITM authentication as a
71
    Greenlet
72

73
    VncAuthProxy forwards VNC traffic from a specified port of the local host
74
    to a specified remote host:port. Furthermore, it implements VNC
75
    Authentication, intercepting the client/server handshake and asking the
76
    client for authentication even if the backend requires none.
77

78
    It is primarily intended for use in virtualization environments, as a VNC
79
    ``switch''.
80

81
    """
82
    id = 1
83

    
84
    def __init__(self, logger, listeners, pool, daddr, dport, server, password, connect_timeout):
85
        """
86
        @type logger: logging.Logger
87
        @param logger: the logger to use
88
        @type listeners: list
89
        @param listeners: list of listening sockets to use for client connections
90
        @type pool: list
91
        @param pool: if not None, return the client port number into this port pool
92
        @type daddr: str
93
        @param daddr: destination address (IPv4, IPv6 or hostname)
94
        @type dport: int
95
        @param dport: destination port
96
        @type server: socket
97
        @param server: VNC server socket
98
        @type password: str
99
        @param password: password to request from the client
100
        @type connect_timeout: int
101
        @param connect_timeout: how long to wait for client connections
102
                                (seconds)
103

104
        """
105
        gevent.Greenlet.__init__(self)
106
        self.id = VncAuthProxy.id
107
        VncAuthProxy.id += 1
108
        self.log = logger
109
        self.listeners = listeners
110
        # All listening sockets are assumed to be on the same port
111
        self.sport = listeners[0].getsockname()[1]
112
        self.pool = pool
113
        self.daddr = daddr
114
        self.dport = dport
115
        self.server = server
116
        self.password = password
117
        self.client = None
118
        self.timeout = connect_timeout
119

    
120
    def _cleanup(self):
121
        """Close all active sockets and exit gracefully"""
122
        # Reintroduce the port number of the client socket in
123
        # the port pool, if applicable.
124
        if not self.pool is None:
125
            self.pool.append(self.sport)
126
            self.log.debug("Returned port %d to port pool, contains %d ports",
127
                self.sport, len(self.pool))
128

    
129
        while self.listeners:
130
            self.listeners.pop().close()
131
        if self.server:
132
            self.server.close()
133
        if self.client:
134
            self.client.close()
135

    
136
        raise gevent.GreenletExit
137

    
138
    def info(self, msg):
139
        self.log.info("[C%d] %s" % (self.id, msg))
140

    
141
    def debug(self, msg):
142
        self.log.debug("[C%d] %s" % (self.id, msg))
143

    
144
    def warn(self, msg):
145
        self.log.warn("[C%d] %s" % (self.id, msg))
146

    
147
    def error(self, msg):
148
        self.log.error("[C%d] %s" % (self.id, msg))
149

    
150
    def critical(self, msg):
151
        self.log.critical("[C%d] %s" % (self.id, msg))
152

    
153
    def __str__(self):
154
        return "VncAuthProxy: %d -> %s:%d" % (self.sport, self.daddr, self.dport)
155

    
156
    def _forward(self, source, dest):
157
        """
158
        Forward traffic from source to dest
159

160
        @type source: socket
161
        @param source: source socket
162
        @type dest: socket
163
        @param dest: destination socket
164

165
        """
166

    
167
        while True:
168
            d = source.recv(16384)
169
            if d == '':
170
                if source == self.client:
171
                    self.info("Client connection closed")
172
                else:
173
                    self.info("Server connection closed")
174
                break
175
            dest.sendall(d)
176
        # No need to close the source and dest sockets here.
177
        # They are owned by and will be closed by the original greenlet.
178

    
179
    def _client_handshake(self):
180
        """
181
        Perform handshake/authentication with a connecting client
182

183
        Outline:
184
        1. Client connects
185
        2. We fake RFB 3.8 protocol and require VNC authentication [also supports RFB 3.3]
186
        3. Client accepts authentication method
187
        4. We send an authentication challenge
188
        5. Client sends the authentication response
189
        6. We check the authentication
190

191
        Upon return, self.client socket is connected to the client.
192

193
        """
194
        self.client.send(rfb.RFB_VERSION_3_8 + "\n")
195
        client_version_str = self.client.recv(1024)
196
        client_version = rfb.check_version(client_version_str)
197
        if not client_version:
198
            self.error("Invalid version: %s" % client_version_str)
199
            raise gevent.GreenletExit
200

    
201
        # Both for RFB 3.3 and 3.8
202
        self.debug("Requesting authentication")
203
        auth_request = rfb.make_auth_request(rfb.RFB_AUTHTYPE_VNC,
204
            version=client_version)
205
        self.client.send(auth_request)
206

    
207
        # The client gets to propose an authtype only for RFB 3.8
208
        if client_version == rfb.RFB_VERSION_3_8:
209
            res = self.client.recv(1024)
210
            type = rfb.parse_client_authtype(res)
211
            if type == rfb.RFB_AUTHTYPE_ERROR:
212
                self.warn("Client refused authentication: %s" % res[1:])
213
            else:
214
                self.debug("Client requested authtype %x" % type)
215

    
216
            if type != rfb.RFB_AUTHTYPE_VNC:
217
                self.error("Wrong auth type: %d" % type)
218
                self.client.send(rfb.to_u32(rfb.RFB_AUTH_ERROR))
219
                raise gevent.GreenletExit
220
        
221
        # Generate the challenge
222
        challenge = os.urandom(16)
223
        self.client.send(challenge)
224
        response = self.client.recv(1024)
225
        if len(response) != 16:
226
            self.error("Wrong response length %d, should be 16" % len(response))
227
            raise gevent.GreenletExit
228

    
229
        if rfb.check_password(challenge, response, self.password):
230
            self.debug("Authentication successful!")
231
        else:
232
            self.warn("Authentication failed")
233
            self.client.send(rfb.to_u32(rfb.RFB_AUTH_ERROR))
234
            raise gevent.GreenletExit
235

    
236
        # Accept the authentication
237
        self.client.send(rfb.to_u32(rfb.RFB_AUTH_SUCCESS))
238
       
239
    def _run(self):
240
        try:
241
            self.log.debug("Waiting for client to connect")
242
            rlist, _, _ = select(self.listeners, [], [], timeout=self.timeout)
243

    
244
            if not rlist:
245
                self.info("Timed out, no connection after %d sec" % self.timeout)
246
                raise gevent.GreenletExit
247

    
248
            for sock in rlist:
249
                self.client, addrinfo = sock.accept()
250
                self.info("Connection from %s:%d" % addrinfo[:2])
251

    
252
                # Close all listening sockets, we only want a one-shot connection
253
                # from a single client.
254
                while self.listeners:
255
                    self.listeners.pop().close()
256
                break
257
       
258
            # Perform RFB handshake with the client.
259
            self._client_handshake()
260

    
261
            # Bridge both connections through two "forwarder" greenlets.
262
            self.workers = [gevent.spawn(self._forward, self.client, self.server),
263
                gevent.spawn(self._forward, self.server, self.client)]
264
            
265
            # If one greenlet goes, the other has to go too.
266
            self.workers[0].link(self.workers[1])
267
            self.workers[1].link(self.workers[0])
268
            gevent.joinall(self.workers)
269
            del self.workers
270
            raise gevent.GreenletExit
271
        except Exception, e:
272
            # Any unhandled exception in the previous block
273
            # is an error and must be logged accordingly
274
            if not isinstance(e, gevent.GreenletExit):
275
                self.log.exception(e)
276
            raise e
277
        finally:
278
            self._cleanup()
279

    
280

    
281
def fatal_signal_handler(signame):
282
    logger.info("Caught %s, will raise SystemExit" % signame)
283
    raise SystemExit
284

    
285
def get_listening_sockets(sport):
286
    sockets = []
287

    
288
    # Use two sockets, one for IPv4, one for IPv6. IPv4-to-IPv6 mapped
289
    # addresses do not work reliably everywhere (under linux it may have
290
    # been disabled in /proc/sys/net/ipv6/bind_ipv6_only).
291
    for res in socket.getaddrinfo(None, sport, socket.AF_UNSPEC,
292
                                  socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
293
        af, socktype, proto, canonname, sa = res
294
        try:
295
            s = None
296
            s = socket.socket(af, socktype, proto)
297
            if af == socket.AF_INET6:
298
                # Bind v6 only when AF_INET6, otherwise either v4 or v6 bind
299
                # will fail.
300
                s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
301
            s.bind(sa)
302
            s.listen(1)
303
            sockets.append(s)
304
            logger.debug("Listening on %s:%d" % sa[:2])
305
        except socket.error, msg:
306
            logger.error("Error binding to %s:%d: %s" %
307
                           (sa[0], sa[1], msg[1]))
308
            if s:
309
                s.close()
310
            while sockets:
311
                sockets.pop().close()
312
            
313
            # Make sure we fail immediately if we cannot get a socket
314
            raise msg
315
    
316
    return sockets
317

    
318
def perform_server_handshake(daddr, dport):
319
    """
320
    Initiate a connection with the backend server and perform basic
321
    RFB 3.8 handshake with it.
322

323
    Returns a socket connected to the backend server.
324

325
    """
326
    server = None
327
    # Try to connect to the server
328
    tries = 50
329

    
330
    while tries:
331
        tries -= 1
332

    
333
        # Initiate server connection
334
        for res in socket.getaddrinfo(daddr, dport, socket.AF_UNSPEC,
335
                                      socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
336
            af, socktype, proto, canonname, sa = res
337
            try:
338
                server = socket.socket(af, socktype, proto)
339
            except socket.error, msg:
340
                server = None
341
                continue
342

    
343
            try:
344
                logger.debug("Connecting to %s:%s" % sa[:2])
345
                server.connect(sa)
346
                logger.debug("Connection to %s:%s successful" % sa[:2])
347
            except socket.error, msg:
348
                server.close()
349
                server = None
350
                continue
351

    
352
            # We succesfully connected to the server
353
            tries = 0
354
            break
355

    
356
        # Wait and retry
357
        sleep(0.2)
358

    
359
    if server is None:
360
        raise Exception("Failed to connect to server")
361

    
362
    version = server.recv(1024)
363
    if not rfb.check_version(version):
364
        raise Exception("Unsupported RFB version: %s" % version.strip())
365

    
366
    server.send(rfb.RFB_VERSION_3_8 + "\n")
367

    
368
    res = server.recv(1024)
369
    types = rfb.parse_auth_request(res)
370
    if not types:
371
        raise Exception("Error handshaking with the server")
372

    
373
    else:
374
        logger.debug("Supported authentication types: %s" %
375
                       " ".join([str(x) for x in types]))
376

    
377
    if rfb.RFB_AUTHTYPE_NONE not in types:
378
        raise Exception("Error, server demands authentication")
379

    
380
    server.send(rfb.to_u8(rfb.RFB_AUTHTYPE_NONE))
381

    
382
    # Check authentication response
383
    res = server.recv(4)
384
    res = rfb.from_u32(res)
385

    
386
    if res != 0:
387
        raise Exception("Authentication error")
388

    
389
    return server
390

    
391
def parse_arguments(args):
392
    from optparse import OptionParser
393

    
394
    parser = OptionParser()
395
    parser.add_option("-s", "--socket", dest="ctrl_socket",
396
                      default=DEFAULT_CTRL_SOCKET,
397
                      metavar="PATH",
398
                      help="UNIX socket path for control connections (default: %s" %
399
                          DEFAULT_CTRL_SOCKET)
400
    parser.add_option("-d", "--debug", action="store_true", dest="debug",
401
                      help="Enable debugging information")
402
    parser.add_option("-l", "--log", dest="log_file",
403
                      default=DEFAULT_LOG_FILE,
404
                      metavar="FILE",
405
                      help="Write log to FILE instead of %s" % DEFAULT_LOG_FILE),
406
    parser.add_option('--pid-file', dest="pid_file",
407
                      default=DEFAULT_PID_FILE,
408
                      metavar='PIDFILE',
409
                      help="Save PID to file (default: %s)" %
410
                          DEFAULT_PID_FILE)
411
    parser.add_option("-t", "--connect-timeout", dest="connect_timeout",
412
                      default=DEFAULT_CONNECT_TIMEOUT, type="int", metavar="SECONDS",
413
                      help="How long to listen for clients to forward")
414
    parser.add_option("-p", "--min-port", dest="min_port",
415
                      default=DEFAULT_MIN_PORT, type="int", metavar="MIN_PORT",
416
                      help="The minimum port to use for automatically-allocated ephemeral ports")
417
    parser.add_option("-P", "--max-port", dest="max_port",
418
                      default=DEFAULT_MAX_PORT, type="int", metavar="MAX_PORT",
419
                      help="The minimum port to use for automatically-allocated ephemeral ports")
420

    
421
    return parser.parse_args(args)
422

    
423

    
424
def main():
425
    """Run the daemon from the command line."""
426

    
427
    (opts, args) = parse_arguments(sys.argv[1:])
428

    
429
    # Create pidfile
430
    pidf = daemon.pidlockfile.TimeoutPIDLockFile(
431
        opts.pid_file, 10)
432
    
433
    # Initialize logger
434
    lvl = logging.DEBUG if opts.debug else logging.INFO
435

    
436
    global logger
437
    logger = logging.getLogger("vncauthproxy")
438
    logger.setLevel(lvl)
439
    formatter = logging.Formatter("%(asctime)s %(module)s[%(process)d] %(levelname)s: %(message)s",
440
        "%Y-%m-%d %H:%M:%S")
441
    handler = logging.FileHandler(opts.log_file)
442
    handler.setFormatter(formatter)
443
    logger.addHandler(handler)
444

    
445
    # Become a daemon:
446
    # Redirect stdout and stderr to handler.stream to catch
447
    # early errors in the daemonization process [e.g., pidfile creation]
448
    # which will otherwise go to /dev/null.
449
    daemon_context = AllFilesDaemonContext(
450
        pidfile=pidf,
451
        umask=0022,
452
        stdout=handler.stream,
453
        stderr=handler.stream,
454
        files_preserve=[handler.stream])
455
    daemon_context.open()
456
    logger.info("Became a daemon")
457

    
458
    # A fork() has occured while daemonizing,
459
    # we *must* reinit gevent
460
    gevent.reinit()
461

    
462
    if os.path.exists(opts.ctrl_socket):
463
        logger.critical("Socket '%s' already exists" % opts.ctrl_socket)
464
        sys.exit(1)
465

    
466
    # TODO: make this tunable? chgrp as well?
467
    old_umask = os.umask(0007)
468

    
469
    ctrl = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
470
    ctrl.bind(opts.ctrl_socket)
471

    
472
    os.umask(old_umask)
473

    
474
    ctrl.listen(1)
475
    logger.info("Initialized, waiting for control connections at %s" %
476
                 opts.ctrl_socket)
477

    
478
    # Catch signals to ensure graceful shutdown,
479
    # e.g., to make sure the control socket gets unlink()ed.
480
    #
481
    # Uses gevent.signal so the handler fires even during
482
    # gevent.socket.accept()
483
    gevent.signal(SIGINT, fatal_signal_handler, "SIGINT")
484
    gevent.signal(SIGTERM, fatal_signal_handler, "SIGTERM")
485

    
486
    # Init ephemeral port pool
487
    ports = range(opts.min_port, opts.max_port + 1) 
488

    
489
    while True:
490
        try:
491
            client, addr = ctrl.accept()
492
            logger.info("New control connection")
493
           
494
            # Receive and parse a client request.
495
            response = {
496
                "source_port": 0,
497
                "status": "FAILED"
498
            }
499
            try:
500
                # TODO: support multiple forwardings in the same message?
501
                # 
502
                # Control request, in JSON:
503
                #
504
                # {
505
                #     "source_port": <source port or 0 for automatic allocation>,
506
                #     "destination_address": <destination address of backend server>,
507
                #     "destination_port": <destination port>
508
                #     "password": <the password to use for MITM authentication of clients>
509
                # }
510
                # 
511
                # The <password> is used for MITM authentication of clients
512
                # connecting to <source_port>, who will subsequently be forwarded
513
                # to a VNC server at <destination_address>:<destination_port>
514
                #
515
                # Control reply, in JSON:
516
                # {
517
                #     "source_port": <the allocated source port>
518
                #     "status": <one of "OK" or "FAILED">
519
                # }
520
                buf = client.recv(1024)
521
                req = json.loads(buf)
522
                
523
                sport_orig = int(req['source_port'])
524
                daddr = req['destination_address']
525
                dport = int(req['destination_port'])
526
                password = req['password']
527
            except Exception, e:
528
                logger.warn("Malformed request: %s" % buf)
529
                client.send(json.dumps(response))
530
                client.close()
531
                continue
532
            
533
            # Spawn a new Greenlet to service the request.
534
            server = None
535
            try:
536
                # If the client has so indicated, pick an ephemeral source port
537
                # randomly, and remove it from the port pool.
538
                if sport_orig == 0:
539
                    sport = random.choice(ports)
540
                    ports.remove(sport)
541
                    logger.debug("Got port %d from port pool, contains %d ports",
542
                        sport, len(ports))
543
                    pool = ports
544
                else:
545
                    sport = sport_orig
546
                    pool = None
547

    
548
                listeners = get_listening_sockets(sport)
549
                server = perform_server_handshake(daddr, dport)
550

    
551
                VncAuthProxy.spawn(logger, listeners, pool, daddr, dport,
552
                    server, password, opts.connect_timeout)
553

    
554
                logger.info("New forwarding [%d (req'd by client: %d) -> %s:%d]" %
555
                    (sport, sport_orig, daddr, dport))
556
                response = {
557
                    "source_port": sport,
558
                    "status": "OK"
559
                }
560
            except IndexError:
561
                logger.error("FAILED forwarding, out of ports for [req'd by "
562
                    "client: %d -> %s:%d]" % (sport_orig, daddr, dport))
563
            except Exception, msg:
564
                logger.error(msg)
565
                logger.error("FAILED forwarding [%d (req'd by client: %d) -> %s:%d]" %
566
                    (sport, sport_orig, daddr, dport))
567
                if not pool is None:
568
                    pool.append(sport)
569
                    logger.debug("Returned port %d to port pool, contains %d ports",
570
                        sport, len(pool))
571
                if not server is None:
572
                    server.close()
573
            finally:
574
                client.send(json.dumps(response))
575
                client.close()
576
        except Exception, e:
577
            logger.exception(e)
578
            continue
579
        except SystemExit:
580
            break
581
 
582
    logger.info("Unlinking control socket at %s" %
583
                 opts.ctrl_socket)
584
    os.unlink(opts.ctrl_socket)
585
    daemon_context.close()
586
    sys.exit(0)