Revision d12b9f66

b/.gitignore
94 94
/tools/kvm-ifup
95 95
/tools/ensure-dirs
96 96
/tools/vcluster-setup
97
/tools/prepare-node-join
97 98

  
98 99
# scripts
99 100
/scripts/gnt-backup
b/Makefile.am
315 315

  
316 316
pytools_PYTHON = \
317 317
	lib/tools/__init__.py \
318
	lib/tools/ensure_dirs.py
318
	lib/tools/ensure_dirs.py \
319
	lib/tools/prepare_node_join.py
319 320

  
320 321
utils_PYTHON = \
321 322
	lib/utils/__init__.py \
......
578 579

  
579 580
PYTHON_BOOTSTRAP = \
580 581
	$(PYTHON_BOOTSTRAP_SBIN) \
581
	tools/ensure-dirs
582
	tools/ensure-dirs \
583
	tools/prepare-node-join
582 584

  
583 585
qa_scripts = \
584 586
	qa/__init__.py \
......
690 692
	tools/check-cert-expired
691 693

  
692 694
nodist_pkglib_python_scripts = \
693
	tools/ensure-dirs
695
	tools/ensure-dirs \
696
	tools/prepare-node-join
694 697

  
695 698
myexeclib_SCRIPTS = \
696 699
	daemons/daemon-util \
......
822 825
	test/data/bdev-drbd-net-ip4.txt \
823 826
	test/data/bdev-drbd-net-ip6.txt \
824 827
	test/data/cert1.pem \
828
	test/data/cert2.pem \
825 829
	test/data/ip-addr-show-dummy0.txt \
826 830
	test/data/ip-addr-show-lo-ipv4.txt \
827 831
	test/data/ip-addr-show-lo-ipv6.txt \
......
926 930
	test/ganeti.ssh_unittest.py \
927 931
	test/ganeti.storage_unittest.py \
928 932
	test/ganeti.tools.ensure_dirs_unittest.py \
933
	test/ganeti.tools.prepare_node_join_unittest.py \
929 934
	test/ganeti.uidpool_unittest.py \
930 935
	test/ganeti.utils.algo_unittest.py \
931 936
	test/ganeti.utils.filelock_unittest.py \
......
1327 1332
daemons/ganeti-watcher: MODULE = ganeti.watcher
1328 1333
scripts/%: MODULE = ganeti.client.$(subst -,_,$(notdir $@))
1329 1334
tools/ensure-dirs: MODULE = ganeti.tools.ensure_dirs
1335
tools/prepare-node-join: MODULE = ganeti.tools.prepare_node_join
1330 1336
$(HS_BUILT_TEST_HELPERS): TESTROLE = $(patsubst htest/%,%,$@)
1331 1337

  
1332 1338
$(PYTHON_BOOTSTRAP): Makefile | stamp-directories
b/lib/constants.py
2049 2049
SSHK_DSA = "dsa"
2050 2050
SSHK_ALL = frozenset([SSHK_RSA, SSHK_DSA])
2051 2051

  
2052
# SSH authorized key types
2053
SSHAK_RSA = "ssh-rsa"
2054
SSHAK_DSS = "ssh-dss"
2055
SSHAK_ALL = frozenset([SSHAK_RSA, SSHAK_DSS])
2056

  
2057
# SSH setup
2058
SSHS_CLUSTER_NAME = "cluster_name"
2059
SSHS_FORCE = "force"
2060
SSHS_SSH_HOST_KEY = "ssh_host_key"
2061
SSHS_SSH_ROOT_KEY = "ssh_root_key"
2062
SSHS_NODE_DAEMON_CERTIFICATE = "node_daemon_certificate"
2063

  
2052 2064
# Do not re-export imported modules
2053 2065
del re, _vcsversion, _autoconf, socket, pathutils
b/lib/ssh.py
49 49

  
50 50

  
51 51
def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA,
52
                 _homedir_fn=utils.GetHomeDir):
