Statistics
| Branch: | Tag: | Revision:

root / vncauthproxy / proxy.py @ 75eed2cf

History | View | Annotate | Download (21.4 kB)

1
#!/usr/bin/env python
2
"""
3
vncauthproxy - a VNC authentication proxy
4
"""
5
#
6
# Copyright (c) 2010-2011 Greek Research and Technology Network S.A.
7
#
8
# This program is free software; you can redistribute it and/or modify
9
# it under the terms of the GNU General Public License as published by
10
# the Free Software Foundation; either version 2 of the License, or
11
# (at your option) any later version.
12
#
13
# This program is distributed in the hope that it will be useful, but
14
# WITHOUT ANY WARRANTY; without even the implied warranty of
15
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16
# General Public License for more details.
17
#
18
# You should have received a copy of the GNU General Public License
19
# along with this program; if not, write to the Free Software
20
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
21
# 02110-1301, USA.
22

    
23
DEFAULT_CTRL_SOCKET = "/var/run/vncauthproxy/ctrl.sock"
24
DEFAULT_LOG_FILE = "/var/log/vncauthproxy/vncauthproxy.log"
25
DEFAULT_PID_FILE = "/var/run/vncauthproxy/vncauthproxy.pid"
26
DEFAULT_CONNECT_TIMEOUT = 30
27
DEFAULT_CONNECT_RETRIES = 3
28
DEFAULT_RETRY_WAIT = 0.1
29
# Default values per http://www.iana.org/assignments/port-numbers
30
DEFAULT_MIN_PORT = 49152
31
DEFAULT_MAX_PORT = 65535
32

    
33
import os
34
import sys
35
import logging
36
import gevent
37
import daemon
38
import random
39
import daemon.pidlockfile
40
import daemon.runner
41

    
42
import rfb
43

    
44
try:
45
    import simplejson as json
46
except ImportError:
47
    import json
48

    
49
from lockfile import LockTimeout
50
from gevent import socket
51
from signal import SIGINT, SIGTERM
52
from gevent import signal
53
from gevent.select import select
54
from time import sleep
55

    
56
logger = None
57

    
58
# Currently, gevent uses libevent-dns for asynchornous DNS resolution,
59
# which opens a socket upon initialization time. Since we can't get the fd
60
# reliably, We have to maintain all file descriptors open (which won't harm
61
# anyway)
62

    
63
class AllFilesDaemonContext(daemon.DaemonContext):
64
    """DaemonContext class keeping all file descriptors open"""
65
    def _get_exclude_file_descriptors(self):
66
        class All:
67
            def __contains__(self, value):
68
                return True
69
        return All()
70

    
71

    
72
class VncAuthProxy(gevent.Greenlet):
73
    """
74
    Simple class implementing a VNC Forwarder with MITM authentication as a
75
    Greenlet
76

77
    VncAuthProxy forwards VNC traffic from a specified port of the local host
78
    to a specified remote host:port. Furthermore, it implements VNC
79
    Authentication, intercepting the client/server handshake and asking the
80
    client for authentication even if the backend requires none.
81

82
    It is primarily intended for use in virtualization environments, as a VNC
83
    ``switch''.
84

85
    """
86
    id = 1
87

    
88
    def __init__(self, logger, listeners, pool, daddr, dport, server, password, connect_timeout):
89
        """
90
        @type logger: logging.Logger
91
        @param logger: the logger to use
92
        @type listeners: list
93
        @param listeners: list of listening sockets to use for client connections
94
        @type pool: list
95
        @param pool: if not None, return the client port number into this port pool
96
        @type daddr: str
97
        @param daddr: destination address (IPv4, IPv6 or hostname)
98
        @type dport: int
99
        @param dport: destination port
100
        @type server: socket
101
        @param server: VNC server socket
102
        @type password: str
103
        @param password: password to request from the client
104
        @type connect_timeout: int
105
        @param connect_timeout: how long to wait for client connections
106
                                (seconds)
107

108
        """
109
        gevent.Greenlet.__init__(self)
110
        self.id = VncAuthProxy.id
111
        VncAuthProxy.id += 1
112
        self.log = logger
113
        self.listeners = listeners
114
        # All listening sockets are assumed to be on the same port
115
        self.sport = listeners[0].getsockname()[1]
116
        self.pool = pool
117
        self.daddr = daddr
118
        self.dport = dport
119
        self.server = server
120
        self.password = password
121
        self.client = None
122
        self.timeout = connect_timeout
123

    
124
    def _cleanup(self):
125
        """Close all active sockets and exit gracefully"""
