mirror of
https://github.com/shopspring/decimal.git
synced 2024-11-23 04:40:49 +01:00
Add unbiased rounding algorithm
This commit is contained in:
parent
d6f52241f3
commit
2faaec71dd
2 changed files with 85 additions and 0 deletions
54
decimal.go
54
decimal.go
|
@ -56,6 +56,7 @@ var Zero = New(0, 1)
|
|||
|
||||
var zeroInt = big.NewInt(0)
|
||||
var oneInt = big.NewInt(1)
|
||||
var twoInt = big.NewInt(2)
|
||||
var fiveInt = big.NewInt(5)
|
||||
var tenInt = big.NewInt(10)
|
||||
|
||||
|
@ -422,6 +423,51 @@ func (d Decimal) Round(places int32) Decimal {
|
|||
return ret
|
||||
}
|
||||
|
||||
// RoundFair rounds the decimal to places decimal places.
|
||||
// If places < 0, it will round the integer part to the nearest 10^(-places).
|
||||
//
|
||||
// Unlike Round(), in the case of a a tie where the resulting decimal place
|
||||
// equals 0.5, this function will round up for odd numbers and down for
|
||||
// even numbers. Negative values are treated symmetrically.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// NewFromFloat(5.5).Round(0).String() // output: "6"
|
||||
// NewFromFloat(8.5).Round(0).String() // output: "8"
|
||||
//
|
||||
func (d Decimal) RoundFair(places int32) Decimal {
|
||||
shift := big.NewInt(abs(int64(places)-int64(d.exp)) - 1)
|
||||
|
||||
// First, truncate the number to see if there are trailing decimal places.
|
||||
// If there are it can't end in 5.
|
||||
exp := new(big.Int).Exp(tenInt, shift, zeroInt)
|
||||
rounded := new(big.Int).Quo(d.value, exp)
|
||||
|
||||
tmp := new(big.Int)
|
||||
if tmp.Mul(rounded, exp).Cmp(d.value) != 0 {
|
||||
return d.Round(places)
|
||||
}
|
||||
|
||||
// If the last digit of the number isn't five, then do normal division.
|
||||
if tmp.Mod(rounded, tenInt).Cmp(fiveInt) != 0 {
|
||||
return d.Round(places)
|
||||
}
|
||||
|
||||
ret := Decimal{
|
||||
value: rounded.Quo(rounded, tenInt),
|
||||
exp: -places,
|
||||
}
|
||||
|
||||
odd := new(big.Int).Mod(ret.value, twoInt).Cmp(zeroInt) == 1
|
||||
if odd && ret.value.Sign() >= 0 {
|
||||
ret.value.Add(ret.value, oneInt)
|
||||
} else if odd && ret.value.Sign() < 0 {
|
||||
ret.value.Sub(ret.value, oneInt)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Floor returns the nearest integer value less than or equal to d.
|
||||
func (d Decimal) Floor() Decimal {
|
||||
d.ensureInitialized()
|
||||
|
@ -639,6 +685,14 @@ func Max(first Decimal, rest ...Decimal) Decimal {
|
|||
return ans
|
||||
}
|
||||
|
||||
func abs(x int64) int64 {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
func min(x, y int32) int32 {
|
||||
if x >= y {
|
||||
return y
|
||||
|
|
|
@ -3,6 +3,7 @@ package decimal
|
|||
import (
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
@ -783,6 +784,36 @@ func TestDecimal_ExtremeValues(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
func TestRoundFair(t *testing.T) {
|
||||
testCases := []struct {
|
||||
original float64
|
||||
rounded float64
|
||||
places int32
|
||||
}{
|
||||
{-5.5, -6, 0},
|
||||
{-6.5, -6, 0},
|
||||
{0.5, 0, 0},
|
||||
{-0.5, 0, 0},
|
||||
{5.5, 6, 0},
|
||||
{6.5, 6, 0},
|
||||
{6.51, 7, 0},
|
||||
|
||||
{650, 650, 0},
|
||||
{650, 650, -1},
|
||||
{650, 600, -2},
|
||||
{550, 600, -2},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(fmt.Sprintf("(%.1f).RoundFair(%d)==%.0f", test.original, test.places, test.rounded), func(t *testing.T) {
|
||||
got, _ := NewFromFloat(test.original).RoundFair(test.places).Float64()
|
||||
expected := test.rounded
|
||||
if got != expected {
|
||||
t.Errorf("Error: got %.2f, expected %.2f", got, expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntPart(t *testing.T) {
|
||||
for _, testCase := range []struct {
|
||||
|
|
Loading…
Reference in a new issue