From e95193504b4b7b25453f2e6a7d37c05d0d0db68b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Behrmann?= Date: Sat, 5 Oct 2024 00:15:42 +0200 Subject: [PATCH] ukify: Type-annotate ukify --- src/ukify/mypy.ini | 21 +++++ src/ukify/ukify.py | 196 +++++++++++++++++++++++++++------------------ 2 files changed, 138 insertions(+), 79 deletions(-) create mode 100644 src/ukify/mypy.ini diff --git a/src/ukify/mypy.ini b/src/ukify/mypy.ini new file mode 100644 index 00000000000..bae4e70f44e --- /dev/null +++ b/src/ukify/mypy.ini @@ -0,0 +1,21 @@ +[mypy] +python_version = 3.9 +allow_redefinition = True +# belonging to --strict +warn_unused_configs = true +disallow_untyped_calls = true +disallow_untyped_defs = true +disallow_untyped_decorators = true +disallow_incomplete_defs = true +check_untyped_defs = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = false +warn_return_any = true +no_implicit_reexport = true +# extra options not in --strict +pretty = true +show_error_codes = true +show_column_numbers = true +warn_unreachable = true +strict_equality = true diff --git a/src/ukify/ukify.py b/src/ukify/ukify.py index 75e0afc6a6c..a76ceadf114 100755 --- a/src/ukify/ukify.py +++ b/src/ukify/ukify.py @@ -40,15 +40,18 @@ import subprocess import sys import tempfile import textwrap -from collections.abc import Sequence +from collections.abc import Iterable, Iterator, Sequence from hashlib import sha256 from pathlib import Path +from types import ModuleType from typing import ( IO, Any, Callable, Optional, + TypeVar, Union, + cast, ) import pefile # type: ignore @@ -83,7 +86,7 @@ class Style: reset = '\033[0m' if sys.stderr.isatty() else '' -def guess_efi_arch(): +def guess_efi_arch() -> str: arch = os.uname().machine for glob, mapping in EFI_ARCH_MAP.items(): @@ -118,23 +121,23 @@ def page(text: str, enabled: Optional[bool]) -> None: print(text) -def shell_join(cmd): +def shell_join(cmd: list[Union[str, Path]]) -> str: # TODO: drop in favour of shlex.join once shlex.join supports Path. return ' '.join(shlex.quote(str(x)) for x in cmd) -def round_up(x, blocksize=4096): +def round_up(x: int, blocksize: int = 4096) -> int: return (x + blocksize - 1) // blocksize * blocksize -def try_import(modname, name=None): +def try_import(modname: str, name: Optional[str] = None) -> ModuleType: try: return __import__(modname) except ImportError as e: raise ValueError(f'Kernel is compressed with {name or modname}, but module unavailable') from e -def get_zboot_kernel(f): +def get_zboot_kernel(f: IO[bytes]) -> bytes: """Decompress zboot efistub kernel if compressed. Return contents.""" # See linux/drivers/firmware/efi/libstub/Makefile.zboot # and linux/drivers/firmware/efi/libstub/zboot-header.S @@ -157,25 +160,25 @@ def get_zboot_kernel(f): f.seek(start) if comp_type.startswith(b'gzip'): gzip = try_import('gzip') - return gzip.open(f).read(size) + return cast(bytes, gzip.open(f).read(size)) elif comp_type.startswith(b'lz4'): lz4 = try_import('lz4.frame', 'lz4') - return lz4.frame.decompress(f.read(size)) + return cast(bytes, lz4.frame.decompress(f.read(size))) elif comp_type.startswith(b'lzma'): lzma = try_import('lzma') - return lzma.open(f).read(size) + return cast(bytes, lzma.open(f).read(size)) elif comp_type.startswith(b'lzo'): raise NotImplementedError('lzo decompression not implemented') elif comp_type.startswith(b'xzkern'): raise NotImplementedError('xzkern decompression not implemented') elif comp_type.startswith(b'zstd22'): zstd = try_import('zstd') - return zstd.uncompress(f.read(size)) - else: - raise NotImplementedError(f'unknown compressed type: {comp_type}') + return cast(bytes, zstd.uncompress(f.read(size))) + + raise NotImplementedError(f'unknown compressed type: {comp_type!r}') -def maybe_decompress(filename): +def maybe_decompress(filename: Union[str, Path]) -> bytes: """Decompress file if compressed. Return contents.""" f = open(filename, 'rb') start = f.read(4) @@ -197,20 +200,20 @@ def maybe_decompress(filename): if start.startswith(b'\x1f\x8b'): gzip = try_import('gzip') - return gzip.open(f).read() + return cast(bytes, gzip.open(f).read()) if start.startswith(b'\x28\xb5\x2f\xfd'): zstd = try_import('zstd') - return zstd.uncompress(f.read()) + return cast(bytes, zstd.uncompress(f.read())) if start.startswith(b'\x02\x21\x4c\x18'): lz4 = try_import('lz4.frame', 'lz4') - return lz4.frame.decompress(f.read()) + return cast(bytes, lz4.frame.decompress(f.read())) if start.startswith(b'\x04\x22\x4d\x18'): print('Newer lz4 stream format detected! This may not boot!') lz4 = try_import('lz4.frame', 'lz4') - return lz4.frame.decompress(f.read()) + return cast(bytes, lz4.frame.decompress(f.read())) if start.startswith(b'\x89LZO'): # python3-lzo is not packaged for Fedora @@ -218,13 +221,13 @@ def maybe_decompress(filename): if start.startswith(b'BZh'): bz2 = try_import('bz2', 'bzip2') - return bz2.open(f).read() + return cast(bytes, bz2.open(f).read()) if start.startswith(b'\x5d\x00\x00'): lzma = try_import('lzma') - return lzma.open(f).read() + return cast(bytes, lzma.open(f).read()) - raise NotImplementedError(f'unknown file format (starts with {start})') + raise NotImplementedError(f'unknown file format (starts with {start!r})') class Uname: @@ -240,7 +243,7 @@ class Uname: TEXT_PATTERN = rb'Linux version (?P\d\.\S+) \(' @classmethod - def scrape_x86(cls, filename, opts=None): + def scrape_x86(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str: # Based on https://gitlab.archlinux.org/archlinux/mkinitcpio/mkinitcpio/-/blob/master/functions#L136 # and https://docs.kernel.org/arch/x86/boot.html#the-real-mode-kernel-header with open(filename, 'rb') as f: @@ -253,15 +256,17 @@ class Uname: f.seek(0x200 + offset) text = f.read(128) text = text.split(b'\0', maxsplit=1)[0] - text = text.decode() + decoded = text.decode() - if not (m := re.match(cls.VERSION_PATTERN, text)): + if not (m := re.match(cls.VERSION_PATTERN, decoded)): raise ValueError(f'Cannot parse version-host-release uname string: {text!r}') return m.group('version') @classmethod - def scrape_elf(cls, filename, opts=None): + def scrape_elf(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str: readelf = find_tool('readelf', opts=opts) + if not readelf: + raise ValueError('FIXME') cmd = [ readelf, @@ -282,7 +287,7 @@ class Uname: return text.rstrip('\0') @classmethod - def scrape_generic(cls, filename, opts=None): + def scrape_generic(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str: # import libarchive # libarchive-c fails with # ArchiveError: Unrecognized archive format (errno=84, retcode=-30, archive_p=94705420454656) @@ -296,7 +301,7 @@ class Uname: return m.group('version').decode() @classmethod - def scrape(cls, filename, opts=None): + def scrape(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> Optional[str]: for func in (cls.scrape_x86, cls.scrape_elf, cls.scrape_generic): try: version = func(filename, opts=opts) @@ -328,13 +333,13 @@ DEFAULT_SECTIONS_TO_SHOW = { class Section: name: str content: Optional[Path] - tmpfile: Optional[IO] = None + tmpfile: Optional[IO[Any]] = None measure: bool = False output_mode: Optional[str] = None virtual_size: Optional[int] = None @classmethod - def create(cls, name, contents, **kwargs): + def create(cls, name: str, contents: Union[str, bytes, Path, None], **kwargs: Any) -> 'Section': if isinstance(contents, (str, bytes)): mode = 'wt' if isinstance(contents, str) else 'wb' tmp = tempfile.NamedTemporaryFile(mode=mode, prefix=f'tmp{name}') @@ -347,7 +352,7 @@ class Section: return cls(name, contents, tmpfile=tmp, **kwargs) @classmethod - def parse_input(cls, s): + def parse_input(cls, s: str) -> 'Section': try: name, contents, *rest = s.split(':') except ValueError as e: @@ -356,14 +361,15 @@ class Section: raise ValueError(f'Cannot parse section spec (extraneous parameters): {s!r}') if contents.startswith('@'): - contents = Path(contents[1:]) + sec = cls.create(name, Path(contents[1:])) + else: + sec = cls.create(name, contents) - sec = cls.create(name, contents) sec.check_name() return sec @classmethod - def parse_output(cls, s): + def parse_output(cls, s: str) -> 'Section': if not (m := re.match(r'([a-zA-Z0-9_.]+):(text|binary)(?:@(.+))?', s)): raise ValueError(f'Cannot parse section spec: {s!r}') @@ -372,7 +378,7 @@ class Section: return cls.create(name, out, output_mode=ttype) - def check_name(self): + def check_name(self) -> None: # PE section names with more than 8 characters are legal, but our stub does # not support them. if not self.name.isascii() or not self.name.isprintable(): @@ -386,7 +392,7 @@ class UKI: executable: list[Union[Path, str]] sections: list[Section] = dataclasses.field(default_factory=list, init=False) - def add_section(self, section): + def add_section(self, section: Section) -> None: start = 0 # Start search at last .profile section, if there is one @@ -400,7 +406,7 @@ class UKI: self.sections += [section] -def parse_banks(s): +def parse_banks(s: str) -> list[str]: banks = re.split(r',|\s+', s) # TODO: do some sanity checking here return banks @@ -416,7 +422,7 @@ KNOWN_PHASES = ( ) -def parse_phase_paths(s): +def parse_phase_paths(s: str) -> list[str]: # Split on commas or whitespace here. Commas might be hard to parse visually. paths = re.split(r',|\s+', s) @@ -428,7 +434,7 @@ def parse_phase_paths(s): return paths -def check_splash(filename): +def check_splash(filename: Optional[str]) -> None: if filename is None: return @@ -442,7 +448,7 @@ def check_splash(filename): print(f'Splash image {filename} is {img.width}×{img.height} pixels') -def check_inputs(opts): +def check_inputs(opts: argparse.Namespace) -> None: for name, value in vars(opts).items(): if name in {'output', 'tools'}: continue @@ -458,7 +464,7 @@ def check_inputs(opts): check_splash(opts.splash) -def check_cert_and_keys_nonexistent(opts): +def check_cert_and_keys_nonexistent(opts: argparse.Namespace) -> None: # Raise if any of the keys and certs are found on disk paths = itertools.chain( (opts.sb_key, opts.sb_cert), @@ -469,12 +475,16 @@ def check_cert_and_keys_nonexistent(opts): raise ValueError(f'{path} is present') -def find_tool(name, fallback=None, opts=None): +def find_tool( + name: str, + fallback: Optional[str] = None, + opts: Optional[argparse.Namespace] = None, +) -> Union[str, Path, None]: if opts and opts.tools: for d in opts.tools: tool = d / name if tool.exists(): - return tool + return cast(Path, tool) if shutil.which(name) is not None: return name @@ -485,8 +495,8 @@ def find_tool(name, fallback=None, opts=None): return fallback -def combine_signatures(pcrsigs): - combined = collections.defaultdict(list) +def combine_signatures(pcrsigs: list[dict[str, str]]) -> str: + combined: collections.defaultdict[str, list[str]] = collections.defaultdict(list) for pcrsig in pcrsigs: for bank, sigs in pcrsig.items(): for sig in sigs: @@ -495,7 +505,7 @@ def combine_signatures(pcrsigs): return json.dumps(combined) -def key_path_groups(opts): +def key_path_groups(opts: argparse.Namespace) -> Iterator: if not opts.pcr_private_keys: return @@ -510,16 +520,18 @@ def key_path_groups(opts): ) -def pe_strip_section_name(name): +def pe_strip_section_name(name: bytes) -> str: return name.rstrip(b'\x00').decode() -def call_systemd_measure(uki, opts, profile_start=0): +def call_systemd_measure(uki: UKI, opts: argparse.Namespace, profile_start: int = 0) -> None: measure_tool = find_tool( 'systemd-measure', '/usr/lib/systemd/systemd-measure', opts=opts, ) + if not measure_tool: + raise ValueError('FIXME') banks = opts.pcr_banks or () @@ -583,8 +595,8 @@ def call_systemd_measure(uki, opts, profile_start=0): extra += [f'--public-key={pub_key}'] extra += [f'--phase={phase_path}' for phase_path in group or ()] - print('+', shell_join(cmd + extra)) - pcrsig = subprocess.check_output(cmd + extra, text=True) + print('+', shell_join(cmd + extra)) # type: ignore + pcrsig = subprocess.check_output(cmd + extra, text=True) # type: ignore pcrsig = json.loads(pcrsig) pcrsigs += [pcrsig] @@ -592,7 +604,7 @@ def call_systemd_measure(uki, opts, profile_start=0): uki.add_section(Section.create('.pcrsig', combined)) -def join_initrds(initrds): +def join_initrds(initrds: list[Path]) -> Union[Path, bytes, None]: if not initrds: return None if len(initrds) == 1: @@ -608,7 +620,10 @@ def join_initrds(initrds): return b''.join(seq) -def pairwise(iterable): +T = TypeVar('T') + + +def pairwise(iterable: Iterable[T]) -> Iterator[tuple[T, Optional[T]]]: a, b = itertools.tee(iterable) next(b, None) return zip(a, b) @@ -618,7 +633,7 @@ class PEError(Exception): pass -def pe_add_sections(uki: UKI, output: str): +def pe_add_sections(uki: UKI, output: str) -> None: pe = pefile.PE(uki.executable, fast_load=True) # Old stubs do not have the symbol/string table stripped, even though image files should not have one. @@ -750,7 +765,7 @@ def pe_add_sections(uki: UKI, output: str): pe.write(output) -def merge_sbat(input_pe: [Path], input_text: [str]) -> str: +def merge_sbat(input_pe: list[Path], input_text: list[str]) -> str: sbat = [] for f in input_pe: @@ -786,16 +801,21 @@ def merge_sbat(input_pe: [Path], input_text: [str]) -> str: ) -def signer_sign(cmd): +def signer_sign(cmd: list[Union[str, Path]]) -> None: print('+', shell_join(cmd)) subprocess.check_call(cmd) -def find_sbsign(opts=None): +def find_sbsign(opts: Optional[argparse.Namespace] = None) -> Union[str, Path, None]: return find_tool('sbsign', opts=opts) -def sbsign_sign(sbsign_tool, input_f, output_f, opts=None): +def sbsign_sign( + sbsign_tool: Union[str, Path], + input_f: str, + output_f: str, + opts: argparse.Namespace, +) -> None: sign_invocation = [ sbsign_tool, '--key', opts.sb_key, @@ -810,11 +830,16 @@ def sbsign_sign(sbsign_tool, input_f, output_f, opts=None): signer_sign(sign_invocation) -def find_pesign(opts=None): +def find_pesign(opts: Optional[argparse.Namespace] = None) -> Union[str, Path, None]: return find_tool('pesign', opts=opts) -def pesign_sign(pesign_tool, input_f, output_f, opts=None): +def pesign_sign( + pesign_tool: Union[str, Path], + input_f: str, + output_f: str, + opts: argparse.Namespace, +) -> None: sign_invocation = [ pesign_tool, '-s', @@ -841,7 +866,7 @@ PESIGCHECK = { } -def verify(tool, opts): +def verify(tool: dict[str, str], opts: argparse.Namespace) -> bool: verify_tool = find_tool(tool['name'], opts=opts) cmd = [ verify_tool, @@ -857,13 +882,13 @@ def verify(tool, opts): return tool['output'] in info -def make_uki(opts): +def make_uki(opts: argparse.Namespace) -> None: # kernel payload signing sign_tool = None sign_args_present = opts.sb_key or opts.sb_cert_name sign_kernel = opts.sign_kernel - sign = None + sign: Optional[Callable[[Union[str, Path], str, str, argparse.Namespace], None]] = None linux = opts.linux if sign_args_present: @@ -884,6 +909,8 @@ def make_uki(opts): sign_kernel = verify(verify_tool, opts) if sign_kernel: + assert sign is not None + assert sign_tool is not None linux_signed = tempfile.NamedTemporaryFile(prefix='linux-signed') linux = Path(linux_signed.name) sign(sign_tool, opts.linux, linux, opts=opts) @@ -1051,8 +1078,9 @@ uki-addon,1,UKI Addon,addon,1,https://www.freedesktop.org/software/systemd/man/l # UKI signing if sign_args_present: - assert sign - sign(sign_tool, unsigned_output, opts.output, opts=opts) + assert sign is not None + assert sign_tool is not None + sign(sign_tool, unsigned_output, opts.output, opts) # We end up with no executable bits, let's reapply them os.umask(umask := os.umask(0)) @@ -1062,7 +1090,7 @@ uki-addon,1,UKI Addon,addon,1,https://www.freedesktop.org/software/systemd/man/l @contextlib.contextmanager -def temporary_umask(mask: int): +def temporary_umask(mask: int) -> Iterator[None]: # Drop bits from umask old = os.umask(0) os.umask(old | mask) @@ -1076,7 +1104,7 @@ def generate_key_cert_pair( common_name: str, valid_days: int, keylength: int = 2048, -) -> tuple[bytes]: +) -> tuple[bytes, bytes]: from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa @@ -1133,7 +1161,7 @@ def generate_key_cert_pair( return key_pem, cert_pem -def generate_priv_pub_key_pair(keylength: int = 2048) -> tuple[bytes]: +def generate_priv_pub_key_pair(keylength: int = 2048) -> tuple[bytes, bytes]: from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa @@ -1154,7 +1182,7 @@ def generate_priv_pub_key_pair(keylength: int = 2048) -> tuple[bytes]: return priv_key_pem, pub_key_pem -def generate_keys(opts): +def generate_keys(opts: argparse.Namespace) -> None: work = False # This will generate keys and certificates and write them to the paths that @@ -1192,7 +1220,10 @@ def generate_keys(opts): ) -def inspect_section(opts, section): +def inspect_section( + opts: argparse.Namespace, + section: pefile.SectionStructure, +) -> tuple[str, Optional[dict[str, Union[int, str]]]]: name = pe_strip_section_name(section.Name) # find the config for this section in opts and whether to show it @@ -1233,7 +1264,7 @@ def inspect_section(opts, section): return name, struct -def inspect_sections(opts): +def inspect_sections(opts: argparse.Namespace) -> None: indent = 4 if opts.json == 'pretty' else None for file in opts.files: @@ -1348,7 +1379,7 @@ class ConfigItem: return self.dest return self._names()[0].lstrip('-').replace('-', '_') - def add_to(self, parser: argparse.ArgumentParser): + def add_to(self, parser: argparse.ArgumentParser) -> None: kwargs = { key: val for key in dataclasses.asdict(self) @@ -1357,7 +1388,14 @@ class ConfigItem: args = self._names() parser.add_argument(*args, **kwargs) - def apply_config(self, namespace, section, group, key, value) -> None: + def apply_config( + self, + namespace: argparse.Namespace, + section: str, + group: Optional[str], + key: str, + value: Any, + ) -> None: assert f'{section}/{key}' == self.config_key dest = self.argparse_dest() @@ -1662,7 +1700,7 @@ CONFIG_ITEMS = [ CONFIGFILE_ITEMS = {item.config_key: item for item in CONFIG_ITEMS if item.config_key} -def apply_config(namespace, filename=None): +def apply_config(namespace: argparse.Namespace, filename: Union[str, Path, None] = None) -> None: if filename is None: if namespace.config: # Config set by the user, use that. @@ -1694,7 +1732,7 @@ def apply_config(namespace, filename=None): strict=False, ) # Do not make keys lowercase - cp.optionxform = lambda option: option + cp.optionxform = lambda option: option # type: ignore # The API is not great. read = cp.read(filename) @@ -1718,8 +1756,8 @@ def apply_config(namespace, filename=None): print(f'Unknown config setting [{section_name}] {key}=') -def config_example(): - prev_section = None +def config_example() -> Iterator[str]: + prev_section: Optional[str] = None for item in CONFIG_ITEMS: section, key, value = item.config_example() if section: @@ -1743,7 +1781,7 @@ class PagerHelpAction(argparse._HelpAction): # pylint: disable=protected-access parser.exit() -def create_parser(): +def create_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( description='Build and sign Unified Kernel Images', usage='\n ' @@ -1762,7 +1800,7 @@ def create_parser(): item.add_to(p) # Suppress printing of usage synopsis on errors - p.error = lambda message: p.exit(2, f'{p.prog}: error: {message}\n') + p.error = lambda message: p.exit(2, f'{p.prog}: error: {message}\n') # type: ignore # Make --help paged p.add_argument( @@ -1774,7 +1812,7 @@ def create_parser(): return p -def finalize_options(opts): +def finalize_options(opts: argparse.Namespace) -> None: # Figure out which syntax is being used, one of: # ukify verb --arg --arg --arg # ukify linux initrd… @@ -1887,14 +1925,14 @@ def finalize_options(opts): sys.exit() -def parse_args(args=None): +def parse_args(args: Optional[list[str]] = None) -> argparse.Namespace: opts = create_parser().parse_args(args) apply_config(opts) finalize_options(opts) return opts -def main(): +def main() -> None: opts = parse_args() if opts.verb == 'build': check_inputs(opts)