Move cluster verification out of prepare-node-join
authorMichael Hanselmann <hansmi@google.com>
Wed, 28 Nov 2012 09:03:40 +0000 (10:03 +0100)
committerMichael Hanselmann <hansmi@google.com>
Wed, 28 Nov 2012 11:48:47 +0000 (12:48 +0100)
A new tool for configuring the node daemon will also have to verify the
cluster name, so it's better to have this function in a central place.
In the process of moving it to ssconf it is also changed to use
“SimpleStore” instead of reading the ssconf directly. Tests are updated.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Helga Velroyen <helgav@google.com>

lib/ssconf.py
lib/tools/prepare_node_join.py
test/ganeti.ssconf_unittest.py
test/ganeti.tools.prepare_node_join_unittest.py

index ad5969c..7e34b5d 100644 (file)
@@ -28,6 +28,7 @@ configuration data, which is mostly static and available to all nodes.
 
 import sys
 import errno
+import logging
 
 from ganeti import errors
 from ganeti import constants
@@ -368,3 +369,21 @@ def CheckMaster(debug, ss=None):
     if debug:
       sys.stderr.write("Not master, exiting.\n")
     sys.exit(constants.EXIT_NOTMASTER)
+
+
+def VerifyClusterName(name, _cfg_location=None):
+  """Verifies cluster name against a local cluster name.
+
+  @type name: string
+  @param name: Cluster name
+
+  """
+  sstore = SimpleStore(cfg_location=_cfg_location)
+
+  try:
+    local_name = sstore.GetClusterName()
+  except errors.ConfigurationError, err:
+    logging.debug("Can't get local cluster name: %s", err)
+  else:
+    if name != local_name:
+      raise errors.GenericError("Current cluster name is '%s'" % local_name)
index b88e02e..2441785 100644 (file)
@@ -150,31 +150,7 @@ def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
     _verify_fn(cert)
 
 
-def _VerifyClusterName(name, _ss_cluster_name_file=None):
-  """Verifies cluster name against a local cluster name.
-
-  @type name: string
-  @param name: Cluster name
-
-  """
-  if _ss_cluster_name_file is None:
-    _ss_cluster_name_file = \
-      ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME)
-
-  try:
-    local_name = utils.ReadOneLineFile(_ss_cluster_name_file)
-  except EnvironmentError, err:
-    if err.errno != errno.ENOENT:
-      raise
-
-    logging.debug("Local cluster name was not found (file %s)",
-                  _ss_cluster_name_file)
-  else:
-    if name != local_name:
-      raise JoinError("Current cluster name is '%s'" % local_name)
-
-
-def VerifyClusterName(data, _verify_fn=_VerifyClusterName):
+def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
   """Verifies cluster name.
 
   @type data: dict
index 86d93be..1db88e8 100755 (executable)
@@ -141,5 +141,33 @@ class TestSimpleStore(unittest.TestCase):
                      "cluster.example.com")
 
 
+class TestVerifyClusterName(unittest.TestCase):
+  def setUp(self):
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    shutil.rmtree(self.tmpdir)
+
+  def testMissingFile(self):
+    tmploc = utils.PathJoin(self.tmpdir, "does-not-exist")
+    ssconf.VerifyClusterName(NotImplemented, _cfg_location=tmploc)
+
+  def testMatchingName(self):
+    tmpfile = utils.PathJoin(self.tmpdir, "ssconf_cluster_name")
+
+    for content in ["cluster.example.com", "cluster.example.com\n\n"]:
+      utils.WriteFile(tmpfile, data=content)
+      ssconf.VerifyClusterName("cluster.example.com",
+                               _cfg_location=self.tmpdir)
+
+  def testNameMismatch(self):
+    tmpfile = utils.PathJoin(self.tmpdir, "ssconf_cluster_name")
+
+    for content in ["something.example.com", "foobar\n\ncluster.example.com"]:
+      utils.WriteFile(tmpfile, data=content)
+      self.assertRaises(errors.GenericError, ssconf.VerifyClusterName,
+                        "cluster.example.com", _cfg_location=self.tmpdir)
+
+
 if __name__ == "__main__":
   testutils.GanetiTestProgram()
index 1cda5d2..c014280 100755 (executable)
@@ -130,26 +130,18 @@ class TestVerifyClusterName(unittest.TestCase):
     self.assertRaises(_JoinError, prepare_node_join.VerifyClusterName,
                       {}, _verify_fn=NotImplemented)
 
-  def testMissingFile(self):
-    tmpfile = utils.PathJoin(self.tmpdir, "does-not-exist")
-    prepare_node_join._VerifyClusterName(NotImplemented,
-                                         _ss_cluster_name_file=tmpfile)
-
-  def testMatchingName(self):
-    tmpfile = utils.PathJoin(self.tmpdir, "cluster_name")
-
-    for content in ["cluster.example.com", "cluster.example.com\n\n"]:
-      utils.WriteFile(tmpfile, data=content)
-      prepare_node_join._VerifyClusterName("cluster.example.com",
-                                           _ss_cluster_name_file=tmpfile)
+  @staticmethod
+  def _FailingVerify(name):
+    assert name == "cluster.example.com"
+    raise errors.GenericError()
 
-  def testNameMismatch(self):
-    tmpfile = utils.PathJoin(self.tmpdir, "cluster_name")
+  def testFailingVerification(self):
+    data = {
+      constants.SSHS_CLUSTER_NAME: "cluster.example.com",
+      }
 
-    for content in ["something.example.com", "foobar\n\ncluster.example.com"]:
-      utils.WriteFile(tmpfile, data=content)
-      self.assertRaises(_JoinError, prepare_node_join._VerifyClusterName,
-                        "cluster.example.com", _ss_cluster_name_file=tmpfile)
+    self.assertRaises(errors.GenericError, prepare_node_join.VerifyClusterName,
+                      data, _verify_fn=self._FailingVerify)
 
 
 class TestUpdateSshDaemon(unittest.TestCase):