52
                 _homedir_fn=None):
53 53
  """Return the paths of a user's SSH files.
54 54

  
55 55
  @type user: string
......
67 67
    exception is raised if C{~$user/.ssh} is not a directory
68 68

  
69 69
  """
70
  if _homedir_fn is None:
71
    _homedir_fn = utils.GetHomeDir
72

  
70 73
  user_dir = _homedir_fn(user)
71 74
  if not user_dir:
72 75
    raise errors.OpExecError("Cannot resolve home of user '%s'" % user)
b/lib/tools/prepare_node_join.py
1
#
2
#
3

  
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

  
21
"""Script to prepare a node for joining a cluster.
22

  
23
"""
24

  
25
import os
26
import os.path
27
import optparse
28
import sys
29
import logging
30
import errno
31
import OpenSSL
32

  
33
from ganeti import cli
34
from ganeti import constants
35
from ganeti import errors
36
from ganeti import pathutils
37
from ganeti import utils
38
from ganeti import serializer
39
from ganeti import ht
40
from ganeti import ssh
41
from ganeti import ssconf
42

  
43

  
44
_SSH_KEY_LIST = \
45
  ht.TListOf(ht.TAnd(ht.TIsLength(3),
46
                     ht.TItems([
47
                       ht.TElemOf(constants.SSHK_ALL),
48
                       ht.Comment("public")(ht.TNonEmptyString),
49
                       ht.Comment("private")(ht.TNonEmptyString),
50
                       ])))
51

  
52
_DATA_CHECK = ht.TStrictDict(False, True, {
53
  constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
54
  constants.SSHS_FORCE: ht.TBool,
55
  constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
56
  constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
57
  constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
58
  })
59

  
60
_SSHK_TO_SSHAK = {
61
  constants.SSHK_RSA: constants.SSHAK_RSA,
62
  constants.SSHK_DSA: constants.SSHAK_DSS,
63
  }
64

  
65
_SSH_DAEMON_KEYFILES = {
66
  constants.SSHK_RSA:
67
    (pathutils.SSH_HOST_RSA_PUB, pathutils.SSH_HOST_RSA_PRIV),
68
  constants.SSHK_DSA:
69
    (pathutils.SSH_HOST_DSA_PUB, pathutils.SSH_HOST_DSA_PRIV),
70
    }
71

  
72
assert frozenset(_SSHK_TO_SSHAK.keys()) == constants.SSHK_ALL
73
assert frozenset(_SSHK_TO_SSHAK.values()) == constants.SSHAK_ALL
74

  
75

  
76
class JoinError(errors.GenericError):
77
  """Local class for reporting errors.
78

  
79
  """
80

  
81

  
82
def ParseOptions():
83
  """Parses the options passed to the program.
84

  
85
  @return: Options and arguments
86

  
87
  """
88
  program = os.path.basename(sys.argv[0])
89

  
90
  parser = optparse.OptionParser(usage="%prog [--dry-run]",
91
                                 prog=program)
92
  parser.add_option(cli.DEBUG_OPT)
93
  parser.add_option(cli.VERBOSE_OPT)
94
  parser.add_option(cli.DRY_RUN_OPT)
95

  
96
  (opts, args) = parser.parse_args()
97

  
98
  return VerifyOptions(parser, opts, args)
99

  
100

  
101
def VerifyOptions(parser, opts, args):
102
  """Verifies options and arguments for correctness.
103

  
104
  """
105
  if args:
106
    parser.error("No arguments are expected")
107

  
108
  return opts
109

  
110

  
111
def SetupLogging(opts):
112
  """Configures the logging module.
113

  
114
  """
115
  formatter = logging.Formatter("%(asctime)s: %(message)s")
116

  
117
  stderr_handler = logging.StreamHandler()
118
  stderr_handler.setFormatter(formatter)
