2020-04-19 23:23:05 +03:00
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
2016-11-29 19:26:36 +03:00
package proto
import (
"errors"
"fmt"
"reflect"
2020-04-19 23:23:05 +03:00
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/runtime/protoiface"
"google.golang.org/protobuf/runtime/protoimpl"
)
2018-11-05 06:20:00 +03:00
2020-04-19 23:23:05 +03:00
type (
// ExtensionDesc represents an extension descriptor and
// is used to interact with an extension field in a message.
//
// Variables of this type are generated in code by protoc-gen-go.
ExtensionDesc = protoimpl . ExtensionInfo
2018-11-05 06:20:00 +03:00
2020-04-19 23:23:05 +03:00
// ExtensionRange represents a range of message extensions.
// Used in code generated by protoc-gen-go.
ExtensionRange = protoiface . ExtensionRangeV1
2018-11-05 06:20:00 +03:00
2020-04-19 23:23:05 +03:00
// Deprecated: Do not use; this is an internal type.
Extension = protoimpl . ExtensionFieldV1
2018-11-05 06:20:00 +03:00
2020-04-19 23:23:05 +03:00
// Deprecated: Do not use; this is an internal type.
XXX_InternalExtensions = protoimpl . ExtensionFields
)
2018-11-05 06:20:00 +03:00
2020-04-19 23:23:05 +03:00
// ErrMissingExtension reports whether the extension was not present.
var ErrMissingExtension = errors . New ( "proto: missing extension" )
2018-11-05 06:20:00 +03:00
var errNotExtendable = errors . New ( "proto: not an extendable proto.Message" )
2020-04-19 23:23:05 +03:00
// HasExtension reports whether the extension field is present in m
// either as an explicitly populated field or as an unknown field.
func HasExtension ( m Message , xt * ExtensionDesc ) ( has bool ) {
mr := MessageReflect ( m )
if mr == nil || ! mr . IsValid ( ) {
return false
2018-11-05 06:20:00 +03:00
}
2020-04-19 23:23:05 +03:00
// Check whether any populated known field matches the field number.
xtd := xt . TypeDescriptor ( )
if isValidExtension ( mr . Descriptor ( ) , xtd ) {
has = mr . Has ( xtd )
} else {
mr . Range ( func ( fd protoreflect . FieldDescriptor , _ protoreflect . Value ) bool {
has = int32 ( fd . Number ( ) ) == xt . Field
return ! has
2018-11-05 06:20:00 +03:00
} )
}
2020-04-19 23:23:05 +03:00
// Check whether any unknown field matches the field number.
for b := mr . GetUnknown ( ) ; ! has && len ( b ) > 0 ; {
num , _ , n := protowire . ConsumeField ( b )
has = int32 ( num ) == xt . Field
b = b [ n : ]
2018-11-05 06:20:00 +03:00
}
2020-04-19 23:23:05 +03:00
return has
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
// ClearExtension removes the the exntesion field from m
// either as an explicitly populated field or as an unknown field.
func ClearExtension ( m Message , xt * ExtensionDesc ) {
mr := MessageReflect ( m )
if mr == nil || ! mr . IsValid ( ) {
2018-11-05 06:20:00 +03:00
return
}
2016-11-29 19:26:36 +03:00
2020-04-19 23:23:05 +03:00
xtd := xt . TypeDescriptor ( )
if isValidExtension ( mr . Descriptor ( ) , xtd ) {
mr . Clear ( xtd )
} else {
mr . Range ( func ( fd protoreflect . FieldDescriptor , _ protoreflect . Value ) bool {
if int32 ( fd . Number ( ) ) == xt . Field {
mr . Clear ( fd )
return false
}
2016-11-29 19:26:36 +03:00
return true
2020-04-19 23:23:05 +03:00
} )
2018-11-05 06:20:00 +03:00
}
2020-04-19 23:23:05 +03:00
clearUnknown ( mr , fieldNum ( xt . Field ) )
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
// ClearAllExtensions clears all extensions from m.
// This includes populated fields and unknown fields in the extension range.
func ClearAllExtensions ( m Message ) {
mr := MessageReflect ( m )
if mr == nil || ! mr . IsValid ( ) {
2018-11-05 06:20:00 +03:00
return
}
2020-04-19 23:23:05 +03:00
mr . Range ( func ( fd protoreflect . FieldDescriptor , _ protoreflect . Value ) bool {
if fd . IsExtension ( ) {
mr . Clear ( fd )
}
return true
} )
clearUnknown ( mr , mr . Descriptor ( ) . ExtensionRanges ( ) )
2016-11-29 19:26:36 +03:00
}
2018-11-05 06:20:00 +03:00
// GetExtension retrieves a proto2 extended field from pb.
//
// If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
// then GetExtension parses the encoded field and returns a Go value of the specified type.
// If the field is not present, then the default value is returned (if one is specified),
// otherwise ErrMissingExtension is reported.
//
2020-04-19 23:23:05 +03:00
// If the descriptor is type incomplete (i.e., ExtensionDesc.ExtensionType is nil),
// then GetExtension returns the raw encoded bytes for the extension field.
func GetExtension ( m Message , xt * ExtensionDesc ) ( interface { } , error ) {
mr := MessageReflect ( m )
if mr == nil || ! mr . IsValid ( ) || mr . Descriptor ( ) . ExtensionRanges ( ) . Len ( ) == 0 {
return nil , errNotExtendable
}
// Retrieve the unknown fields for this extension field.
var bo protoreflect . RawFields
for bi := mr . GetUnknown ( ) ; len ( bi ) > 0 ; {
num , _ , n := protowire . ConsumeField ( bi )
if int32 ( num ) == xt . Field {
bo = append ( bo , bi [ : n ] ... )
2018-11-05 06:20:00 +03:00
}
2020-04-19 23:23:05 +03:00
bi = bi [ n : ]
2018-11-05 06:20:00 +03:00
}
2020-04-19 23:23:05 +03:00
// For type incomplete descriptors, only retrieve the unknown fields.
if xt . ExtensionType == nil {
return [ ] byte ( bo ) , nil
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
// If the extension field only exists as unknown fields, unmarshal it.
// This is rarely done since proto.Unmarshal eagerly unmarshals extensions.
xtd := xt . TypeDescriptor ( )
if ! isValidExtension ( mr . Descriptor ( ) , xtd ) {
return nil , fmt . Errorf ( "proto: bad extended type; %T does not extend %T" , xt . ExtendedType , m )
2018-11-05 06:20:00 +03:00
}
2020-04-19 23:23:05 +03:00
if ! mr . Has ( xtd ) && len ( bo ) > 0 {
m2 := mr . New ( )
if err := ( proto . UnmarshalOptions {
Resolver : extensionResolver { xt } ,
} . Unmarshal ( bo , m2 . Interface ( ) ) ) ; err != nil {
return nil , err
}
if m2 . Has ( xtd ) {
mr . Set ( xtd , m2 . Get ( xtd ) )
clearUnknown ( mr , fieldNum ( xt . Field ) )
}
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
// Check whether the message has the extension field set or a default.
var pv protoreflect . Value
switch {
case mr . Has ( xtd ) :
pv = mr . Get ( xtd )
case xtd . HasDefault ( ) :
pv = xtd . Default ( )
default :
2018-11-05 06:20:00 +03:00
return nil , ErrMissingExtension
}
2020-04-19 23:23:05 +03:00
v := xt . InterfaceOf ( pv )
rv := reflect . ValueOf ( v )
if isScalarKind ( rv . Kind ( ) ) {
rv2 := reflect . New ( rv . Type ( ) )
rv2 . Elem ( ) . Set ( rv )
v = rv2 . Interface ( )
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
return v , nil
}
2016-11-29 19:26:36 +03:00
2020-04-19 23:23:05 +03:00
// extensionResolver is a custom extension resolver that stores a single
// extension type that takes precedence over the global registry.
type extensionResolver struct { xt protoreflect . ExtensionType }
2016-11-29 19:26:36 +03:00
2020-04-19 23:23:05 +03:00
func ( r extensionResolver ) FindExtensionByName ( field protoreflect . FullName ) ( protoreflect . ExtensionType , error ) {
if xtd := r . xt . TypeDescriptor ( ) ; xtd . FullName ( ) == field {
return r . xt , nil
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
return protoregistry . GlobalTypes . FindExtensionByName ( field )
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
func ( r extensionResolver ) FindExtensionByNumber ( message protoreflect . FullName , field protoreflect . FieldNumber ) ( protoreflect . ExtensionType , error ) {
if xtd := r . xt . TypeDescriptor ( ) ; xtd . ContainingMessage ( ) . FullName ( ) == message && xtd . Number ( ) == field {
return r . xt , nil
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
return protoregistry . GlobalTypes . FindExtensionByNumber ( message , field )
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
// GetExtensions returns a list of the extensions values present in m,
// corresponding with the provided list of extension descriptors, xts.
// If an extension is missing in m, the corresponding value is nil.
func GetExtensions ( m Message , xts [ ] * ExtensionDesc ) ( [ ] interface { } , error ) {
mr := MessageReflect ( m )
if mr == nil || ! mr . IsValid ( ) {
return nil , errNotExtendable
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
vs := make ( [ ] interface { } , len ( xts ) )
for i , xt := range xts {
v , err := GetExtension ( m , xt )
2016-11-29 19:26:36 +03:00
if err != nil {
2020-04-19 23:23:05 +03:00
if err == ErrMissingExtension {
continue
}
return vs , err
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
vs [ i ] = v
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
return vs , nil
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
// SetExtension sets an extension field in m to the provided value.
func SetExtension ( m Message , xt * ExtensionDesc , v interface { } ) error {
mr := MessageReflect ( m )
if mr == nil || ! mr . IsValid ( ) || mr . Descriptor ( ) . ExtensionRanges ( ) . Len ( ) == 0 {
return errNotExtendable
2018-11-05 06:20:00 +03:00
}
2020-04-19 23:23:05 +03:00
rv := reflect . ValueOf ( v )
if reflect . TypeOf ( v ) != reflect . TypeOf ( xt . ExtensionType ) {
return fmt . Errorf ( "proto: bad extension value type. got: %T, want: %T" , v , xt . ExtensionType )
2018-11-05 06:20:00 +03:00
}
2020-04-19 23:23:05 +03:00
if rv . Kind ( ) == reflect . Ptr {
if rv . IsNil ( ) {
return fmt . Errorf ( "proto: SetExtension called with nil value of type %T" , v )
}
if isScalarKind ( rv . Elem ( ) . Kind ( ) ) {
v = rv . Elem ( ) . Interface ( )
2018-11-05 06:20:00 +03:00
}
}
2020-04-19 23:23:05 +03:00
xtd := xt . TypeDescriptor ( )
if ! isValidExtension ( mr . Descriptor ( ) , xtd ) {
return fmt . Errorf ( "proto: bad extended type; %T does not extend %T" , xt . ExtendedType , m )
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
mr . Set ( xtd , xt . ValueOf ( v ) )
clearUnknown ( mr , fieldNum ( xt . Field ) )
2016-11-29 19:26:36 +03:00
return nil
}
2020-04-19 23:23:05 +03:00
// SetRawExtension inserts b into the unknown fields of m.
//
// Deprecated: Use Message.ProtoReflect.SetUnknown instead.
func SetRawExtension ( m Message , fnum int32 , b [ ] byte ) {
mr := MessageReflect ( m )
if mr == nil || ! mr . IsValid ( ) {
2018-11-05 06:20:00 +03:00
return
}
2020-04-19 23:23:05 +03:00
// Verify that the raw field is valid.
for b0 := b ; len ( b0 ) > 0 ; {
num , _ , n := protowire . ConsumeField ( b0 )
if int32 ( num ) != fnum {
panic ( fmt . Sprintf ( "mismatching field number: got %d, want %d" , num , fnum ) )
}
b0 = b0 [ n : ]
2018-11-05 06:20:00 +03:00
}
2020-04-19 23:23:05 +03:00
ClearExtension ( m , & ExtensionDesc { Field : fnum } )
mr . SetUnknown ( append ( mr . GetUnknown ( ) , b ... ) )
}
2016-11-29 19:26:36 +03:00
2020-04-19 23:23:05 +03:00
// ExtensionDescs returns a list of extension descriptors found in m,
// containing descriptors for both populated extension fields in m and
// also unknown fields of m that are in the extension range.
// For the later case, an type incomplete descriptor is provided where only
// the ExtensionDesc.Field field is populated.
// The order of the extension descriptors is undefined.
func ExtensionDescs ( m Message ) ( [ ] * ExtensionDesc , error ) {
mr := MessageReflect ( m )
if mr == nil || ! mr . IsValid ( ) || mr . Descriptor ( ) . ExtensionRanges ( ) . Len ( ) == 0 {
return nil , errNotExtendable
}
2016-11-29 19:26:36 +03:00
2020-04-19 23:23:05 +03:00
// Collect a set of known extension descriptors.
extDescs := make ( map [ protoreflect . FieldNumber ] * ExtensionDesc )
mr . Range ( func ( fd protoreflect . FieldDescriptor , v protoreflect . Value ) bool {
if fd . IsExtension ( ) {
xt := fd . ( protoreflect . ExtensionTypeDescriptor )
if xd , ok := xt . Type ( ) . ( * ExtensionDesc ) ; ok {
extDescs [ fd . Number ( ) ] = xd
}
}
return true
} )
// Collect a set of unknown extension descriptors.
extRanges := mr . Descriptor ( ) . ExtensionRanges ( )
for b := mr . GetUnknown ( ) ; len ( b ) > 0 ; {
num , _ , n := protowire . ConsumeField ( b )
if extRanges . Has ( num ) && extDescs [ num ] == nil {
extDescs [ num ] = nil
}
b = b [ n : ]
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
// Transpose the set of descriptors into a list.
var xts [ ] * ExtensionDesc
for num , xt := range extDescs {
if xt == nil {
xt = & ExtensionDesc { Field : int32 ( num ) }
}
xts = append ( xts , xt )
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
return xts , nil
2016-11-29 19:26:36 +03:00
}
2020-04-19 23:23:05 +03:00
// isValidExtension reports whether xtd is a valid extension descriptor for md.
func isValidExtension ( md protoreflect . MessageDescriptor , xtd protoreflect . ExtensionTypeDescriptor ) bool {
return xtd . ContainingMessage ( ) == md && md . ExtensionRanges ( ) . Has ( xtd . Number ( ) )
2016-11-29 19:26:36 +03:00
}
2019-07-06 18:16:43 +03:00
2020-04-19 23:23:05 +03:00
// isScalarKind reports whether k is a protobuf scalar kind (except bytes).
// This function exists for historical reasons since the representation of
// scalars differs between v1 and v2, where v1 uses *T and v2 uses T.
func isScalarKind ( k reflect . Kind ) bool {
switch k {
2019-07-06 18:16:43 +03:00
case reflect . Bool , reflect . Int32 , reflect . Int64 , reflect . Uint32 , reflect . Uint64 , reflect . Float32 , reflect . Float64 , reflect . String :
2020-04-19 23:23:05 +03:00
return true
default :
return false
2019-07-06 18:16:43 +03:00
}
}
2020-04-19 23:23:05 +03:00
// clearUnknown removes unknown fields from m where remover.Has reports true.
func clearUnknown ( m protoreflect . Message , remover interface {
Has ( protoreflect . FieldNumber ) bool
} ) {
var bo protoreflect . RawFields
for bi := m . GetUnknown ( ) ; len ( bi ) > 0 ; {
num , _ , n := protowire . ConsumeField ( bi )
if ! remover . Has ( num ) {
bo = append ( bo , bi [ : n ] ... )
2019-07-06 18:16:43 +03:00
}
2020-04-19 23:23:05 +03:00
bi = bi [ n : ]
2019-07-06 18:16:43 +03:00
}
2020-04-19 23:23:05 +03:00
if bi := m . GetUnknown ( ) ; len ( bi ) != len ( bo ) {
m . SetUnknown ( bo )
}
}
type fieldNum protoreflect . FieldNumber
func ( n1 fieldNum ) Has ( n2 protoreflect . FieldNumber ) bool {
return protoreflect . FieldNumber ( n1 ) == n2
2019-07-06 18:16:43 +03:00
}