From 02487a179e892664332d301fb864aca8304795dd Mon Sep 17 00:00:00 2001 From: Craig Jackson Date: Wed, 2 Nov 2016 10:15:14 -0600 Subject: [PATCH 1/2] Added NullDecimal. Changed Decimal to implement driver.Valuer. From comments in PR 16. --- decimal.go | 30 ++++++++-- decimal_test.go | 142 +++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 160 insertions(+), 12 deletions(-) diff --git a/decimal.go b/decimal.go index 91f9f6c..c4a73c5 100644 --- a/decimal.go +++ b/decimal.go @@ -521,10 +521,7 @@ func (d *Decimal) Scan(value interface{}) error { } // Value implements the driver.Valuer interface for database serialization. -func (d *Decimal) Value() (driver.Value, error) { - if d == nil { - return nil, nil - } +func (d Decimal) Value() (driver.Value, error) { return d.String(), nil } @@ -672,3 +669,28 @@ func unquoteIfQuoted(value interface{}) (string, error) { } return string(bytes), nil } + +// NullDecimal represents a fixed-point decimal. It is immutable. +// number = value * 10 ^ exp +type NullDecimal struct { + Decimal Decimal + Valid bool +} + +// Scan implements the sql.Scanner interface for database deserialization. +func (d *NullDecimal) Scan(value interface{}) error { + if value == nil { + d.Valid = false + return nil + } + d.Valid = true + return d.Decimal.Scan(value) +} + +// Value implements the driver.Valuer interface for database serialization. +func (d NullDecimal) Value() (driver.Value, error) { + if !d.Valid { + return nil, nil + } + return d.Decimal.String(), nil +} diff --git a/decimal_test.go b/decimal_test.go index b0b2088..f8b0312 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -1,6 +1,7 @@ package decimal import ( + "database/sql/driver" "encoding/json" "encoding/xml" "math" @@ -941,19 +942,16 @@ func TestDecimal_Scan(t *testing.T) { } func TestDecimal_Value(t *testing.T) { - // check that nil is handled appropriately - var decimalPtr *Decimal - value, err := decimalPtr.Value() - if err != nil { - t.Errorf("(*Decimal)().Value() failed with message: %s", err) - } else if value != nil { - t.Errorf("%v is not nil", value) + // Make sure this does implement the database/sql's driver.Valuer interface + var d Decimal + if _, ok := interface{}(d).(driver.Valuer); !ok { + t.Error("Decimal does not implement driver.Valuer") } // check that normal case is handled appropriately a := New(1234, -2) expected := "12.34" - value, err = a.Value() + value, err := a.Value() if err != nil { t.Errorf("Decimal(12.34).Value() failed with message: %s", err) } else if value.(string) != expected { @@ -1059,3 +1057,131 @@ func Benchmark_Cmp(b *testing.B) { sort.Sort(decimals) } } + +func TestNullDecimal_Scan(t *testing.T) { + // test the Scan method that implements the + // sql.Scanner interface + // check for the for different type of values + // that are possible to be received from the database + // drivers + + // in normal operations the db driver (sqlite at least) + // will return an int64 if you specified a numeric format + + // Make sure handles nil values + a := NullDecimal{} + var dbvaluePtr interface{} + err := a.Scan(dbvaluePtr) + if err != nil { + // Scan failed... no need to test result value + t.Errorf("a.Scan(nil) failed with message: %s", err) + } else { + if a.Valid { + t.Errorf("%s is not null", a.Decimal) + } + } + + dbvalue := float64(54.33) + expected := NewFromFloat(dbvalue) + + err = a.Scan(dbvalue) + if err != nil { + // Scan failed... no need to test result value + t.Errorf("a.Scan(54.33) failed with message: %s", err) + + } else { + // Scan suceeded... test resulting values + if !a.Valid { + t.Errorf("%s is null", a.Decimal) + } else if !a.Decimal.Equals(expected) { + t.Errorf("%s does not equal to %s", a.Decimal, expected) + } + } + + // at least SQLite returns an int64 when 0 is stored in the db + // and you specified a numeric format on the schema + dbvalueInt := int64(0) + expected = New(dbvalueInt, 0) + + err = a.Scan(dbvalueInt) + if err != nil { + // Scan failed... no need to test result value + t.Errorf("a.Scan(0) failed with message: %s", err) + + } else { + // Scan suceeded... test resulting values + if !a.Valid { + t.Errorf("%s is null", a.Decimal) + } else if !a.Decimal.Equals(expected) { + t.Errorf("%s does not equal to %s", a, expected) + } + } + + // in case you specified a varchar in your SQL schema, + // the database driver will return byte slice []byte + valueStr := "535.666" + dbvalueStr := []byte(valueStr) + expected, err = NewFromString(valueStr) + if err != nil { + t.Fatal(err) + } + + err = a.Scan(dbvalueStr) + if err != nil { + // Scan failed... no need to test result value + t.Errorf("a.Scan('535.666') failed with message: %s", err) + + } else { + // Scan suceeded... test resulting values + if !a.Valid { + t.Errorf("%s is null", a.Decimal) + } else if !a.Decimal.Equals(expected) { + t.Errorf("%s does not equal to %s", a, expected) + } + } + + // lib/pq can also return strings + expected, err = NewFromString(valueStr) + if err != nil { + t.Fatal(err) + } + + err = a.Scan(valueStr) + if err != nil { + // Scan failed... no need to test result value + t.Errorf("a.Scan('535.666') failed with message: %s", err) + } else { + // Scan suceeded... test resulting values + if !a.Valid { + t.Errorf("%s is null", a.Decimal) + } else if !a.Decimal.Equals(expected) { + t.Errorf("%s does not equal to %s", a, expected) + } + } +} + +func TestNullDecimal_Value(t *testing.T) { + // Make sure this does implement the database/sql's driver.Valuer interface + var nullDecimal NullDecimal + if _, ok := interface{}(nullDecimal).(driver.Valuer); !ok { + t.Error("NullDecimal does not implement driver.Valuer") + } + + // check that null is handled appropriately + value, err := nullDecimal.Value() + if err != nil { + t.Errorf("NullDecimal{}.Valid() failed with message: %s", err) + } else if value != nil { + t.Errorf("%v is not nil", value) + } + + // check that normal case is handled appropriately + a := NullDecimal{Decimal: New(1234, -2), Valid: true} + expected := "12.34" + value, err = a.Value() + if err != nil { + t.Errorf("NullDecimal(12.34).Value() failed with message: %s", err) + } else if value.(string) != expected { + t.Errorf("%s does not equal to %s", a, expected) + } +} From f03103a27dcb0bcb35f980d13c749a4ff2212edb Mon Sep 17 00:00:00 2001 From: Craig Jackson Date: Thu, 10 Nov 2016 16:07:19 -0700 Subject: [PATCH 2/2] Have NullDecimal.Value() use Decimal.Value(). --- decimal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/decimal.go b/decimal.go index c4a73c5..b4d611f 100644 --- a/decimal.go +++ b/decimal.go @@ -692,5 +692,5 @@ func (d NullDecimal) Value() (driver.Value, error) { if !d.Valid { return nil, nil } - return d.Decimal.String(), nil + return d.Decimal.Value() }