119
  if opts.debug:
120
    stderr_handler.setLevel(logging.NOTSET)
121
  elif opts.verbose:
122
    stderr_handler.setLevel(logging.INFO)
123
  else:
124
    stderr_handler.setLevel(logging.WARNING)
125

  
126
  root_logger = logging.getLogger("")
127
  root_logger.setLevel(logging.NOTSET)
128
  root_logger.addHandler(stderr_handler)
129

  
130

  
131
def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
132
  """Verifies a certificate against the local node daemon certificate.
133

  
134
  @type cert: string
135
  @param cert: Certificate in PEM format (no key)
136

  
137
  """
138
  try:
139
    OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
140
  except OpenSSL.crypto.Error, err:
141
    pass
142
  else:
143
    raise JoinError("No private key may be given")
144

  
145
  try:
146
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
147
  except Exception, err:
148
    raise errors.X509CertError("(stdin)",
149
                               "Unable to load certificate: %s" % err)
150

  
151
  try:
152
    noded_pem = utils.ReadFile(_noded_cert_file)
153
  except EnvironmentError, err:
154
    if err.errno != errno.ENOENT:
155
      raise
156

  
157
    logging.debug("Local node certificate was not found (file %s)",
158
                  _noded_cert_file)
159
    return
160

  
161
  try:
162
    key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem)
163
  except Exception, err:
164
    raise errors.X509CertError(_noded_cert_file,
165
                               "Unable to load private key: %s" % err)
166

  
167
  ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
168
  ctx.use_privatekey(key)
169
  ctx.use_certificate(cert)
170
  try:
171
    ctx.check_privatekey()
172
  except OpenSSL.SSL.Error:
173
    raise JoinError("Given cluster certificate does not match local key")
174

  
175

  
176
def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
177
  """Verifies cluster certificate.
178

  
179
  @type data: dict
180

  
181
  """
182
  cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
183
  if cert:
184
    _verify_fn(cert)
185

  
186

  
187
def _VerifyClusterName(name, _ss_cluster_name_file=None):
188
  """Verifies cluster name against a local cluster name.
189

  
190
  @type name: string
191
  @param name: Cluster name
192

  
193
  """
194
  if _ss_cluster_name_file is None:
195
    _ss_cluster_name_file = \
196
      ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME)
197

  
198
  try:
199
    local_name = utils.ReadOneLineFile(_ss_cluster_name_file)
200
  except EnvironmentError, err:
201
    if err.errno != errno.ENOENT:
202
      raise
203

  
204
    logging.debug("Local cluster name was not found (file %s)",
205
                  _ss_cluster_name_file)
206
  else:
207
    if name != local_name:
208
      raise JoinError("Current cluster name is '%s'" % local_name)
209

  
210

  
211
def VerifyClusterName(data, _verify_fn=_VerifyClusterName):
212
  """Verifies cluster name.
213

  
214
  @type data: dict
215

  
216
  """
217
  name = data.get(constants.SSHS_CLUSTER_NAME)
218
  if name:
219
    _verify_fn(name)
220
  else:
221
    raise JoinError("Cluster name must be specified")
222

  
223

  
224
def _UpdateKeyFiles(keys, dry_run, keyfiles):
225
  """Updates SSH key files.
226

  
227
  @type keys: sequence of tuple; (string, string, string)
228
  @param keys: Keys to write, tuples consist of key type
229
    (L{constants.SSHK_ALL}), public and private key
230
  @type dry_run: boolean
231
  @param dry_run: Whether to perform a dry run
232
  @type keyfiles: dict; (string as key, tuple with (string, string) as values)
233
  @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
234
    names; value tuples consist of public key filename and private key filename
235

  
236
  """
237
  assert set(keyfiles) == constants.SSHK_ALL
238

  
239
  for (kind, public_key, private_key) in keys:
240
    (public_file, private_file) = keyfiles[kind]