126
        # Reintroduce the port number of the client socket in
127
        # the port pool, if applicable.
128
        if not self.pool is None:
129
            self.pool.append(self.sport)
130
            self.log.debug("Returned port %d to port pool, contains %d ports",
131
                self.sport, len(self.pool))
132

    
133
        while self.listeners:
134
            self.listeners.pop().close()
135
        if self.server:
136
            self.server.close()
137
        if self.client:
138
            self.client.close()
139

    
140
        raise gevent.GreenletExit
141

    
142
    def info(self, msg):
143
        self.log.info("[C%d] %s" % (self.id, msg))
144

    
145
    def debug(self, msg):
146
        self.log.debug("[C%d] %s" % (self.id, msg))
147

    
148
    def warn(self, msg):
149
        self.log.warn("[C%d] %s" % (self.id, msg))
150

    
151
    def error(self, msg):
152
        self.log.error("[C%d] %s" % (self.id, msg))
153

    
154
    def critical(self, msg):
155
        self.log.critical("[C%d] %s" % (self.id, msg))
156

    
157
    def __str__(self):
158
        return "VncAuthProxy: %d -> %s:%d" % (self.sport, self.daddr, self.dport)
159

    
160
    def _forward(self, source, dest):
161
        """
162
        Forward traffic from source to dest
163

164
        @type source: socket
165
        @param source: source socket
166
        @type dest: socket
167
        @param dest: destination socket
168

169
        """
170

    
171
        while True:
172
            d = source.recv(16384)
173
            if d == '':
174
                if source == self.client:
175
                    self.info("Client connection closed")
176
                else:
177
                    self.info("Server connection closed")
178
                break
179
            dest.sendall(d)
180
        # No need to close the source and dest sockets here.
181
        # They are owned by and will be closed by the original greenlet.
182

    
183
    def _client_handshake(self):
184
        """
185
        Perform handshake/authentication with a connecting client
186

187
        Outline:
188
        1. Client connects
189
        2. We fake RFB 3.8 protocol and require VNC authentication [also supports RFB 3.3]
190
        3. Client accepts authentication method
191
        4. We send an authentication challenge
192
        5. Client sends the authentication response
193
        6. We check the authentication
194

195
        Upon return, self.client socket is connected to the client.
196

197
        """
198
        self.client.send(rfb.RFB_VERSION_3_8 + "\n")
199
        client_version_str = self.client.recv(1024)
200
        client_version = rfb.check_version(client_version_str)
201
        if not client_version:
202
            self.error("Invalid version: %s" % client_version_str)
203
            raise gevent.GreenletExit
204

    
205
        # Both for RFB 3.3 and 3.8
206
        self.debug("Requesting authentication")
207
        auth_request = rfb.make_auth_request(rfb.RFB_AUTHTYPE_VNC,
208
            version=client_version)
209
        self.client.send(auth_request)
210

    
211
        # The client gets to propose an authtype only for RFB 3.8
212
        if client_version == rfb.RFB_VERSION_3_8:
213
            res = self.client.recv(1024)
214
            type = rfb.parse_client_authtype(res)
215
            if type == rfb.RFB_AUTHTYPE_ERROR:
216
                self.warn("Client refused authentication: %s" % res[1:])
217
            else:
218
                self.debug("Client requested authtype %x" % type)
219

    
220
            if type != rfb.RFB_AUTHTYPE_VNC:
221
                self.error("Wrong auth type: %d" % type)
222
                self.client.send(rfb.to_u32(rfb.RFB_AUTH_ERROR))
223
                raise gevent.GreenletExit
224

    
225
        # Generate the challenge
226
        challenge = os.urandom(16)
227
        self.client.send(challenge)
228
        response = self.client.recv(1024)
229
        if len(response) != 16:
230
            self.error("Wrong response length %d, should be 16" % len(response))
231
            raise gevent.GreenletExit
232

    
233
        if rfb.check_password(challenge, response, self.password):
234
            self.debug("Authentication successful!")
235
        else:
236
            self.warn("Authentication failed")
237
            self.client.send(rfb.to_u32(rfb.RFB_AUTH_ERROR))
238
            raise gevent.GreenletExit
239

    
240
        # Accept the authentication
241
        self.client.send(rfb.to_u32(rfb.RFB_AUTH_SUCCESS))
242

    
243
    def _run(self):
244
        try:
245
            self.log.debug("Waiting for client to connect")
246
            rlist, _, _ = select(self.listeners, [], [], timeout=self.timeout)
247

    
248
            if not rlist:
249
                self.info("Timed out, no connection after %d sec" % self.timeout)
