Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.netutils_unittest.py @ 7845b8c8

History | View | Annotate | Download (14 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the netutils module"""
23

    
24
import os
25
import shutil
26
import socket
27
import tempfile
28
import unittest
29

    
30
import testutils
31
from ganeti import constants
32
from ganeti import errors
33
from ganeti import netutils
34
from ganeti import serializer
35
from ganeti import utils
36

    
37

    
38
def _GetSocketCredentials(path):
39
  """Connect to a Unix socket and return remote credentials.
40

41
  """
42
  sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
43
  try:
44
    sock.settimeout(10)
45
    sock.connect(path)
46
    return netutils.GetSocketCredentials(sock)
47
  finally:
48
    sock.close()
49

    
50

    
51
class TestGetSocketCredentials(unittest.TestCase):
52
  def setUp(self):
53
    self.tmpdir = tempfile.mkdtemp()
54
    self.sockpath = utils.PathJoin(self.tmpdir, "sock")
55

    
56
    self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
57
    self.listener.settimeout(10)
58
    self.listener.bind(self.sockpath)
59
    self.listener.listen(1)
60

    
61
  def tearDown(self):
62
    self.listener.shutdown(socket.SHUT_RDWR)
63
    self.listener.close()
64
    shutil.rmtree(self.tmpdir)
65

    
66
  def test(self):
67
    (c2pr, c2pw) = os.pipe()
68

    
69
    # Start child process
70
    child = os.fork()
71
    if child == 0:
72
      try:
73
        data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
74

    
75
        os.write(c2pw, data)
76
        os.close(c2pw)
77

    
78
        os._exit(0)
79
      finally:
80
        os._exit(1)
81

    
82
    os.close(c2pw)
83

    
84
    # Wait for one connection
85
    (conn, _) = self.listener.accept()
86
    conn.recv(1)
87
    conn.close()
88

    
89
    # Wait for result
90
    result = os.read(c2pr, 4096)
91
    os.close(c2pr)
92

    
93
    # Check child's exit code
94
    (_, status) = os.waitpid(child, 0)
95
    self.assertFalse(os.WIFSIGNALED(status))
96
    self.assertEqual(os.WEXITSTATUS(status), 0)
97

    
98
    # Check result
99
    (pid, uid, gid) = serializer.LoadJson(result)
100
    self.assertEqual(pid, os.getpid())
101
    self.assertEqual(uid, os.getuid())
102
    self.assertEqual(gid, os.getgid())
103

    
104

    
105
class TestHostname(unittest.TestCase):
106
  """Testing case for Hostname"""
107

    
108
  def testUppercase(self):
109
    data = "AbC.example.com"
110
    self.assertEqual(netutils.Hostname.GetNormalizedName(data), data.lower())
111

    
112
  def testTooLongName(self):
113
    data = "a.b." + "c" * 255
114
    self.assertRaises(errors.OpPrereqError,
115
                      netutils.Hostname.GetNormalizedName, data)
116

    
117
  def testTrailingDot(self):
118
    data = "a.b.c"
119
    self.assertEqual(netutils.Hostname.GetNormalizedName(data + "."), data)
120

    
121
  def testInvalidName(self):
122
    data = [
123
      "a b",
124
      "a/b",
125
      ".a.b",
126
      "a..b",
127
      ]
128
    for value in data:
129
      self.assertRaises(errors.OpPrereqError,
130
                        netutils.Hostname.GetNormalizedName, value)
131

    
132
  def testValidName(self):
133
    data = [
134
      "a.b",
135
      "a-b",
136
      "a_b",
137
      "a.b.c",
138
      ]
139
    for value in data:
140
      self.assertEqual(netutils.Hostname.GetNormalizedName(value), value)
141

    
142

    
143
class TestIPAddress(unittest.TestCase):
144
  def testIsValid(self):
145
    self.assert_(netutils.IPAddress.IsValid("0.0.0.0"))
146
    self.assert_(netutils.IPAddress.IsValid("127.0.0.1"))
147
    self.assert_(netutils.IPAddress.IsValid("::"))
148
    self.assert_(netutils.IPAddress.IsValid("::1"))
149

    
150
  def testNotIsValid(self):
151
    self.assertFalse(netutils.IPAddress.IsValid("0"))
152
    self.assertFalse(netutils.IPAddress.IsValid("1.1.1.256"))
153
    self.assertFalse(netutils.IPAddress.IsValid("a:g::1"))
154

    
155
  def testGetAddressFamily(self):
156
    fn = netutils.IPAddress.GetAddressFamily
157
    self.assertEqual(fn("127.0.0.1"), socket.AF_INET)
158
    self.assertEqual(fn("10.2.0.127"), socket.AF_INET)
159
    self.assertEqual(fn("::1"), socket.AF_INET6)
160
    self.assertEqual(fn("2001:db8::1"), socket.AF_INET6)
161
    self.assertRaises(errors.IPAddressError, fn, "0")
162

    
163
  def testOwnLoopback(self):
164
    # FIXME: In a pure IPv6 environment this is no longer true
165
    self.assert_(netutils.IPAddress.Own("127.0.0.1"),
166
                 "Should own 127.0.0.1 address")
167

    
168
  def testNotOwnAddress(self):
169
    self.assertFalse(netutils.IPAddress.Own("2001:db8::1"),
170
                     "Should not own IP address 2001:db8::1")
171
    self.assertFalse(netutils.IPAddress.Own("192.0.2.1"),
172
                     "Should not own IP address 192.0.2.1")
173

    
174

    
175
class TestIP4Address(unittest.TestCase):
176
  def testGetIPIntFromString(self):
177
    fn = netutils.IP4Address._GetIPIntFromString
178
    self.assertEqual(fn("0.0.0.0"), 0)
179
    self.assertEqual(fn("0.0.0.1"), 1)
180
    self.assertEqual(fn("127.0.0.1"), 2130706433)
181
    self.assertEqual(fn("192.0.2.129"), 3221226113)
182
    self.assertEqual(fn("255.255.255.255"), 2**32 - 1)
183
    self.assertNotEqual(fn("0.0.0.0"), 1)
184
    self.assertNotEqual(fn("0.0.0.0"), 1)
185

    
186
  def testIsValid(self):
187
    self.assert_(netutils.IP4Address.IsValid("0.0.0.0"))
188
    self.assert_(netutils.IP4Address.IsValid("127.0.0.1"))
189
    self.assert_(netutils.IP4Address.IsValid("192.0.2.199"))
190
    self.assert_(netutils.IP4Address.IsValid("255.255.255.255"))
191

    
192
  def testNotIsValid(self):
193
    self.assertFalse(netutils.IP4Address.IsValid("0"))
194
    self.assertFalse(netutils.IP4Address.IsValid("1"))
195
    self.assertFalse(netutils.IP4Address.IsValid("1.1.1"))
196
    self.assertFalse(netutils.IP4Address.IsValid("255.255.255.256"))
197
    self.assertFalse(netutils.IP4Address.IsValid("::1"))
198

    
199
  def testInNetwork(self):
200
    self.assert_(netutils.IP4Address.InNetwork("127.0.0.0/8", "127.0.0.1"))
201

    
202
  def testNotInNetwork(self):
203
    self.assertFalse(netutils.IP4Address.InNetwork("192.0.2.0/24",
204
                                                   "127.0.0.1"))
205

    
206
  def testIsLoopback(self):
207
    self.assert_(netutils.IP4Address.IsLoopback("127.0.0.1"))
208

    
209
  def testNotIsLoopback(self):
210
    self.assertFalse(netutils.IP4Address.IsLoopback("192.0.2.1"))
211

    
212

    
213
class TestIP6Address(unittest.TestCase):
214
  def testGetIPIntFromString(self):
215
    fn = netutils.IP6Address._GetIPIntFromString
216
    self.assertEqual(fn("::"), 0)
217
    self.assertEqual(fn("::1"), 1)
218
    self.assertEqual(fn("2001:db8::1"),
219
                     42540766411282592856903984951653826561L)
220
    self.assertEqual(fn("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"), 2**128-1)
221
    self.assertNotEqual(netutils.IP6Address._GetIPIntFromString("::2"), 1)
222

    
223
  def testIsValid(self):
224
    self.assert_(netutils.IP6Address.IsValid("::"))
225
    self.assert_(netutils.IP6Address.IsValid("::1"))
226
    self.assert_(netutils.IP6Address.IsValid("1" + (":1" * 7)))
227
    self.assert_(netutils.IP6Address.IsValid("ffff" + (":ffff" * 7)))
228
    self.assert_(netutils.IP6Address.IsValid("::"))
229

    
230
  def testNotIsValid(self):
231
    self.assertFalse(netutils.IP6Address.IsValid("0"))
232
    self.assertFalse(netutils.IP6Address.IsValid(":1"))
233
    self.assertFalse(netutils.IP6Address.IsValid("f" + (":f" * 6)))
234
    self.assertFalse(netutils.IP6Address.IsValid("fffg" + (":ffff" * 7)))
235
    self.assertFalse(netutils.IP6Address.IsValid("fffff" + (":ffff" * 7)))
236
    self.assertFalse(netutils.IP6Address.IsValid("1" + (":1" * 8)))
237
    self.assertFalse(netutils.IP6Address.IsValid("127.0.0.1"))
238

    
239
  def testInNetwork(self):
240
    self.assert_(netutils.IP6Address.InNetwork("::1/128", "::1"))
241

    
242
  def testNotInNetwork(self):
243
    self.assertFalse(netutils.IP6Address.InNetwork("2001:db8::1/128", "::1"))
244

    
245
  def testIsLoopback(self):
246
    self.assert_(netutils.IP6Address.IsLoopback("::1"))
247

    
248
  def testNotIsLoopback(self):
249
    self.assertFalse(netutils.IP6Address.IsLoopback("2001:db8::1"))
250

    
251

    
252
class _BaseTcpPingTest:
253
  """Base class for TcpPing tests against listen(2)ing port"""
254
  family = None
255
  address = None
256

    
257
  def setUp(self):
258
    self.listener = socket.socket(self.family, socket.SOCK_STREAM)
259
    self.listener.bind((self.address, 0))
260
    self.listenerport = self.listener.getsockname()[1]
261
    self.listener.listen(1)
262

    
263
  def tearDown(self):
264
    self.listener.shutdown(socket.SHUT_RDWR)
265
    del self.listener
266
    del self.listenerport
267

    
268
  def testTcpPingToLocalHostAccept(self):
269
    self.assert_(netutils.TcpPing(self.address,
270
                                  self.listenerport,
271
                                  timeout=constants.TCP_PING_TIMEOUT,
272
                                  live_port_needed=True,
273
                                  source=self.address,
274
                                  ),
275
                 "failed to connect to test listener")
276

    
277
    self.assert_(netutils.TcpPing(self.address, self.listenerport,
278
                                  timeout=constants.TCP_PING_TIMEOUT,
279
                                  live_port_needed=True),
280
                 "failed to connect to test listener (no source)")
281

    
282

    
283
class TestIP4TcpPing(unittest.TestCase, _BaseTcpPingTest):
284
  """Testcase for IPv4 TCP version of ping - against listen(2)ing port"""
285
  family = socket.AF_INET
286
  address = constants.IP4_ADDRESS_LOCALHOST
287

    
288
  def setUp(self):
289
    unittest.TestCase.setUp(self)
290
    _BaseTcpPingTest.setUp(self)
291

    
292
  def tearDown(self):
293
    unittest.TestCase.tearDown(self)
294
    _BaseTcpPingTest.tearDown(self)
295

    
296

    
297
class TestIP6TcpPing(unittest.TestCase, _BaseTcpPingTest):
298
  """Testcase for IPv6 TCP version of ping - against listen(2)ing port"""
299
  family = socket.AF_INET6
300
  address = constants.IP6_ADDRESS_LOCALHOST
301

    
302
  def setUp(self):
303
    unittest.TestCase.setUp(self)
304
    _BaseTcpPingTest.setUp(self)
305

    
306
  def tearDown(self):
307
    unittest.TestCase.tearDown(self)
308
    _BaseTcpPingTest.tearDown(self)
309

    
310

    
311
class _BaseTcpPingDeafTest:
312
  """Base class for TcpPing tests against non listen(2)ing port"""
313
  family = None
314
  address = None
315

    
316
  def setUp(self):
317
    self.deaflistener = socket.socket(self.family, socket.SOCK_STREAM)
318
    self.deaflistener.bind((self.address, 0))
319
    self.deaflistenerport = self.deaflistener.getsockname()[1]
320

    
321
  def tearDown(self):
322
    del self.deaflistener
323
    del self.deaflistenerport
324

    
325
  def testTcpPingToLocalHostAcceptDeaf(self):
326
    self.assertFalse(netutils.TcpPing(self.address,
327
                                      self.deaflistenerport,
328
                                      timeout=constants.TCP_PING_TIMEOUT,
329
                                      live_port_needed=True,
330
                                      source=self.address,
331
                                      ), # need successful connect(2)
332
                     "successfully connected to deaf listener")
333

    
334
    self.assertFalse(netutils.TcpPing(self.address,
335
                                      self.deaflistenerport,
336
                                      timeout=constants.TCP_PING_TIMEOUT,
337
                                      live_port_needed=True,
338
                                      ), # need successful connect(2)
339
                     "successfully connected to deaf listener (no source)")
340

    
341
  def testTcpPingToLocalHostNoAccept(self):
342
    self.assert_(netutils.TcpPing(self.address,
343
                                  self.deaflistenerport,
344
                                  timeout=constants.TCP_PING_TIMEOUT,
345
                                  live_port_needed=False,
346
                                  source=self.address,
347
                                  ), # ECONNREFUSED is OK
348
                 "failed to ping alive host on deaf port")
349

    
350
    self.assert_(netutils.TcpPing(self.address,
351
                                  self.deaflistenerport,
352
                                  timeout=constants.TCP_PING_TIMEOUT,
353
                                  live_port_needed=False,
354
                                  ), # ECONNREFUSED is OK
355
                 "failed to ping alive host on deaf port (no source)")
356

    
357

    
358
class TestIP4TcpPingDeaf(unittest.TestCase, _BaseTcpPingDeafTest):
359
  """Testcase for IPv4 TCP version of ping - against non listen(2)ing port"""
360
  family = socket.AF_INET
361
  address = constants.IP4_ADDRESS_LOCALHOST
362

    
363
  def setUp(self):
364
    self.deaflistener = socket.socket(self.family, socket.SOCK_STREAM)
365
    self.deaflistener.bind((self.address, 0))
366
    self.deaflistenerport = self.deaflistener.getsockname()[1]
367

    
368
  def tearDown(self):
369
    del self.deaflistener
370
    del self.deaflistenerport
371

    
372

    
373
class TestIP6TcpPingDeaf(unittest.TestCase, _BaseTcpPingDeafTest):
374
  """Testcase for IPv6 TCP version of ping - against non listen(2)ing port"""
375
  family = socket.AF_INET6
376
  address = constants.IP6_ADDRESS_LOCALHOST
377

    
378
  def setUp(self):
379
    unittest.TestCase.setUp(self)
380
    _BaseTcpPingDeafTest.setUp(self)
381

    
382
  def tearDown(self):
383
    unittest.TestCase.tearDown(self)
384
    _BaseTcpPingDeafTest.tearDown(self)
385

    
386

    
387
class TestFormatAddress(unittest.TestCase):
388
  """Testcase for FormatAddress"""
389

    
390
  def testFormatAddressUnixSocket(self):
391
    res1 = netutils.FormatAddress(socket.AF_UNIX, ("12352", 0, 0))
392
    self.assertEqual(res1, "pid=12352, uid=0, gid=0")
393

    
394
  def testFormatAddressIP4(self):
395
    res1 = netutils.FormatAddress(socket.AF_INET, ("127.0.0.1", 1234))
396
    self.assertEqual(res1, "127.0.0.1:1234")
397
    res2 = netutils.FormatAddress(socket.AF_INET, ("192.0.2.32", None))
398
    self.assertEqual(res2, "192.0.2.32")
399

    
400
  def testFormatAddressIP6(self):
401
    res1 = netutils.FormatAddress(socket.AF_INET6, ("::1", 1234))
402
    self.assertEqual(res1, "[::1]:1234")
403
    res2 = netutils.FormatAddress(socket.AF_INET6, ("::1", None))
404
    self.assertEqual(res2, "[::1]")
405
    res2 = netutils.FormatAddress(socket.AF_INET6, ("2001:db8::beef", "80"))
406
    self.assertEqual(res2, "[2001:db8::beef]:80")
407

    
408
  def testInvalidFormatAddress(self):
409
    self.assertRaises(errors.ParameterError,
410
                      netutils.FormatAddress, None, ("::1", None))
411
    self.assertRaises(errors.ParameterError,
412
                      netutils.FormatAddress, socket.AF_INET, "127.0.0.1")
413
    self.assertRaises(errors.ParameterError,
414
                      netutils.FormatAddress, socket.AF_INET, ("::1"))
415

    
416

    
417
if __name__ == "__main__":
418
  testutils.GanetiTestProgram()