241

  
242
    logging.debug("Writing %s ...", public_file)
243
    utils.WriteFile(public_file, data=public_key, mode=0644,
244
                    backup=True, dry_run=dry_run)
245

  
246
    logging.debug("Writing %s ...", private_file)
247
    utils.WriteFile(private_file, data=private_key, mode=0600,
248
                    backup=True, dry_run=dry_run)
249

  
250

  
251
def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
252
                    _keyfiles=None):
253
  """Updates SSH daemon's keys.
254

  
255
  Unless C{dry_run} is set, the daemon is restarted at the end.
256

  
257
  @type data: dict
258
  @param data: Input data
259
  @type dry_run: boolean
260
  @param dry_run: Whether to perform a dry run
261

  
262
  """
263
  keys = data.get(constants.SSHS_SSH_HOST_KEY)
264
  if not keys:
265
    return
266

  
267
  if _keyfiles is None:
268
    _keyfiles = _SSH_DAEMON_KEYFILES
269

  
270
  logging.info("Updating SSH daemon key files")
271
  _UpdateKeyFiles(keys, dry_run, _keyfiles)
272

  
273
  if dry_run:
274
    logging.info("This is a dry run, not restarting SSH daemon")
275
  else:
276
    result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
277
                        interactive=True)
278
    if result.failed:
279
      raise JoinError("Could not reload SSH keys, command '%s'"
280
                      " had exitcode %s and error %s" %
281
                       (result.cmd, result.exit_code, result.output))
282

  
283

  
284
def UpdateSshRoot(data, dry_run, _homedir_fn=None):
285
  """Updates root's SSH keys.
286

  
287
  Root's C{authorized_keys} file is also updated with new public keys.
288

  
289
  @type data: dict
290
  @param data: Input data
291
  @type dry_run: boolean
292
  @param dry_run: Whether to perform a dry run
293

  
294
  """
295
  keys = data.get(constants.SSHS_SSH_ROOT_KEY)
296
  if not keys:
297
    return
298

  
299
  (dsa_private_file, dsa_public_file, auth_keys_file) = \
300
    ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
301
                     kind=constants.SSHK_DSA, _homedir_fn=_homedir_fn)
302
  (rsa_private_file, rsa_public_file, _) = \
303
    ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
304
                     kind=constants.SSHK_RSA, _homedir_fn=_homedir_fn)
305

  
306
  _UpdateKeyFiles(keys, dry_run, {
307
    constants.SSHK_RSA: (rsa_public_file, rsa_private_file),
308
    constants.SSHK_DSA: (dsa_public_file, dsa_private_file),
309
    })
310

  
311
  if dry_run:
312
    logging.info("This is a dry run, not modifying %s", auth_keys_file)
313
  else:
314
    for (kind, public_key, _) in keys:
315
      line = "%s %s" % (_SSHK_TO_SSHAK[kind], public_key)
316
      utils.AddAuthorizedKey(auth_keys_file, line)
317

  
318

  
319
def LoadData(raw):
320
  """Parses and verifies input data.
321

  
322
  @rtype: dict
323

  
324
  """
325
  try:
326
    data = serializer.LoadJson(raw)
327
  except Exception, err:
328
    raise errors.ParseError("Can't parse input data: %s" % err)
329

  
330
  if not _DATA_CHECK(data):
331
    raise errors.ParseError("Input data does not match expected format: %s" %
332
                            _DATA_CHECK)
333

  
334
  return data
335

  
336

  
337
def Main():
338
  """Main routine.
339

  
340
  """
341
  opts = ParseOptions()
342

  
343
  SetupLogging(opts)
344

  
345
  try:
346
    data = LoadData(sys.stdin.read())
347

  
348
    # Check if input data is correct
349
    VerifyClusterName(data)
350
    VerifyCertificate(data)
351

  
352
    # Update SSH files
353
    UpdateSshDaemon(data, opts.dry_run)
