Create BSD_Disklabel class in disklabel.py
authorNikos Skalkotos <skalkoto@grnet.gr>
Thu, 20 Feb 2014 08:11:55 +0000 (10:11 +0200)
committerNikos Skalkotos <skalkoto@grnet.gr>
Thu, 20 Feb 2014 08:11:55 +0000 (10:11 +0200)
We need this class to support {Free,Net}BSD disklabels

snf-image-helper/disklabel.py

index b0c4b5e..3bc241e 100755 (executable)
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 # 02110-1301, USA.
 
-"""This module provides the code for handling OpenBSD disklabels"""
+"""This module provides the code for handling BSD disklabels"""
 
 import struct
 import sys
+import os
 import cStringIO
 import optparse
 
@@ -36,8 +37,6 @@ LABELOFFSET = 0
 BBSIZE = 8192  # size of boot area with label
 SBSIZE = 8192  # max size of fs superblock
 
-DISKMAGIC = 0x82564557
-
 
 class MBR(object):
     """Represents a Master Boot Record."""
@@ -158,10 +157,74 @@ class MBR(object):
             ret += "Partition %d: %s\n" % (i, self.part[i])
         ret += "Signature: %s %s\n" % (hex(ord(self.signature[0])),
                                        hex(ord(self.signature[1])))
-        return ret
+        title = "Master Boot Record"
+        return "%s\n%s\n%s\n" % (title, len(title) * "=", ret)
 
 
 class Disklabel:
