Revision dffa96d6

b/lib/ssconf.py
28 28

  
29 29
import sys
30 30
import errno
31
import logging
31 32

  
32 33
from ganeti import errors
33 34
from ganeti import constants
......
368 369
    if debug:
369 370
      sys.stderr.write("Not master, exiting.\n")
370 371
    sys.exit(constants.EXIT_NOTMASTER)
372

  
373

  
374
def VerifyClusterName(name, _cfg_location=None):
375
  """Verifies cluster name against a local cluster name.
376

  
377
  @type name: string
378
  @param name: Cluster name
379

  
380
  """
381
  sstore = SimpleStore(cfg_location=_cfg_location)
382

  
383
  try:
384
    local_name = sstore.GetClusterName()
385
  except errors.ConfigurationError, err:
386
    logging.debug("Can't get local cluster name: %s", err)
387
  else:
388
    if name != local_name:
389
      raise errors.GenericError("Current cluster name is '%s'" % local_name)
b/lib/tools/prepare_node_join.py
150 150
    _verify_fn(cert)
151 151

  
152 152

  
153
def _VerifyClusterName(name, _ss_cluster_name_file=None):
154
  """Verifies cluster name against a local cluster name.
155

  
156
  @type name: string
157
  @param name: Cluster name
158

  
159
  """
160
  if _ss_cluster_name_file is None:
161
    _ss_cluster_name_file = \
162
      ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME)
163

  
164
  try:
165
    local_name = utils.ReadOneLineFile(_ss_cluster_name_file)
166
  except EnvironmentError, err:
167
    if err.errno != errno.ENOENT:
168
      raise
169

  
170
    logging.debug("Local cluster name was not found (file %s)",
171
                  _ss_cluster_name_file)
172
  else:
173
    if name != local_name:
174
      raise JoinError("Current cluster name is '%s'" % local_name)
175

  
176

  
177
def VerifyClusterName(data, _verify_fn=_VerifyClusterName):
153
def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
178 154
  """Verifies cluster name.
179 155

  
180 156
  @type data: dict
b/test/ganeti.ssconf_unittest.py
141 141
                     "cluster.example.com")
142 142

  
143 143

  
144
class TestVerifyClusterName(unittest.TestCase):
145
  def setUp(self):
146
    self.tmpdir = tempfile.mkdtemp()
147

  
148
  def tearDown(self):
149
    shutil.rmtree(self.tmpdir)
150

  
151
  def testMissingFile(self):
152
    tmploc = utils.PathJoin(self.tmpdir, "does-not-exist")
153
    ssconf.VerifyClusterName(NotImplemented, _cfg_location=tmploc)
154

  
155
  def testMatchingName(self):
156
    tmpfile = utils.PathJoin(self.tmpdir, "ssconf_cluster_name")
157

  
158
    for content in ["cluster.example.com", "cluster.example.com\n\n"]:
159
      utils.WriteFile(tmpfile, data=content)
160
      ssconf.VerifyClusterName("cluster.example.com",
161
                               _cfg_location=self.tmpdir)
162

  
163
  def testNameMismatch(self):
164
    tmpfile = utils.PathJoin(self.tmpdir, "ssconf_cluster_name")
165

  
166
    for content in ["something.example.com", "foobar\n\ncluster.example.com"]:
167
      utils.WriteFile(tmpfile, data=content)
168
      self.assertRaises(errors.GenericError, ssconf.VerifyClusterName,
169
                        "cluster.example.com", _cfg_location=self.tmpdir)
170

  
171

  
144 172
if __name__ == "__main__":
145 173
  testutils.GanetiTestProgram()
b/test/ganeti.tools.prepare_node_join_unittest.py
130 130
    self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
131 131
                      {}, _verify_fn=NotImplemented)
132 132

  
133
  def testMissingFile(self):
134
    tmpfile = utils.PathJoin(self.tmpdir, "does-not-exist")
135
    prepare_node_join._VerifyClusterName(NotImplemented,
136
                                         _ss_cluster_name_file=tmpfile)
137

  
138
  def testMatchingName(self):
139
    tmpfile = utils.PathJoin(self.tmpdir, "cluster_name")
140

  
141
    for content in ["cluster.example.com", "cluster.example.com\n\n"]:
142
      utils.WriteFile(tmpfile, data=content)
143
      prepare_node_join._VerifyClusterName("cluster.example.com",
144
                                           _ss_cluster_name_file=tmpfile)
133
  @staticmethod
134
  def _FailingVerify(name):
135
    assert name == "cluster.example.com"
136
    raise errors.GenericError()
145 137

  
146
  def testNameMismatch(self):
147
    tmpfile = utils.PathJoin(self.tmpdir, "cluster_name")
138
  def testFailingVerification(self):
139
    data = {
140
      constants.SSHS_CLUSTER_NAME: "cluster.example.com",
141
      }
148 142

  
149
    for content in ["something.example.com", "foobar\n\ncluster.example.com"]:
150
      utils.WriteFile(tmpfile, data=content)
151
      self.assertRaises(_JoinError, prepare_node_join._VerifyClusterName,
152
                        "cluster.example.com", _ss_cluster_name_file=tmpfile)
143
    self.assertRaises(errors.GenericError, prepare_node_join.VerifyClusterName,
144
                      data, _verify_fn=self._FailingVerify)
153 145

  
154 146

  
155 147
class TestUpdateSshDaemon(unittest.TestCase):

Also available in: Unified diff