Add unbiased rounding algorithm

This commit is contained in:
Connor Peet 2016-10-06 09:30:46 -07:00
parent d6f52241f3
commit 2faaec71dd
No known key found for this signature in database
GPG key ID: CF8FD2EA0DBC61BD
2 changed files with 85 additions and 0 deletions

View file

@ -56,6 +56,7 @@ var Zero = New(0, 1)
var zeroInt = big.NewInt(0) var zeroInt = big.NewInt(0)
var oneInt = big.NewInt(1) var oneInt = big.NewInt(1)
var twoInt = big.NewInt(2)
var fiveInt = big.NewInt(5) var fiveInt = big.NewInt(5)
var tenInt = big.NewInt(10) var tenInt = big.NewInt(10)
@ -422,6 +423,51 @@ func (d Decimal) Round(places int32) Decimal {
return ret return ret
} }
// RoundFair rounds the decimal to places decimal places.
// If places < 0, it will round the integer part to the nearest 10^(-places).
//
// Unlike Round(), in the case of a a tie where the resulting decimal place
// equals 0.5, this function will round up for odd numbers and down for
// even numbers. Negative values are treated symmetrically.
//
// Example:
//
// NewFromFloat(5.5).Round(0).String() // output: "6"
// NewFromFloat(8.5).Round(0).String() // output: "8"
//
func (d Decimal) RoundFair(places int32) Decimal {
shift := big.NewInt(abs(int64(places)-int64(d.exp)) - 1)
// First, truncate the number to see if there are trailing decimal places.
// If there are it can't end in 5.
exp := new(big.Int).Exp(tenInt, shift, zeroInt)
rounded := new(big.Int).Quo(d.value, exp)
tmp := new(big.Int)
if tmp.Mul(rounded, exp).Cmp(d.value) != 0 {
return d.Round(places)
}
// If the last digit of the number isn't five, then do normal division.
if tmp.Mod(rounded, tenInt).Cmp(fiveInt) != 0 {
return d.Round(places)
}
ret := Decimal{
value: rounded.Quo(rounded, tenInt),
exp: -places,
}
odd := new(big.Int).Mod(ret.value, twoInt).Cmp(zeroInt) == 1
if odd && ret.value.Sign() >= 0 {
ret.value.Add(ret.value, oneInt)
} else if odd && ret.value.Sign() < 0 {
ret.value.Sub(ret.value, oneInt)
}
return ret
}
// Floor returns the nearest integer value less than or equal to d. // Floor returns the nearest integer value less than or equal to d.
func (d Decimal) Floor() Decimal { func (d Decimal) Floor() Decimal {
d.ensureInitialized() d.ensureInitialized()
@ -639,6 +685,14 @@ func Max(first Decimal, rest ...Decimal) Decimal {
return ans return ans
} }
func abs(x int64) int64 {
if x < 0 {
return -x
}
return x
}
func min(x, y int32) int32 { func min(x, y int32) int32 {
if x >= y { if x >= y {
return y return y

View file

@ -3,6 +3,7 @@ package decimal
import ( import (
"encoding/json" "encoding/json"
"encoding/xml" "encoding/xml"
"fmt"
"math" "math"
"sort" "sort"
"strconv" "strconv"
@ -783,6 +784,36 @@ func TestDecimal_ExtremeValues(t *testing.T) {
} }
}) })
} }
func TestRoundFair(t *testing.T) {
testCases := []struct {
original float64
rounded float64
places int32
}{
{-5.5, -6, 0},
{-6.5, -6, 0},
{0.5, 0, 0},
{-0.5, 0, 0},
{5.5, 6, 0},
{6.5, 6, 0},
{6.51, 7, 0},
{650, 650, 0},
{650, 650, -1},
{650, 600, -2},
{550, 600, -2},
}
for _, test := range testCases {
t.Run(fmt.Sprintf("(%.1f).RoundFair(%d)==%.0f", test.original, test.places, test.rounded), func(t *testing.T) {
got, _ := NewFromFloat(test.original).RoundFair(test.places).Float64()
expected := test.rounded
if got != expected {
t.Errorf("Error: got %.2f, expected %.2f", got, expected)
}
})
}
}
func TestIntPart(t *testing.T) { func TestIntPart(t *testing.T) {
for _, testCase := range []struct { for _, testCase := range []struct {