mirror of
https://github.com/shopspring/decimal.git
synced 2024-11-22 20:40:48 +01:00
add stupid test for extreme values, fix some super-edge-case bugs
This commit is contained in:
parent
54dc68463b
commit
6e97405099
2 changed files with 89 additions and 19 deletions
48
decimal.go
48
decimal.go
|
@ -305,7 +305,8 @@ func (d Decimal) IntPart() int64 {
|
||||||
func (d Decimal) Rat() *big.Rat {
|
func (d Decimal) Rat() *big.Rat {
|
||||||
d.ensureInitialized()
|
d.ensureInitialized()
|
||||||
if d.exp <= 0 {
|
if d.exp <= 0 {
|
||||||
denom := new(big.Int).Exp(tenInt, big.NewInt(int64(-d.exp)), nil)
|
// NOTE(vadim): must negate after casting to prevent int32 overflow
|
||||||
|
denom := new(big.Int).Exp(tenInt, big.NewInt(-int64(d.exp)), nil)
|
||||||
return new(big.Rat).SetFrac(d.value, denom)
|
return new(big.Rat).SetFrac(d.value, denom)
|
||||||
} else {
|
} else {
|
||||||
mul := new(big.Int).Exp(tenInt, big.NewInt(int64(d.exp)), nil)
|
mul := new(big.Int).Exp(tenInt, big.NewInt(int64(d.exp)), nil)
|
||||||
|
@ -365,7 +366,7 @@ func (d Decimal) StringFixed(places int32) string {
|
||||||
//
|
//
|
||||||
func (d Decimal) Round(places int32) Decimal {
|
func (d Decimal) Round(places int32) Decimal {
|
||||||
// truncate to places + 1
|
// truncate to places + 1
|
||||||
ret := d.rescale(-(places + 1))
|
ret := d.rescale(-places - 1)
|
||||||
|
|
||||||
// add sign(d) * 0.5
|
// add sign(d) * 0.5
|
||||||
if ret.value.Sign() < 0 {
|
if ret.value.Sign() < 0 {
|
||||||
|
@ -389,7 +390,10 @@ func (d Decimal) Floor() Decimal {
|
||||||
d.ensureInitialized()
|
d.ensureInitialized()
|
||||||
|
|
||||||
exp := big.NewInt(10)
|
exp := big.NewInt(10)
|
||||||
exp.Exp(exp, big.NewInt(int64(-d.exp)), nil)
|
|
||||||
|
// NOTE(vadim): must negate after casting to prevent int32 overflow
|
||||||
|
exp.Exp(exp, big.NewInt(-int64(d.exp)), nil)
|
||||||
|
|
||||||
z := new(big.Int).Div(d.value, exp)
|
z := new(big.Int).Div(d.value, exp)
|
||||||
return Decimal{value: z, exp: 0}
|
return Decimal{value: z, exp: 0}
|
||||||
}
|
}
|
||||||
|
@ -399,7 +403,10 @@ func (d Decimal) Ceil() Decimal {
|
||||||
d.ensureInitialized()
|
d.ensureInitialized()
|
||||||
|
|
||||||
exp := big.NewInt(10)
|
exp := big.NewInt(10)
|
||||||
exp.Exp(exp, big.NewInt(int64(-d.exp)), nil)
|
|
||||||
|
// NOTE(vadim): must negate after casting to prevent int32 overflow
|
||||||
|
exp.Exp(exp, big.NewInt(-int64(d.exp)), nil)
|
||||||
|
|
||||||
z, m := new(big.Int).DivMod(d.value, exp, new(big.Int))
|
z, m := new(big.Int).DivMod(d.value, exp, new(big.Int))
|
||||||
if m.Cmp(zeroInt) != 0 {
|
if m.Cmp(zeroInt) != 0 {
|
||||||
z.Add(z, oneInt)
|
z.Add(z, oneInt)
|
||||||
|
@ -407,6 +414,22 @@ func (d Decimal) Ceil() Decimal {
|
||||||
return Decimal{value: z, exp: 0}
|
return Decimal{value: z, exp: 0}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Truncate truncates off digits from the number, without rounding.
|
||||||
|
//
|
||||||
|
// NOTE: precision is the last digit that will not be truncated (must be >= 0).
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// decimal.NewFromString("123.456").Truncate(2).String() // "123.45"
|
||||||
|
//
|
||||||
|
func (d Decimal) Truncate(precision int32) Decimal {
|
||||||
|
d.ensureInitialized()
|
||||||
|
if precision >= 0 && -precision > d.exp {
|
||||||
|
return d.rescale(-precision)
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
// UnmarshalJSON implements the json.Unmarshaler interface.
|
// UnmarshalJSON implements the json.Unmarshaler interface.
|
||||||
func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error {
|
func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error {
|
||||||
str, err := unquoteIfQuoted(decimalBytes)
|
str, err := unquoteIfQuoted(decimalBytes)
|
||||||
|
@ -428,20 +451,6 @@ func (d Decimal) MarshalJSON() ([]byte, error) {
|
||||||
return []byte(str), nil
|
return []byte(str), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Truncate truncates off digits from the number, without rounding.
|
|
||||||
//
|
|
||||||
// NOTE: precision is the last digit that will not be truncated (should be >= 0)
|
|
||||||
//
|
|
||||||
// decimal.NewFromString("123.456").Truncate(2).String() // "123.45"
|
|
||||||
//
|
|
||||||
func (d Decimal) Truncate(precision int32) Decimal {
|
|
||||||
d.ensureInitialized()
|
|
||||||
if precision >= 0 && -precision > d.exp {
|
|
||||||
return d.rescale(-precision)
|
|
||||||
}
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan implements the sql.Scanner interface for database deserialization.
|
// Scan implements the sql.Scanner interface for database deserialization.
|
||||||
func (d *Decimal) Scan(value interface{}) error {
|
func (d *Decimal) Scan(value interface{}) error {
|
||||||
str, err := unquoteIfQuoted(value)
|
str, err := unquoteIfQuoted(value)
|
||||||
|
@ -493,6 +502,9 @@ func (d Decimal) string(trimTrailingZeros bool) string {
|
||||||
str := abs.String()
|
str := abs.String()
|
||||||
|
|
||||||
var intPart, fractionalPart string
|
var intPart, fractionalPart string
|
||||||
|
|
||||||
|
// NOTE(vadim): this cast to int will cause bugs if d.exp == INT_MIN
|
||||||
|
// and you are on a 32-bit machine. Won't fix this super-edge case.
|
||||||
dExpInt := int(d.exp)
|
dExpInt := int(d.exp)
|
||||||
if len(str) > -dExpInt {
|
if len(str) > -dExpInt {
|
||||||
intPart = str[:len(str)+dExpInt]
|
intPart = str[:len(str)+dExpInt]
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"math"
|
"math"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testTable = map[float64]string{
|
var testTable = map[float64]string{
|
||||||
|
@ -423,9 +424,10 @@ func TestDecimal_Uninitialized(t *testing.T) {
|
||||||
a.Sub(b),
|
a.Sub(b),
|
||||||
a.Mul(b),
|
a.Mul(b),
|
||||||
a.Div(New(1, -1)),
|
a.Div(New(1, -1)),
|
||||||
|
a.Round(2),
|
||||||
a.Floor(),
|
a.Floor(),
|
||||||
a.Ceil(),
|
a.Ceil(),
|
||||||
a.Round(2),
|
a.Truncate(2),
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, d := range decs {
|
for _, d := range decs {
|
||||||
|
@ -446,6 +448,16 @@ func TestDecimal_Uninitialized(t *testing.T) {
|
||||||
if a.Exponent() != 0 {
|
if a.Exponent() != 0 {
|
||||||
t.Errorf("a.Exponent() != 0")
|
t.Errorf("a.Exponent() != 0")
|
||||||
}
|
}
|
||||||
|
if a.IntPart() != 0 {
|
||||||
|
t.Errorf("a.IntPar() != 0")
|
||||||
|
}
|
||||||
|
f, _ := a.Float64()
|
||||||
|
if f != 0 {
|
||||||
|
t.Errorf("a.Float64() != 0")
|
||||||
|
}
|
||||||
|
if a.Rat().RatString() != "0" {
|
||||||
|
t.Errorf("a.Rat() != 0, got %s", a.Rat().RatString())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDecimal_Add(t *testing.T) {
|
func TestDecimal_Add(t *testing.T) {
|
||||||
|
@ -633,6 +645,52 @@ func TestDecimal_Overflow(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDecimal_ExtremeValues(t *testing.T) {
|
||||||
|
// NOTE(vadim): this test takes pretty much forever
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE(vadim): Seriously, the numbers invovled are so large that this
|
||||||
|
// test will take way too long, so mark it as success if it takes over
|
||||||
|
// 1 second. The way this test typically fails (integer overflow) is that
|
||||||
|
// a wrong result appears quickly, so if it takes a long time then it is
|
||||||
|
// probably working properly.
|
||||||
|
// Why even bother testing this? Completeness, I guess. -Vadim
|
||||||
|
const timeLimit = 1 * time.Second
|
||||||
|
test := func(f func()) {
|
||||||
|
c := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
f()
|
||||||
|
close(c)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-c:
|
||||||
|
case <-time.After(timeLimit):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test(func() {
|
||||||
|
got := New(123, math.MinInt32).Floor()
|
||||||
|
if !got.Equals(NewFromFloat(0)) {
|
||||||
|
t.Errorf("Error: got %s, expected 0", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
test(func() {
|
||||||
|
got := New(123, math.MinInt32).Ceil()
|
||||||
|
if !got.Equals(NewFromFloat(1)) {
|
||||||
|
t.Errorf("Error: got %s, expected 1", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
test(func() {
|
||||||
|
got := New(123, math.MinInt32).Rat().FloatString(10)
|
||||||
|
expected := "0.0000000000"
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("Error: got %s, expected %s", got, expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestIntPart(t *testing.T) {
|
func TestIntPart(t *testing.T) {
|
||||||
for _, testCase := range []struct {
|
for _, testCase := range []struct {
|
||||||
Dec string
|
Dec string
|
||||||
|
|
Loading…
Reference in a new issue