354
    UpdateSshRoot(data, opts.dry_run)
355

  
356
    logging.info("Setup finished successfully")
357
  except Exception, err: # pylint: disable=W0703
358
    logging.debug("Caught unhandled exception", exc_info=True)
359

  
360
    (retcode, message) = cli.FormatError(err)
361
    logging.error(message)
362

  
363
    return retcode
364
  else:
365
    return constants.EXIT_SUCCESS
b/lib/utils/io.py
828 828
  return None
829 829

  
830 830

  
831
_SSH_KEYS_WITH_TWO_PARTS = frozenset(["ssh-dss", "ssh-rsa"])
832

  
833

  
834 831
def _SplitSshKey(key):
835 832
  """Splits a line for SSH's C{authorized_keys} file.
836 833

  
......
845 842
  """
846 843
  parts = key.split()
847 844

  
848
  if parts and parts[0] in _SSH_KEYS_WITH_TWO_PARTS:
845
  if parts and parts[0] in constants.SSHAK_ALL:
849 846
    # If the key has no options in front of it, we only want the significant
850 847
    # fields
851 848
    return (False, parts[:2])
b/test/data/cert2.pem
1
-----BEGIN PRIVATE KEY-----
2
MIIBUwIBADANBgkqhkiG9w0BAQEFAASCAT0wggE5AgEAAkEAt8OZYvvi8noVPlpR
3
/SrHcya9ne7RG5DjvMssksUqyGriUs/WGnpZlL4nz+BcLFGwNNntoxqR30Tjk47S
4
cmSBRQIDAQABAkAqTP5MCMuPIYcuWUAyVNygpzRS3JyKCepClUpnZreYdo4sUQE3
5
/AM7xeb92R06iZ3f9/MPrbaMKTWRh3uCyfKBAiEA5TxdacnVxdS8+ZLyys4p/C1s
6
iajrarBb/j+NIAnsdnECIQDNOCDO7Jq/iN5qE4Vbi/3zmnP1Ca5aBo+KJ/hhSjRq
7
FQIgIBpWEqybbXsfg+waaGB67MAHxTeM0IImP/LydpwtK2ECIB3SrlHj6Ik1Jr1b
8
oOGw8nLYW0mc4o2KrolxTZM16XARAiBKW3aSjY5UrnoEqa8pAeiO8LJaRj73Epmr
9
zC89IuLZfg==
10
-----END PRIVATE KEY-----
11
-----BEGIN CERTIFICATE-----
12
MIIB0zCCAX2gAwIBAgIJAKrAqGX6UolVMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
13
BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
14
aWRnaXRzIFB0eSBMdGQwHhcNMTIxMDE5MTQ1NjA4WhcNMTIxMDIwMTQ1NjA4WjBF
15
MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
16
ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALfD
17
mWL74vJ6FT5aUf0qx3MmvZ3u0RuQ47zLLJLFKshq4lLP1hp6WZS+J8/gXCxRsDTZ
18
7aMakd9E45OO0nJkgUUCAwEAAaNQME4wHQYDVR0OBBYEFA1Fc/GIVtd6nMocrSsA
19
e5bxmVhMMB8GA1UdIwQYMBaAFA1Fc/GIVtd6nMocrSsAe5bxmVhMMAwGA1UdEwQF
20
MAMBAf8wDQYJKoZIhvcNAQEFBQADQQCTUwzDGU+IJTQ3PIJrA3fHMyKbBvc4Rkvi
21
ZNFsmgsidWhb+5APlPjtlS7rXlonNHBzDoGb4RNArtxhEx+rBcAE
22
-----END CERTIFICATE-----
b/test/ganeti.tools.prepare_node_join_unittest.py
1
#!/usr/bin/python
2
#
3

  
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

  
21

  
22
"""Script for testing ganeti.tools.prepare_node_join"""
23

  
24
import unittest
25
import shutil
26
import tempfile
27
import os.path
28
import OpenSSL
29

  
30
from ganeti import errors
31
from ganeti import constants
32
from ganeti import serializer
33
from ganeti import pathutils
34
from ganeti import compat
35
from ganeti import utils
36
from ganeti.tools import prepare_node_join
37

  
38
import testutils
39

  
40

  
41
_JoinError = prepare_node_join.JoinError
42

  
43

  
44
class TestLoadData(unittest.TestCase):
45
  def testNoJson(self):
