diff --git a/decimal_magpack_test.go b/decimal_magpack_test.go new file mode 100644 index 0000000..bc865a4 --- /dev/null +++ b/decimal_magpack_test.go @@ -0,0 +1,55 @@ +package decimal + +import ( + "testing" +) + +func TestMsgPack(t *testing.T) { + for _, x := range testTable { + s := x.short + // limit to 31 digits + if len(s) > 31 { + s = s[:31] + if s[30] == '.' { + s = s[:30] + } + } + + // Prepare Test Decimal Data + amount, err := NewFromString(s) + if err != nil { + t.Error(err) + } + s = amount.String() + + // MarshalMsg + var b []byte + out, err := amount.MarshalMsg(b) + if err != nil { + t.Errorf("error marshalMsg %s: %v", s, err) + } + + // check msg type + typ := out[0] & 0xe0 + if typ != 0xa0 { + t.Errorf("error marshalMsg, expected type = %b, got %b", 0xa0, typ) + } + + // check msg len + sz := int(out[0] & 0x1f) + if sz != len(s) { + t.Errorf("error marshalMsg, expected size = %d, got %d", len(s), sz) + } + + // UnmarshalMsg + var unmarshalAmount Decimal + _, err = unmarshalAmount.UnmarshalMsg(out) + if err != nil { + t.Errorf("error unmarshalMsg %s: %v", s, err) + } else if !unmarshalAmount.Equal(amount) { + t.Errorf("expected %s, got %s (%s, %d)", + amount.String(), unmarshalAmount.String(), + unmarshalAmount.value.String(), unmarshalAmount.exp) + } + } +} diff --git a/decimal_msgpack.go b/decimal_msgpack.go new file mode 100644 index 0000000..55d22d2 --- /dev/null +++ b/decimal_msgpack.go @@ -0,0 +1,94 @@ +package decimal + +import ( + "errors" +) + +var ( + errShortBytes = errors.New("msgp: too few bytes left to read object") +) + +// MarshalMsg implements msgp.Marshaler +// Note: limit to 31 digits, if d.IntPart size large than 31, will be lose. +func (d Decimal) MarshalMsg(b []byte) ([]byte, error) { + o := require(b, d.Msgsize()) + str := d.String() + sz := len(str) + // limit to 31 digits + // note, if d.IntPart size large than 3, will be lose. + if sz > 31 { + sz = 31 + // if last char is '.' then limit to 30 digits + if str[30] == '.' { + sz = 30 + } + + str = str[:sz] + } + + o, n := ensure(o, 1+sz) + o[n] = byte(0xa0 | sz) + n++ + + return o[:n+copy(o[n:], str)], nil +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (d *Decimal) UnmarshalMsg(b []byte) ([]byte, error) { + o, err := b, errShortBytes + + l := len(b) + if l < 1 { + return o, err + } + + sz := int(b[0] & 0x1f) + if len(b[1:]) < sz { + return o, err + } + if *d, err = NewFromString(string(b[1 : sz+1])); err == nil { + o = b[sz:] + } + return o, err +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (d Decimal) Msgsize() int { + return 32 +} + +// Require ensures that cap(old)-len(old) >= extra. +func require(old []byte, extra int) []byte { + l := len(old) + c := cap(old) + r := l + extra + if c >= r { + return old + } else if l == 0 { + return make([]byte, 0, extra) + } + // the new size is the greater + // of double the old capacity + // and the sum of the old length + // and the number of new bytes + // necessary. + c <<= 1 + if c < r { + c = r + } + n := make([]byte, l, c) + copy(n, old) + return n +} + +// ensure 'sz' extra bytes in 'b' btw len(b) and cap(b) +func ensure(b []byte, sz int) ([]byte, int) { + l := len(b) + c := cap(b) + if c-l < sz { + o := make([]byte, (2*c)+sz) // exponential growth + n := copy(o, b) + return o[:n+sz], n + } + return b[:l+sz], l +}