250
                raise gevent.GreenletExit
251

    
252
            for sock in rlist:
253
                self.client, addrinfo = sock.accept()
254
                self.info("Connection from %s:%d" % addrinfo[:2])
255

    
256
                # Close all listening sockets, we only want a one-shot connection
257
                # from a single client.
258
                while self.listeners:
259
                    self.listeners.pop().close()
260
                break
261

    
262
            # Perform RFB handshake with the client.
263
            self._client_handshake()
264

    
265
            # Bridge both connections through two "forwarder" greenlets.
266
            self.workers = [gevent.spawn(self._forward, self.client, self.server),
267
                gevent.spawn(self._forward, self.server, self.client)]
268

    
269
            # If one greenlet goes, the other has to go too.
270
            self.workers[0].link(self.workers[1])
271
            self.workers[1].link(self.workers[0])
272
            gevent.joinall(self.workers)
273
            del self.workers
274
            raise gevent.GreenletExit
275
        except Exception, e:
276
            # Any unhandled exception in the previous block
277
            # is an error and must be logged accordingly
278
            if not isinstance(e, gevent.GreenletExit):
279
                self.log.exception(e)
280
            raise e
281
        finally:
282
            self._cleanup()
283

    
284

    
285
def fatal_signal_handler(signame):
286
    logger.info("Caught %s, will raise SystemExit" % signame)
287
    raise SystemExit
288

    
289
def get_listening_sockets(sport):
290
    sockets = []
291

    
292
    # Use two sockets, one for IPv4, one for IPv6. IPv4-to-IPv6 mapped
293
    # addresses do not work reliably everywhere (under linux it may have
294
    # been disabled in /proc/sys/net/ipv6/bind_ipv6_only).