46
    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "")
47
    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, "}")
48

  
49
  def testInvalidDataStructure(self):
50
    raw = serializer.DumpJson({
51
      "some other thing": False,
52
      })
53
    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
54

  
55
    raw = serializer.DumpJson([])
56
    self.assertRaises(errors.ParseError, prepare_node_join.LoadData, raw)
57

  
58
  def testValidData(self):
59
    raw = serializer.DumpJson({})
60
    self.assertEqual(prepare_node_join.LoadData(raw), {})
61

  
62

  
63
class TestVerifyCertificate(testutils.GanetiTestCase):
64
  def setUp(self):
65
    testutils.GanetiTestCase.setUp(self)
66
    self.tmpdir = tempfile.mkdtemp()
67

  
68
  def tearDown(self):
69
    testutils.GanetiTestCase.tearDown(self)
70
    shutil.rmtree(self.tmpdir)
71

  
72
  def testNoCert(self):
73
    prepare_node_join.VerifyCertificate({}, _verify_fn=NotImplemented)
74

  
75
  def testMismatchingKey(self):
76
    other_cert = self._TestDataFilename("cert1.pem")
77
    node_cert = self._TestDataFilename("cert2.pem")
78

  
79
    self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
80
                      utils.ReadFile(other_cert), _noded_cert_file=node_cert)
81

  
82
  def testGivenPrivateKey(self):
83
    cert_filename = self._TestDataFilename("cert2.pem")
84
    cert_pem = utils.ReadFile(cert_filename)
85

  
86
    self.assertRaises(_JoinError, prepare_node_join._VerifyCertificate,
87
                      cert_pem, _noded_cert_file=cert_filename)
88

  
89
  def testMatchingKey(self):
90
    cert_filename = self._TestDataFilename("cert2.pem")
91

  
92
    # Extract certificate
93
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
94
                                           utils.ReadFile(cert_filename))
95
    cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM,
96
                                               cert)
97

  
98
    prepare_node_join._VerifyCertificate(cert_pem,
99
                                         _noded_cert_file=cert_filename)
100

  
101
  def testMissingFile(self):
102
    cert = self._TestDataFilename("cert1.pem")
103
    nodecert = utils.PathJoin(self.tmpdir, "does-not-exist")
104
    prepare_node_join._VerifyCertificate(utils.ReadFile(cert),
105
                                         _noded_cert_file=nodecert)
106

  
107
  def testInvalidCertificate(self):
108
    self.assertRaises(errors.X509CertError,
109
                      prepare_node_join._VerifyCertificate,
110
                      "Something that's not a certificate",
111
                      _noded_cert_file=NotImplemented)
112

  
113
  def testNoPrivateKey(self):
114
    cert = self._TestDataFilename("cert1.pem")
115
    self.assertRaises(errors.X509CertError,
116
                      prepare_node_join._VerifyCertificate,
117
                      utils.ReadFile(cert), _noded_cert_file=cert)
118

  
119

  
120
class TestVerifyClusterName(unittest.TestCase):
121
  def setUp(self):
122
    unittest.TestCase.setUp(self)
123
    self.tmpdir = tempfile.mkdtemp()
124

  
125
  def tearDown(self):
126
    unittest.TestCase.tearDown(self)
127
    shutil.rmtree(self.tmpdir)
128

  
129
  def testNoName(self):
130
    self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
