diff --git a/decimal.go b/decimal.go index 1729b61..6078361 100644 --- a/decimal.go +++ b/decimal.go @@ -928,8 +928,8 @@ func unquoteIfQuoted(value interface{}) (string, error) { return string(bytes), nil } -// NullDecimal represents a fixed-point decimal. It is immutable. -// number = value * 10 ^ exp +// NullDecimal represents a nullable decimal with compatibility for +// scanning null values from the database. type NullDecimal struct { Decimal Decimal Valid bool @@ -952,3 +952,21 @@ func (d NullDecimal) Value() (driver.Value, error) { } return d.Decimal.Value() } + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (d *NullDecimal) UnmarshalJSON(decimalBytes []byte) error { + if string(decimalBytes) == "null" { + d.Valid = false + return nil + } + d.Valid = true + return d.Decimal.UnmarshalJSON(decimalBytes) +} + +// MarshalJSON implements the json.Marshaler interface. +func (d NullDecimal) MarshalJSON() ([]byte, error) { + if !d.Valid { + return []byte("null"), nil + } + return d.Decimal.MarshalJSON() +} diff --git a/decimal_test.go b/decimal_test.go index d606dc9..9786bcc 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -332,6 +332,98 @@ func TestBadJSON(t *testing.T) { } } +func TestNullDecimalJSON(t *testing.T) { + for _, s := range testTable { + var doc struct { + Amount NullDecimal `json:"amount"` + } + docStr := `{"amount":"` + s + `"}` + docStrNumber := `{"amount":` + s + `}` + err := json.Unmarshal([]byte(docStr), &doc) + if err != nil { + t.Errorf("error unmarshaling %s: %v", docStr, err) + } else { + if !doc.Amount.Valid { + t.Errorf("expected %s to be valid (not NULL), got Valid = false", s) + } + if doc.Amount.Decimal.String() != s { + t.Errorf("expected %s, got %s (%s, %d)", + s, doc.Amount.Decimal.String(), + doc.Amount.Decimal.value.String(), doc.Amount.Decimal.exp) + } + } + + out, err := json.Marshal(&doc) + if err != nil { + t.Errorf("error marshaling %+v: %v", doc, err) + } else if string(out) != docStr { + t.Errorf("expected %s, got %s", docStr, string(out)) + } + + // make sure unquoted marshalling works too + MarshalJSONWithoutQuotes = true + out, err = json.Marshal(&doc) + if err != nil { + t.Errorf("error marshaling %+v: %v", doc, err) + } else if string(out) != docStrNumber { + t.Errorf("expected %s, got %s", docStrNumber, string(out)) + } + MarshalJSONWithoutQuotes = false + } + + var doc struct { + Amount NullDecimal `json:"amount"` + } + docStr := `{"amount": null}` + err := json.Unmarshal([]byte(docStr), &doc) + if err != nil { + t.Errorf("error unmarshaling %s: %v", docStr, err) + } else if doc.Amount.Valid { + t.Errorf("expected null value to have Valid = false, got Valid = true and Decimal = %s (%s, %d)", + doc.Amount.Decimal.String(), + doc.Amount.Decimal.value.String(), doc.Amount.Decimal.exp) + } + + expected := `{"amount":null}` + out, err := json.Marshal(&doc) + if err != nil { + t.Errorf("error marshaling %+v: %v", doc, err) + } else if string(out) != expected { + t.Errorf("expected %s, got %s", expected, string(out)) + } + + // make sure unquoted marshalling works too + MarshalJSONWithoutQuotes = true + expectedUnquoted := `{"amount":null}` + out, err = json.Marshal(&doc) + if err != nil { + t.Errorf("error marshaling %+v: %v", doc, err) + } else if string(out) != expectedUnquoted { + t.Errorf("expected %s, got %s", expectedUnquoted, string(out)) + } + MarshalJSONWithoutQuotes = false +} + +func TestNullDecimalBadJSON(t *testing.T) { + for _, testCase := range []string{ + "]o_o[", + "{", + `{"amount":""`, + `{"amount":""}`, + `{"amount":"nope"}`, + `{"amount":nope}`, + `0.333`, + } { + var doc struct { + Amount NullDecimal `json:"amount"` + } + err := json.Unmarshal([]byte(testCase), &doc) + if err == nil { + t.Errorf("expected error, got %+v", doc) + } + } +} + func TestXML(t *testing.T) { for _, s := range testTable { var doc struct {