Unit testing for def/ module.

Added unit tests to test code in def/ module.
This commit is contained in:
Pradyumna Kaushik 2019-10-12 06:48:45 +00:00
parent e24b8a08c9
commit bac60e872a
396 changed files with 83991 additions and 13209 deletions

View file

@ -99,6 +99,8 @@ type unmarshalFieldInfo struct {
// if a required field, contains a single set bit at this field's index in the required field list.
reqMask uint64
name string // name of the field, for error reporting
}
var (
@ -136,10 +138,10 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
u.computeUnmarshalInfo()
}
if u.isMessageSet {
return UnmarshalMessageSet(b, m.offset(u.extensions).toExtensions())
return unmarshalMessageSet(b, m.offset(u.extensions).toExtensions())
}
var reqMask uint64 // bitmask of required fields we've seen.
var rnse *RequiredNotSetError // an instance of a RequiredNotSetError returned by a submessage.
var reqMask uint64 // bitmask of required fields we've seen.
var errLater error
for len(b) > 0 {
// Read tag and wire type.
// Special case 1 and 2 byte varints.
@ -178,11 +180,20 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
if r, ok := err.(*RequiredNotSetError); ok {
// Remember this error, but keep parsing. We need to produce
// a full parse even if a required field is missing.
rnse = r
if errLater == nil {
errLater = r
}
reqMask |= f.reqMask
continue
}
if err != errInternalBadWireType {
if err == errInvalidUTF8 {
if errLater == nil {
fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name
errLater = &invalidUTF8Error{fullName}
}
continue
}
return err
}
// Fragments with bad wire type are treated as unknown fields.
@ -244,20 +255,16 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
emap[int32(tag)] = e
}
}
if rnse != nil {
// A required field of a submessage/group is missing. Return that error.
return rnse
}
if reqMask != u.reqMask {
if reqMask != u.reqMask && errLater == nil {
// A required field of this message is missing.
for _, n := range u.reqFields {
if reqMask&1 == 0 {
return &RequiredNotSetError{n}
errLater = &RequiredNotSetError{n}
}
reqMask >>= 1
}
}
return nil
return errLater
}
// computeUnmarshalInfo fills in u with information for use
@ -360,26 +367,36 @@ func (u *unmarshalInfo) computeUnmarshalInfo() {
}
// Store the info in the correct slot in the message.
u.setTag(tag, toField(&f), unmarshal, reqMask)
u.setTag(tag, toField(&f), unmarshal, reqMask, name)
}
// Find any types associated with oneof fields.
// TODO: XXX_OneofFuncs returns more info than we need. Get rid of some of it?
fn := reflect.Zero(reflect.PtrTo(t)).MethodByName("XXX_OneofFuncs")
// gogo: len(oneofFields) > 0 is needed for embedded oneof messages, without a marshaler and unmarshaler
if fn.IsValid() && len(oneofFields) > 0 {
res := fn.Call(nil)[3] // last return value from XXX_OneofFuncs: []interface{}
for i := res.Len() - 1; i >= 0; i-- {
v := res.Index(i) // interface{}
tptr := reflect.ValueOf(v.Interface()).Type() // *Msg_X
typ := tptr.Elem() // Msg_X
if len(oneofFields) > 0 {
var oneofImplementers []interface{}
switch m := reflect.Zero(reflect.PtrTo(t)).Interface().(type) {
case oneofFuncsIface:
_, _, _, oneofImplementers = m.XXX_OneofFuncs()
case oneofWrappersIface:
oneofImplementers = m.XXX_OneofWrappers()
}
for _, v := range oneofImplementers {
tptr := reflect.TypeOf(v) // *Msg_X
typ := tptr.Elem() // Msg_X
f := typ.Field(0) // oneof implementers have one field
baseUnmarshal := fieldUnmarshaler(&f)
tagstr := strings.Split(f.Tag.Get("protobuf"), ",")[1]
tag, err := strconv.Atoi(tagstr)
tags := strings.Split(f.Tag.Get("protobuf"), ",")
fieldNum, err := strconv.Atoi(tags[1])
if err != nil {
panic("protobuf tag field not an integer: " + tagstr)
panic("protobuf tag field not an integer: " + tags[1])
}
var name string
for _, tag := range tags {
if strings.HasPrefix(tag, "name=") {
name = strings.TrimPrefix(tag, "name=")
break
}
}
// Find the oneof field that this struct implements.
@ -390,14 +407,15 @@ func (u *unmarshalInfo) computeUnmarshalInfo() {
// That lets us know where this struct should be stored
// when we encounter it during unmarshaling.
unmarshal := makeUnmarshalOneof(typ, of.ityp, baseUnmarshal)
u.setTag(tag, of.field, unmarshal, 0)
u.setTag(fieldNum, of.field, unmarshal, 0, name)
}
}
}
}
// Get extension ranges, if any.
fn = reflect.Zero(reflect.PtrTo(t)).MethodByName("ExtensionRangeArray")
fn := reflect.Zero(reflect.PtrTo(t)).MethodByName("ExtensionRangeArray")
if fn.IsValid() {
if !u.extensions.IsValid() && !u.oldExtensions.IsValid() && !u.bytesExtensions.IsValid() {
panic("a message with extensions, but no extensions field in " + t.Name())
@ -411,7 +429,7 @@ func (u *unmarshalInfo) computeUnmarshalInfo() {
// [0 0] is [tag=0/wiretype=varint varint-encoded-0].
u.setTag(0, zeroField, func(b []byte, f pointer, w int) ([]byte, error) {
return nil, fmt.Errorf("proto: %s: illegal tag 0 (wire type %d)", t, w)
}, 0)
}, 0, "")
// Set mask for required field check.
u.reqMask = uint64(1)<<uint(len(u.reqFields)) - 1
@ -423,8 +441,9 @@ func (u *unmarshalInfo) computeUnmarshalInfo() {
// tag = tag # for field
// field/unmarshal = unmarshal info for that field.
// reqMask = if required, bitmask for field position in required field list. 0 otherwise.
func (u *unmarshalInfo) setTag(tag int, field field, unmarshal unmarshaler, reqMask uint64) {
i := unmarshalFieldInfo{field: field, unmarshal: unmarshal, reqMask: reqMask}
// name = short name of the field.
func (u *unmarshalInfo) setTag(tag int, field field, unmarshal unmarshaler, reqMask uint64, name string) {
i := unmarshalFieldInfo{field: field, unmarshal: unmarshal, reqMask: reqMask, name: name}
n := u.typ.NumField()
if tag >= 0 && (tag < 16 || tag < 2*n) { // TODO: what are the right numbers here?
for len(u.dense) <= tag {
@ -455,10 +474,16 @@ func typeUnmarshaler(t reflect.Type, tags string) unmarshaler {
ctype := false
isTime := false
isDuration := false
isWktPointer := false
proto3 := false
validateUTF8 := true
for _, tag := range tagArray[3:] {
if strings.HasPrefix(tag, "name=") {
name = tag[5:]
}
if tag == "proto3" {
proto3 = true
}
if strings.HasPrefix(tag, "customtype=") {
ctype = true
}
@ -468,7 +493,11 @@ func typeUnmarshaler(t reflect.Type, tags string) unmarshaler {
if tag == "stdduration" {
isDuration = true
}
if tag == "wktptr" {
isWktPointer = true
}
}
validateUTF8 = validateUTF8 && proto3
// Figure out packaging (pointer, slice, or both)
slice := false
@ -522,6 +551,112 @@ func typeUnmarshaler(t reflect.Type, tags string) unmarshaler {
return makeUnmarshalDuration(getUnmarshalInfo(t), name)
}
if isWktPointer {
switch t.Kind() {
case reflect.Float64:
if pointer {
if slice {
return makeStdDoubleValuePtrSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdDoubleValuePtrUnmarshaler(getUnmarshalInfo(t), name)
}
if slice {
return makeStdDoubleValueSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdDoubleValueUnmarshaler(getUnmarshalInfo(t), name)
case reflect.Float32:
if pointer {
if slice {
return makeStdFloatValuePtrSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdFloatValuePtrUnmarshaler(getUnmarshalInfo(t), name)
}
if slice {
return makeStdFloatValueSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdFloatValueUnmarshaler(getUnmarshalInfo(t), name)
case reflect.Int64:
if pointer {
if slice {
return makeStdInt64ValuePtrSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdInt64ValuePtrUnmarshaler(getUnmarshalInfo(t), name)
}
if slice {
return makeStdInt64ValueSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdInt64ValueUnmarshaler(getUnmarshalInfo(t), name)
case reflect.Uint64:
if pointer {
if slice {
return makeStdUInt64ValuePtrSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdUInt64ValuePtrUnmarshaler(getUnmarshalInfo(t), name)
}
if slice {
return makeStdUInt64ValueSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdUInt64ValueUnmarshaler(getUnmarshalInfo(t), name)
case reflect.Int32:
if pointer {
if slice {
return makeStdInt32ValuePtrSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdInt32ValuePtrUnmarshaler(getUnmarshalInfo(t), name)
}
if slice {
return makeStdInt32ValueSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdInt32ValueUnmarshaler(getUnmarshalInfo(t), name)
case reflect.Uint32:
if pointer {
if slice {
return makeStdUInt32ValuePtrSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdUInt32ValuePtrUnmarshaler(getUnmarshalInfo(t), name)
}
if slice {
return makeStdUInt32ValueSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdUInt32ValueUnmarshaler(getUnmarshalInfo(t), name)
case reflect.Bool:
if pointer {
if slice {
return makeStdBoolValuePtrSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdBoolValuePtrUnmarshaler(getUnmarshalInfo(t), name)
}
if slice {
return makeStdBoolValueSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdBoolValueUnmarshaler(getUnmarshalInfo(t), name)
case reflect.String:
if pointer {
if slice {
return makeStdStringValuePtrSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdStringValuePtrUnmarshaler(getUnmarshalInfo(t), name)
}
if slice {
return makeStdStringValueSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdStringValueUnmarshaler(getUnmarshalInfo(t), name)
case uint8SliceType:
if pointer {
if slice {
return makeStdBytesValuePtrSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdBytesValuePtrUnmarshaler(getUnmarshalInfo(t), name)
}
if slice {
return makeStdBytesValueSliceUnmarshaler(getUnmarshalInfo(t), name)
}
return makeStdBytesValueUnmarshaler(getUnmarshalInfo(t), name)
default:
panic(fmt.Sprintf("unknown wktpointer type %#v", t))
}
}
// We'll never have both pointer and slice for basic types.
if pointer && slice && t.Kind() != reflect.Struct {
panic("both pointer and slice for basic type in " + t.Name())
@ -656,6 +791,15 @@ func typeUnmarshaler(t reflect.Type, tags string) unmarshaler {
}
return unmarshalBytesValue
case reflect.String:
if validateUTF8 {
if pointer {
return unmarshalUTF8StringPtr
}
if slice {
return unmarshalUTF8StringSlice
}
return unmarshalUTF8StringValue
}
if pointer {
return unmarshalStringPtr
}
@ -1516,9 +1660,6 @@ func unmarshalStringValue(b []byte, f pointer, w int) ([]byte, error) {
return nil, io.ErrUnexpectedEOF
}
v := string(b[:x])
if !utf8.ValidString(v) {
return nil, errInvalidUTF8
}
*f.toString() = v
return b[x:], nil
}
@ -1536,9 +1677,6 @@ func unmarshalStringPtr(b []byte, f pointer, w int) ([]byte, error) {
return nil, io.ErrUnexpectedEOF
}
v := string(b[:x])
if !utf8.ValidString(v) {
return nil, errInvalidUTF8
}
*f.toStringPtr() = &v
return b[x:], nil
}
@ -1556,14 +1694,72 @@ func unmarshalStringSlice(b []byte, f pointer, w int) ([]byte, error) {
return nil, io.ErrUnexpectedEOF
}
v := string(b[:x])
if !utf8.ValidString(v) {
return nil, errInvalidUTF8
}
s := f.toStringSlice()
*s = append(*s, v)
return b[x:], nil
}
func unmarshalUTF8StringValue(b []byte, f pointer, w int) ([]byte, error) {
if w != WireBytes {
return b, errInternalBadWireType
}
x, n := decodeVarint(b)
if n == 0 {
return nil, io.ErrUnexpectedEOF
}
b = b[n:]
if x > uint64(len(b)) {
return nil, io.ErrUnexpectedEOF
}
v := string(b[:x])
*f.toString() = v
if !utf8.ValidString(v) {
return b[x:], errInvalidUTF8
}
return b[x:], nil
}
func unmarshalUTF8StringPtr(b []byte, f pointer, w int) ([]byte, error) {
if w != WireBytes {
return b, errInternalBadWireType
}
x, n := decodeVarint(b)
if n == 0 {
return nil, io.ErrUnexpectedEOF
}
b = b[n:]
if x > uint64(len(b)) {
return nil, io.ErrUnexpectedEOF
}
v := string(b[:x])
*f.toStringPtr() = &v
if !utf8.ValidString(v) {
return b[x:], errInvalidUTF8
}
return b[x:], nil
}
func unmarshalUTF8StringSlice(b []byte, f pointer, w int) ([]byte, error) {
if w != WireBytes {
return b, errInternalBadWireType
}
x, n := decodeVarint(b)
if n == 0 {
return nil, io.ErrUnexpectedEOF
}
b = b[n:]
if x > uint64(len(b)) {
return nil, io.ErrUnexpectedEOF
}
v := string(b[:x])
s := f.toStringSlice()
*s = append(*s, v)
if !utf8.ValidString(v) {
return b[x:], errInvalidUTF8
}
return b[x:], nil
}
var emptyBuf [0]byte
func unmarshalBytesValue(b []byte, f pointer, w int) ([]byte, error) {
@ -1731,6 +1927,9 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler {
if t == "stdduration" {
valTags = append(valTags, t)
}
if t == "wktptr" {
valTags = append(valTags, t)
}
}
unmarshalKey := typeUnmarshaler(kt, f.Tag.Get("protobuf_key"))
unmarshalVal := typeUnmarshaler(vt, strings.Join(valTags, ","))
@ -1755,6 +1954,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler {
// Maps will be somewhat slow. Oh well.
// Read key and value from data.
var nerr nonFatal
k := reflect.New(kt)
v := reflect.New(vt)
for len(b) > 0 {
@ -1775,7 +1975,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler {
err = errInternalBadWireType // skip unknown tag
}
if err == nil {
if nerr.Merge(err) {
continue
}
if err != errInternalBadWireType {
@ -1798,7 +1998,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler {
// Insert into map.
m.SetMapIndex(k.Elem(), v.Elem())
return r, nil
return r, nerr.E
}
}
@ -1824,15 +2024,16 @@ func makeUnmarshalOneof(typ, ityp reflect.Type, unmarshal unmarshaler) unmarshal
// Unmarshal data into holder.
// We unmarshal into the first field of the holder object.
var err error
var nerr nonFatal
b, err = unmarshal(b, valToPointer(v).offset(field0), w)
if err != nil {
if !nerr.Merge(err) {
return nil, err
}
// Write pointer to holder into target field.
f.asPointerTo(ityp).Elem().Set(v)
return b, nil
return b, nerr.E
}
}
@ -1945,7 +2146,7 @@ func encodeVarint(b []byte, x uint64) []byte {
// If there is an error, it returns 0,0.
func decodeVarint(b []byte) (uint64, int) {
var x, y uint64
if len(b) <= 0 {
if len(b) == 0 {
goto bad
}
x = uint64(b[0])