1
0
mirror of https://github.com/systemd/systemd.git synced 2025-01-09 01:18:19 +03:00

ukify: Type-annotate ukify

This commit is contained in:
Jörg Behrmann 2024-10-05 00:15:42 +02:00
parent 2572afa405
commit e95193504b
2 changed files with 138 additions and 79 deletions

21
src/ukify/mypy.ini Normal file
View File

@ -0,0 +1,21 @@
[mypy]
python_version = 3.9
allow_redefinition = True
# belonging to --strict
warn_unused_configs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
disallow_untyped_decorators = true
disallow_incomplete_defs = true
check_untyped_defs = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = false
warn_return_any = true
no_implicit_reexport = true
# extra options not in --strict
pretty = true
show_error_codes = true
show_column_numbers = true
warn_unreachable = true
strict_equality = true

View File

@ -40,15 +40,18 @@ import subprocess
import sys
import tempfile
import textwrap
from collections.abc import Sequence
from collections.abc import Iterable, Iterator, Sequence
from hashlib import sha256
from pathlib import Path
from types import ModuleType
from typing import (
IO,
Any,
Callable,
Optional,
TypeVar,
Union,
cast,
)
import pefile # type: ignore
@ -83,7 +86,7 @@ class Style:
reset = '\033[0m' if sys.stderr.isatty() else ''
def guess_efi_arch():
def guess_efi_arch() -> str:
arch = os.uname().machine
for glob, mapping in EFI_ARCH_MAP.items():
@ -118,23 +121,23 @@ def page(text: str, enabled: Optional[bool]) -> None:
print(text)
def shell_join(cmd):
def shell_join(cmd: list[Union[str, Path]]) -> str:
# TODO: drop in favour of shlex.join once shlex.join supports Path.
return ' '.join(shlex.quote(str(x)) for x in cmd)
def round_up(x, blocksize=4096):
def round_up(x: int, blocksize: int = 4096) -> int:
return (x + blocksize - 1) // blocksize * blocksize
def try_import(modname, name=None):
def try_import(modname: str, name: Optional[str] = None) -> ModuleType:
try:
return __import__(modname)
except ImportError as e:
raise ValueError(f'Kernel is compressed with {name or modname}, but module unavailable') from e
def get_zboot_kernel(f):
def get_zboot_kernel(f: IO[bytes]) -> bytes:
"""Decompress zboot efistub kernel if compressed. Return contents."""
# See linux/drivers/firmware/efi/libstub/Makefile.zboot
# and linux/drivers/firmware/efi/libstub/zboot-header.S
@ -157,25 +160,25 @@ def get_zboot_kernel(f):
f.seek(start)
if comp_type.startswith(b'gzip'):
gzip = try_import('gzip')
return gzip.open(f).read(size)
return cast(bytes, gzip.open(f).read(size))
elif comp_type.startswith(b'lz4'):
lz4 = try_import('lz4.frame', 'lz4')
return lz4.frame.decompress(f.read(size))
return cast(bytes, lz4.frame.decompress(f.read(size)))
elif comp_type.startswith(b'lzma'):
lzma = try_import('lzma')
return lzma.open(f).read(size)
return cast(bytes, lzma.open(f).read(size))
elif comp_type.startswith(b'lzo'):
raise NotImplementedError('lzo decompression not implemented')
elif comp_type.startswith(b'xzkern'):
raise NotImplementedError('xzkern decompression not implemented')
elif comp_type.startswith(b'zstd22'):
zstd = try_import('zstd')
return zstd.uncompress(f.read(size))
else:
raise NotImplementedError(f'unknown compressed type: {comp_type}')
return cast(bytes, zstd.uncompress(f.read(size)))
raise NotImplementedError(f'unknown compressed type: {comp_type!r}')
def maybe_decompress(filename):
def maybe_decompress(filename: Union[str, Path]) -> bytes:
"""Decompress file if compressed. Return contents."""
f = open(filename, 'rb')
start = f.read(4)
@ -197,20 +200,20 @@ def maybe_decompress(filename):
if start.startswith(b'\x1f\x8b'):
gzip = try_import('gzip')
return gzip.open(f).read()
return cast(bytes, gzip.open(f).read())
if start.startswith(b'\x28\xb5\x2f\xfd'):
zstd = try_import('zstd')
return zstd.uncompress(f.read())
return cast(bytes, zstd.uncompress(f.read()))
if start.startswith(b'\x02\x21\x4c\x18'):
lz4 = try_import('lz4.frame', 'lz4')
return lz4.frame.decompress(f.read())
return cast(bytes, lz4.frame.decompress(f.read()))
if start.startswith(b'\x04\x22\x4d\x18'):
print('Newer lz4 stream format detected! This may not boot!')
lz4 = try_import('lz4.frame', 'lz4')
return lz4.frame.decompress(f.read())
return cast(bytes, lz4.frame.decompress(f.read()))
if start.startswith(b'\x89LZO'):
# python3-lzo is not packaged for Fedora
@ -218,13 +221,13 @@ def maybe_decompress(filename):
if start.startswith(b'BZh'):
bz2 = try_import('bz2', 'bzip2')
return bz2.open(f).read()
return cast(bytes, bz2.open(f).read())
if start.startswith(b'\x5d\x00\x00'):
lzma = try_import('lzma')
return lzma.open(f).read()
return cast(bytes, lzma.open(f).read())
raise NotImplementedError(f'unknown file format (starts with {start})')
raise NotImplementedError(f'unknown file format (starts with {start!r})')
class Uname:
@ -240,7 +243,7 @@ class Uname:
TEXT_PATTERN = rb'Linux version (?P<version>\d\.\S+) \('
@classmethod
def scrape_x86(cls, filename, opts=None):
def scrape_x86(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str:
# Based on https://gitlab.archlinux.org/archlinux/mkinitcpio/mkinitcpio/-/blob/master/functions#L136
# and https://docs.kernel.org/arch/x86/boot.html#the-real-mode-kernel-header
with open(filename, 'rb') as f:
@ -253,15 +256,17 @@ class Uname:
f.seek(0x200 + offset)
text = f.read(128)
text = text.split(b'\0', maxsplit=1)[0]
text = text.decode()
decoded = text.decode()
if not (m := re.match(cls.VERSION_PATTERN, text)):
if not (m := re.match(cls.VERSION_PATTERN, decoded)):
raise ValueError(f'Cannot parse version-host-release uname string: {text!r}')
return m.group('version')
@classmethod
def scrape_elf(cls, filename, opts=None):
def scrape_elf(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str:
readelf = find_tool('readelf', opts=opts)
if not readelf:
raise ValueError('FIXME')
cmd = [
readelf,
@ -282,7 +287,7 @@ class Uname:
return text.rstrip('\0')
@classmethod
def scrape_generic(cls, filename, opts=None):
def scrape_generic(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str:
# import libarchive
# libarchive-c fails with
# ArchiveError: Unrecognized archive format (errno=84, retcode=-30, archive_p=94705420454656)
@ -296,7 +301,7 @@ class Uname:
return m.group('version').decode()
@classmethod
def scrape(cls, filename, opts=None):
def scrape(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> Optional[str]:
for func in (cls.scrape_x86, cls.scrape_elf, cls.scrape_generic):
try:
version = func(filename, opts=opts)
@ -328,13 +333,13 @@ DEFAULT_SECTIONS_TO_SHOW = {
class Section:
name: str
content: Optional[Path]
tmpfile: Optional[IO] = None
tmpfile: Optional[IO[Any]] = None
measure: bool = False
output_mode: Optional[str] = None
virtual_size: Optional[int] = None
@classmethod
def create(cls, name, contents, **kwargs):
def create(cls, name: str, contents: Union[str, bytes, Path, None], **kwargs: Any) -> 'Section':
if isinstance(contents, (str, bytes)):
mode = 'wt' if isinstance(contents, str) else 'wb'
tmp = tempfile.NamedTemporaryFile(mode=mode, prefix=f'tmp{name}')
@ -347,7 +352,7 @@ class Section:
return cls(name, contents, tmpfile=tmp, **kwargs)
@classmethod
def parse_input(cls, s):
def parse_input(cls, s: str) -> 'Section':
try:
name, contents, *rest = s.split(':')
except ValueError as e:
@ -356,14 +361,15 @@ class Section:
raise ValueError(f'Cannot parse section spec (extraneous parameters): {s!r}')
if contents.startswith('@'):
contents = Path(contents[1:])
sec = cls.create(name, Path(contents[1:]))
else:
sec = cls.create(name, contents)
sec = cls.create(name, contents)
sec.check_name()
return sec
@classmethod
def parse_output(cls, s):
def parse_output(cls, s: str) -> 'Section':
if not (m := re.match(r'([a-zA-Z0-9_.]+):(text|binary)(?:@(.+))?', s)):
raise ValueError(f'Cannot parse section spec: {s!r}')
@ -372,7 +378,7 @@ class Section:
return cls.create(name, out, output_mode=ttype)
def check_name(self):
def check_name(self) -> None:
# PE section names with more than 8 characters are legal, but our stub does
# not support them.
if not self.name.isascii() or not self.name.isprintable():
@ -386,7 +392,7 @@ class UKI:
executable: list[Union[Path, str]]
sections: list[Section] = dataclasses.field(default_factory=list, init=False)
def add_section(self, section):
def add_section(self, section: Section) -> None:
start = 0
# Start search at last .profile section, if there is one
@ -400,7 +406,7 @@ class UKI:
self.sections += [section]
def parse_banks(s):
def parse_banks(s: str) -> list[str]:
banks = re.split(r',|\s+', s)
# TODO: do some sanity checking here
return banks
@ -416,7 +422,7 @@ KNOWN_PHASES = (
)
def parse_phase_paths(s):
def parse_phase_paths(s: str) -> list[str]:
# Split on commas or whitespace here. Commas might be hard to parse visually.
paths = re.split(r',|\s+', s)
@ -428,7 +434,7 @@ def parse_phase_paths(s):
return paths
def check_splash(filename):
def check_splash(filename: Optional[str]) -> None:
if filename is None:
return
@ -442,7 +448,7 @@ def check_splash(filename):
print(f'Splash image {filename} is {img.width}×{img.height} pixels')
def check_inputs(opts):
def check_inputs(opts: argparse.Namespace) -> None:
for name, value in vars(opts).items():
if name in {'output', 'tools'}:
continue
@ -458,7 +464,7 @@ def check_inputs(opts):
check_splash(opts.splash)
def check_cert_and_keys_nonexistent(opts):
def check_cert_and_keys_nonexistent(opts: argparse.Namespace) -> None:
# Raise if any of the keys and certs are found on disk
paths = itertools.chain(
(opts.sb_key, opts.sb_cert),
@ -469,12 +475,16 @@ def check_cert_and_keys_nonexistent(opts):
raise ValueError(f'{path} is present')
def find_tool(name, fallback=None, opts=None):
def find_tool(
name: str,
fallback: Optional[str] = None,
opts: Optional[argparse.Namespace] = None,
) -> Union[str, Path, None]:
if opts and opts.tools:
for d in opts.tools:
tool = d / name
if tool.exists():
return tool
return cast(Path, tool)
if shutil.which(name) is not None:
return name
@ -485,8 +495,8 @@ def find_tool(name, fallback=None, opts=None):
return fallback
def combine_signatures(pcrsigs):
combined = collections.defaultdict(list)
def combine_signatures(pcrsigs: list[dict[str, str]]) -> str:
combined: collections.defaultdict[str, list[str]] = collections.defaultdict(list)
for pcrsig in pcrsigs:
for bank, sigs in pcrsig.items():
for sig in sigs:
@ -495,7 +505,7 @@ def combine_signatures(pcrsigs):
return json.dumps(combined)
def key_path_groups(opts):
def key_path_groups(opts: argparse.Namespace) -> Iterator:
if not opts.pcr_private_keys:
return
@ -510,16 +520,18 @@ def key_path_groups(opts):
)
def pe_strip_section_name(name):
def pe_strip_section_name(name: bytes) -> str:
return name.rstrip(b'\x00').decode()
def call_systemd_measure(uki, opts, profile_start=0):
def call_systemd_measure(uki: UKI, opts: argparse.Namespace, profile_start: int = 0) -> None:
measure_tool = find_tool(
'systemd-measure',
'/usr/lib/systemd/systemd-measure',
opts=opts,
)
if not measure_tool:
raise ValueError('FIXME')
banks = opts.pcr_banks or ()
@ -583,8 +595,8 @@ def call_systemd_measure(uki, opts, profile_start=0):
extra += [f'--public-key={pub_key}']
extra += [f'--phase={phase_path}' for phase_path in group or ()]
print('+', shell_join(cmd + extra))
pcrsig = subprocess.check_output(cmd + extra, text=True)
print('+', shell_join(cmd + extra)) # type: ignore
pcrsig = subprocess.check_output(cmd + extra, text=True) # type: ignore
pcrsig = json.loads(pcrsig)
pcrsigs += [pcrsig]
@ -592,7 +604,7 @@ def call_systemd_measure(uki, opts, profile_start=0):
uki.add_section(Section.create('.pcrsig', combined))
def join_initrds(initrds):
def join_initrds(initrds: list[Path]) -> Union[Path, bytes, None]:
if not initrds:
return None
if len(initrds) == 1:
@ -608,7 +620,10 @@ def join_initrds(initrds):
return b''.join(seq)
def pairwise(iterable):
T = TypeVar('T')
def pairwise(iterable: Iterable[T]) -> Iterator[tuple[T, Optional[T]]]:
a, b = itertools.tee(iterable)
next(b, None)
return zip(a, b)
@ -618,7 +633,7 @@ class PEError(Exception):
pass
def pe_add_sections(uki: UKI, output: str):
def pe_add_sections(uki: UKI, output: str) -> None:
pe = pefile.PE(uki.executable, fast_load=True)
# Old stubs do not have the symbol/string table stripped, even though image files should not have one.
@ -750,7 +765,7 @@ def pe_add_sections(uki: UKI, output: str):
pe.write(output)
def merge_sbat(input_pe: [Path], input_text: [str]) -> str:
def merge_sbat(input_pe: list[Path], input_text: list[str]) -> str:
sbat = []
for f in input_pe:
@ -786,16 +801,21 @@ def merge_sbat(input_pe: [Path], input_text: [str]) -> str:
)
def signer_sign(cmd):
def signer_sign(cmd: list[Union[str, Path]]) -> None:
print('+', shell_join(cmd))
subprocess.check_call(cmd)
def find_sbsign(opts=None):
def find_sbsign(opts: Optional[argparse.Namespace] = None) -> Union[str, Path, None]:
return find_tool('sbsign', opts=opts)
def sbsign_sign(sbsign_tool, input_f, output_f, opts=None):
def sbsign_sign(
sbsign_tool: Union[str, Path],
input_f: str,
output_f: str,
opts: argparse.Namespace,
) -> None:
sign_invocation = [
sbsign_tool,
'--key', opts.sb_key,
@ -810,11 +830,16 @@ def sbsign_sign(sbsign_tool, input_f, output_f, opts=None):
signer_sign(sign_invocation)
def find_pesign(opts=None):
def find_pesign(opts: Optional[argparse.Namespace] = None) -> Union[str, Path, None]:
return find_tool('pesign', opts=opts)
def pesign_sign(pesign_tool, input_f, output_f, opts=None):
def pesign_sign(
pesign_tool: Union[str, Path],
input_f: str,
output_f: str,
opts: argparse.Namespace,
) -> None:
sign_invocation = [
pesign_tool,
'-s',
@ -841,7 +866,7 @@ PESIGCHECK = {
}
def verify(tool, opts):
def verify(tool: dict[str, str], opts: argparse.Namespace) -> bool:
verify_tool = find_tool(tool['name'], opts=opts)
cmd = [
verify_tool,
@ -857,13 +882,13 @@ def verify(tool, opts):
return tool['output'] in info
def make_uki(opts):
def make_uki(opts: argparse.Namespace) -> None:
# kernel payload signing
sign_tool = None
sign_args_present = opts.sb_key or opts.sb_cert_name
sign_kernel = opts.sign_kernel
sign = None
sign: Optional[Callable[[Union[str, Path], str, str, argparse.Namespace], None]] = None
linux = opts.linux
if sign_args_present:
@ -884,6 +909,8 @@ def make_uki(opts):
sign_kernel = verify(verify_tool, opts)
if sign_kernel:
assert sign is not None
assert sign_tool is not None
linux_signed = tempfile.NamedTemporaryFile(prefix='linux-signed')
linux = Path(linux_signed.name)
sign(sign_tool, opts.linux, linux, opts=opts)
@ -1051,8 +1078,9 @@ uki-addon,1,UKI Addon,addon,1,https://www.freedesktop.org/software/systemd/man/l
# UKI signing
if sign_args_present:
assert sign
sign(sign_tool, unsigned_output, opts.output, opts=opts)
assert sign is not None
assert sign_tool is not None
sign(sign_tool, unsigned_output, opts.output, opts)
# We end up with no executable bits, let's reapply them
os.umask(umask := os.umask(0))
@ -1062,7 +1090,7 @@ uki-addon,1,UKI Addon,addon,1,https://www.freedesktop.org/software/systemd/man/l
@contextlib.contextmanager
def temporary_umask(mask: int):
def temporary_umask(mask: int) -> Iterator[None]:
# Drop <mask> bits from umask
old = os.umask(0)
os.umask(old | mask)
@ -1076,7 +1104,7 @@ def generate_key_cert_pair(
common_name: str,
valid_days: int,
keylength: int = 2048,
) -> tuple[bytes]:
) -> tuple[bytes, bytes]:
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
@ -1133,7 +1161,7 @@ def generate_key_cert_pair(
return key_pem, cert_pem
def generate_priv_pub_key_pair(keylength: int = 2048) -> tuple[bytes]:
def generate_priv_pub_key_pair(keylength: int = 2048) -> tuple[bytes, bytes]:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
@ -1154,7 +1182,7 @@ def generate_priv_pub_key_pair(keylength: int = 2048) -> tuple[bytes]:
return priv_key_pem, pub_key_pem
def generate_keys(opts):
def generate_keys(opts: argparse.Namespace) -> None:
work = False
# This will generate keys and certificates and write them to the paths that
@ -1192,7 +1220,10 @@ def generate_keys(opts):
)
def inspect_section(opts, section):
def inspect_section(
opts: argparse.Namespace,
section: pefile.SectionStructure,
) -> tuple[str, Optional[dict[str, Union[int, str]]]]:
name = pe_strip_section_name(section.Name)
# find the config for this section in opts and whether to show it
@ -1233,7 +1264,7 @@ def inspect_section(opts, section):
return name, struct
def inspect_sections(opts):
def inspect_sections(opts: argparse.Namespace) -> None:
indent = 4 if opts.json == 'pretty' else None
for file in opts.files:
@ -1348,7 +1379,7 @@ class ConfigItem:
return self.dest
return self._names()[0].lstrip('-').replace('-', '_')
def add_to(self, parser: argparse.ArgumentParser):
def add_to(self, parser: argparse.ArgumentParser) -> None:
kwargs = {
key: val
for key in dataclasses.asdict(self)
@ -1357,7 +1388,14 @@ class ConfigItem:
args = self._names()
parser.add_argument(*args, **kwargs)
def apply_config(self, namespace, section, group, key, value) -> None:
def apply_config(
self,
namespace: argparse.Namespace,
section: str,
group: Optional[str],
key: str,
value: Any,
) -> None:
assert f'{section}/{key}' == self.config_key
dest = self.argparse_dest()
@ -1662,7 +1700,7 @@ CONFIG_ITEMS = [
CONFIGFILE_ITEMS = {item.config_key: item for item in CONFIG_ITEMS if item.config_key}
def apply_config(namespace, filename=None):
def apply_config(namespace: argparse.Namespace, filename: Union[str, Path, None] = None) -> None:
if filename is None:
if namespace.config:
# Config set by the user, use that.
@ -1694,7 +1732,7 @@ def apply_config(namespace, filename=None):
strict=False,
)
# Do not make keys lowercase
cp.optionxform = lambda option: option
cp.optionxform = lambda option: option # type: ignore
# The API is not great.
read = cp.read(filename)
@ -1718,8 +1756,8 @@ def apply_config(namespace, filename=None):
print(f'Unknown config setting [{section_name}] {key}=')
def config_example():
prev_section = None
def config_example() -> Iterator[str]:
prev_section: Optional[str] = None
for item in CONFIG_ITEMS:
section, key, value = item.config_example()
if section:
@ -1743,7 +1781,7 @@ class PagerHelpAction(argparse._HelpAction): # pylint: disable=protected-access
parser.exit()
def create_parser():
def create_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description='Build and sign Unified Kernel Images',
usage='\n '
@ -1762,7 +1800,7 @@ def create_parser():
item.add_to(p)
# Suppress printing of usage synopsis on errors
p.error = lambda message: p.exit(2, f'{p.prog}: error: {message}\n')
p.error = lambda message: p.exit(2, f'{p.prog}: error: {message}\n') # type: ignore
# Make --help paged
p.add_argument(
@ -1774,7 +1812,7 @@ def create_parser():
return p
def finalize_options(opts):
def finalize_options(opts: argparse.Namespace) -> None:
# Figure out which syntax is being used, one of:
# ukify verb --arg --arg --arg
# ukify linux initrd…
@ -1887,14 +1925,14 @@ def finalize_options(opts):
sys.exit()
def parse_args(args=None):
def parse_args(args: Optional[list[str]] = None) -> argparse.Namespace:
opts = create_parser().parse_args(args)
apply_config(opts)
finalize_options(opts)
return opts
def main():
def main() -> None:
opts = parse_args()
if opts.verb == 'build':
check_inputs(opts)