Statistics
| Branch: | Tag: | Revision:

root / vncauthproxy / proxy.py @ 7eb27319

History | View | Annotate | Download (21 kB)

1
#!/usr/bin/env python
2
"""
3
vncauthproxy - a VNC authentication proxy
4
"""
5
#
6
# Copyright (c) 2010-2011 Greek Research and Technology Network S.A.
7
#
8
# This program is free software; you can redistribute it and/or modify
9
# it under the terms of the GNU General Public License as published by
10
# the Free Software Foundation; either version 2 of the License, or
11
# (at your option) any later version.
12
#
13
# This program is distributed in the hope that it will be useful, but
14
# WITHOUT ANY WARRANTY; without even the implied warranty of
15
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16
# General Public License for more details.
17
#
18
# You should have received a copy of the GNU General Public License
19
# along with this program; if not, write to the Free Software
20
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
21
# 02110-1301, USA.
22

    
23
DEFAULT_CTRL_SOCKET = "/var/run/vncauthproxy/ctrl.sock"
24
DEFAULT_LOG_FILE = "/var/log/vncauthproxy/vncauthproxy.log"
25
DEFAULT_PID_FILE = "/var/run/vncauthproxy/vncauthproxy.pid"
26
DEFAULT_CONNECT_TIMEOUT = 30
27
DEFAULT_CONNECT_RETRIES = 3
28
DEFAULT_RETRY_WAIT = 0.1
29
# Default values per http://www.iana.org/assignments/port-numbers
30
DEFAULT_MIN_PORT = 49152 
31
DEFAULT_MAX_PORT = 65535
32

    
33
import os
34
import sys
35
import logging
36
import gevent
37
import daemon
38
import random
39
import daemon.pidlockfile
40

    
41
import rfb
42
 
43
try:
44
    import simplejson as json
45
except ImportError:
46
    import json
47

    
48
from gevent import socket
49
from signal import SIGINT, SIGTERM
50
from gevent import signal
51
from gevent.select import select
52
from time import sleep
53

    
54
logger = None
55

    
56
# Currently, gevent uses libevent-dns for asynchornous DNS resolution,
57
# which opens a socket upon initialization time. Since we can't get the fd
58
# reliably, We have to maintain all file descriptors open (which won't harm
59
# anyway)
60

    
61
class AllFilesDaemonContext(daemon.DaemonContext):
62
    """DaemonContext class keeping all file descriptors open"""
63
    def _get_exclude_file_descriptors(self):
64
        class All:
65
            def __contains__(self, value):
66
                return True
67
        return All()
68

    
69

    
70
class VncAuthProxy(gevent.Greenlet):
71
    """
72
    Simple class implementing a VNC Forwarder with MITM authentication as a
73
    Greenlet
74

75
    VncAuthProxy forwards VNC traffic from a specified port of the local host
76
    to a specified remote host:port. Furthermore, it implements VNC
77
    Authentication, intercepting the client/server handshake and asking the
78
    client for authentication even if the backend requires none.
79

80
    It is primarily intended for use in virtualization environments, as a VNC
81
    ``switch''.
82

83
    """
84
    id = 1
85

    
86
    def __init__(self, logger, listeners, pool, daddr, dport, server, password, connect_timeout):
87
        """
88
        @type logger: logging.Logger
89
        @param logger: the logger to use
90
        @type listeners: list
91
        @param listeners: list of listening sockets to use for client connections
92
        @type pool: list
93
        @param pool: if not None, return the client port number into this port pool
94
        @type daddr: str
95
        @param daddr: destination address (IPv4, IPv6 or hostname)
96
        @type dport: int
97
        @param dport: destination port
98
        @type server: socket
99
        @param server: VNC server socket
100
        @type password: str
101
        @param password: password to request from the client
102
        @type connect_timeout: int
103
        @param connect_timeout: how long to wait for client connections
104
                                (seconds)
105

106
        """
107
        gevent.Greenlet.__init__(self)
108
        self.id = VncAuthProxy.id
109
        VncAuthProxy.id += 1
110
        self.log = logger
111
        self.listeners = listeners
112
        # All listening sockets are assumed to be on the same port
113
        self.sport = listeners[0].getsockname()[1]
114
        self.pool = pool
115
        self.daddr = daddr
116
        self.dport = dport
117
        self.server = server
118
        self.password = password
119
        self.client = None
120
        self.timeout = connect_timeout
121

    
122
    def _cleanup(self):
123
        """Close all active sockets and exit gracefully"""