295
    for res in socket.getaddrinfo(None, sport, socket.AF_UNSPEC,
296
                                  socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
297
        af, socktype, proto, canonname, sa = res
298
        try:
299
            s = None
300
            s = socket.socket(af, socktype, proto)
301
            if af == socket.AF_INET6:
302
                # Bind v6 only when AF_INET6, otherwise either v4 or v6 bind
303
                # will fail.
304
                s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
305
            s.bind(sa)
306
            s.listen(1)
307
            sockets.append(s)
308
            logger.debug("Listening on %s:%d" % sa[:2])
309
        except socket.error, msg:
310
            logger.error("Error binding to %s:%d: %s" %
311
                           (sa[0], sa[1], msg[1]))
312
            if s:
313
                s.close()
314
            while sockets:
315
                sockets.pop().close()
316

    
317
            # Make sure we fail immediately if we cannot get a socket
318
            raise msg
319

    
320
    return sockets
321

    
322
def perform_server_handshake(daddr, dport, tries, retry_wait):
323
    """
324
    Initiate a connection with the backend server and perform basic
325
    RFB 3.8 handshake with it.
326

327
    Returns a socket connected to the backend server.
328

329
    """
330
    server = None
331

    
332
    while tries:
333
        tries -= 1
334

    
335
        # Initiate server connection
336
        for res in socket.getaddrinfo(daddr, dport, socket.AF_UNSPEC,
337
                                      socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
338
            af, socktype, proto, canonname, sa = res
339
            try:
340
                server = socket.socket(af, socktype, proto)
341
            except socket.error, msg:
342
                server = None
343
                continue
344

    
345
            try:
346
                logger.debug("Connecting to %s:%s" % sa[:2])
347
                server.connect(sa)
348
                logger.debug("Connection to %s:%s successful" % sa[:2])
349
            except socket.error, msg:
350
                server.close()
351
                server = None
352
                continue
353

    
354
            # We succesfully connected to the server
355
            tries = 0
356
            break
357

    
358
        # Wait and retry
359
        sleep(retry_wait)
360

    
361
    if server is None:
362
        raise Exception("Failed to connect to server")
363

    
364
    version = server.recv(1024)
365
    if not rfb.check_version(version):
366
        raise Exception("Unsupported RFB version: %s" % version.strip())
367

    
368
    server.send(rfb.RFB_VERSION_3_8 + "\n")
369

    
370
    res = server.recv(1024)
371
    types = rfb.parse_auth_request(res)
372
    if not types:
373
        raise Exception("Error handshaking with the server")
374

    
375
    else:
376
        logger.debug("Supported authentication types: %s" %
377
                       " ".join([str(x) for x in types]))
378

    
379
    if rfb.RFB_AUTHTYPE_NONE not in types:
380
        raise Exception("Error, server demands authentication")
381

    
382
    server.send(rfb.to_u8(rfb.RFB_AUTHTYPE_NONE))
383

    
384
    # Check authentication response
385
    res = server.recv(4)
386
    res = rfb.from_u32(res)
387

    
388
    if res != 0:
389
        raise Exception("Authentication error")
390

    
391
    return server
392

    
393
def parse_arguments(args):
394
    from optparse import OptionParser
395

    
396
    parser = OptionParser()
397
    parser.add_option("-s", "--socket", dest="ctrl_socket",
398
                      default=DEFAULT_CTRL_SOCKET,
399
                      metavar="PATH",
400
                      help="UNIX socket path for control connections (default: %s" %
401
                          DEFAULT_CTRL_SOCKET)
402
    parser.add_option("-d", "--debug", action="store_true", dest="debug",
403
                      help="Enable debugging information")
404
    parser.add_option("-l", "--log", dest="log_file",
405
                      default=DEFAULT_LOG_FILE,
406
                      metavar="FILE",
407
                      help="Write log to FILE instead of %s" % DEFAULT_LOG_FILE),
408
    parser.add_option('--pid-file', dest="pid_file",
409
                      default=DEFAULT_PID_FILE,
410
                      metavar='PIDFILE',
411
                      help="Save PID to file (default: %s)" %
412
                          DEFAULT_PID_FILE)
413
    parser.add_option("-t", "--connect-timeout", dest="connect_timeout",
414
                      default=DEFAULT_CONNECT_TIMEOUT, type="int", metavar="SECONDS",
415
                      help="How long to listen for clients to forward")
416
    parser.add_option("-r", "--connect-retries", dest="connect_retries",
417
                      default=DEFAULT_CONNECT_RETRIES, type="int",
418
                      metavar="RETRIES",
419
                      help="How many times to try to connect to the server")
420
    parser.add_option("-w", "--retry-wait", dest="retry_wait",
421
                      default=DEFAULT_RETRY_WAIT, type="float", metavar="SECONDS",
422
                      help="How long to wait between retrying to connect to the server")
423
    parser.add_option("-p", "--min-port", dest="min_port",
424
                      default=DEFAULT_MIN_PORT, type="int", metavar="MIN_PORT",
425
                      help="The minimum port to use for automatically-allocated ephemeral ports")
426
    parser.add_option("-P", "--max-port", dest="max_port",
427
                      default=DEFAULT_MAX_PORT, type="int", metavar="MAX_PORT",
428
                      help="The minimum port to use for automatically-allocated ephemeral ports")
429

    
430
    return parser.parse_args(args)
431

    
432

    
433
def main():
434
    """Run the daemon from the command line."""
435

    
436
    (opts, args) = parse_arguments(sys.argv[1:])
437

    
438
    # Create pidfile
439
    pidf = daemon.pidlockfile.TimeoutPIDLockFile(
440
        opts.pid_file, 10)
441

    
442
    # Initialize logger
443
    lvl = logging.DEBUG if opts.debug else logging.INFO
444

    
445
    global logger
446
    logger = logging.getLogger("vncauthproxy")
447
    logger.setLevel(lvl)
448
    formatter = logging.Formatter("%(asctime)s %(module)s[%(process)d] %(levelname)s: %(message)s",
449
        "%Y-%m-%d %H:%M:%S")
450
    handler = logging.FileHandler(opts.log_file)
451
    handler.setFormatter(formatter)
452
    logger.addHandler(handler)
453

    
454
    # Become a daemon:
455
    # Redirect stdout and stderr to handler.stream to catch
456
    # early errors in the daemonization process [e.g., pidfile creation]
457
    # which will otherwise go to /dev/null.
458
    daemon_context = AllFilesDaemonContext(
459
        pidfile=pidf,
460
        umask=0022,
461
        stdout=handler.stream,
462
        stderr=handler.stream,
463
        files_preserve=[handler.stream])
464

    
465
    # Remove any stale PID files, left behind by previous invocations
466
    if daemon.runner.is_pidfile_stale(pidf):
467
        logger.warning("Removing stale PID lock file %s", pidf.path)
468
        pidf.break_lock()
469

    
470
    try:
471
        daemon_context.open()
472
    except (daemon.pidlockfile.AlreadyLocked, LockTimeout):
473
        logger.critical("Failed to lock PID file %s, another instance running?",
474
                        pidf.path)
475
        sys.exit(1)
476
    logger.info("Became a daemon")
477

    
478
    # A fork() has occured while daemonizing,
479
    # we *must* reinit gevent
480
    gevent.reinit()
481

    
482
    if os.path.exists(opts.ctrl_socket):
483
        logger.critical("Socket '%s' already exists" % opts.ctrl_socket)
484
        sys.exit(1)
485

    
486
    # TODO: make this tunable? chgrp as well?
487
    old_umask = os.umask(0007)
488

    
489
    ctrl = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
490
    ctrl.bind(opts.ctrl_socket)
491

    
492
    os.umask(old_umask)
493

    
494
    ctrl.listen(1)
495
    logger.info("Initialized, waiting for control connections at %s" %
496
                 opts.ctrl_socket)
497

    
498
    # Catch signals to ensure graceful shutdown,
499
    # e.g., to make sure the control socket gets unlink()ed.
500
    #
501
    # Uses gevent.signal so the handler fires even during
502
    # gevent.socket.accept()
503
    gevent.signal(SIGINT, fatal_signal_handler, "SIGINT")
504
    gevent.signal(SIGTERM, fatal_signal_handler, "SIGTERM")
505

    
506
    # Init ephemeral port pool
507
    ports = range(opts.min_port, opts.max_port + 1)
508

    
509
    while True:
510
        try:
511
            client, addr = ctrl.accept()
512
            logger.info("New control connection")
513

    
514
            # Receive and parse a client request.
515
            response = {
516
                "source_port": 0,
517
                "status": "FAILED"
518
            }
519
            try:
520
                # TODO: support multiple forwardings in the same message?
521
                #
522
                # Control request, in JSON:
523
                #
524
                # {
525
                #     "source_port": <source port or 0 for automatic allocation>,
526
                #     "destination_address": <destination address of backend server>,
527
                #     "destination_port": <destination port>
528
                #     "password": <the password to use for MITM authentication of clients>
529
                # }
530
                #
531
                # The <password> is used for MITM authentication of clients
532
                # connecting to <source_port>, who will subsequently be forwarded
533
                # to a VNC server at <destination_address>:<destination_port>
534
                #
535
                # Control reply, in JSON:
536
                # {
537
                #     "source_port": <the allocated source port>
538
                #     "status": <one of "OK" or "FAILED">
539
                # }
540
                buf = client.recv(1024)
541
                req = json.loads(buf)
542

    
543
                sport_orig = int(req['source_port'])
544
                daddr = req['destination_address']
545
                dport = int(req['destination_port'])
546
                password = req['password']
547
            except Exception, e:
548
                logger.warn("Malformed request: %s" % buf)
549
                client.send(json.dumps(response))
550
                client.close()
551
                continue
552

    
553
            # Spawn a new Greenlet to service the request.
554
            server = None
555
            try:
556
                # If the client has so indicated, pick an ephemeral source port
557
                # randomly, and remove it from the port pool.
558
                if sport_orig == 0:
559
                    sport = random.choice(ports)
560
                    ports.remove(sport)
561
                    logger.debug("Got port %d from port pool, contains %d ports",
562
                        sport, len(ports))
563
                    pool = ports
564
                else:
565
                    sport = sport_orig
566
                    pool = None
567

    
568
                listeners = get_listening_sockets(sport)
569
                server = perform_server_handshake(daddr, dport,
570
                    opts.connect_retries, opts.retry_wait)
571

    
572
                VncAuthProxy.spawn(logger, listeners, pool, daddr, dport,
573
                    server, password, opts.connect_timeout)
574

    
575
                logger.info("New forwarding [%d (req'd by client: %d) -> %s:%d]" %
576
                    (sport, sport_orig, daddr, dport))
577
                response = {
578
                    "source_port": sport,
579
                    "status": "OK"
580
                }
581
            except IndexError:
582
                logger.error("FAILED forwarding, out of ports for [req'd by "
583
                    "client: %d -> %s:%d]" % (sport_orig, daddr, dport))
584
            except Exception, msg:
585
                logger.error(msg)
586
                logger.error("FAILED forwarding [%d (req'd by client: %d) -> %s:%d]" %
587
                    (sport, sport_orig, daddr, dport))
588
                if not pool is None:
589
                    pool.append(sport)
590
                    logger.debug("Returned port %d to port pool, contains %d ports",
591
                        sport, len(pool))
592
                if not server is None:
593
                    server.close()
594
            finally:
595
                client.send(json.dumps(response))
596
                client.close()
597
        except Exception, e:
598
            logger.exception(e)
599
            continue
600
        except SystemExit:
601
            break
602

    
603
    logger.info("Unlinking control socket at %s" %
604
                 opts.ctrl_socket)
605
    os.unlink(opts.ctrl_socket)
606
    daemon_context.close()
607
    sys.exit(0)