diff --git a/README.md b/README.md index 96c70eb..72b7c80 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,24 @@ Should work anywhere with POSIX `sh`, POSIX.2 `od`, and either a `sh`-builtin Python module used by scripts here. Includes low-level pure-Python RPM header parsing and RPM tag metadata! Fun! +## `dino/` + +A Python module (that uses `rpmtoys`) I'm using for prototyping the +work-in-progress [DINO] package repo/packfile format. + +## `mkdino.py` + +A simple CLI to build [DINO] packfiles out of sets of RPMs, extract RPMs from +packfiles, examine packfile contents, etc. + +Requirements: + +* [python-libarchive-c]: `dnf install python3-libarchive-c` or `pip-3 install python-libarchive-c` +* [zstandard]: `pip-3 install zstandard` + +[python-libarchive-c]: https://github.com/Changaco/python-libarchive-c +[DINO]: https://github.com/wgwoods/libdino + ## `measure-metadata.py` A script to examine actual RPM headers and determine the amount of space used diff --git a/dino/__init__.py b/dino/__init__.py index 9e133df..f67868e 100644 --- a/dino/__init__.py +++ b/dino/__init__.py @@ -32,7 +32,7 @@ from .const import * from .section import * -from .struct import Dhdrp, Shdrp, StringTable +from .dstruct import Dhdrp, Shdrp, StringTable from .compression import get_compressor, get_decompressor # This only exports the public-facing stuff enums and classes. @@ -178,12 +178,14 @@ def section_index(self, section): # TODO: this needs a progress callback or something... def write_to(self, fobj): wrote = fobj.write(self.pack_hdrs()) - for n,(name,sec) in enumerate(self.sections()): - # FIXME: pass through the compressor? + for sec in self.sectab: + # FIXME: pass through the compressor if that flag is set + # compr = self.get_compressor() wrote += sec.write_to(fobj) return wrote def get_compressor(self, level=None): + # TODO: compression_opts! return get_compressor(self.compression_id, level=level) def get_decompressor(self): diff --git a/dino/compression.py b/dino/compression.py index b8f607d..edbae73 100644 --- a/dino/compression.py +++ b/dino/compression.py @@ -1,36 +1,109 @@ # dino.compression - compression/decompression helpers +import logging as log + from .const import CompressionID -# TODO: Define a CompressionOpts structure that we can store in the header +# TODO: Define a CompressionOpts structure that we can store in the header, +# like squashfs does... + +available_compressors = {"zstd", "xz"} + +DEFAULT_COMPRESSION_LEVEL = { + CompressionID.XZ: 2, # Fedora default (ca. F30) + CompressionID.ZSTD: 10, # Diminishing returns above here... +} + +DEFAULT_CHUNK_SIZE = 4*1024 + + + +class CompressionStreamWriter(object): + def __init__(self, cobj, fobj): + self._cobj = cobj + self._fobj = fobj + + def write(self, data): + return self._fobj.write(self._cobj.compress(data)) + + def flush(self): + r = self._fobj.write(self._cobj.flush()) + self._cobj = None + return r + +class MultiCompressor(object): + def __init__(self, make_compress_obj, **kwargs): + if not callable(make_compress_obj): + raise ValueError(f'{make_compress_obj} is not callable') + self._mkcobj = make_compress_obj + self.args = kwargs + log.debug("MultiCompressor(%s, kwargs=%s)", make_compress_obj, kwargs) + + def copy_stream(self, inf, outf, size=0, read_size=None, write_size=None): + if read_size is None: + read_size = DEFAULT_CHUNK_SIZE + if write_size is None: + write_size = DEFAULT_CHUNK_SIZE + read = 0 + wrote = 0 + to_read = size or -1 + cobj = self._mkcobj(**self.args) + while to_read and (read < to_read): + chunk = inf.read(min(read_size, to_read)) + if not chunk: + break + read += len(chunk) + wrote += outf.write(cobj.compress(chunk)) + wrote += outf.write(cobj.flush()) + return read, wrote + +class CopyStreamMultiCompressor(MultiCompressor): + def __init__(self, cctx): + self._cctx = cctx + def copy_stream(self, inf, outf, size=0, read_size=None, write_size=None): + kwargs = dict() + if size: + kwargs['size'] = size + if read_size: + kwargs['read_size'] = read_size + if write_size: + kwargs['write_size'] = write_size + return self._cctx.copy_stream(inf, outf, **kwargs) + + +# Utility function to get CompressionID by id or name (or None) +cidmap = {n.lower():cid for n,cid in CompressionID.__members__.items()} +cidmap['gzip'] = cidmap['zlib'] +cidmap['gz'] = cidmap['gzip'] +def get_compressid(which): + if isinstance(which, int): + return CompressionID(which) + if which is None: + return CompressionID.NONE + if not isinstance(which, str): + which = str(which, 'ascii', 'ignore') + return cidmap.get(which.lower()) # We don't import the compression modules at the toplevel because I want this # to work even if you don't have Every Compression Library installed. # As long as you have the ones you actually use, we should be fine. def get_compressor(which, level=None): - which = CompressionID(which) + which = get_compressid(which) + if level is None or level < 0: + level = DEFAULT_COMPRESSION_LEVEL.get(which) if which == CompressionID.ZSTD: import zstandard as zstd - if level and level < 0: - level = zstd.MAX_COMPRESSION_LEVEL cctx = zstd.ZstdCompressor(write_content_size=True, level=level) - return cctx + return CopyStreamMultiCompressor(cctx) elif which == CompressionID.XZ: import lzma - if level and level < 0: - level = 9 - # TODO: this doesn't support zstd's copy_stream function. - # Might need a wrapper object to make the different compressors - # all play nice, while still making sure they can flush and - # start a new compression frame when needed.. - cctx = lzma.LZMACompressor(preset=level) - return cctx + return MultiCompressor(lzma.LZMACompressor, preset=level) else: raise NotImplementedError(f"{which.name} not implemented!") def get_decompressor(which): - which = CompressionID(which) + which = get_compressid(which) if which == CompressionID.ZSTD: import zstandard as zstd return zstd.ZstdDecompressor() @@ -39,3 +112,8 @@ def get_decompressor(which): return lzma.LZMADecompressor() else: raise NotImplementedError("{which.name} not implemented!") + +# FIXME: need tests to confirm that each chunk of output from the compressor +# can be individually uncompressed.... +# FIXME: also need some benchmarks to compare performance of algorithms +# (compresion ratio, decompression speed/mem use) diff --git a/dino/const.py b/dino/const.py index 7075f61..589dbe0 100644 --- a/dino/const.py +++ b/dino/const.py @@ -53,12 +53,22 @@ class Arch(IntEnum): RISCV = 243 class HeaderEncoding(IntFlag): - LE = 0b00000000 # Little-endian is the default; no bit set - BE = 0b00000001 # Big-endian - OFF64 = 0b00000010 # TODO: 64-bit sizes/offsets + # Endianness could be a single bit - 0=LE, 1=BE - but I'd prefer that an + # empty value be invalid, to reduce the probability that garbage data will + # be interpreted as valid. So instead we'll use the low two bits to store + # one of two valid ELF EI_DATA values - LSB=1, MSB=2. + # Anything else is invalid and should be rejected. + INVALID = 0b00000000 # No bits set! Invalid! + LE = 0b00000001 # Little-endian (our default) + BE = 0b00000010 # Big-endian + SEC64 = 0b00000100 # 64-bit sizes/offsets in sectab def byteorder(self): - return self & 0b1 + bo = self & 0b11 + if bo == 0b11: + return self.INVALID + else: + return bo def endian(self): return '>' if self.byteorder() == self.BE else '<' @@ -121,3 +131,4 @@ class SectionType(IntEnum): class SectionFlags(IntFlag): NONE = 0b00000000 COMPRESSED = 0b00000001 + VARINT = 0b00000010 diff --git a/dino/struct.py b/dino/dstruct.py similarity index 100% rename from dino/struct.py rename to dino/dstruct.py diff --git a/dino/section.py b/dino/section.py index db507b7..325fd0a 100644 --- a/dino/section.py +++ b/dino/section.py @@ -3,11 +3,14 @@ from io import BytesIO from struct import Struct from collections import Counter +from dataclasses import dataclass from tempfile import SpooledTemporaryFile +from enum import IntFlag from .util import copy_stream from .const import SectionFlags, SectionType, NAME_IDX_NONE -from .struct import Shdrp +from .varint import varint_encode, varint_iter_decode +from .dstruct import Shdrp from .fileview import FileView class BaseSection(object): @@ -45,6 +48,9 @@ def flags(self): def info(self): return self._info + def _parse_info(self): + pass + @property def size(self): return 0 @@ -72,7 +78,6 @@ def pack_hdr(self): self.info, self.size, self.count) def write_to(self, fobj): - self.fobj.seek(0) return copy_stream(self.fobj, fobj, size=self.size) def tobytes(self): @@ -138,6 +143,61 @@ def from_file(self, fobj, size, count=0): # Maybe we should just make dino objects mmap-able? self._data = FileView(fobj, fobj.tell(), size) + def write_to(self, fobj): + oldpos = self.fobj.tell() + self.fobj.seek(0) + r = copy_stream(self.fobj, fobj, size=self.size) + self.fobj.seek(oldpos) + return r + +class IndexFlags(IntFlag): + NONE = 0 + NoFanout = 1 << 0 + Off64 = 1 << 1 + UncSize = 1 << 2 + +@dataclass +class IndexInfo: + othersec: int = 0 + keysize: int = 0 + fanout: bool = True + off64: bool = False + unc_size: bool = True + + @property + def flags(self): + return (IndexFlags.NONE | + (not self.fanout and IndexFlags.NoFanout) | + (self.off64 and IndexFlags.Off64) | + (self.unc_size and IndexFlags.UncSize)) + + def to_int(self): + if self.othersec < 0 or self.othersec > 0xff: + raise ValueError(f"invalid othersec {self.othersec}") + if self.keysize < 0 or self.keysize > 0xff: + raise ValueError(f"invalid keysize {self.keysize}") + return (self.keysize | (self.othersec << 8) | (int(self.flags) << 16)) + + @classmethod + def from_int(cls, info): + flags = IndexFlags((info >> 16) & 0xff) + return cls(keysize=info & 0xff, + othersec=(info >> 8) & 0xff, + fanout=IndexFlags.NoFanout not in flags, + off64=IndexFlags.Off64 in flags, + unc_size=IndexFlags.UncSize in flags) + +# FIXME use logging for this!! +DEBUG=1 +if DEBUG: + def dprint(*args, **kwargs): + print(*args, **kwargs) +else: + def dprint(*args, **kwargs): + pass + + + # TODO: make fanout and sizes optional class IndexSection(BaseSection): ''' @@ -146,33 +206,61 @@ class IndexSection(BaseSection): ''' typeid = SectionType.Index datatype = dict - offset_sfmt = 'II' - fanout_sfmt = '256I' - def __init__(self, *args, othersec=None, othersec_idx=None, keysize=32, endian='<', **kwargs): + def __init__(self, *args, othersec=None, othersec_idx=None, keysize=32, + fanout=True, off64=False, unc_size=True, varint=False, + endian='<', **kwargs): # TODO: flag for whether or not there's a full fanout table # (so we can skip it for small indexes) - # TODO: flag for whether we have offsets and sizes or just offsets # TODO: flag for varint encoding of offsets/sizes BaseSection.__init__(self, *args, **kwargs) if not (othersec_idx or isinstance(othersec, BaseSection)): raise ValueError("expected BaseSection, got {type(othersec)}") - self.keysize = keysize + + # these control the output encoding and can be set/changed whenever self.endian = endian + self.fanout = fanout + self.varint = varint + # off64 is an output encoding setting that gets set automatically + # if a 64-bit offset/size is added + self._off64 = off64 + # keysize and unc_size can't be changed once an index is created + self._keysize = keysize + self._unc_size = unc_size + # references to the section we're an index over self._othersec = othersec self._othersec_idx = othersec_idx - self._key_s = Struct(f'{self.endian}{self.keysize}s') - self._offset_s = Struct(f'{self.endian}{self.offset_sfmt}') - self._fanout_s = Struct(f'{self.endian}{self.fanout_sfmt}') + + # set up + self._key_s = Struct(f'{self._keysize}s') + valfmt = ('L' if off64 else 'I') * (3 if unc_size else 2) + self._val_s = Struct(f'{self.endian}{valfmt}') + self._fanout_s = Struct(f'{self.endian}256I') + + if unc_size: + self.add = self.add3 + else: + self.add = self.add2 + + @property + def keysize(self): + return self._keysize + + @staticmethod + def parse_info(info): + return IndexInfo.from_int(info) @classmethod def from_hdr(cls, shdr): - keysize = shdr.info & 0xff - othersec_idx = (shdr.info >> 8) & 0xff + info = cls.parse_info(shdr.info) return cls(name_idx=shdr.name, flags=shdr.flags, - keysize=keysize, - othersec_idx=othersec_idx) + othersec_idx=info.othersec, + keysize=info.keysize, + fanout=info.fanout, + off64=info.off64, + unc_size=info.unc_size, + varint=bool(shdr.flags & SectionFlags.VARINT)) @property def count(self): @@ -180,13 +268,21 @@ def count(self): @property def size(self): - return (self._fanout_s.size + - self.count*(self._key_s.size + self._offset_s.size)) + if self.varint: + return (len(self.make_fanout()) + + (self.count*self.keysize) + + sum(len(varint_encode(i)) for v in self.values() for i in v)) + else: + return (self._fanout_s.size + + self.count*(self.keysize + self._val_s.size)) @property def info(self): - return ((0xff & self.keysize) | - ((0xff & self.othersec.idx) << 8)) + return IndexInfo(keysize=self.keysize, + othersec=self.othersec.idx if self.othersec else 0xff, + fanout=self.fanout, + unc_size=self._unc_size, + off64=self._off64).to_int() @property def othersec(self): @@ -210,29 +306,61 @@ def get(self, key, default=None): def __contains__(self, key): return key in self._data - def add(self, key, offset, size): + def add2(self, key, offset, size): self._data[self._key_s.pack(key)] = (offset, size) + def add3(self, key, offset, size, uncsize): + self._data[self._key_s.pack(key)] = (offset, size, uncsize) + def remove(self, key): del self._data[key] def make_fanout(self): counts = Counter(k[0] for k in self.keys()) - fanout = [0] * 257 - for i in range(256): - fanout[i+1] = fanout[i] + counts[i] - return self._fanout_s.pack(*fanout[1:]) + if self.varint: + # varint-encoded fanout just gives the counts for each byte + return self._varint_pack(*[counts[b] for b in range(256)]) + else: + fanout = [0] * 257 + for i in range(256): + fanout[i+1] = fanout[i] + counts[i] + return self._fanout_s.pack(*fanout[1:]) + + def _varint_pack(self, *values): + return b''.join(varint_encode(i) for i in values) + + @property + def keysize(self): + return self._key_s.size def write_to(self, fobj): if self.count == 0: return 0 - # TODO: if count is small we should skip fanout.. - wrote = fobj.write(self.make_fanout()) - keys, offsets = zip(*(sorted(self.items()))) + dprint(f"writing index: fanout={self.fanout} varint={self.varint} " + f"unc_size={self._unc_size} keysize={self.keysize} " + f"count={self.count}") + wrote = 0 + if self.fanout: + wrote += fobj.write(self.make_fanout()) + keys, vals = zip(*(sorted(self.items()))) + dprint(f" fanout: {wrote:7} bytes") + + prevpos = wrote for k in keys: wrote += fobj.write(self._key_s.pack(k)) - for o in offsets: - wrote += fobj.write(self._offset_s.pack(*o)) + dprint(f" keys: {wrote-prevpos:7} bytes") + + if self.varint: + valpack = self._varint_pack + else: + valpack = self._val_s.pack + + prevpos = wrote + for v in vals: + wrote += fobj.write(valpack(*v)) + dprint(f" vals: {wrote-prevpos:7} bytes") + dprint(f" total: {wrote:7} bytes") + return wrote def from_file(self, fobj, size, count=0): @@ -245,13 +373,45 @@ def from_file(self, fobj, size, count=0): if size == 0: self._data = self.datatype() return - fanout = self._fanout_s.unpack(fobj.read(self._fanout_s.size)) - keycount = fanout[-1] - if count: - assert keycount == count - keys = [i[0] for i in self._key_s.iter_unpack(fobj.read(self.keysize*keycount))] - offs = self._offset_s.iter_unpack(fobj.read(self._offset_s.size * keycount)) - self._data = self.datatype(zip(keys, offs)) + + dprint(f"reading index: fanout={self.fanout} varint={self.varint} " + f"unc_size={self._unc_size} keysize={self.keysize} " + f"count={self.count}") + + data = fobj.read(size) + keypos = 0 + if self.fanout: + if self.varint: + # NOTE: varint-encoded fanout is a sequence of counts, not a + # running count.. + fv = 0 + for v, n in varint_iter_decode(data, 256): + fv += v + keypos += n + fanout.append(fv) + else: + keypos = self._fanout_s.size + fanout = self._fanout_s.unpack(data[0:keypos]) + if count: + assert count == fanout[-1] + dprint(f" fanout: {keypos:7} bytes, count={fanout[-1]}") + keylen = self.keysize * count + valpos = keypos + keylen + keydata = data[keypos:valpos] + valdata = data[valpos:] + dprint(f" keys: {valpos-keypos:7} bytes") + dprint(f" vals: {len(valdata):7} bytes") + keys = [i[0] for i in self._key_s.iter_unpack(keydata)] + if self.varint: + vals = [i[0] for i in varint_iter_decode(valdata)] + n, m = divmod(len(vals), count) + assert (m == 0), "Incorrect/corrupt index" + vals = [tuple(vals[i:i+n]) for i in range(0,len(vals),n)] + else: + if (len(valdata) % self._val_s.size): + print(f"wtf: size {self._val_s.size} * count {count} != {len(valdata)}") + vals = self._val_s.iter_unpack(valdata) + self._data = self.datatype(zip(keys, vals)) class RPMSection(BlobSection): '''A section containing one or more RPM headers.''' diff --git a/dino/varint.py b/dino/varint.py index 8f1537e..b0da2a6 100644 --- a/dino/varint.py +++ b/dino/varint.py @@ -6,6 +6,7 @@ 'VARINT_MAXVAL_WIDTH', 'varint_encode', 'varint_decode', + 'varint_iter_decode', ] from ctypes import c_ulonglong as uintmax @@ -47,16 +48,28 @@ def varint_encode(val): def varint_decode(varint): '''Decode the varint byte sequence and return (val, nbytes)''' - n = 0 - val = varint[n] & 127 - while (varint[n] & 128): + b = varint[0] + val = b & 127 + n = 1 + while (b & 128): if (val & VARINT_MS7B_MASK): raise ValueError("varint overflow") val += 1 + b = varint[n] n += 1 - val = (val << 7) + (varint[n] & 127) + val = (val << 7) + (b & 127) return val, n +def varint_iter_decode(data, maxcount=None): + end = len(data) + itemsleft = maxcount if maxcount else end + data = memoryview(data) + pos = 0 + while itemsleft and pos < end: + val, size = varint_decode(data[pos:]) + yield val, size + pos += size + itemsleft -= 1 def varint_str(varint): varint_bytes = [f"{'+' if b & 128 else '.'}{b&127:02x}" for b in varint] @@ -65,6 +78,7 @@ def varint_str(varint): if __name__ == '__main__': assert (varint_encode(0) == b'\0'), "varint_encode(0) != b'\0'" + assert (varint_decode(b'\0') == (0,1)) assert (varint_encode(127) == b'\x7f') varint = varint_encode(128) varexp = b'\x80\x00' @@ -74,10 +88,12 @@ def varint_str(varint): vnext = varint_encode(val + 1) assert (len(vmax) == width) assert (len(vnext) == width + 1) - d_vmax, _ = varint_decode(vmax) - d_vnext, _ = varint_decode(vnext) + d_vmax, w_vmax = varint_decode(vmax) + d_vnext, w_vnext = varint_decode(vnext) assert (d_vmax == val), f'{d_vmax} != {val}' assert (d_vnext == val + 1), f'{d_vnext} != {val+1}' + assert (w_vmax == width) + assert (w_vnext == width + 1) vmax = varint_encode(VARINT_MAXVAL) assert(len(vmax) == VARINT_MAXLEN)