124
        # Reintroduce the port number of the client socket in
125
        # the port pool, if applicable.
126
        if not self.pool is None:
127
            self.pool.append(self.sport)
128
            self.log.debug("Returned port %d to port pool, contains %d ports",
129
                self.sport, len(self.pool))
130

    
131
        while self.listeners:
132
            self.listeners.pop().close()
133
        if self.server:
134
            self.server.close()
135
        if self.client:
136
            self.client.close()
137

    
138
        raise gevent.GreenletExit
139

    
140
    def info(self, msg):
141
        self.log.info("[C%d] %s" % (self.id, msg))
142

    
143
    def debug(self, msg):
144
        self.log.debug("[C%d] %s" % (self.id, msg))
145

    
146
    def warn(self, msg):
147
        self.log.warn("[C%d] %s" % (self.id, msg))
148

    
149
    def error(self, msg):
150
        self.log.error("[C%d] %s" % (self.id, msg))
151

    
152
    def critical(self, msg):
153
        self.log.critical("[C%d] %s" % (self.id, msg))
154

    
155
    def __str__(self):
156
        return "VncAuthProxy: %d -> %s:%d" % (self.sport, self.daddr, self.dport)
157

    
158
    def _forward(self, source, dest):
159
        """
160
        Forward traffic from source to dest
161

162
        @type source: socket
163
        @param source: source socket
164
        @type dest: socket
165
        @param dest: destination socket
166

167
        """
168

    
169
        while True:
170
            d = source.recv(16384)
171
            if d == '':
172
                if source == self.client:
173
                    self.info("Client connection closed")
174
                else:
175
                    self.info("Server connection closed")
176
                break
177
            dest.sendall(d)
178
        # No need to close the source and dest sockets here.
179
        # They are owned by and will be closed by the original greenlet.
180

    
181
    def _client_handshake(self):
182
        """
183
        Perform handshake/authentication with a connecting client
184

185
        Outline:
186
        1. Client connects
187
        2. We fake RFB 3.8 protocol and require VNC authentication [also supports RFB 3.3]
188
        3. Client accepts authentication method
189
        4. We send an authentication challenge
190
        5. Client sends the authentication response
191
        6. We check the authentication
192

193
        Upon return, self.client socket is connected to the client.
194

195
        """
196
        self.client.send(rfb.RFB_VERSION_3_8 + "\n")
197
        client_version_str = self.client.recv(1024)
198
        client_version = rfb.check_version(client_version_str)
199
        if not client_version:
200
            self.error("Invalid version: %s" % client_version_str)
201
            raise gevent.GreenletExit
202

    
203
        # Both for RFB 3.3 and 3.8
204
        self.debug("Requesting authentication")
205
        auth_request = rfb.make_auth_request(rfb.RFB_AUTHTYPE_VNC,
206
            version=client_version)
207
        self.client.send(auth_request)
208

    
209
        # The client gets to propose an authtype only for RFB 3.8
210
        if client_version == rfb.RFB_VERSION_3_8:
211
            res = self.client.recv(1024)
212
            type = rfb.parse_client_authtype(res)
213
            if type == rfb.RFB_AUTHTYPE_ERROR:
214
                self.warn("Client refused authentication: %s" % res[1:])
215
            else:
216
                self.debug("Client requested authtype %x" % type)
217

    
218
            if type != rfb.RFB_AUTHTYPE_VNC:
219
                self.error("Wrong auth type: %d" % type)
220
                self.client.send(rfb.to_u32(rfb.RFB_AUTH_ERROR))
221
                raise gevent.GreenletExit
222
        
223
        # Generate the challenge
224
        challenge = os.urandom(16)
225
        self.client.send(challenge)
226
        response = self.client.recv(1024)
227
        if len(response) != 16:
228
            self.error("Wrong response length %d, should be 16" % len(response))
229
            raise gevent.GreenletExit
230

    
231
        if rfb.check_password(challenge, response, self.password):
232
            self.debug("Authentication successful!")
233
        else:
234
            self.warn("Authentication failed")
235
            self.client.send(rfb.to_u32(rfb.RFB_AUTH_ERROR))
236
            raise gevent.GreenletExit
237

    
238
        # Accept the authentication
239
        self.client.send(rfb.to_u32(rfb.RFB_AUTH_SUCCESS))
240
       
241
    def _run(self):
242
        try:
243
            self.log.debug("Waiting for client to connect")
244
            rlist, _, _ = select(self.listeners, [], [], timeout=self.timeout)
