From 879e52d70a0d7b72c7131dc55fd5502fe8478b38 Mon Sep 17 00:00:00 2001 From: 13981712066 Date: Mon, 6 Dec 2021 09:09:53 +0800 Subject: [PATCH] Add MsgPack Marshal and Unmarshal Test Case --- decimal_magpack_test.go | 55 +++++++++++++++++++++++++++++++++++++++++ decimal_msgpack.go | 15 +++++------ 2 files changed, 63 insertions(+), 7 deletions(-) create mode 100644 decimal_magpack_test.go diff --git a/decimal_magpack_test.go b/decimal_magpack_test.go new file mode 100644 index 0000000..41d9dd2 --- /dev/null +++ b/decimal_magpack_test.go @@ -0,0 +1,55 @@ +package decimal + +import ( + "testing" +) + +func TestMsgPack(t *testing.T) { + for _, x := range testTable { + s := x.short + // limit to 31 digits + if len(s) > 31 { + s = s[:31] + if s[30] == '.' { + s = s[:30] + } + } + + // Prepare Test Decimal Data + amount, err := NewFromString(s) + if err != nil{ + t.Error(err) + } + s = amount.String() + + // MarshalMsg + var b []byte + out, err := amount.MarshalMsg(b) + if err != nil{ + t.Errorf("error marshalMsg %s: %v", s, err) + } + + // check msg type + typ := out[0] & 0xe0 + if typ != 0xa0 { + t.Errorf("error marshalMsg, expected type = %b, got %b", 0xa0, typ) + } + + // check msg len + sz := int(out[0] & 0x1f) + if sz != len(s) { + t.Errorf("error marshalMsg, expected size = %d, got %d", len(s), sz) + } + + // UnmarshalMsg + var unmarshalAmount Decimal + _, err = unmarshalAmount.UnmarshalMsg(out) + if err != nil{ + t.Errorf("error unmarshalMsg %s: %v", s, err) + }else if !unmarshalAmount.Equal(amount) { + t.Errorf("expected %s, got %s (%s, %d)", + amount.String(), unmarshalAmount.String(), + unmarshalAmount.value.String(), unmarshalAmount.exp) + } + } +} \ No newline at end of file diff --git a/decimal_msgpack.go b/decimal_msgpack.go index f0647ae..03a335d 100644 --- a/decimal_msgpack.go +++ b/decimal_msgpack.go @@ -9,17 +9,18 @@ var ( ) // MarshalMsg implements msgp.Marshaler +// Note: limit to 31 digits, if d.IntPart size large than 31, will be lose. func (d Decimal) MarshalMsg(b []byte) (o []byte, err error) { o = require(b, d.Msgsize()) str := d.String() sz := len(str) - // limit to 30 digits - // note, if d.IntPart size large than 30, will be lose. - if sz > 30 { - sz = 30 - // if last char is '.' then limit to 20 digits - if str[29] == '.' { - sz = 29 + // limit to 31 digits + // note, if d.IntPart size large than 31, will be lose. + if sz > 31 { + sz = 31 + // if last char is '.' then limit to 30 digits + if str[30] == '.' { + sz = 30 } str = str[:sz]