2023-12-08 06:38:21 +03:00
# Unix SMB/CIFS implementation.
# Copyright (C) Catalyst.Net Ltd 2023
#
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
""" Group Key Distribution Service module """
from enum import Enum
from functools import total_ordering
2024-02-02 02:23:58 +03:00
from typing import Optional , Tuple
2023-12-08 06:38:21 +03:00
from cryptography . hazmat . primitives import hashes
from samba import _glue
from samba . dcerpc import gkdi , misc
from samba . ndr import ndr_pack , ndr_unpack
from samba . nt_time import NtTime , NtTimeDelta
2024-02-02 02:23:58 +03:00
uint64_max : int = 2 * * 64 - 1
2023-12-08 06:38:21 +03:00
2024-02-02 02:23:58 +03:00
L1_KEY_ITERATION : int = _glue . GKDI_L1_KEY_ITERATION
L2_KEY_ITERATION : int = _glue . GKDI_L2_KEY_ITERATION
KEY_CYCLE_DURATION : NtTimeDelta = _glue . GKDI_KEY_CYCLE_DURATION
MAX_CLOCK_SKEW : NtTimeDelta = _glue . GKDI_MAX_CLOCK_SKEW
2023-12-08 06:38:21 +03:00
2024-02-02 02:23:58 +03:00
KEY_LEN_BYTES = 64
2023-12-08 06:38:21 +03:00
class Algorithm ( Enum ) :
SHA1 = " SHA1 "
SHA256 = " SHA256 "
SHA384 = " SHA384 "
SHA512 = " SHA512 "
def algorithm ( self ) - > hashes . HashAlgorithm :
if self is Algorithm . SHA1 :
return hashes . SHA1 ( )
if self is Algorithm . SHA256 :
return hashes . SHA256 ( )
if self is Algorithm . SHA384 :
return hashes . SHA384 ( )
if self is Algorithm . SHA512 :
return hashes . SHA512 ( )
raise RuntimeError ( " unknown hash algorithm {self} " )
def __repr__ ( self ) - > str :
return str ( self )
@staticmethod
def from_kdf_parameters ( kdf_param : Optional [ bytes ] ) - > " Algorithm " :
if not kdf_param :
return Algorithm . SHA256 # the default used by Windows.
kdf_parameters = ndr_unpack ( gkdi . KdfParameters , kdf_param )
return Algorithm ( kdf_parameters . hash_algorithm )
class GkidType ( Enum ) :
DEFAULT = object ( )
L0_SEED_KEY = object ( )
L1_SEED_KEY = object ( )
L2_SEED_KEY = object ( )
def description ( self ) - > str :
if self is GkidType . DEFAULT :
return " a default GKID "
if self is GkidType . L0_SEED_KEY :
return " an L0 seed key "
if self is GkidType . L1_SEED_KEY :
return " an L1 seed key "
if self is GkidType . L2_SEED_KEY :
return " an L2 seed key "
raise RuntimeError ( " unknown GKID type {self} " )
class InvalidDerivation ( Exception ) :
pass
class UndefinedStartTime ( Exception ) :
pass
@total_ordering
class Gkid :
__slots__ = [ " _l0_idx " , " _l1_idx " , " _l2_idx " ]
2024-02-02 02:23:58 +03:00
max_l0_idx = 0x7FFF_FFFF
2023-12-08 06:38:21 +03:00
def __init__ ( self , l0_idx : int , l1_idx : int , l2_idx : int ) - > None :
if not - 1 < = l0_idx < = Gkid . max_l0_idx :
raise ValueError ( f " L0 index { l0_idx } out of range " )
if not - 1 < = l1_idx < L1_KEY_ITERATION :
raise ValueError ( f " L1 index { l1_idx } out of range " )
if not - 1 < = l2_idx < L2_KEY_ITERATION :
raise ValueError ( f " L2 index { l2_idx } out of range " )
if l0_idx == - 1 and l1_idx != - 1 :
raise ValueError ( " invalid combination of negative and non‐ negative indices " )
if l1_idx == - 1 and l2_idx != - 1 :
raise ValueError ( " invalid combination of negative and non‐ negative indices " )
self . _l0_idx = l0_idx
self . _l1_idx = l1_idx
self . _l2_idx = l2_idx
@property
def l0_idx ( self ) - > int :
return self . _l0_idx
@property
def l1_idx ( self ) - > int :
return self . _l1_idx
@property
def l2_idx ( self ) - > int :
return self . _l2_idx
def gkid_type ( self ) - > GkidType :
if self . l0_idx == - 1 :
return GkidType . DEFAULT
if self . l1_idx == - 1 :
return GkidType . L0_SEED_KEY
if self . l2_idx == - 1 :
return GkidType . L1_SEED_KEY
return GkidType . L2_SEED_KEY
def wrapped_l1_idx ( self ) - > int :
if self . l1_idx == - 1 :
return L1_KEY_ITERATION
return self . l1_idx
def wrapped_l2_idx ( self ) - > int :
if self . l2_idx == - 1 :
return L2_KEY_ITERATION
return self . l2_idx
def derive_l1_seed_key ( self ) - > " Gkid " :
gkid_type = self . gkid_type ( )
if (
gkid_type is not GkidType . L0_SEED_KEY
and gkid_type is not GkidType . L1_SEED_KEY
) :
raise InvalidDerivation (
" Invalid attempt to derive an L1 seed key from "
f " { gkid_type . description ( ) } "
)
if self . l1_idx == 0 :
raise InvalidDerivation ( " No further derivation of L1 seed keys is possible " )
return Gkid ( self . l0_idx , self . wrapped_l1_idx ( ) - 1 , self . l2_idx )
def derive_l2_seed_key ( self ) - > " Gkid " :
gkid_type = self . gkid_type ( )
if (
gkid_type is not GkidType . L1_SEED_KEY
and gkid_type is not GkidType . L2_SEED_KEY
) :
raise InvalidDerivation (
f " Attempt to derive an L2 seed key from { gkid_type . description ( ) } "
)
if self . l2_idx == 0 :
raise InvalidDerivation ( " No further derivation of L2 seed keys is possible " )
return Gkid ( self . l0_idx , self . l1_idx , self . wrapped_l2_idx ( ) - 1 )
def __str__ ( self ) - > str :
return f " Gkid( { self . l0_idx } , { self . l1_idx } , { self . l2_idx } ) "
def __repr__ ( self ) - > str :
cls = type ( self )
return (
f " { cls . __qualname__ } ( { repr ( self . l0_idx ) } , { repr ( self . l1_idx ) } , "
f " { repr ( self . l2_idx ) } ) "
)
def __eq__ ( self , other : object ) - > bool :
if not isinstance ( other , Gkid ) :
return NotImplemented
return ( self . l0_idx , self . l1_idx , self . l2_idx ) == (
other . l0_idx ,
other . l1_idx ,
other . l2_idx ,
)
def __lt__ ( self , other : object ) - > bool :
if not isinstance ( other , Gkid ) :
return NotImplemented
def as_tuple ( gkid : Gkid ) - > Tuple [ int , int , int ] :
l0_idx , l1_idx , l2_idx = gkid . l0_idx , gkid . l1_idx , gkid . l2_idx
# DEFAULT is considered less than everything else, so that the
# lexical ordering requirement in [MS-GKDI] 3.1.4.1.3 (GetKey) makes
# sense.
if gkid . gkid_type ( ) is not GkidType . DEFAULT :
# Use the wrapped indices so that L1 seed keys are considered
# greater than their children L2 seed keys, and L0 seed keys are
# considered greater than their children L1 seed keys.
l1_idx = gkid . wrapped_l1_idx ( )
l2_idx = gkid . wrapped_l2_idx ( )
return l0_idx , l1_idx , l2_idx
return as_tuple ( self ) < as_tuple ( other )
def __hash__ ( self ) - > int :
return hash ( ( self . l0_idx , self . l1_idx , self . l2_idx ) )
@staticmethod
def default ( ) - > " Gkid " :
return Gkid ( - 1 , - 1 , - 1 )
@staticmethod
def l0_seed_key ( l0_idx : int ) - > " Gkid " :
return Gkid ( l0_idx , - 1 , - 1 )
@staticmethod
def l1_seed_key ( l0_idx : int , l1_idx : int ) - > " Gkid " :
return Gkid ( l0_idx , l1_idx , - 1 )
@staticmethod
def from_nt_time ( nt_time : NtTime ) - > " Gkid " :
l0 = nt_time / / ( L1_KEY_ITERATION * L2_KEY_ITERATION * KEY_CYCLE_DURATION )
l1 = (
nt_time
% ( L1_KEY_ITERATION * L2_KEY_ITERATION * KEY_CYCLE_DURATION )
/ / ( L2_KEY_ITERATION * KEY_CYCLE_DURATION )
)
l2 = nt_time % ( L2_KEY_ITERATION * KEY_CYCLE_DURATION ) / / KEY_CYCLE_DURATION
return Gkid ( l0 , l1 , l2 )
def start_nt_time ( self ) - > NtTime :
gkid_type = self . gkid_type ( )
if gkid_type is not GkidType . L2_SEED_KEY :
raise UndefinedStartTime (
f " { gkid_type . description ( ) } has no defined start time "
)
start_time = NtTime (
(
self . l0_idx * L1_KEY_ITERATION * L2_KEY_ITERATION
+ self . l1_idx * L2_KEY_ITERATION
+ self . l2_idx
)
* KEY_CYCLE_DURATION
)
if not 0 < = start_time < = uint64_max :
raise OverflowError ( f " start time { start_time } out of range " )
return start_time
class SeedKeyPair :
__slots__ = [ " l1_key " , " l2_key " , " gkid " , " hash_algorithm " , " root_key_id " ]
def __init__ (
self ,
l1_key : Optional [ bytes ] ,
l2_key : Optional [ bytes ] ,
gkid : Gkid ,
hash_algorithm : Algorithm ,
root_key_id : misc . GUID ,
) - > None :
if l1_key is not None and len ( l1_key ) != KEY_LEN_BYTES :
raise ValueError ( f " L1 key ( { repr ( l1_key ) } ) must be { KEY_LEN_BYTES } bytes " )
if l2_key is not None and len ( l2_key ) != KEY_LEN_BYTES :
raise ValueError ( f " L2 key ( { repr ( l2_key ) } ) must be { KEY_LEN_BYTES } bytes " )
self . l1_key = l1_key
self . l2_key = l2_key
self . gkid = gkid
self . hash_algorithm = hash_algorithm
self . root_key_id = root_key_id
def __str__ ( self ) - > str :
l1_key_hex = None if self . l1_key is None else self . l1_key . hex ( )
l2_key_hex = None if self . l2_key is None else self . l2_key . hex ( )
return (
f " SeedKeyPair(L1Key( { l1_key_hex } ), L2Key( { l2_key_hex } ), { self . gkid } , "
f " { self . root_key_id } , { self . hash_algorithm } ) "
)
def __repr__ ( self ) - > str :
cls = type ( self )
return (
f " { cls . __qualname__ } ( { repr ( self . l1_key ) } , { repr ( self . l2_key ) } , "
f " { repr ( self . gkid ) } , { repr ( self . hash_algorithm ) } , "
f " { repr ( self . root_key_id ) } ) "
)
def __eq__ ( self , other : object ) - > bool :
if not isinstance ( other , SeedKeyPair ) :
return NotImplemented
return (
self . l1_key ,
self . l2_key ,
self . gkid ,
self . hash_algorithm ,
self . root_key_id ,
) == (
other . l1_key ,
other . l2_key ,
other . gkid ,
other . hash_algorithm ,
other . root_key_id ,
)
def __hash__ ( self ) - > int :
return hash ( (
self . l1_key ,
self . l2_key ,
self . gkid ,
self . hash_algorithm ,
ndr_pack ( self . root_key_id ) ,
) )
class GroupKey :
__slots__ = [ " gkid " , " key " , " hash_algorithm " , " root_key_id " ]
def __init__ (
self , key : bytes , gkid : Gkid , hash_algorithm : Algorithm , root_key_id : misc . GUID
) - > None :
if key is not None and len ( key ) != KEY_LEN_BYTES :
raise ValueError ( f " Key ( { repr ( key ) } ) must be { KEY_LEN_BYTES } bytes " )
self . key = key
self . gkid = gkid
self . hash_algorithm = hash_algorithm
self . root_key_id = root_key_id
def __str__ ( self ) - > str :
return (
f " GroupKey(Key( { self . key . hex ( ) } ), { self . gkid } , { self . hash_algorithm } , "
f " { self . root_key_id } ) "
)
def __repr__ ( self ) - > str :
cls = type ( self )
return (
f " { cls . __qualname__ } ( { repr ( self . key ) } , { repr ( self . gkid ) } , "
f " { repr ( self . hash_algorithm ) } , { repr ( self . root_key_id ) } ) "
)
def __eq__ ( self , other : object ) - > bool :
if not isinstance ( other , GroupKey ) :
return NotImplemented
return ( self . key , self . gkid , self . hash_algorithm , self . root_key_id ) == (
other . key ,
other . gkid ,
other . hash_algorithm ,
other . root_key_id ,
)
def __hash__ ( self ) - > int :
return hash (
( self . key , self . gkid , self . hash_algorithm , ndr_pack ( self . root_key_id ) )
)