mirror of
https://github.com/shopspring/decimal.git
synced 2024-11-23 04:40:49 +01:00
Add unbiased rounding algorithm
This commit is contained in:
parent
d6f52241f3
commit
2faaec71dd
2 changed files with 85 additions and 0 deletions
54
decimal.go
54
decimal.go
|
@ -56,6 +56,7 @@ var Zero = New(0, 1)
|
||||||
|
|
||||||
var zeroInt = big.NewInt(0)
|
var zeroInt = big.NewInt(0)
|
||||||
var oneInt = big.NewInt(1)
|
var oneInt = big.NewInt(1)
|
||||||
|
var twoInt = big.NewInt(2)
|
||||||
var fiveInt = big.NewInt(5)
|
var fiveInt = big.NewInt(5)
|
||||||
var tenInt = big.NewInt(10)
|
var tenInt = big.NewInt(10)
|
||||||
|
|
||||||
|
@ -422,6 +423,51 @@ func (d Decimal) Round(places int32) Decimal {
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RoundFair rounds the decimal to places decimal places.
|
||||||
|
// If places < 0, it will round the integer part to the nearest 10^(-places).
|
||||||
|
//
|
||||||
|
// Unlike Round(), in the case of a a tie where the resulting decimal place
|
||||||
|
// equals 0.5, this function will round up for odd numbers and down for
|
||||||
|
// even numbers. Negative values are treated symmetrically.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// NewFromFloat(5.5).Round(0).String() // output: "6"
|
||||||
|
// NewFromFloat(8.5).Round(0).String() // output: "8"
|
||||||
|
//
|
||||||
|
func (d Decimal) RoundFair(places int32) Decimal {
|
||||||
|
shift := big.NewInt(abs(int64(places)-int64(d.exp)) - 1)
|
||||||
|
|
||||||
|
// First, truncate the number to see if there are trailing decimal places.
|
||||||
|
// If there are it can't end in 5.
|
||||||
|
exp := new(big.Int).Exp(tenInt, shift, zeroInt)
|
||||||
|
rounded := new(big.Int).Quo(d.value, exp)
|
||||||
|
|
||||||
|
tmp := new(big.Int)
|
||||||
|
if tmp.Mul(rounded, exp).Cmp(d.value) != 0 {
|
||||||
|
return d.Round(places)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the last digit of the number isn't five, then do normal division.
|
||||||
|
if tmp.Mod(rounded, tenInt).Cmp(fiveInt) != 0 {
|
||||||
|
return d.Round(places)
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := Decimal{
|
||||||
|
value: rounded.Quo(rounded, tenInt),
|
||||||
|
exp: -places,
|
||||||
|
}
|
||||||
|
|
||||||
|
odd := new(big.Int).Mod(ret.value, twoInt).Cmp(zeroInt) == 1
|
||||||
|
if odd && ret.value.Sign() >= 0 {
|
||||||
|
ret.value.Add(ret.value, oneInt)
|
||||||
|
} else if odd && ret.value.Sign() < 0 {
|
||||||
|
ret.value.Sub(ret.value, oneInt)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
// Floor returns the nearest integer value less than or equal to d.
|
// Floor returns the nearest integer value less than or equal to d.
|
||||||
func (d Decimal) Floor() Decimal {
|
func (d Decimal) Floor() Decimal {
|
||||||
d.ensureInitialized()
|
d.ensureInitialized()
|
||||||
|
@ -639,6 +685,14 @@ func Max(first Decimal, rest ...Decimal) Decimal {
|
||||||
return ans
|
return ans
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func abs(x int64) int64 {
|
||||||
|
if x < 0 {
|
||||||
|
return -x
|
||||||
|
}
|
||||||
|
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
func min(x, y int32) int32 {
|
func min(x, y int32) int32 {
|
||||||
if x >= y {
|
if x >= y {
|
||||||
return y
|
return y
|
||||||
|
|
|
@ -3,6 +3,7 @@ package decimal
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -783,6 +784,36 @@ func TestDecimal_ExtremeValues(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
func TestRoundFair(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
original float64
|
||||||
|
rounded float64
|
||||||
|
places int32
|
||||||
|
}{
|
||||||
|
{-5.5, -6, 0},
|
||||||
|
{-6.5, -6, 0},
|
||||||
|
{0.5, 0, 0},
|
||||||
|
{-0.5, 0, 0},
|
||||||
|
{5.5, 6, 0},
|
||||||
|
{6.5, 6, 0},
|
||||||
|
{6.51, 7, 0},
|
||||||
|
|
||||||
|
{650, 650, 0},
|
||||||
|
{650, 650, -1},
|
||||||
|
{650, 600, -2},
|
||||||
|
{550, 600, -2},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(fmt.Sprintf("(%.1f).RoundFair(%d)==%.0f", test.original, test.places, test.rounded), func(t *testing.T) {
|
||||||
|
got, _ := NewFromFloat(test.original).RoundFair(test.places).Float64()
|
||||||
|
expected := test.rounded
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("Error: got %.2f, expected %.2f", got, expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestIntPart(t *testing.T) {
|
func TestIntPart(t *testing.T) {
|
||||||
for _, testCase := range []struct {
|
for _, testCase := range []struct {
|
||||||
|
|
Loading…
Reference in a new issue