2022-08-06 14:48:49 +12:00
/* SPDX-License-Identifier: GPL-2.0 */
# ifndef MEAN_AND_VARIANCE_H_
# define MEAN_AND_VARIANCE_H_
# include <linux/types.h>
# include <linux/limits.h>
2023-06-04 17:58:56 -04:00
# include <linux/math.h>
2022-08-06 14:48:49 +12:00
# include <linux/math64.h>
# define SQRT_U64_MAX 4294967295ULL
/*
* u128_u : u128 user mode , because not all architectures support a real int128
* type
*/
# ifdef __SIZEOF_INT128__
typedef struct {
unsigned __int128 v ;
} __aligned ( 16 ) u128_u ;
static inline u128_u u64_to_u128 ( u64 a )
{
return ( u128_u ) { . v = a } ;
}
static inline u64 u128_lo ( u128_u a )
{
return a . v ;
}
static inline u64 u128_hi ( u128_u a )
{
return a . v > > 64 ;
}
static inline u128_u u128_add ( u128_u a , u128_u b )
{
a . v + = b . v ;
return a ;
}
static inline u128_u u128_sub ( u128_u a , u128_u b )
{
a . v - = b . v ;
return a ;
}
static inline u128_u u128_shl ( u128_u a , s8 shift )
{
a . v < < = shift ;
return a ;
}
static inline u128_u u128_square ( u64 a )
{
u128_u b = u64_to_u128 ( a ) ;
b . v * = b . v ;
return b ;
}
# else
typedef struct {
u64 hi , lo ;
} __aligned ( 16 ) u128_u ;
/* conversions */
static inline u128_u u64_to_u128 ( u64 a )
{
return ( u128_u ) { . lo = a } ;
}
static inline u64 u128_lo ( u128_u a )
{
return a . lo ;
}
static inline u64 u128_hi ( u128_u a )
{
return a . hi ;
}
/* arithmetic */
static inline u128_u u128_add ( u128_u a , u128_u b )
{
u128_u c ;
c . lo = a . lo + b . lo ;
c . hi = a . hi + b . hi + ( c . lo < a . lo ) ;
return c ;
}
static inline u128_u u128_sub ( u128_u a , u128_u b )
{
u128_u c ;
c . lo = a . lo - b . lo ;
c . hi = a . hi - b . hi - ( c . lo > a . lo ) ;
return c ;
}
static inline u128_u u128_shl ( u128_u i , s8 shift )
{
u128_u r ;
r . lo = i . lo < < shift ;
if ( shift < 64 )
r . hi = ( i . hi < < shift ) | ( i . lo > > ( 64 - shift ) ) ;
else {
r . hi = i . lo < < ( shift - 64 ) ;
r . lo = 0 ;
}
return r ;
}
static inline u128_u u128_square ( u64 i )
{
u128_u r ;
u64 h = i > > 32 , l = i & U32_MAX ;
r = u128_shl ( u64_to_u128 ( h * h ) , 64 ) ;
r = u128_add ( r , u128_shl ( u64_to_u128 ( h * l ) , 32 ) ) ;
r = u128_add ( r , u128_shl ( u64_to_u128 ( l * h ) , 32 ) ) ;
r = u128_add ( r , u64_to_u128 ( l * l ) ) ;
return r ;
}
# endif
static inline u128_u u64s_to_u128 ( u64 hi , u64 lo )
{
u128_u c = u64_to_u128 ( hi ) ;
c = u128_shl ( c , 64 ) ;
c = u128_add ( c , u64_to_u128 ( lo ) ) ;
return c ;
}
u128_u u128_div ( u128_u n , u64 d ) ;
struct mean_and_variance {
s64 n ;
s64 sum ;
u128_u sum_squares ;
} ;
/* expontentially weighted variant */
struct mean_and_variance_weighted {
bool init ;
u8 weight ; /* base 2 logarithim */
s64 mean ;
u64 variance ;
} ;
/**
* fast_divpow2 ( ) - fast approximation for n / ( 1 < < d )
* @ n : numerator
* @ d : the power of 2 denominator .
*
* note : this rounds towards 0.
*/
static inline s64 fast_divpow2 ( s64 n , u8 d )
{
return ( n + ( ( n < 0 ) ? ( ( 1 < < d ) - 1 ) : 0 ) ) > > d ;
}
/**
* mean_and_variance_update ( ) - update a mean_and_variance struct @ s1 with a new sample @ v1
* and return it .
* @ s1 : the mean_and_variance to update .
* @ v1 : the new sample .
*
* see linked pdf equation 12.
*/
2023-05-25 22:22:25 -04:00
static inline void
mean_and_variance_update ( struct mean_and_variance * s , s64 v )
{
s - > n + + ;
s - > sum + = v ;
s - > sum_squares = u128_add ( s - > sum_squares , u128_square ( abs ( v ) ) ) ;
2022-08-06 14:48:49 +12:00
}
s64 mean_and_variance_get_mean ( struct mean_and_variance s ) ;
u64 mean_and_variance_get_variance ( struct mean_and_variance s1 ) ;
u32 mean_and_variance_get_stddev ( struct mean_and_variance s ) ;
void mean_and_variance_weighted_update ( struct mean_and_variance_weighted * s , s64 v ) ;
s64 mean_and_variance_weighted_get_mean ( struct mean_and_variance_weighted s ) ;
u64 mean_and_variance_weighted_get_variance ( struct mean_and_variance_weighted s ) ;
u32 mean_and_variance_weighted_get_stddev ( struct mean_and_variance_weighted s ) ;
# endif // MEAN_AND_VAIRANCE_H_