mirror of
https://github.com/shopspring/decimal.git
synced 2024-11-23 04:40:49 +01:00
add Root and RootRound methods
This commit is contained in:
parent
0ea7e08d01
commit
c2e758a936
2 changed files with 260 additions and 0 deletions
137
decimal.go
137
decimal.go
|
@ -44,6 +44,10 @@ import (
|
||||||
//
|
//
|
||||||
var DivisionPrecision = 16
|
var DivisionPrecision = 16
|
||||||
|
|
||||||
|
// RootPrecision is the number of decimal places in the result from the Root
|
||||||
|
// method.
|
||||||
|
var RootPrecision = 16
|
||||||
|
|
||||||
// MarshalJSONWithoutQuotes should be set to true if you want the decimal to
|
// MarshalJSONWithoutQuotes should be set to true if you want the decimal to
|
||||||
// be JSON marshaled as a number, instead of as a string.
|
// be JSON marshaled as a number, instead of as a string.
|
||||||
// WARNING: this is dangerous for decimals with many digits, since many JSON
|
// WARNING: this is dangerous for decimals with many digits, since many JSON
|
||||||
|
@ -55,9 +59,15 @@ var MarshalJSONWithoutQuotes = false
|
||||||
// Zero constant, to make computations faster.
|
// Zero constant, to make computations faster.
|
||||||
var Zero = New(0, 1)
|
var Zero = New(0, 1)
|
||||||
|
|
||||||
|
// oneDec used for incrementing/decrementing
|
||||||
|
var oneDec = New(1, 0)
|
||||||
|
|
||||||
// fiveDec used in Cash Rounding
|
// fiveDec used in Cash Rounding
|
||||||
var fiveDec = New(5, 0)
|
var fiveDec = New(5, 0)
|
||||||
|
|
||||||
|
// used for shifting digits
|
||||||
|
var tenDec = New(10, 0)
|
||||||
|
|
||||||
var zeroInt = big.NewInt(0)
|
var zeroInt = big.NewInt(0)
|
||||||
var oneInt = big.NewInt(1)
|
var oneInt = big.NewInt(1)
|
||||||
var twoInt = big.NewInt(2)
|
var twoInt = big.NewInt(2)
|
||||||
|
@ -557,6 +567,133 @@ func (d Decimal) Pow(d2 Decimal) Decimal {
|
||||||
return temp.Mul(temp).Div(d)
|
return temp.Mul(temp).Div(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// factorial assumes d is a positive whole number
|
||||||
|
func (d Decimal) factorial() Decimal {
|
||||||
|
out := oneDec
|
||||||
|
for i := New(2, 0); i.LessThanOrEqual(d); i = i.Add(oneDec) {
|
||||||
|
out = out.Mul(i)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Decimal) combination(r Decimal) Decimal {
|
||||||
|
top := d.factorial()
|
||||||
|
bottom := r.factorial().Mul(d.Sub(r).factorial())
|
||||||
|
return top.Div(bottom)
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseDecimalSlice(slice []Decimal) {
|
||||||
|
for i := len(slice)/2 - 1; i >= 0; i-- {
|
||||||
|
opp := len(slice) - 1 - i
|
||||||
|
slice[i], slice[opp] = slice[opp], slice[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns groups of n digits from the original Decimal, centered and split on
|
||||||
|
// the position of the decimal place, for purpose of computing the nth root via
|
||||||
|
// the shifting nth root algo.
|
||||||
|
//
|
||||||
|
// 0.303 (n:2) -> [], [30, 30]
|
||||||
|
// 1000 (n:2) -> [10, 00], []
|
||||||
|
// 2310.0241 (n:3) -> [2, 310], [024, 100]
|
||||||
|
func (d Decimal) rootDigitGroups(n Decimal) (left, right []Decimal) {
|
||||||
|
e := tenDec.Pow(n)
|
||||||
|
|
||||||
|
dLeft := d.rescale(0)
|
||||||
|
for {
|
||||||
|
if dLeft.Equal(Zero) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
next := dLeft.Div(e).rescale(0)
|
||||||
|
left = append(left, dLeft.Sub(next.Mul(e)))
|
||||||
|
dLeft = next
|
||||||
|
}
|
||||||
|
reverseDecimalSlice(left)
|
||||||
|
|
||||||
|
dRight := d.Sub(d.rescale(0))
|
||||||
|
for {
|
||||||
|
if dRight.Equal(Zero) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
dRightE := dRight.Mul(e)
|
||||||
|
dRightETrunc := dRightE.rescale(0)
|
||||||
|
right = append(right, dRightETrunc)
|
||||||
|
dRight = dRightE.Sub(dRightETrunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
return left, right
|
||||||
|
}
|
||||||
|
|
||||||
|
// Root returns the nth root of d.
|
||||||
|
func (d Decimal) Root(n Decimal) Decimal {
|
||||||
|
return d.RootRound(n, int32(RootPrecision))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RootRound returns the nth root of d. If the result is not a whole number then
|
||||||
|
// the given precision determines the number of decimal places to calculate the
|
||||||
|
// result to.
|
||||||
|
func (d Decimal) RootRound(n Decimal, prec int32) Decimal {
|
||||||
|
// The nth root is calculated via the shifting root algorithm, which was
|
||||||
|
// chosen for being definitely correct up to an arbitrary precision, and not
|
||||||
|
// because it's particularly fast.
|
||||||
|
//
|
||||||
|
// https://en.wikipedia.org/wiki/Shifting_nth_root_algorithm
|
||||||
|
// https://www.wikihow.com/Find-Nth-Roots-by-Hand
|
||||||
|
|
||||||
|
leftGroups, rightGroups := d.rootDigitGroups(n)
|
||||||
|
numLeftGroups := len(leftGroups)
|
||||||
|
|
||||||
|
answer, answerNoDecimal, target := Zero, Zero, Zero
|
||||||
|
var numDigits int32
|
||||||
|
for {
|
||||||
|
if numDigits-int32(numLeftGroups) >= prec {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
group := Zero
|
||||||
|
if len(leftGroups) > 0 {
|
||||||
|
group, leftGroups = leftGroups[0], leftGroups[1:]
|
||||||
|
} else if len(rightGroups) > 0 {
|
||||||
|
group, rightGroups = rightGroups[0], rightGroups[1:]
|
||||||
|
}
|
||||||
|
target = target.Mul(tenDec.Pow(n)).Add(group)
|
||||||
|
|
||||||
|
nextDigit, nextSub := oneDec, Zero
|
||||||
|
for ; nextDigit.LessThan(tenDec); nextDigit = nextDigit.Add(oneDec) {
|
||||||
|
tryNextSub := Zero
|
||||||
|
for i := Zero; i.LessThan(n); i = i.Add(oneDec) {
|
||||||
|
sumStep := n.combination(i.Add(oneDec))
|
||||||
|
sumStep = sumStep.Mul(nextDigit.Pow(i))
|
||||||
|
sumStep = sumStep.Mul(answerNoDecimal.Mul(tenDec).Pow(n.Sub(i).Sub(oneDec)))
|
||||||
|
tryNextSub = tryNextSub.Add(sumStep)
|
||||||
|
}
|
||||||
|
tryNextSub = tryNextSub.Mul(nextDigit)
|
||||||
|
if tryNextSub.GreaterThan(target) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
nextSub = tryNextSub
|
||||||
|
}
|
||||||
|
nextDigit = nextDigit.Sub(oneDec)
|
||||||
|
|
||||||
|
answerNoDecimal = answerNoDecimal.Mul(tenDec).Add(nextDigit)
|
||||||
|
if numDigits < int32(numLeftGroups) {
|
||||||
|
answer = answerNoDecimal
|
||||||
|
} else {
|
||||||
|
shift := tenDec.Pow(New(int64(numDigits)-int64(numLeftGroups)+1, 0))
|
||||||
|
answer = answer.Add(nextDigit.DivRound(shift, prec))
|
||||||
|
}
|
||||||
|
|
||||||
|
target = target.Sub(nextSub)
|
||||||
|
if target.Equal(Zero) && len(leftGroups) == 0 && len(rightGroups) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
numDigits++
|
||||||
|
}
|
||||||
|
|
||||||
|
return answer
|
||||||
|
}
|
||||||
|
|
||||||
// Cmp compares the numbers represented by d and d2 and returns:
|
// Cmp compares the numbers represented by d and d2 and returns:
|
||||||
//
|
//
|
||||||
// -1 if d < d2
|
// -1 if d < d2
|
||||||
|
|
123
decimal_test.go
123
decimal_test.go
|
@ -2094,6 +2094,129 @@ func TestNegativePow(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFactorial(t *testing.T) {
|
||||||
|
tests := [][2]string{
|
||||||
|
{"0", "1"},
|
||||||
|
{"1", "1"},
|
||||||
|
{"2", "2"},
|
||||||
|
{"3", "6"},
|
||||||
|
{"4", "24"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
in, out := RequireFromString(test[0]), RequireFromString(test[1])
|
||||||
|
if f := in.factorial(); !f.Equal(out) {
|
||||||
|
t.Errorf("!%v should be %v, got %v", in, out, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCombination(t *testing.T) {
|
||||||
|
tests := [][3]string{
|
||||||
|
{"4", "1", "4"},
|
||||||
|
{"4", "2", "6"},
|
||||||
|
{"4", "3", "4"},
|
||||||
|
{"4", "4", "1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
n, r, out := RequireFromString(test[0]), RequireFromString(test[1]), RequireFromString(test[2])
|
||||||
|
if c := n.combination(r); !c.Equal(out) {
|
||||||
|
t.Errorf("C(%v,%v) should be %v, got %v", n, r, out, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRootDigitGroups(t *testing.T) {
|
||||||
|
rng := rand.New(rand.NewSource(0xdead1337))
|
||||||
|
for i := 0; i < 5e4; i++ {
|
||||||
|
dIV, dIE := rng.Int63n(1e7), rng.Int31n(10)
|
||||||
|
if rng.Intn(2) == 0 {
|
||||||
|
dIE = -dIE
|
||||||
|
}
|
||||||
|
|
||||||
|
d := New(dIV, dIE)
|
||||||
|
n := New(rng.Int63n(3)+1, 0)
|
||||||
|
nI := int(n.IntPart())
|
||||||
|
dStr := d.String()
|
||||||
|
dStrSplit := strings.Split(dStr, ".")
|
||||||
|
|
||||||
|
var left, right string
|
||||||
|
left = dStrSplit[0]
|
||||||
|
if len(dStrSplit) == 2 {
|
||||||
|
right = dStrSplit[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
var expLeft, expRight []Decimal
|
||||||
|
for left != "" && left != "0" {
|
||||||
|
if len(left) < nI {
|
||||||
|
expLeft = append(expLeft, RequireFromString(left))
|
||||||
|
left = ""
|
||||||
|
} else {
|
||||||
|
part := left[len(left)-nI:]
|
||||||
|
expLeft = append(expLeft, RequireFromString(part))
|
||||||
|
left = left[:len(left)-nI]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reverseDecimalSlice(expLeft)
|
||||||
|
|
||||||
|
for right != "" {
|
||||||
|
if len(right) < nI {
|
||||||
|
right += strings.Repeat("0", nI-len(right))
|
||||||
|
expRight = append(expRight, RequireFromString(right))
|
||||||
|
right = ""
|
||||||
|
} else {
|
||||||
|
part := right[:nI]
|
||||||
|
expRight = append(expRight, RequireFromString(part))
|
||||||
|
right = right[nI:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gotLeft, gotRight := d.rootDigitGroups(n)
|
||||||
|
fatalStr := "(%v).rootDigitGroups(%v)\nexpLeft:%v\ngotLeft:%v\nexpRight:%v\ngotRight:%v"
|
||||||
|
fatalArgs := []interface{}{d, n, expLeft, gotLeft, expRight, gotRight}
|
||||||
|
|
||||||
|
if len(expLeft) != len(gotLeft) || len(expRight) != len(gotRight) {
|
||||||
|
t.Fatalf(fatalStr, fatalArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := range expLeft {
|
||||||
|
if !expLeft[j].Equal(gotLeft[j]) {
|
||||||
|
t.Fatalf(fatalStr, fatalArgs...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := range expRight {
|
||||||
|
if !expRight[j].Equal(gotRight[j]) {
|
||||||
|
t.Fatalf(fatalStr, fatalArgs...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoot(t *testing.T) {
|
||||||
|
rng := rand.New(rand.NewSource(0xdead1337))
|
||||||
|
for i := 0; i < 2e3; i++ {
|
||||||
|
rootIV, rootIE := rng.Int63n(1e7), rng.Int31n(10)
|
||||||
|
if rng.Intn(2) == 0 {
|
||||||
|
rootIE = -rootIE
|
||||||
|
}
|
||||||
|
|
||||||
|
root := New(rootIV, rootIE)
|
||||||
|
n := New(rng.Int63n(3)+2, 0) // TODO +1, not +2
|
||||||
|
|
||||||
|
d := root
|
||||||
|
for i := int64(1); i < n.IntPart(); i++ {
|
||||||
|
d = d.Mul(root)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotRoot := d.RootRound(n, 32)
|
||||||
|
if !strings.HasPrefix(gotRoot.String(), root.String()) {
|
||||||
|
t.Fatalf("%v root of %v\nexpected:%v\n got:%v", n, d, root, gotRoot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDecimal_Sign(t *testing.T) {
|
func TestDecimal_Sign(t *testing.T) {
|
||||||
if Zero.Sign() != 0 {
|
if Zero.Sign() != 0 {
|
||||||
t.Errorf("%q should have sign 0", Zero)
|
t.Errorf("%q should have sign 0", Zero)
|
||||||
|
|
Loading…
Reference in a new issue