diff --git a/decimal.go b/decimal.go index ad92a90..dbfffb2 100644 --- a/decimal.go +++ b/decimal.go @@ -25,6 +25,8 @@ import ( "math/big" "strconv" "strings" + + "gopkg.in/mgo.v2/bson" ) // DivisionPrecision is the number of decimal places in the result when it @@ -874,6 +876,32 @@ func (d Decimal) MarshalBinary() (data []byte, err error) { return } +// GetBSON implements the bson.Getter interface +func (d Decimal) GetBSON() (interface{}, error) { + // Pass through string to create Mongo Decimal128 type + dec128, err := bson.ParseDecimal128(d.String()) + if err != nil { + return nil, err + } + return dec128, nil +} + +// SetBSON implements the bson.Setter interface +func (d *Decimal) SetBSON(raw bson.Raw) error { + // Unmarshal as Mongo Decimal128 first then pass through string to obtain Decimal + var dec128 bson.Decimal128 + berr := raw.Unmarshal(&dec128) + if berr != nil { + return berr + } + dec, derr := NewFromString(dec128.String()) + if derr != nil { + return derr + } + *d = dec + return nil +} + // Scan implements the sql.Scanner interface for database deserialization. func (d *Decimal) Scan(value interface{}) error { // first try to see if the data is stored in database as a Numeric datatype diff --git a/decimal_test.go b/decimal_test.go index ac80089..6b0cefe 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -12,6 +12,8 @@ import ( "strings" "testing" "time" + + "gopkg.in/mgo.v2/bson" ) type testEnt struct { @@ -2284,3 +2286,81 @@ func TestRoundBankAnomaly(t *testing.T) { t.Errorf("Expected bank rounding %s to equal %s, but it was %s", b, expected, bRounded) } } + +func TestBSON(t *testing.T) { + // Capture positive and negative cases of floating numbers and whole numbers in various bit ranges; also copied scientific notation + // test cases from above + tests := []string{ + "3.14159", + "42", + "42949672960", + "18446744073709551616000", + "-3.14159", + "-42", + "-42949672960", + "-18446744073709551616000", + "0", + "1e9", + "2.41E-3", + "24.2E-4", + "243E-5", + "1e-5", + "245E3", + "1.2345E-1", + "0e5", + "0e-5", + "123.456e0", + "123.456e2", + "123.456e10", + } + + type decStruct struct { + Dec Decimal + } + + // For each test the idea is that the String output of the original parsed decimal should match the String output of the + // decimal after it has been marshalled and unmarshalled into BSON + for i := range tests { + d, errD := NewFromString(tests[i]) + if errD != nil { + t.Errorf("TestBSON failed decimal parsing for case %v because of parse error - %v", tests[i], errD) + } + exp := d.String() + + // Test structure marshalling first + s1 := decStruct{Dec: d} + data, err := bson.Marshal(s1) + if err != nil { + t.Errorf("TestBSON failed structure marshalling for case %v because of marshal error - %v", tests[i], err) + } + s2 := decStruct{} + err = bson.Unmarshal(data, &s2) + if err != nil { + t.Errorf("TestBSON failed structure marshalling for case %v because of unmarshal error - %v", tests[i], err) + } + res := s2.Dec.String() + if exp != res { + t.Errorf("TestBSON failed structure marshalling for case %v got %v expected %v", tests[i], res, exp) + } + + // Next test map marshalling + m := bson.M{"dec": d} + data2, err2 := bson.Marshal(m) + if err2 != nil { + t.Errorf("TestBSON failed map marshalling for case %v because of marshal error - %v", tests[i], err2) + } + m2 := make(bson.M) + err2 = bson.Unmarshal(data2, m2) + if err2 != nil { + t.Errorf("TestBSON failed map marshalling for case %v because of unmarshal error - %v", tests[i], err2) + } + d2, errD2 := NewFromString(m2["dec"].(bson.Decimal128).String()) + if errD2 != nil { + t.Errorf("TestBSON failed map marshalling for case %v because of parse error - %v", tests[i], errD2) + } + res2 := d2.String() + if exp != res2 { + t.Errorf("TestBSON failed map marshalling for case %v got %v expected %v", tests[i], res2, exp) + } + } +}