diff --git a/decimal.go b/decimal.go index e8d99b4..cbbce7b 100644 --- a/decimal.go +++ b/decimal.go @@ -469,13 +469,29 @@ func (d Decimal) MarshalJSON() ([]byte, error) { // Scan implements the sql.Scanner interface for database deserialization. func (d *Decimal) Scan(value interface{}) error { - str, err := unquoteIfQuoted(value) - if err != nil { + // first try to see if the data is stored in database as a Numeric datatype + switch v := value.(type) { + + case float64: + // numeric in sqlite3 sends us float64 + *d = NewFromFloat(v) + return nil + + case int64: + // at least in sqlite3 when the value is 0 in db, the data is sent + // to us as an int64 instead of a float64 ... + *d = New(v, 0) + return nil + + default: + // default is trying to interpret value stored as string + str, err := unquoteIfQuoted(v) + if err != nil { + return err + } + *d, err = NewFromString(str) return err } - *d, err = NewFromString(str) - - return err } // Value implements the driver.Valuer interface for database serialization. diff --git a/decimal_test.go b/decimal_test.go index 5202234..e560ffb 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -813,6 +813,67 @@ func TestDecimal_Max(t *testing.T) { } } +func TestDecimal_Scan(t *testing.T) { + // test the Scan method the implements the + // sql.Scanner interface + // check for the for different type of values + // that are possible to be received from the database + // drivers + + // in normal operations the db driver (sqlite at least) + // will return an int64 if you specified a numeric format + a := Decimal{} + dbvalue := float64(54.33) + expected := NewFromFloat(dbvalue) + + err := a.Scan(dbvalue) + if err != nil { + // Scan failed... no need to test result value + t.Errorf("a.Scan(54.33) failed with message: %s", err) + + } else { + // Scan suceeded... test resulting values + if !a.Equals(expected) { + t.Errorf("%s does not equal to %s", a, expected) + } + } + + // at least SQLite returns an int64 when 0 is stored in the db + // and you specified a numeric format on the schema + dbvalue_int := int64(0) + expected = New(dbvalue_int, 0) + + err = a.Scan(dbvalue_int) + if err != nil { + // Scan failed... no need to test result value + t.Errorf("a.Scan(0) failed with message: %s", err) + + } else { + // Scan suceeded... test resulting values + if !a.Equals(expected) { + t.Errorf("%s does not equal to %s", a, expected) + } + } + + // in case you specified a varchar in your SQL schema, + // the database driver will return byte slice []byte + value_str := "535.666" + dbvalue_str := []byte(value_str) + expected, err = NewFromString(value_str) + + err = a.Scan(dbvalue_str) + if err != nil { + // Scan failed... no need to test result value + t.Errorf("a.Scan('535.666') failed with message: %s", err) + + } else { + // Scan suceeded... test resulting values + if !a.Equals(expected) { + t.Errorf("%s does not equal to %s", a, expected) + } + } +} + // old tests after this line func TestDecimal_Scale(t *testing.T) {