245

    
246
            if not rlist:
247
                self.info("Timed out, no connection after %d sec" % self.timeout)
248
                raise gevent.GreenletExit
249

    
250
            for sock in rlist:
251
                self.client, addrinfo = sock.accept()
252
                self.info("Connection from %s:%d" % addrinfo[:2])
253

    
254
                # Close all listening sockets, we only want a one-shot connection
255
                # from a single client.
256
                while self.listeners:
257
                    self.listeners.pop().close()
258
                break
259
       
260
            # Perform RFB handshake with the client.
261
            self._client_handshake()
262

    
263
            # Bridge both connections through two "forwarder" greenlets.
264
            self.workers = [gevent.spawn(self._forward, self.client, self.server),
265
                gevent.spawn(self._forward, self.server, self.client)]
266
            
267
            # If one greenlet goes, the other has to go too.
268
            self.workers[0].link(self.workers[1])
269
            self.workers[1].link(self.workers[0])
270
            gevent.joinall(self.workers)
271
            del self.workers
272
            raise gevent.GreenletExit
273
        except Exception, e:
274
            # Any unhandled exception in the previous block
275
            # is an error and must be logged accordingly
276
            if not isinstance(e, gevent.GreenletExit):
277
                self.log.exception(e)
278
            raise e
279
        finally:
280
            self._cleanup()
281

    
282

    
283
def fatal_signal_handler(signame):
284
    logger.info("Caught %s, will raise SystemExit" % signame)
285
    raise SystemExit
286

    
287
def get_listening_sockets(sport):
288
    sockets = []
289

    
290
    # Use two sockets, one for IPv4, one for IPv6. IPv4-to-IPv6 mapped
291
    # addresses do not work reliably everywhere (under linux it may have
292
    # been disabled in /proc/sys/net/ipv6/bind_ipv6_only).
