diff --git a/decimal.go b/decimal.go index 52afdd4..df99c58 100644 --- a/decimal.go +++ b/decimal.go @@ -19,6 +19,7 @@ package decimal import ( "database/sql/driver" + "encoding/binary" "fmt" "math" "math/big" @@ -597,6 +598,34 @@ func (d Decimal) MarshalJSON() ([]byte, error) { return []byte(str), nil } +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. As a string representation +// is already used when encoding to text, this method stores that string as []byte +func (d *Decimal) UnmarshalBinary(data []byte) error { + // Extract the exponent + d.exp = int32(binary.BigEndian.Uint32(data[:4])) + + // Extract the value + d.value = new(big.Int) + return d.value.GobDecode(data[4:]) +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +func (d Decimal) MarshalBinary() (data []byte, err error) { + // Write the exponent first since it's a fixed size + v1 := make([]byte, 4) + binary.BigEndian.PutUint32(v1, uint32(d.exp)) + + // Add the value + var v2 []byte + if v2, err = d.value.GobEncode(); err != nil { + return + } + + // Return the byte array + data = append(v1, v2...) + return +} + // Scan implements the sql.Scanner interface for database deserialization. func (d *Decimal) Scan(value interface{}) error { // first try to see if the data is stored in database as a Numeric datatype diff --git a/decimal_test.go b/decimal_test.go index d55861c..b7b9997 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -1519,3 +1519,29 @@ func TestNullDecimal_Value(t *testing.T) { t.Errorf("%v does not equal %v", a, expected) } } + +func TestBinary(t *testing.T) { + for x, _ := range testTable { + + // Create the decimal + d1 := NewFromFloat(x) + + // Encode to binary + b, err := d1.MarshalBinary() + if err != nil { + t.Errorf("error marshalling %v to binary: %v", d1, err) + } + + // Restore from binary + var d2 Decimal + err = (&d2).UnmarshalBinary(b) + if err != nil { + t.Errorf("error unmarshalling from binary: %v", err) + } + + // The restored decimal should equal the original + if !d1.Equals(d2) { + t.Errorf("expected %v when restoring, got %v", d1, d2) + } + } +}