diff --git a/decimal.go b/decimal.go index 475161d..e68ad2c 100644 --- a/decimal.go +++ b/decimal.go @@ -1381,6 +1381,33 @@ func (d NullDecimal) MarshalJSON() ([]byte, error) { return d.Decimal.MarshalJSON() } +// UnmarshalText implements the encoding.TextUnmarshaler interface for XML +// deserialization +func (d *NullDecimal) UnmarshalText(text []byte) error { + str := string(text) + + // check for empty XML or XML without body e.g., + if str == "" { + d.Valid = false + return nil + } + if err := d.Decimal.UnmarshalText(text); err != nil { + d.Valid = false + return err + } + d.Valid = true + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface for XML +// serialization. +func (d NullDecimal) MarshalText() (text []byte, err error) { + if !d.Valid { + return []byte{}, nil + } + return d.Decimal.MarshalText() +} + // Trig functions // Atan returns the arctangent, in radians, of x. diff --git a/decimal_test.go b/decimal_test.go index 72750ab..1bfca28 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -766,6 +766,96 @@ func TestBadXML(t *testing.T) { } } +func TestNullDecimalXML(t *testing.T) { + // test valid values + for _, x := range testTable { + s := x.short + var doc struct { + XMLName xml.Name `xml:"account"` + Amount NullDecimal `xml:"amount"` + } + docStr := `` + s + `` + err := xml.Unmarshal([]byte(docStr), &doc) + if err != nil { + t.Errorf("error unmarshaling %s: %v", docStr, err) + } else if doc.Amount.Decimal.String() != s { + t.Errorf("expected %s, got %s (%s, %d)", + s, doc.Amount.Decimal.String(), + doc.Amount.Decimal.value.String(), doc.Amount.Decimal.exp) + } + + out, err := xml.Marshal(&doc) + if err != nil { + t.Errorf("error marshaling %+v: %v", doc, err) + } else if string(out) != docStr { + t.Errorf("expected %s, got %s", docStr, string(out)) + } + } + + var doc struct { + XMLName xml.Name `xml:"account"` + Amount NullDecimal `xml:"amount"` + } + + // test for XML with empty body + docStr := `` + err := xml.Unmarshal([]byte(docStr), &doc) + if err != nil { + t.Errorf("error unmarshaling: %s: %v", docStr, err) + } else if doc.Amount.Valid { + t.Errorf("expected null value to have Valid = false, got Valid = true and Decimal = %s (%s, %d)", + doc.Amount.Decimal.String(), + doc.Amount.Decimal.value.String(), doc.Amount.Decimal.exp) + } + + expected := `` + out, err := xml.Marshal(&doc) + if err != nil { + t.Errorf("error marshaling %+v: %v", doc, err) + } else if string(out) != expected { + t.Errorf("expected %s, got %s", expected, string(out)) + } + + // test for empty XML + docStr = `` + err = xml.Unmarshal([]byte(docStr), &doc) + if err != nil { + t.Errorf("error unmarshaling: %s: %v", docStr, err) + } else if doc.Amount.Valid { + t.Errorf("expected null value to have Valid = false, got Valid = true and Decimal = %s (%s, %d)", + doc.Amount.Decimal.String(), + doc.Amount.Decimal.value.String(), doc.Amount.Decimal.exp) + } + + expected = `` + out, err = xml.Marshal(&doc) + if err != nil { + t.Errorf("error marshaling %+v: %v", doc, err) + } else if string(out) != expected { + t.Errorf("expected %s, got %s", expected, string(out)) + } +} + +func TestNullDecimalBadXML(t *testing.T) { + for _, testCase := range []string{ + "o_o", + "7", + ``, + `nope`, + `0.333`, + } { + var doc struct { + XMLName xml.Name `xml:"account"` + Amount NullDecimal `xml:"amount"` + } + err := xml.Unmarshal([]byte(testCase), &doc) + if err == nil { + t.Errorf("expected error, got %+v", doc) + } + } +} + func TestDecimal_rescale(t *testing.T) { type Inp struct { int int64