From c2e758a936e6854ef126ece600ab9b5f4e5b397f Mon Sep 17 00:00:00 2001 From: Brian P Date: Sun, 13 Oct 2019 13:32:12 -0600 Subject: [PATCH] add Root and RootRound methods --- decimal.go | 137 ++++++++++++++++++++++++++++++++++++++++++++++++ decimal_test.go | 123 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+) diff --git a/decimal.go b/decimal.go index b23d053..ab1bdb7 100644 --- a/decimal.go +++ b/decimal.go @@ -44,6 +44,10 @@ import ( // var DivisionPrecision = 16 +// RootPrecision is the number of decimal places in the result from the Root +// method. +var RootPrecision = 16 + // MarshalJSONWithoutQuotes should be set to true if you want the decimal to // be JSON marshaled as a number, instead of as a string. // WARNING: this is dangerous for decimals with many digits, since many JSON @@ -55,9 +59,15 @@ var MarshalJSONWithoutQuotes = false // Zero constant, to make computations faster. var Zero = New(0, 1) +// oneDec used for incrementing/decrementing +var oneDec = New(1, 0) + // fiveDec used in Cash Rounding var fiveDec = New(5, 0) +// used for shifting digits +var tenDec = New(10, 0) + var zeroInt = big.NewInt(0) var oneInt = big.NewInt(1) var twoInt = big.NewInt(2) @@ -557,6 +567,133 @@ func (d Decimal) Pow(d2 Decimal) Decimal { return temp.Mul(temp).Div(d) } +// factorial assumes d is a positive whole number +func (d Decimal) factorial() Decimal { + out := oneDec + for i := New(2, 0); i.LessThanOrEqual(d); i = i.Add(oneDec) { + out = out.Mul(i) + } + return out +} + +func (d Decimal) combination(r Decimal) Decimal { + top := d.factorial() + bottom := r.factorial().Mul(d.Sub(r).factorial()) + return top.Div(bottom) +} + +func reverseDecimalSlice(slice []Decimal) { + for i := len(slice)/2 - 1; i >= 0; i-- { + opp := len(slice) - 1 - i + slice[i], slice[opp] = slice[opp], slice[i] + } +} + +// returns groups of n digits from the original Decimal, centered and split on +// the position of the decimal place, for purpose of computing the nth root via +// the shifting nth root algo. +// +// 0.303 (n:2) -> [], [30, 30] +// 1000 (n:2) -> [10, 00], [] +// 2310.0241 (n:3) -> [2, 310], [024, 100] +func (d Decimal) rootDigitGroups(n Decimal) (left, right []Decimal) { + e := tenDec.Pow(n) + + dLeft := d.rescale(0) + for { + if dLeft.Equal(Zero) { + break + } + next := dLeft.Div(e).rescale(0) + left = append(left, dLeft.Sub(next.Mul(e))) + dLeft = next + } + reverseDecimalSlice(left) + + dRight := d.Sub(d.rescale(0)) + for { + if dRight.Equal(Zero) { + break + } + dRightE := dRight.Mul(e) + dRightETrunc := dRightE.rescale(0) + right = append(right, dRightETrunc) + dRight = dRightE.Sub(dRightETrunc) + } + + return left, right +} + +// Root returns the nth root of d. +func (d Decimal) Root(n Decimal) Decimal { + return d.RootRound(n, int32(RootPrecision)) +} + +// RootRound returns the nth root of d. If the result is not a whole number then +// the given precision determines the number of decimal places to calculate the +// result to. +func (d Decimal) RootRound(n Decimal, prec int32) Decimal { + // The nth root is calculated via the shifting root algorithm, which was + // chosen for being definitely correct up to an arbitrary precision, and not + // because it's particularly fast. + // + // https://en.wikipedia.org/wiki/Shifting_nth_root_algorithm + // https://www.wikihow.com/Find-Nth-Roots-by-Hand + + leftGroups, rightGroups := d.rootDigitGroups(n) + numLeftGroups := len(leftGroups) + + answer, answerNoDecimal, target := Zero, Zero, Zero + var numDigits int32 + for { + if numDigits-int32(numLeftGroups) >= prec { + break + } + + group := Zero + if len(leftGroups) > 0 { + group, leftGroups = leftGroups[0], leftGroups[1:] + } else if len(rightGroups) > 0 { + group, rightGroups = rightGroups[0], rightGroups[1:] + } + target = target.Mul(tenDec.Pow(n)).Add(group) + + nextDigit, nextSub := oneDec, Zero + for ; nextDigit.LessThan(tenDec); nextDigit = nextDigit.Add(oneDec) { + tryNextSub := Zero + for i := Zero; i.LessThan(n); i = i.Add(oneDec) { + sumStep := n.combination(i.Add(oneDec)) + sumStep = sumStep.Mul(nextDigit.Pow(i)) + sumStep = sumStep.Mul(answerNoDecimal.Mul(tenDec).Pow(n.Sub(i).Sub(oneDec))) + tryNextSub = tryNextSub.Add(sumStep) + } + tryNextSub = tryNextSub.Mul(nextDigit) + if tryNextSub.GreaterThan(target) { + break + } + nextSub = tryNextSub + } + nextDigit = nextDigit.Sub(oneDec) + + answerNoDecimal = answerNoDecimal.Mul(tenDec).Add(nextDigit) + if numDigits < int32(numLeftGroups) { + answer = answerNoDecimal + } else { + shift := tenDec.Pow(New(int64(numDigits)-int64(numLeftGroups)+1, 0)) + answer = answer.Add(nextDigit.DivRound(shift, prec)) + } + + target = target.Sub(nextSub) + if target.Equal(Zero) && len(leftGroups) == 0 && len(rightGroups) == 0 { + break + } + + numDigits++ + } + + return answer +} + // Cmp compares the numbers represented by d and d2 and returns: // // -1 if d < d2 diff --git a/decimal_test.go b/decimal_test.go index 64f0552..df16013 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -2094,6 +2094,129 @@ func TestNegativePow(t *testing.T) { } } +func TestFactorial(t *testing.T) { + tests := [][2]string{ + {"0", "1"}, + {"1", "1"}, + {"2", "2"}, + {"3", "6"}, + {"4", "24"}, + } + + for _, test := range tests { + in, out := RequireFromString(test[0]), RequireFromString(test[1]) + if f := in.factorial(); !f.Equal(out) { + t.Errorf("!%v should be %v, got %v", in, out, f) + } + } +} + +func TestCombination(t *testing.T) { + tests := [][3]string{ + {"4", "1", "4"}, + {"4", "2", "6"}, + {"4", "3", "4"}, + {"4", "4", "1"}, + } + + for _, test := range tests { + n, r, out := RequireFromString(test[0]), RequireFromString(test[1]), RequireFromString(test[2]) + if c := n.combination(r); !c.Equal(out) { + t.Errorf("C(%v,%v) should be %v, got %v", n, r, out, c) + } + } +} + +func TestRootDigitGroups(t *testing.T) { + rng := rand.New(rand.NewSource(0xdead1337)) + for i := 0; i < 5e4; i++ { + dIV, dIE := rng.Int63n(1e7), rng.Int31n(10) + if rng.Intn(2) == 0 { + dIE = -dIE + } + + d := New(dIV, dIE) + n := New(rng.Int63n(3)+1, 0) + nI := int(n.IntPart()) + dStr := d.String() + dStrSplit := strings.Split(dStr, ".") + + var left, right string + left = dStrSplit[0] + if len(dStrSplit) == 2 { + right = dStrSplit[1] + } + + var expLeft, expRight []Decimal + for left != "" && left != "0" { + if len(left) < nI { + expLeft = append(expLeft, RequireFromString(left)) + left = "" + } else { + part := left[len(left)-nI:] + expLeft = append(expLeft, RequireFromString(part)) + left = left[:len(left)-nI] + } + } + reverseDecimalSlice(expLeft) + + for right != "" { + if len(right) < nI { + right += strings.Repeat("0", nI-len(right)) + expRight = append(expRight, RequireFromString(right)) + right = "" + } else { + part := right[:nI] + expRight = append(expRight, RequireFromString(part)) + right = right[nI:] + } + } + + gotLeft, gotRight := d.rootDigitGroups(n) + fatalStr := "(%v).rootDigitGroups(%v)\nexpLeft:%v\ngotLeft:%v\nexpRight:%v\ngotRight:%v" + fatalArgs := []interface{}{d, n, expLeft, gotLeft, expRight, gotRight} + + if len(expLeft) != len(gotLeft) || len(expRight) != len(gotRight) { + t.Fatalf(fatalStr, fatalArgs...) + } + + for j := range expLeft { + if !expLeft[j].Equal(gotLeft[j]) { + t.Fatalf(fatalStr, fatalArgs...) + } + } + + for j := range expRight { + if !expRight[j].Equal(gotRight[j]) { + t.Fatalf(fatalStr, fatalArgs...) + } + } + } +} + +func TestRoot(t *testing.T) { + rng := rand.New(rand.NewSource(0xdead1337)) + for i := 0; i < 2e3; i++ { + rootIV, rootIE := rng.Int63n(1e7), rng.Int31n(10) + if rng.Intn(2) == 0 { + rootIE = -rootIE + } + + root := New(rootIV, rootIE) + n := New(rng.Int63n(3)+2, 0) // TODO +1, not +2 + + d := root + for i := int64(1); i < n.IntPart(); i++ { + d = d.Mul(root) + } + + gotRoot := d.RootRound(n, 32) + if !strings.HasPrefix(gotRoot.String(), root.String()) { + t.Fatalf("%v root of %v\nexpected:%v\n got:%v", n, d, root, gotRoot) + } + } +} + func TestDecimal_Sign(t *testing.T) { if Zero.Sign() != 0 { t.Errorf("%q should have sign 0", Zero)