add stupid test for extreme values, fix some super-edge-case bugs

This commit is contained in:
Vadim Graboys 2015-06-14 15:39:05 -04:00
parent 54dc68463b
commit 6e97405099
2 changed files with 89 additions and 19 deletions

View file

@ -305,7 +305,8 @@ func (d Decimal) IntPart() int64 {
func (d Decimal) Rat() *big.Rat { func (d Decimal) Rat() *big.Rat {
d.ensureInitialized() d.ensureInitialized()
if d.exp <= 0 { if d.exp <= 0 {
denom := new(big.Int).Exp(tenInt, big.NewInt(int64(-d.exp)), nil) // NOTE(vadim): must negate after casting to prevent int32 overflow
denom := new(big.Int).Exp(tenInt, big.NewInt(-int64(d.exp)), nil)
return new(big.Rat).SetFrac(d.value, denom) return new(big.Rat).SetFrac(d.value, denom)
} else { } else {
mul := new(big.Int).Exp(tenInt, big.NewInt(int64(d.exp)), nil) mul := new(big.Int).Exp(tenInt, big.NewInt(int64(d.exp)), nil)
@ -365,7 +366,7 @@ func (d Decimal) StringFixed(places int32) string {
// //
func (d Decimal) Round(places int32) Decimal { func (d Decimal) Round(places int32) Decimal {
// truncate to places + 1 // truncate to places + 1
ret := d.rescale(-(places + 1)) ret := d.rescale(-places - 1)
// add sign(d) * 0.5 // add sign(d) * 0.5
if ret.value.Sign() < 0 { if ret.value.Sign() < 0 {
@ -389,7 +390,10 @@ func (d Decimal) Floor() Decimal {
d.ensureInitialized() d.ensureInitialized()
exp := big.NewInt(10) exp := big.NewInt(10)
exp.Exp(exp, big.NewInt(int64(-d.exp)), nil)
// NOTE(vadim): must negate after casting to prevent int32 overflow
exp.Exp(exp, big.NewInt(-int64(d.exp)), nil)
z := new(big.Int).Div(d.value, exp) z := new(big.Int).Div(d.value, exp)
return Decimal{value: z, exp: 0} return Decimal{value: z, exp: 0}
} }
@ -399,7 +403,10 @@ func (d Decimal) Ceil() Decimal {
d.ensureInitialized() d.ensureInitialized()
exp := big.NewInt(10) exp := big.NewInt(10)
exp.Exp(exp, big.NewInt(int64(-d.exp)), nil)
// NOTE(vadim): must negate after casting to prevent int32 overflow
exp.Exp(exp, big.NewInt(-int64(d.exp)), nil)
z, m := new(big.Int).DivMod(d.value, exp, new(big.Int)) z, m := new(big.Int).DivMod(d.value, exp, new(big.Int))
if m.Cmp(zeroInt) != 0 { if m.Cmp(zeroInt) != 0 {
z.Add(z, oneInt) z.Add(z, oneInt)
@ -407,6 +414,22 @@ func (d Decimal) Ceil() Decimal {
return Decimal{value: z, exp: 0} return Decimal{value: z, exp: 0}
} }
// Truncate truncates off digits from the number, without rounding.
//
// NOTE: precision is the last digit that will not be truncated (must be >= 0).
//
// Example:
//
// decimal.NewFromString("123.456").Truncate(2).String() // "123.45"
//
func (d Decimal) Truncate(precision int32) Decimal {
d.ensureInitialized()
if precision >= 0 && -precision > d.exp {
return d.rescale(-precision)
}
return d
}
// UnmarshalJSON implements the json.Unmarshaler interface. // UnmarshalJSON implements the json.Unmarshaler interface.
func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error { func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error {
str, err := unquoteIfQuoted(decimalBytes) str, err := unquoteIfQuoted(decimalBytes)
@ -428,20 +451,6 @@ func (d Decimal) MarshalJSON() ([]byte, error) {
return []byte(str), nil return []byte(str), nil
} }
// Truncate truncates off digits from the number, without rounding.
//
// NOTE: precision is the last digit that will not be truncated (should be >= 0)
//
// decimal.NewFromString("123.456").Truncate(2).String() // "123.45"
//
func (d Decimal) Truncate(precision int32) Decimal {
d.ensureInitialized()
if precision >= 0 && -precision > d.exp {
return d.rescale(-precision)
}
return d
}
// Scan implements the sql.Scanner interface for database deserialization. // Scan implements the sql.Scanner interface for database deserialization.
func (d *Decimal) Scan(value interface{}) error { func (d *Decimal) Scan(value interface{}) error {
str, err := unquoteIfQuoted(value) str, err := unquoteIfQuoted(value)
@ -493,6 +502,9 @@ func (d Decimal) string(trimTrailingZeros bool) string {
str := abs.String() str := abs.String()
var intPart, fractionalPart string var intPart, fractionalPart string
// NOTE(vadim): this cast to int will cause bugs if d.exp == INT_MIN
// and you are on a 32-bit machine. Won't fix this super-edge case.
dExpInt := int(d.exp) dExpInt := int(d.exp)
if len(str) > -dExpInt { if len(str) > -dExpInt {
intPart = str[:len(str)+dExpInt] intPart = str[:len(str)+dExpInt]

View file

@ -6,6 +6,7 @@ import (
"math" "math"
"strings" "strings"
"testing" "testing"
"time"
) )
var testTable = map[float64]string{ var testTable = map[float64]string{
@ -423,9 +424,10 @@ func TestDecimal_Uninitialized(t *testing.T) {
a.Sub(b), a.Sub(b),
a.Mul(b), a.Mul(b),
a.Div(New(1, -1)), a.Div(New(1, -1)),
a.Round(2),
a.Floor(), a.Floor(),
a.Ceil(), a.Ceil(),
a.Round(2), a.Truncate(2),
} }
for _, d := range decs { for _, d := range decs {
@ -446,6 +448,16 @@ func TestDecimal_Uninitialized(t *testing.T) {
if a.Exponent() != 0 { if a.Exponent() != 0 {
t.Errorf("a.Exponent() != 0") t.Errorf("a.Exponent() != 0")
} }
if a.IntPart() != 0 {
t.Errorf("a.IntPar() != 0")
}
f, _ := a.Float64()
if f != 0 {
t.Errorf("a.Float64() != 0")
}
if a.Rat().RatString() != "0" {
t.Errorf("a.Rat() != 0, got %s", a.Rat().RatString())
}
} }
func TestDecimal_Add(t *testing.T) { func TestDecimal_Add(t *testing.T) {
@ -633,6 +645,52 @@ func TestDecimal_Overflow(t *testing.T) {
} }
} }
func TestDecimal_ExtremeValues(t *testing.T) {
// NOTE(vadim): this test takes pretty much forever
if testing.Short() {
t.Skip()
}
// NOTE(vadim): Seriously, the numbers invovled are so large that this
// test will take way too long, so mark it as success if it takes over
// 1 second. The way this test typically fails (integer overflow) is that
// a wrong result appears quickly, so if it takes a long time then it is
// probably working properly.
// Why even bother testing this? Completeness, I guess. -Vadim
const timeLimit = 1 * time.Second
test := func(f func()) {
c := make(chan bool)
go func() {
f()
close(c)
}()
select {
case <-c:
case <-time.After(timeLimit):
}
}
test(func() {
got := New(123, math.MinInt32).Floor()
if !got.Equals(NewFromFloat(0)) {
t.Errorf("Error: got %s, expected 0", got)
}
})
test(func() {
got := New(123, math.MinInt32).Ceil()
if !got.Equals(NewFromFloat(1)) {
t.Errorf("Error: got %s, expected 1", got)
}
})
test(func() {
got := New(123, math.MinInt32).Rat().FloatString(10)
expected := "0.0000000000"
if got != expected {
t.Errorf("Error: got %s, expected %s", got, expected)
}
})
}
func TestIntPart(t *testing.T) { func TestIntPart(t *testing.T) {
for _, testCase := range []struct { for _, testCase := range []struct {
Dec string Dec string