293
    for res in socket.getaddrinfo(None, sport, socket.AF_UNSPEC,
294
                                  socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
295
        af, socktype, proto, canonname, sa = res
296
        try:
297
            s = None
298
            s = socket.socket(af, socktype, proto)
299
            if af == socket.AF_INET6:
300
                # Bind v6 only when AF_INET6, otherwise either v4 or v6 bind
301
                # will fail.
302
                s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
303
            s.bind(sa)
304
            s.listen(1)
305
            sockets.append(s)
306
            logger.debug("Listening on %s:%d" % sa[:2])
307
        except socket.error, msg:
308
            logger.error("Error binding to %s:%d: %s" %
309
                           (sa[0], sa[1], msg[1]))
310
            if s:
311
                s.close()
312
            while sockets:
313
                sockets.pop().close()
314
            
315
            # Make sure we fail immediately if we cannot get a socket
316
            raise msg
317
    
318
    return sockets
319

    
320
def perform_server_handshake(daddr, dport, tries, retry_wait):
321
    """
322
    Initiate a connection with the backend server and perform basic
323
    RFB 3.8 handshake with it.
324

325
    Returns a socket connected to the backend server.
326

327
    """
328
    server = None
329

    
330
    while tries:
331
        tries -= 1
332

    
333
        # Initiate server connection
334
        for res in socket.getaddrinfo(daddr, dport, socket.AF_UNSPEC,
335
                                      socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
336
            af, socktype, proto, canonname, sa = res
337
            try:
338
                server = socket.socket(af, socktype, proto)
339
            except socket.error, msg:
340
                server = None
341
                continue
342

    
343
            try:
344
                logger.debug("Connecting to %s:%s" % sa[:2])
345
                server.connect(sa)
346
                logger.debug("Connection to %s:%s successful" % sa[:2])
347
            except socket.error, msg:
348
                server.close()
349
                server = None
350
                continue
351

    
352
            # We succesfully connected to the server
353
            tries = 0
354
            break
355

    
356
        # Wait and retry
357
        sleep(retry_wait)
358

    
359
    if server is None:
360
        raise Exception("Failed to connect to server")
361

    
362
    version = server.recv(1024)
363
    if not rfb.check_version(version):
364
        raise Exception("Unsupported RFB version: %s" % version.strip())
365

    
366
    server.send(rfb.RFB_VERSION_3_8 + "\n")
367

    
368
    res = server.recv(1024)
369
    types = rfb.parse_auth_request(res)
370
    if not types:
371
        raise Exception("Error handshaking with the server")
372

    
373
    else:
374
        logger.debug("Supported authentication types: %s" %
375
                       " ".join([str(x) for x in types]))
376

    
377
    if rfb.RFB_AUTHTYPE_NONE not in types:
378
        raise Exception("Error, server demands authentication")
379

    
380
    server.send(rfb.to_u8(rfb.RFB_AUTHTYPE_NONE))
381

    
382
    # Check authentication response
383
    res = server.recv(4)
384
    res = rfb.from_u32(res)
385

    
386
    if res != 0:
387
        raise Exception("Authentication error")
388

    
389
    return server
390

    
391
def parse_arguments(args):
392
    from optparse import OptionParser
393

    
394
    parser = OptionParser()
395
    parser.add_option("-s", "--socket", dest="ctrl_socket",
396
                      default=DEFAULT_CTRL_SOCKET,
397
                      metavar="PATH",
398
                      help="UNIX socket path for control connections (default: %s" %
399
                          DEFAULT_CTRL_SOCKET)
400
    parser.add_option("-d", "--debug", action="store_true", dest="debug",
401
                      help="Enable debugging information")
402
    parser.add_option("-l", "--log", dest="log_file",
403
                      default=DEFAULT_LOG_FILE,
404
                      metavar="FILE",
405
                      help="Write log to FILE instead of %s" % DEFAULT_LOG_FILE),
406
    parser.add_option('--pid-file', dest="pid_file",
407
                      default=DEFAULT_PID_FILE,
408
                      metavar='PIDFILE',
409
                      help="Save PID to file (default: %s)" %
410
                          DEFAULT_PID_FILE)
411
    parser.add_option("-t", "--connect-timeout", dest="connect_timeout",
412
                      default=DEFAULT_CONNECT_TIMEOUT, type="int", metavar="SECONDS",
413
                      help="How long to listen for clients to forward")
414
    parser.add_option("-r", "--connect-retries", dest="connect_retries",
415
                      default=DEFAULT_CONNECT_RETRIES, type="int",
416
                      metavar="RETRIES",
417
                      help="How many times to try to connect to the server")
418
    parser.add_option("-w", "--retry-wait", dest="retry_wait",
419
                      default=DEFAULT_RETRY_WAIT, type="float", metavar="SECONDS",
420
                      help="How long to wait between retrying to connect to the server")
421
    parser.add_option("-p", "--min-port", dest="min_port",
422
                      default=DEFAULT_MIN_PORT, type="int", metavar="MIN_PORT",
423
                      help="The minimum port to use for automatically-allocated ephemeral ports")
424
    parser.add_option("-P", "--max-port", dest="max_port",
425
                      default=DEFAULT_MAX_PORT, type="int", metavar="MAX_PORT",
426
                      help="The minimum port to use for automatically-allocated ephemeral ports")
427

    
428
    return parser.parse_args(args)
429

    
430

    
431
def main():
432
    """Run the daemon from the command line."""
433

    
434
    (opts, args) = parse_arguments(sys.argv[1:])
435

    
436
    # Create pidfile
437
    pidf = daemon.pidlockfile.TimeoutPIDLockFile(
438
        opts.pid_file, 10)
439
    
440
    # Initialize logger
441
    lvl = logging.DEBUG if opts.debug else logging.INFO
442

    
443
    global logger
444
    logger = logging.getLogger("vncauthproxy")
445
    logger.setLevel(lvl)
446
    formatter = logging.Formatter("%(asctime)s %(module)s[%(process)d] %(levelname)s: %(message)s",
447
        "%Y-%m-%d %H:%M:%S")
448
    handler = logging.FileHandler(opts.log_file)
449
    handler.setFormatter(formatter)
450
    logger.addHandler(handler)
451

    
452
    # Become a daemon:
453
    # Redirect stdout and stderr to handler.stream to catch
454
    # early errors in the daemonization process [e.g., pidfile creation]
455
    # which will otherwise go to /dev/null.
456
    daemon_context = AllFilesDaemonContext(
457
        pidfile=pidf,
458
        umask=0022,
459
        stdout=handler.stream,
460
        stderr=handler.stream,
461
        files_preserve=[handler.stream])
462
    daemon_context.open()
463
    logger.info("Became a daemon")
464

    
465
    # A fork() has occured while daemonizing,
466
    # we *must* reinit gevent
467
    gevent.reinit()
468

    
469
    if os.path.exists(opts.ctrl_socket):
470
        logger.critical("Socket '%s' already exists" % opts.ctrl_socket)
471
        sys.exit(1)
472

    
473
    # TODO: make this tunable? chgrp as well?
474
    old_umask = os.umask(0007)
475

    
476
    ctrl = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
477
    ctrl.bind(opts.ctrl_socket)
478

    
479
    os.umask(old_umask)
480

    
481
    ctrl.listen(1)
482
    logger.info("Initialized, waiting for control connections at %s" %
483
                 opts.ctrl_socket)
484

    
485
    # Catch signals to ensure graceful shutdown,
486
    # e.g., to make sure the control socket gets unlink()ed.
487
    #
488
    # Uses gevent.signal so the handler fires even during
489
    # gevent.socket.accept()
490
    gevent.signal(SIGINT, fatal_signal_handler, "SIGINT")
491
    gevent.signal(SIGTERM, fatal_signal_handler, "SIGTERM")
492

    
493
    # Init ephemeral port pool
494
    ports = range(opts.min_port, opts.max_port + 1) 
495

    
496
    while True:
497
        try:
498
            client, addr = ctrl.accept()
499
            logger.info("New control connection")
500
           
501
            # Receive and parse a client request.
502
            response = {
503
                "source_port": 0,
504
                "status": "FAILED"
505
            }
506
            try:
507
                # TODO: support multiple forwardings in the same message?
508
                # 
509
                # Control request, in JSON:
510
                #
511
                # {
512
                #     "source_port": <source port or 0 for automatic allocation>,
513
                #     "destination_address": <destination address of backend server>,
514
                #     "destination_port": <destination port>
515
                #     "password": <the password to use for MITM authentication of clients>
516
                # }
517
                # 
518
                # The <password> is used for MITM authentication of clients
519
                # connecting to <source_port>, who will subsequently be forwarded
520
                # to a VNC server at <destination_address>:<destination_port>
521
                #
522
                # Control reply, in JSON:
523
                # {
524
                #     "source_port": <the allocated source port>
525
                #     "status": <one of "OK" or "FAILED">
526
                # }
527
                buf = client.recv(1024)
528
                req = json.loads(buf)
529
                
530
                sport_orig = int(req['source_port'])
531
                daddr = req['destination_address']
532
                dport = int(req['destination_port'])
533
                password = req['password']
534
            except Exception, e:
535
                logger.warn("Malformed request: %s" % buf)
536
                client.send(json.dumps(response))
537
                client.close()
538
                continue
539
            
540
            # Spawn a new Greenlet to service the request.
541
            server = None
542
            try:
543
                # If the client has so indicated, pick an ephemeral source port
544
                # randomly, and remove it from the port pool.
545
                if sport_orig == 0:
546
                    sport = random.choice(ports)
547
                    ports.remove(sport)
548
                    logger.debug("Got port %d from port pool, contains %d ports",
549
                        sport, len(ports))
550
                    pool = ports
551
                else:
552
                    sport = sport_orig
553
                    pool = None
554

    
555
                listeners = get_listening_sockets(sport)
556
                server = perform_server_handshake(daddr, dport,
557
                    opts.connect_retries, opts.retry_wait)
558

    
559
                VncAuthProxy.spawn(logger, listeners, pool, daddr, dport,
560
                    server, password, opts.connect_timeout)
561

    
562
                logger.info("New forwarding [%d (req'd by client: %d) -> %s:%d]" %
563
                    (sport, sport_orig, daddr, dport))
564
                response = {
565
                    "source_port": sport,
566
                    "status": "OK"
567
                }
568
            except IndexError:
569
                logger.error("FAILED forwarding, out of ports for [req'd by "
570
                    "client: %d -> %s:%d]" % (sport_orig, daddr, dport))
571
            except Exception, msg:
572
                logger.error(msg)
573
                logger.error("FAILED forwarding [%d (req'd by client: %d) -> %s:%d]" %
574
                    (sport, sport_orig, daddr, dport))
575
                if not pool is None:
576
                    pool.append(sport)
577
                    logger.debug("Returned port %d to port pool, contains %d ports",
578
                        sport, len(pool))
579
                if not server is None:
580
                    server.close()
581
            finally:
582
                client.send(json.dumps(response))
583
                client.close()
584
        except Exception, e:
585
            logger.exception(e)
586
            continue
587
        except SystemExit:
588
            break
589
 
590
    logger.info("Unlinking control socket at %s" %
591
                 opts.ctrl_socket)
592
    os.unlink(opts.ctrl_socket)
593
    daemon_context.close()
594
    sys.exit(0)