mirror of
https://github.com/shopspring/decimal.git
synced 2024-11-23 04:40:49 +01:00
add Root and RootRound methods
This commit is contained in:
parent
0ea7e08d01
commit
c2e758a936
2 changed files with 260 additions and 0 deletions
137
decimal.go
137
decimal.go
|
@ -44,6 +44,10 @@ import (
|
|||
//
|
||||
var DivisionPrecision = 16
|
||||
|
||||
// RootPrecision is the number of decimal places in the result from the Root
|
||||
// method.
|
||||
var RootPrecision = 16
|
||||
|
||||
// MarshalJSONWithoutQuotes should be set to true if you want the decimal to
|
||||
// be JSON marshaled as a number, instead of as a string.
|
||||
// WARNING: this is dangerous for decimals with many digits, since many JSON
|
||||
|
@ -55,9 +59,15 @@ var MarshalJSONWithoutQuotes = false
|
|||
// Zero constant, to make computations faster.
|
||||
var Zero = New(0, 1)
|
||||
|
||||
// oneDec used for incrementing/decrementing
|
||||
var oneDec = New(1, 0)
|
||||
|
||||
// fiveDec used in Cash Rounding
|
||||
var fiveDec = New(5, 0)
|
||||
|
||||
// used for shifting digits
|
||||
var tenDec = New(10, 0)
|
||||
|
||||
var zeroInt = big.NewInt(0)
|
||||
var oneInt = big.NewInt(1)
|
||||
var twoInt = big.NewInt(2)
|
||||
|
@ -557,6 +567,133 @@ func (d Decimal) Pow(d2 Decimal) Decimal {
|
|||
return temp.Mul(temp).Div(d)
|
||||
}
|
||||
|
||||
// factorial assumes d is a positive whole number
|
||||
func (d Decimal) factorial() Decimal {
|
||||
out := oneDec
|
||||
for i := New(2, 0); i.LessThanOrEqual(d); i = i.Add(oneDec) {
|
||||
out = out.Mul(i)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (d Decimal) combination(r Decimal) Decimal {
|
||||
top := d.factorial()
|
||||
bottom := r.factorial().Mul(d.Sub(r).factorial())
|
||||
return top.Div(bottom)
|
||||
}
|
||||
|
||||
func reverseDecimalSlice(slice []Decimal) {
|
||||
for i := len(slice)/2 - 1; i >= 0; i-- {
|
||||
opp := len(slice) - 1 - i
|
||||
slice[i], slice[opp] = slice[opp], slice[i]
|
||||
}
|
||||
}
|
||||
|
||||
// returns groups of n digits from the original Decimal, centered and split on
|
||||
// the position of the decimal place, for purpose of computing the nth root via
|
||||
// the shifting nth root algo.
|
||||
//
|
||||
// 0.303 (n:2) -> [], [30, 30]
|
||||
// 1000 (n:2) -> [10, 00], []
|
||||
// 2310.0241 (n:3) -> [2, 310], [024, 100]
|
||||
func (d Decimal) rootDigitGroups(n Decimal) (left, right []Decimal) {
|
||||
e := tenDec.Pow(n)
|
||||
|
||||
dLeft := d.rescale(0)
|
||||
for {
|
||||
if dLeft.Equal(Zero) {
|
||||
break
|
||||
}
|
||||
next := dLeft.Div(e).rescale(0)
|
||||
left = append(left, dLeft.Sub(next.Mul(e)))
|
||||
dLeft = next
|
||||
}
|
||||
reverseDecimalSlice(left)
|
||||
|
||||
dRight := d.Sub(d.rescale(0))
|
||||
for {
|
||||
if dRight.Equal(Zero) {
|
||||
break
|
||||
}
|
||||
dRightE := dRight.Mul(e)
|
||||
dRightETrunc := dRightE.rescale(0)
|
||||
right = append(right, dRightETrunc)
|
||||
dRight = dRightE.Sub(dRightETrunc)
|
||||
}
|
||||
|
||||
return left, right
|
||||
}
|
||||
|
||||
// Root returns the nth root of d.
|
||||
func (d Decimal) Root(n Decimal) Decimal {
|
||||
return d.RootRound(n, int32(RootPrecision))
|
||||
}
|
||||
|
||||
// RootRound returns the nth root of d. If the result is not a whole number then
|
||||
// the given precision determines the number of decimal places to calculate the
|
||||
// result to.
|
||||
func (d Decimal) RootRound(n Decimal, prec int32) Decimal {
|
||||
// The nth root is calculated via the shifting root algorithm, which was
|
||||
// chosen for being definitely correct up to an arbitrary precision, and not
|
||||
// because it's particularly fast.
|
||||
//
|
||||
// https://en.wikipedia.org/wiki/Shifting_nth_root_algorithm
|
||||
// https://www.wikihow.com/Find-Nth-Roots-by-Hand
|
||||
|
||||
leftGroups, rightGroups := d.rootDigitGroups(n)
|
||||
numLeftGroups := len(leftGroups)
|
||||
|
||||
answer, answerNoDecimal, target := Zero, Zero, Zero
|
||||
var numDigits int32
|
||||
for {
|
||||
if numDigits-int32(numLeftGroups) >= prec {
|
||||
break
|
||||
}
|
||||
|
||||
group := Zero
|
||||
if len(leftGroups) > 0 {
|
||||
group, leftGroups = leftGroups[0], leftGroups[1:]
|
||||
} else if len(rightGroups) > 0 {
|
||||
group, rightGroups = rightGroups[0], rightGroups[1:]
|
||||
}
|
||||
target = target.Mul(tenDec.Pow(n)).Add(group)
|
||||
|
||||
nextDigit, nextSub := oneDec, Zero
|
||||
for ; nextDigit.LessThan(tenDec); nextDigit = nextDigit.Add(oneDec) {
|
||||
tryNextSub := Zero
|
||||
for i := Zero; i.LessThan(n); i = i.Add(oneDec) {
|
||||
sumStep := n.combination(i.Add(oneDec))
|
||||
sumStep = sumStep.Mul(nextDigit.Pow(i))
|
||||
sumStep = sumStep.Mul(answerNoDecimal.Mul(tenDec).Pow(n.Sub(i).Sub(oneDec)))
|
||||
tryNextSub = tryNextSub.Add(sumStep)
|
||||
}
|
||||
tryNextSub = tryNextSub.Mul(nextDigit)
|
||||
if tryNextSub.GreaterThan(target) {
|
||||
break
|
||||
}
|
||||
nextSub = tryNextSub
|
||||
}
|
||||
nextDigit = nextDigit.Sub(oneDec)
|
||||
|
||||
answerNoDecimal = answerNoDecimal.Mul(tenDec).Add(nextDigit)
|
||||
if numDigits < int32(numLeftGroups) {
|
||||
answer = answerNoDecimal
|
||||
} else {
|
||||
shift := tenDec.Pow(New(int64(numDigits)-int64(numLeftGroups)+1, 0))
|
||||
answer = answer.Add(nextDigit.DivRound(shift, prec))
|
||||
}
|
||||
|
||||
target = target.Sub(nextSub)
|
||||
if target.Equal(Zero) && len(leftGroups) == 0 && len(rightGroups) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
numDigits++
|
||||
}
|
||||
|
||||
return answer
|
||||
}
|
||||
|
||||
// Cmp compares the numbers represented by d and d2 and returns:
|
||||
//
|
||||
// -1 if d < d2
|
||||
|
|
123
decimal_test.go
123
decimal_test.go
|
@ -2094,6 +2094,129 @@ func TestNegativePow(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestFactorial(t *testing.T) {
|
||||
tests := [][2]string{
|
||||
{"0", "1"},
|
||||
{"1", "1"},
|
||||
{"2", "2"},
|
||||
{"3", "6"},
|
||||
{"4", "24"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
in, out := RequireFromString(test[0]), RequireFromString(test[1])
|
||||
if f := in.factorial(); !f.Equal(out) {
|
||||
t.Errorf("!%v should be %v, got %v", in, out, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombination(t *testing.T) {
|
||||
tests := [][3]string{
|
||||
{"4", "1", "4"},
|
||||
{"4", "2", "6"},
|
||||
{"4", "3", "4"},
|
||||
{"4", "4", "1"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
n, r, out := RequireFromString(test[0]), RequireFromString(test[1]), RequireFromString(test[2])
|
||||
if c := n.combination(r); !c.Equal(out) {
|
||||
t.Errorf("C(%v,%v) should be %v, got %v", n, r, out, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRootDigitGroups(t *testing.T) {
|
||||
rng := rand.New(rand.NewSource(0xdead1337))
|
||||
for i := 0; i < 5e4; i++ {
|
||||
dIV, dIE := rng.Int63n(1e7), rng.Int31n(10)
|
||||
if rng.Intn(2) == 0 {
|
||||
dIE = -dIE
|
||||
}
|
||||
|
||||
d := New(dIV, dIE)
|
||||
n := New(rng.Int63n(3)+1, 0)
|
||||
nI := int(n.IntPart())
|
||||
dStr := d.String()
|
||||
dStrSplit := strings.Split(dStr, ".")
|
||||
|
||||
var left, right string
|
||||
left = dStrSplit[0]
|
||||
if len(dStrSplit) == 2 {
|
||||
right = dStrSplit[1]
|
||||
}
|
||||
|
||||
var expLeft, expRight []Decimal
|
||||
for left != "" && left != "0" {
|
||||
if len(left) < nI {
|
||||
expLeft = append(expLeft, RequireFromString(left))
|
||||
left = ""
|
||||
} else {
|
||||
part := left[len(left)-nI:]
|
||||
expLeft = append(expLeft, RequireFromString(part))
|
||||
left = left[:len(left)-nI]
|
||||
}
|
||||
}
|
||||
reverseDecimalSlice(expLeft)
|
||||
|
||||
for right != "" {
|
||||
if len(right) < nI {
|
||||
right += strings.Repeat("0", nI-len(right))
|
||||
expRight = append(expRight, RequireFromString(right))
|
||||
right = ""
|
||||
} else {
|
||||
part := right[:nI]
|
||||
expRight = append(expRight, RequireFromString(part))
|
||||
right = right[nI:]
|
||||
}
|
||||
}
|
||||
|
||||
gotLeft, gotRight := d.rootDigitGroups(n)
|
||||
fatalStr := "(%v).rootDigitGroups(%v)\nexpLeft:%v\ngotLeft:%v\nexpRight:%v\ngotRight:%v"
|
||||
fatalArgs := []interface{}{d, n, expLeft, gotLeft, expRight, gotRight}
|
||||
|
||||
if len(expLeft) != len(gotLeft) || len(expRight) != len(gotRight) {
|
||||
t.Fatalf(fatalStr, fatalArgs...)
|
||||
}
|
||||
|
||||
for j := range expLeft {
|
||||
if !expLeft[j].Equal(gotLeft[j]) {
|
||||
t.Fatalf(fatalStr, fatalArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
for j := range expRight {
|
||||
if !expRight[j].Equal(gotRight[j]) {
|
||||
t.Fatalf(fatalStr, fatalArgs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoot(t *testing.T) {
|
||||
rng := rand.New(rand.NewSource(0xdead1337))
|
||||
for i := 0; i < 2e3; i++ {
|
||||
rootIV, rootIE := rng.Int63n(1e7), rng.Int31n(10)
|
||||
if rng.Intn(2) == 0 {
|
||||
rootIE = -rootIE
|
||||
}
|
||||
|
||||
root := New(rootIV, rootIE)
|
||||
n := New(rng.Int63n(3)+2, 0) // TODO +1, not +2
|
||||
|
||||
d := root
|
||||
for i := int64(1); i < n.IntPart(); i++ {
|
||||
d = d.Mul(root)
|
||||
}
|
||||
|
||||
gotRoot := d.RootRound(n, 32)
|
||||
if !strings.HasPrefix(gotRoot.String(), root.String()) {
|
||||
t.Fatalf("%v root of %v\nexpected:%v\n got:%v", n, d, root, gotRoot)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecimal_Sign(t *testing.T) {
|
||||
if Zero.Sign() != 0 {
|
||||
t.Errorf("%q should have sign 0", Zero)
|
||||
|
|
Loading…
Reference in a new issue