+    """Represents an BSD Disklabel"""
+
+    def __init__(self, disk):
+        """Create a DiskLabel instance"""
+        self.disk = disk
+        self.part_num = None
+        self.disklabel = None
+
+        with open(disk, "rb") as d:
+            sector0 = d.read(BLOCKSIZE)
+            self.mbr = MBR(sector0)
+
+            for i in range(4):
+                ptype = self.mbr.part[i].type
+                if ptype in (0xa5, 0xa6, 0xa9):
+                    d.seek(BLOCKSIZE * self.mbr.part[i].first_sector)
+                    self.part_num = i
+                    if ptype == 0xa5:  # FreeBSD
+                        self.disklabel = BSD_Disklabel(d)
+                    elif ptype == 0xa6:  # OpenBSD
+                        self.disklabel = OpenBSD_Disklabel(d)
+                    else:  # NetBSD
+                        self.disklabel = BSD_Disklabel(d)
+                    break
+
+        assert self.disklabel is not None, "No *BSD partition found"
+
+    def write(self):
+        """Write the disklabel back to the media"""
+        with open(self.disk, 'rw+b') as d:
+            d.write(self.mbr.pack())
+
+            d.seek(self.mbr.part[self.part_num].first_sector * BLOCKSIZE)
+            self.disklabel.write_to(d)
+
+    def __str__(self):
+        return str(self.mbr) + str(self.disklabel)
+
+    def enlarge_disk(self, new_size):
+        """Enlarge the size of the disk and return the last usable sector"""
+
+        # Fix the disklabel
+        end = self.disklabel.enlarge_disk(new_size)
+
+        # Fix the MBR
+        start = self.mbr.part[self.part_num].first_sector
+        self.mbr.part[self.part_num].sector_count = end - start + 1
+
+        cylinder = end // (self.disklabel.ntracks * self.disklabel.nsectors)
+        header = (end // self.disklabel.nsectors) % self.disklabel.ntracks
+        sector = (end % self.disklabel.nsectors) + 1
+        chs = MBR.Partition.pack_chs(cylinder, header, sector)
+        self.mbr.part[self.part_num].end = chs
+
+    def enlarge_last_partition(self):
+        self.disklabel.enlarge_last_partition()
+
+
+class BSD_Disklabel:
+    pass
+
+
+class OpenBSD_Disklabel:
     """Represents an OpenBSD Disklabel"""
     format = "<IHH16s16sIIIIII8sIHHIII20sHH16sIHHII364s"
     """
@@ -258,39 +321,29 @@ class Disklabel:
                                           tmp.frag, tmp.cpg)
 
         def getpsize(self, i):
+            """Get size for partition i"""
             return (self.part[i].sizeh << 32) + self.part[i].size
 
         def setpoffset(self, i, offset):
-            """Set  offset for partition i"""
+            """Set offset for partition i"""
             tmp = self.part[i]
             self.part[i] = self.Partition(tmp.size, offset & 0xffffffff,
                                           offset >> 32, tmp.sizeh, tmp.frag,
                                           tmp.cpg)
 
         def getpoffset(self, i):
+            """Get offset for partition i"""
             return (self.part[i].offseth << 32) + self.part[i].offset
 
-    def __init__(self, disk):
-        """Create a DiskLabel instance"""
-        self.disk = disk
-        self.part_num = None
-
-        with open(disk, "rb") as d:
-            sector0 = d.read(BLOCKSIZE)
-            self.mbr = MBR(sector0)
-
-            for i in range(4):
-                if self.mbr.part[i].type == 0xa6:  # OpenBSD type
-                    self.part_num = i
-                    break
+    DISKMAGIC = 0x82564557
 
-            assert self.part_num is not None, "No OpenBSD partition found"
+    def __init__(self, device):
+        """Create a DiskLabel instance"""
 
-            d.seek(BLOCKSIZE * self.mbr.part[self.part_num].first_sector)
-            part_sector0 = d.read(BLOCKSIZE)
-            # The offset of the disklabel from the begining of the
-            # partition is one sector
-            part_sector1 = d.read(BLOCKSIZE)
+        device.seek(BLOCKSIZE, os.SEEK_CUR)
+        # The offset of the disklabel from the beginning of the partition is
+        # one sector
+        sector1 = device.read(BLOCKSIZE)
 
         (self.magic,
          self.dtype,
@@ -319,9 +372,9 @@ class Disklabel:
          self.npartitions,
          self.bbsize,
          self.sbsize,
-         ptable_raw) = struct.unpack(self.format, part_sector1)
+         ptable_raw) = struct.unpack(self.format, sector1)
 
-        assert self.magic == DISKMAGIC, "Disklabel is not valid"
+        assert self.magic == self.DISKMAGIC, "Disklabel is not valid"
 
         self.ptable = self.PartitionTable(ptable_raw, self.npartitions)
 
@@ -391,18 +444,19 @@ class Disklabel:
         return (self.bstarth << 32) + self.bstart
 
     def setbend(self, bend):
-        """Set end of useable region"""
+        """Set size of useable region"""
         self.bendh = bend >> 32
         self.bend = bend & 0xffffffff
 
     def getbend(self):
+        """Get size of usable region"""
         return (self.bendh << 32) + self.bend
 
     def enlarge_disk(self, new_size):
-        """Enlarge the size of the disk"""
+        """Enlarge the size of the disk and return the last usable sector"""
 
-        assert new_size >= self.secperunit, \
-            "New size cannot be smaller that %s" % self.secperunit
+        assert new_size >= self.getdsize(), \
+            "New size cannot be smaller that %s" % self.getdsize()
 
         # Fix the disklabel
         self.setdsize(new_size)
@@ -412,26 +466,21 @@ class Disklabel:
         # Partition 'c' descriptes the entire disk
         self.ptable.setpsize(2, new_size)
 
-        # Fix the MBR table
-        start = self.mbr.part[self.part_num].first_sector
-        self.mbr.part[self.part_num].sector_count = self.getbend() - start
-
-        lba = self.getbend() - 1
-        cylinder = lba // (self.ntracks * self.nsectors)
-        header = (lba // self.nsectors) % self.ntracks
-        sector = (lba % self.nsectors) + 1
-        chs = MBR.Partition.pack_chs(cylinder, header, sector)
-        self.mbr.part[self.part_num].end = chs
-
+        # Update the checksum
         self.checksum = self.compute_checksum()
 
-    def write(self):
-        """Write the disklabel back to the media"""
-        with open(self.disk, 'rw+b') as d:
-            d.write(self.mbr.pack())
+        # getbend() gives back the size of the usable region and not the end of
+        # the usable region. I named it like this because this is how it is
+        # named in OpenBSD. To get the last usable sector you need to reduce
+        # this value by one.
+        return self.getbend() - 1
+
+    def write_to(self, device):
+        """Write the disklabel to a device"""
 
-            d.seek((self.mbr.part[self.part_num].first_sector + 1) * BLOCKSIZE)
-            d.write(self.pack())
+        # The disklabel starts at sector 1
+        device.seek(BLOCKSIZE, os.SEEK_CUR)
+        device.write(self.pack())
 
     def get_last_partition_id(self):
         """Returns the id of the last partition"""
@@ -471,12 +520,10 @@ class Disklabel:
 
     def __str__(self):
         """Print the Disklabel"""
-        title1 = "Master Boot Record"
-        title2 = "Disklabel"
 
+        title = "Disklabel"
         return \
-            "%s\n%s\n%s\n" % (title1, len(title1) * "=", str(self.mbr)) + \
-            "%s\n%s\n" % (title2, len(title2) * "=") + \
+            "%s\n%s\n" % (title, len(title) * "=") + \
             "Magic Number: 0x%x\n" % self.magic + \
             "Drive type: %d\n" % self.dtype + \
             "Subtype: %d\n" % self.subtype + \