131
                      {}, _verify_fn=NotImplemented)
132

  
133
  def testMissingFile(self):
134
    tmpfile = utils.PathJoin(self.tmpdir, "does-not-exist")
135
    prepare_node_join._VerifyClusterName(NotImplemented,
136
                                         _ss_cluster_name_file=tmpfile)
137

  
138
  def testMatchingName(self):
139
    tmpfile = utils.PathJoin(self.tmpdir, "cluster_name")
140

  
141
    for content in ["cluster.example.com", "cluster.example.com\n\n"]:
142
      utils.WriteFile(tmpfile, data=content)
143
      prepare_node_join._VerifyClusterName("cluster.example.com",
144
                                           _ss_cluster_name_file=tmpfile)
145

  
146
  def testNameMismatch(self):
147
    tmpfile = utils.PathJoin(self.tmpdir, "cluster_name")
148

  
149
    for content in ["something.example.com", "foobar\n\ncluster.example.com"]:
150
      utils.WriteFile(tmpfile, data=content)
151
      self.assertRaises(_JoinError, prepare_node_join._VerifyClusterName,
152
                        "cluster.example.com", _ss_cluster_name_file=tmpfile)
153

  
154

  
155
class TestUpdateSshDaemon(unittest.TestCase):
156
  def setUp(self):
157
    unittest.TestCase.setUp(self)
158
    self.tmpdir = tempfile.mkdtemp()
159

  
160
    self.keyfiles = {
161
      constants.SSHK_RSA:
162
        (utils.PathJoin(self.tmpdir, "rsa.public"),
163
         utils.PathJoin(self.tmpdir, "rsa.private")),
164
      constants.SSHK_DSA:
165
        (utils.PathJoin(self.tmpdir, "dsa.public"),
166
         utils.PathJoin(self.tmpdir, "dsa.private")),
167
      }
168

  
169
  def tearDown(self):
170
    unittest.TestCase.tearDown(self)
171
    shutil.rmtree(self.tmpdir)
172

  
173
  def testNoKeys(self):
174
    data_empty_keys = {
175
      constants.SSHS_SSH_HOST_KEY: [],
176
      }
177

  
178
    for data in [{}, data_empty_keys]:
179
      for dry_run in [False, True]:
180
        prepare_node_join.UpdateSshDaemon(data, dry_run,
181
                                          _runcmd_fn=NotImplemented,
182
                                          _keyfiles=NotImplemented)
183
    self.assertEqual(os.listdir(self.tmpdir), [])
184

  
185
  def _TestDryRun(self, data):
186
    prepare_node_join.UpdateSshDaemon(data, True, _runcmd_fn=NotImplemented,
187
                                      _keyfiles=self.keyfiles)
188
    self.assertEqual(os.listdir(self.tmpdir), [])
189

  
190
  def testDryRunRsa(self):
191
    self._TestDryRun({
192
      constants.SSHS_SSH_HOST_KEY: [
193
        (constants.SSHK_RSA, "rsapub", "rsapriv"),
194
        ],
195
      })
196

  
197
  def testDryRunDsa(self):
198
    self._TestDryRun({
199
      constants.SSHS_SSH_HOST_KEY: [
200
        (constants.SSHK_DSA, "dsapub", "dsapriv"),
201
        ],
202
      })
203

  
204
  def _RunCmd(self, fail, cmd, interactive=NotImplemented):
205
    self.assertTrue(interactive)
206
    self.assertEqual(cmd, [pathutils.DAEMON_UTIL, "reload-ssh-keys"])
207
    if fail:
208
      exit_code = constants.EXIT_FAILURE
209
    else:
210
      exit_code = constants.EXIT_SUCCESS
211
    return utils.RunResult(exit_code, None, "stdout", "stderr",
212
                           utils.ShellQuoteArgs(cmd),
213
                           NotImplemented, NotImplemented)
