mirror of
https://github.com/shopspring/decimal.git
synced 2024-11-22 20:40:48 +01:00
add stupid test for extreme values, fix some super-edge-case bugs
This commit is contained in:
parent
54dc68463b
commit
6e97405099
2 changed files with 89 additions and 19 deletions
48
decimal.go
48
decimal.go
|
@ -305,7 +305,8 @@ func (d Decimal) IntPart() int64 {
|
|||
func (d Decimal) Rat() *big.Rat {
|
||||
d.ensureInitialized()
|
||||
if d.exp <= 0 {
|
||||
denom := new(big.Int).Exp(tenInt, big.NewInt(int64(-d.exp)), nil)
|
||||
// NOTE(vadim): must negate after casting to prevent int32 overflow
|
||||
denom := new(big.Int).Exp(tenInt, big.NewInt(-int64(d.exp)), nil)
|
||||
return new(big.Rat).SetFrac(d.value, denom)
|
||||
} else {
|
||||
mul := new(big.Int).Exp(tenInt, big.NewInt(int64(d.exp)), nil)
|
||||
|
@ -365,7 +366,7 @@ func (d Decimal) StringFixed(places int32) string {
|
|||
//
|
||||
func (d Decimal) Round(places int32) Decimal {
|
||||
// truncate to places + 1
|
||||
ret := d.rescale(-(places + 1))
|
||||
ret := d.rescale(-places - 1)
|
||||
|
||||
// add sign(d) * 0.5
|
||||
if ret.value.Sign() < 0 {
|
||||
|
@ -389,7 +390,10 @@ func (d Decimal) Floor() Decimal {
|
|||
d.ensureInitialized()
|
||||
|
||||
exp := big.NewInt(10)
|
||||
exp.Exp(exp, big.NewInt(int64(-d.exp)), nil)
|
||||
|
||||
// NOTE(vadim): must negate after casting to prevent int32 overflow
|
||||
exp.Exp(exp, big.NewInt(-int64(d.exp)), nil)
|
||||
|
||||
z := new(big.Int).Div(d.value, exp)
|
||||
return Decimal{value: z, exp: 0}
|
||||
}
|
||||
|
@ -399,7 +403,10 @@ func (d Decimal) Ceil() Decimal {
|
|||
d.ensureInitialized()
|
||||
|
||||
exp := big.NewInt(10)
|
||||
exp.Exp(exp, big.NewInt(int64(-d.exp)), nil)
|
||||
|
||||
// NOTE(vadim): must negate after casting to prevent int32 overflow
|
||||
exp.Exp(exp, big.NewInt(-int64(d.exp)), nil)
|
||||
|
||||
z, m := new(big.Int).DivMod(d.value, exp, new(big.Int))
|
||||
if m.Cmp(zeroInt) != 0 {
|
||||
z.Add(z, oneInt)
|
||||
|
@ -407,6 +414,22 @@ func (d Decimal) Ceil() Decimal {
|
|||
return Decimal{value: z, exp: 0}
|
||||
}
|
||||
|
||||
// Truncate truncates off digits from the number, without rounding.
|
||||
//
|
||||
// NOTE: precision is the last digit that will not be truncated (must be >= 0).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// decimal.NewFromString("123.456").Truncate(2).String() // "123.45"
|
||||
//
|
||||
func (d Decimal) Truncate(precision int32) Decimal {
|
||||
d.ensureInitialized()
|
||||
if precision >= 0 && -precision > d.exp {
|
||||
return d.rescale(-precision)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface.
|
||||
func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error {
|
||||
str, err := unquoteIfQuoted(decimalBytes)
|
||||
|
@ -428,20 +451,6 @@ func (d Decimal) MarshalJSON() ([]byte, error) {
|
|||
return []byte(str), nil
|
||||
}
|
||||
|
||||
// Truncate truncates off digits from the number, without rounding.
|
||||
//
|
||||
// NOTE: precision is the last digit that will not be truncated (should be >= 0)
|
||||
//
|
||||
// decimal.NewFromString("123.456").Truncate(2).String() // "123.45"
|
||||
//
|
||||
func (d Decimal) Truncate(precision int32) Decimal {
|
||||
d.ensureInitialized()
|
||||
if precision >= 0 && -precision > d.exp {
|
||||
return d.rescale(-precision)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface for database deserialization.
|
||||
func (d *Decimal) Scan(value interface{}) error {
|
||||
str, err := unquoteIfQuoted(value)
|
||||
|
@ -493,6 +502,9 @@ func (d Decimal) string(trimTrailingZeros bool) string {
|
|||
str := abs.String()
|
||||
|
||||
var intPart, fractionalPart string
|
||||
|
||||
// NOTE(vadim): this cast to int will cause bugs if d.exp == INT_MIN
|
||||
// and you are on a 32-bit machine. Won't fix this super-edge case.
|
||||
dExpInt := int(d.exp)
|
||||
if len(str) > -dExpInt {
|
||||
intPart = str[:len(str)+dExpInt]
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var testTable = map[float64]string{
|
||||
|
@ -423,9 +424,10 @@ func TestDecimal_Uninitialized(t *testing.T) {
|
|||
a.Sub(b),
|
||||
a.Mul(b),
|
||||
a.Div(New(1, -1)),
|
||||
a.Round(2),
|
||||
a.Floor(),
|
||||
a.Ceil(),
|
||||
a.Round(2),
|
||||
a.Truncate(2),
|
||||
}
|
||||
|
||||
for _, d := range decs {
|
||||
|
@ -446,6 +448,16 @@ func TestDecimal_Uninitialized(t *testing.T) {
|
|||
if a.Exponent() != 0 {
|
||||
t.Errorf("a.Exponent() != 0")
|
||||
}
|
||||
if a.IntPart() != 0 {
|
||||
t.Errorf("a.IntPar() != 0")
|
||||
}
|
||||
f, _ := a.Float64()
|
||||
if f != 0 {
|
||||
t.Errorf("a.Float64() != 0")
|
||||
}
|
||||
if a.Rat().RatString() != "0" {
|
||||
t.Errorf("a.Rat() != 0, got %s", a.Rat().RatString())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecimal_Add(t *testing.T) {
|
||||
|
@ -633,6 +645,52 @@ func TestDecimal_Overflow(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDecimal_ExtremeValues(t *testing.T) {
|
||||
// NOTE(vadim): this test takes pretty much forever
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
// NOTE(vadim): Seriously, the numbers invovled are so large that this
|
||||
// test will take way too long, so mark it as success if it takes over
|
||||
// 1 second. The way this test typically fails (integer overflow) is that
|
||||
// a wrong result appears quickly, so if it takes a long time then it is
|
||||
// probably working properly.
|
||||
// Why even bother testing this? Completeness, I guess. -Vadim
|
||||
const timeLimit = 1 * time.Second
|
||||
test := func(f func()) {
|
||||
c := make(chan bool)
|
||||
go func() {
|
||||
f()
|
||||
close(c)
|
||||
}()
|
||||
select {
|
||||
case <-c:
|
||||
case <-time.After(timeLimit):
|
||||
}
|
||||
}
|
||||
|
||||
test(func() {
|
||||
got := New(123, math.MinInt32).Floor()
|
||||
if !got.Equals(NewFromFloat(0)) {
|
||||
t.Errorf("Error: got %s, expected 0", got)
|
||||
}
|
||||
})
|
||||
test(func() {
|
||||
got := New(123, math.MinInt32).Ceil()
|
||||
if !got.Equals(NewFromFloat(1)) {
|
||||
t.Errorf("Error: got %s, expected 1", got)
|
||||
}
|
||||
})
|
||||
test(func() {
|
||||
got := New(123, math.MinInt32).Rat().FloatString(10)
|
||||
expected := "0.0000000000"
|
||||
if got != expected {
|
||||
t.Errorf("Error: got %s, expected %s", got, expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntPart(t *testing.T) {
|
||||
for _, testCase := range []struct {
|
||||
Dec string
|
||||
|
|
Loading…
Reference in a new issue