From 2faaec71dd0f89356bfacc1d5e55822777193402 Mon Sep 17 00:00:00 2001 From: Connor Peet Date: Thu, 6 Oct 2016 09:30:46 -0700 Subject: [PATCH] Add unbiased rounding algorithm --- decimal.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++++ decimal_test.go | 31 ++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/decimal.go b/decimal.go index 91f9f6c..9e2d9a4 100644 --- a/decimal.go +++ b/decimal.go @@ -56,6 +56,7 @@ var Zero = New(0, 1) var zeroInt = big.NewInt(0) var oneInt = big.NewInt(1) +var twoInt = big.NewInt(2) var fiveInt = big.NewInt(5) var tenInt = big.NewInt(10) @@ -422,6 +423,51 @@ func (d Decimal) Round(places int32) Decimal { return ret } +// RoundFair rounds the decimal to places decimal places. +// If places < 0, it will round the integer part to the nearest 10^(-places). +// +// Unlike Round(), in the case of a a tie where the resulting decimal place +// equals 0.5, this function will round up for odd numbers and down for +// even numbers. Negative values are treated symmetrically. +// +// Example: +// +// NewFromFloat(5.5).Round(0).String() // output: "6" +// NewFromFloat(8.5).Round(0).String() // output: "8" +// +func (d Decimal) RoundFair(places int32) Decimal { + shift := big.NewInt(abs(int64(places)-int64(d.exp)) - 1) + + // First, truncate the number to see if there are trailing decimal places. + // If there are it can't end in 5. + exp := new(big.Int).Exp(tenInt, shift, zeroInt) + rounded := new(big.Int).Quo(d.value, exp) + + tmp := new(big.Int) + if tmp.Mul(rounded, exp).Cmp(d.value) != 0 { + return d.Round(places) + } + + // If the last digit of the number isn't five, then do normal division. + if tmp.Mod(rounded, tenInt).Cmp(fiveInt) != 0 { + return d.Round(places) + } + + ret := Decimal{ + value: rounded.Quo(rounded, tenInt), + exp: -places, + } + + odd := new(big.Int).Mod(ret.value, twoInt).Cmp(zeroInt) == 1 + if odd && ret.value.Sign() >= 0 { + ret.value.Add(ret.value, oneInt) + } else if odd && ret.value.Sign() < 0 { + ret.value.Sub(ret.value, oneInt) + } + + return ret +} + // Floor returns the nearest integer value less than or equal to d. func (d Decimal) Floor() Decimal { d.ensureInitialized() @@ -639,6 +685,14 @@ func Max(first Decimal, rest ...Decimal) Decimal { return ans } +func abs(x int64) int64 { + if x < 0 { + return -x + } + + return x +} + func min(x, y int32) int32 { if x >= y { return y diff --git a/decimal_test.go b/decimal_test.go index b0b2088..dfaa054 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -3,6 +3,7 @@ package decimal import ( "encoding/json" "encoding/xml" + "fmt" "math" "sort" "strconv" @@ -783,6 +784,36 @@ func TestDecimal_ExtremeValues(t *testing.T) { } }) } +func TestRoundFair(t *testing.T) { + testCases := []struct { + original float64 + rounded float64 + places int32 + }{ + {-5.5, -6, 0}, + {-6.5, -6, 0}, + {0.5, 0, 0}, + {-0.5, 0, 0}, + {5.5, 6, 0}, + {6.5, 6, 0}, + {6.51, 7, 0}, + + {650, 650, 0}, + {650, 650, -1}, + {650, 600, -2}, + {550, 600, -2}, + } + + for _, test := range testCases { + t.Run(fmt.Sprintf("(%.1f).RoundFair(%d)==%.0f", test.original, test.places, test.rounded), func(t *testing.T) { + got, _ := NewFromFloat(test.original).RoundFair(test.places).Float64() + expected := test.rounded + if got != expected { + t.Errorf("Error: got %.2f, expected %.2f", got, expected) + } + }) + } +} func TestIntPart(t *testing.T) { for _, testCase := range []struct {