Revision a9d68e40

b/lib/utils/io.py
117 117
  if backup and not dry_run and os.path.isfile(file_name):
118 118
    CreateBackup(file_name)
119 119

  
120
  dir_name, base_name = os.path.split(file_name)
121
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
120
  # Whether temporary file needs to be removed (e.g. if any error occurs)
122 121
  do_remove = True
123
  # here we need to make sure we remove the temp file, if any error
124
  # leaves it in place
122

  
123
  # Function result
124
  result = None
125

  
126
  (dir_name, base_name) = os.path.split(file_name)
127
  (fd, new_name) = tempfile.mkstemp(suffix=".new", prefix=base_name,
128
                                    dir=dir_name)
125 129
  try:
126
    if uid != -1 or gid != -1:
127
      os.chown(new_name, uid, gid)
128
    if mode:
129
      os.chmod(new_name, mode)
130
    if callable(prewrite):
131
      prewrite(fd)
132
    if data is not None:
133
      os.write(fd, data)
134
    else:
135
      fn(fd)
136
    if callable(postwrite):
137
      postwrite(fd)
138
    os.fsync(fd)
139
    if atime is not None and mtime is not None:
140
      os.utime(new_name, (atime, mtime))
130
    try:
131
      if uid != -1 or gid != -1:
132
        os.chown(new_name, uid, gid)
133
      if mode:
134
        os.chmod(new_name, mode)
135
      if callable(prewrite):
136
        prewrite(fd)
137
      if data is not None:
138
        os.write(fd, data)
139
      else:
140
        fn(fd)
141
      if callable(postwrite):
142
        postwrite(fd)
143
      os.fsync(fd)
144
      if atime is not None and mtime is not None:
145
        os.utime(new_name, (atime, mtime))
146
    finally:
147
      # Close file unless the file descriptor should be returned
148
      if close:
149
        os.close(fd)
150
      else:
151
        result = fd
152

  
153
    # Rename file to destination name
141 154
    if not dry_run:
142 155
      os.rename(new_name, file_name)
156
      # Successful, no need to remove anymore
143 157
      do_remove = False
144 158
  finally:
145
    if close:
146
      os.close(fd)
147
      result = None
148
    else:
149
      result = fd
150 159
    if do_remove:
151 160
      RemoveFile(new_name)
152 161

  
b/test/ganeti.utils.io_unittest.py
234 234

  
235 235
class TestWriteFile(unittest.TestCase):
236 236
  def setUp(self):
237
    self.tmpdir = None
237 238
    self.tfile = tempfile.NamedTemporaryFile()
238 239
    self.did_pre = False
239 240
    self.did_post = False
240 241
    self.did_write = False
241 242

  
243
  def tearDown(self):
244
    if self.tmpdir:
245
      shutil.rmtree(self.tmpdir)
246

  
242 247
  def markPre(self, fd):
243 248
    self.did_pre = True
244 249

  
......
260 265
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
261 266
                      self.tfile.name, data="test", atime=0)
262 267

  
263
  def testCalls(self):
264
    utils.WriteFile(self.tfile.name, fn=self.markWrite,
265
                    prewrite=self.markPre, postwrite=self.markPost)
268
  def testPreWrite(self):
269
    utils.WriteFile(self.tfile.name, data="", prewrite=self.markPre)
266 270
    self.assertTrue(self.did_pre)
271
    self.assertFalse(self.did_post)
272
    self.assertFalse(self.did_write)
273

  
274
  def testPostWrite(self):
275
    utils.WriteFile(self.tfile.name, data="", postwrite=self.markPost)
276
    self.assertFalse(self.did_pre)
267 277
    self.assertTrue(self.did_post)
278
    self.assertFalse(self.did_write)
279

  
280
  def testWriteFunction(self):
281
    utils.WriteFile(self.tfile.name, fn=self.markWrite)
282
    self.assertFalse(self.did_pre)
283
    self.assertFalse(self.did_post)
268 284
    self.assertTrue(self.did_write)
269 285

  
270 286
  def testDryRun(self):
......
293 309
    finally:
294 310
      os.close(fd)
295 311

  
312
  def testNoLeftovers(self):
313
    self.tmpdir = tempfile.mkdtemp()
314
    self.assertEqual(utils.WriteFile(utils.PathJoin(self.tmpdir, "test"),
315
                                     data="abc"),
316
                     None)
317
    self.assertEqual(os.listdir(self.tmpdir), ["test"])
318

  
319
  def testFailRename(self):
320
    self.tmpdir = tempfile.mkdtemp()
321
    target = utils.PathJoin(self.tmpdir, "target")
322
    os.mkdir(target)
323
    self.assertRaises(OSError, utils.WriteFile, target, data="abc")
324
    self.assertTrue(os.path.isdir(target))
325
    self.assertEqual(os.listdir(self.tmpdir), ["target"])
326
    self.assertFalse(os.listdir(target))
327

  
328
  def testFailRenameDryRun(self):
329
    self.tmpdir = tempfile.mkdtemp()
330
    target = utils.PathJoin(self.tmpdir, "target")
331
    os.mkdir(target)
332
    self.assertEqual(utils.WriteFile(target, data="abc", dry_run=True), None)
333
    self.assertTrue(os.path.isdir(target))
334
    self.assertEqual(os.listdir(self.tmpdir), ["target"])
335
    self.assertFalse(os.listdir(target))
336

  
337
  def testBackup(self):
338
    self.tmpdir = tempfile.mkdtemp()
339
    testfile = utils.PathJoin(self.tmpdir, "test")
340

  
341
    self.assertEqual(utils.WriteFile(testfile, data="foo", backup=True), None)
342
    self.assertEqual(utils.ReadFile(testfile), "foo")
343
    self.assertEqual(os.listdir(self.tmpdir), ["test"])
344

  
345
    # Write again
346
    assert os.path.isfile(testfile)
347
    self.assertEqual(utils.WriteFile(testfile, data="bar", backup=True), None)
348
    self.assertEqual(utils.ReadFile(testfile), "bar")
349
    self.assertEqual(len(glob.glob("%s.backup*" % testfile)), 1)
350
    self.assertTrue("test" in os.listdir(self.tmpdir))
351
    self.assertEqual(len(os.listdir(self.tmpdir)), 2)
352

  
353
    # Write again as dry-run
354
    assert os.path.isfile(testfile)
355
    self.assertEqual(utils.WriteFile(testfile, data="000", backup=True,
356
                                     dry_run=True),
357
                     None)
358
    self.assertEqual(utils.ReadFile(testfile), "bar")
359
    self.assertEqual(len(glob.glob("%s.backup*" % testfile)), 1)
360
    self.assertTrue("test" in os.listdir(self.tmpdir))
361
    self.assertEqual(len(os.listdir(self.tmpdir)), 2)
362

  
296 363

  
297 364
class TestFileID(testutils.GanetiTestCase):
298 365
  def testEquality(self):

Also available in: Unified diff