214

  
215
  def _TestUpdate(self, failcmd):
216
    data = {
217
      constants.SSHS_SSH_HOST_KEY: [
218
        (constants.SSHK_DSA, "dsapub", "dsapriv"),
219
        (constants.SSHK_RSA, "rsapub", "rsapriv"),
220
        ],
221
      }
222
    runcmd_fn = compat.partial(self._RunCmd, failcmd)
223
    if failcmd:
224
      self.assertRaises(_JoinError, prepare_node_join.UpdateSshDaemon,
225
                        data, False, _runcmd_fn=runcmd_fn,
226
                        _keyfiles=self.keyfiles)
227
    else:
228
      prepare_node_join.UpdateSshDaemon(data, False, _runcmd_fn=runcmd_fn,
229
                                        _keyfiles=self.keyfiles)
230
    self.assertEqual(sorted(os.listdir(self.tmpdir)), sorted([
231
      "rsa.private", "rsa.public",
232
      "dsa.private", "dsa.public",
233
      ]))
234
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.public")),
235
                     "rsapub")
236
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "rsa.private")),
237
                     "rsapriv")
238
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.public")),
239
                     "dsapub")
240
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.tmpdir, "dsa.private")),
241
                     "dsapriv")
242

  
243
  def testSuccess(self):
244
    self._TestUpdate(False)
245

  
246
  def testFailure(self):
247
    self._TestUpdate(True)
248

  
249

  
250
class TestUpdateSshRoot(unittest.TestCase):
251
  def setUp(self):
252
    unittest.TestCase.setUp(self)
253
    self.tmpdir = tempfile.mkdtemp()
254
    self.sshdir = utils.PathJoin(self.tmpdir, ".ssh")
255

  
256
  def tearDown(self):
257
    unittest.TestCase.tearDown(self)
258
    shutil.rmtree(self.tmpdir)
259

  
260
  def _GetHomeDir(self, user):
261
    self.assertEqual(user, constants.SSH_LOGIN_USER)
262
    return self.tmpdir
263

  
264
  def testNoKeys(self):
265
    data_empty_keys = {
266
      constants.SSHS_SSH_ROOT_KEY: [],
267
      }
268

  
269
    for data in [{}, data_empty_keys]:
270
      for dry_run in [False, True]:
271
        prepare_node_join.UpdateSshRoot(data, dry_run,
272
                                        _homedir_fn=NotImplemented)
273
    self.assertEqual(os.listdir(self.tmpdir), [])
274

  
275
  def testDryRun(self):
276
    data = {
277
      constants.SSHS_SSH_ROOT_KEY: [
278
        (constants.SSHK_RSA, "aaa", "bbb"),
279
        ]
280
      }
281

  
282
    prepare_node_join.UpdateSshRoot(data, True,
283
                                    _homedir_fn=self._GetHomeDir)
284
    self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
285
    self.assertEqual(os.listdir(self.sshdir), [])
286

  
287
  def testUpdate(self):
288
    data = {
289
      constants.SSHS_SSH_ROOT_KEY: [
290
        (constants.SSHK_DSA, "pubdsa", "privatedsa"),
291
        ]
292
      }
293

  
294
    prepare_node_join.UpdateSshRoot(data, False,
295
                                    _homedir_fn=self._GetHomeDir)
296
    self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
297
    self.assertEqual(sorted(os.listdir(self.sshdir)),
298
                     sorted(["authorized_keys", "id_dsa", "id_dsa.pub"]))
299
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa")),
300
                     "privatedsa")
301
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir, "id_dsa.pub")),
302
                     "pubdsa")
303
    self.assertEqual(utils.ReadFile(utils.PathJoin(self.sshdir,
304
                                                   "authorized_keys")),
305
                     "ssh-dss pubdsa\n")
306

  
307

  
308
if __name__ == "__main__":
309
  testutils.GanetiTestProgram()

Also available in: Unified diff