diff --git a/decimal.go b/decimal.go index cf42469..bffac10 100644 --- a/decimal.go +++ b/decimal.go @@ -19,6 +19,7 @@ package decimal import ( "database/sql/driver" "encoding/binary" + "errors" "fmt" "math" "math/big" @@ -1360,52 +1361,6 @@ func RescalePair(d1 Decimal, d2 Decimal) (Decimal, Decimal) { return d1, d2.rescale(baseScale) } -// SqrtMaxIter sets a limit for number of iterations for the Sqrt function -const SqrtMaxIter = 100000 - -// Sqrt returns the square root of d, accurate to DivisionPrecision decimal places. -func Sqrt(d Decimal) Decimal { - s, _ := SqrtRound(d, int32(DivisionPrecision)) - return s -} - -// SqrtRound returns the square root of d, accurate to precision decimal places. -// The bool precise returns whether the precision was reached. -func SqrtRound(d Decimal, precision int32) (Decimal, bool) { - maxError := New(1, -precision) - one := NewFromFloat(1) - var lo Decimal - var hi Decimal - // Handle cases where d < 0, d = 0, 0 < d < 1, and d > 1 - if d.GreaterThanOrEqual(one) { - lo = Zero - hi = d - } else if d.Equal(one) { - return one, true - } else if d.LessThan(Zero) { - return NewFromFloat(-1), false // call this an error , cannot take sqrt of neg w/o imaginaries - } else if d.Equal(Zero) { - return Zero, true - } else { - // d is between 0 and 1. Therefore, 0 < d < Sqrt(d) < 1. - lo = d - hi = one - } - var mid Decimal - for i := 0; i < SqrtMaxIter; i++ { - mid = lo.Add(hi).Div(New(2, 0)) //mid = (lo+hi)/2; - if mid.Mul(mid).Sub(d).Abs().LessThan(maxError) { - return mid, true - } - if mid.Mul(mid).GreaterThan(d) { - hi = mid - } else { - lo = mid - } - } - return mid, false -} - func min(x, y int32) int32 { if x >= y { return y @@ -1723,3 +1678,59 @@ func (d Decimal) Tan() Decimal { } return y } + +// More math + +// Sqrt returns the square root of d, accurate to DivisionPrecision decimal places. +// Sqrt is only valid for non-negative numbers; it will return an error otherwise. +func (d Decimal) Sqrt() (Decimal, error) { + s, _, err := d.SqrtRound(int32(DivisionPrecision)) + return s, err +} + +// ErrImaginaryResult indicates an operation that would produce an imaginary result. +var ErrImaginaryResult = errors.New("The result of this operation is imaginary.") + +// SqrtMaxIter sets a limit for number of iterations for the Sqrt function +const SqrtMaxIter = 1000000 + +// SqrtRound returns the square root of d, accurate to precision decimal places. +// The bool precise returns whether the precision was achieved. +// SqrtRound is only valid for non-negative numbers; it will return an error otherwise. +func (d Decimal) SqrtRound(precision int32) (Decimal, bool, error) { + var ( + maxError = New(1, -precision) + one = NewFromFloat(1) + lo, hi Decimal + ) + + // Handle cases where d < 0, d = 0, 0 < d < 1, and d > 1 + if d.GreaterThanOrEqual(one) { + lo = Zero + hi = d + } else if d.Equal(one) { + return one, true, nil + } else if d.LessThan(Zero) { + return Zero, false, ErrImaginaryResult + } else if d.Equal(Zero) { + return Zero, true, nil + } else { + // d is between 0 and 1. Therefore, 0 < d < Sqrt(d) < 1. + lo = d + hi = one + } + + var mid Decimal + for i := 0; i < SqrtMaxIter; i++ { + mid = lo.Add(hi).Div(New(2, 0)) //mid = (lo+hi)/2; + if mid.Mul(mid).Sub(d).Abs().LessThan(maxError) { + return mid, true, nil + } + if mid.Mul(mid).GreaterThan(d) { + hi = mid + } else { + lo = mid + } + } + return mid, false, nil +} diff --git a/decimal_test.go b/decimal_test.go index 3667754..f8bcc62 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -2881,6 +2881,40 @@ func TestAvg(t *testing.T) { } } +func TestSqrtRound(t *testing.T) { + i := NewFromFloat(-1) + if _, err := i.Sqrt(); err != ErrImaginaryResult { + t.Errorf("Square root of -1 should produce error") + } + + var vals = map[string]string{ + // value : Sqrt(value) + "0.0": "0.0", + "0.002342": "0.0483942145302514", + "1.0": "1.0", + "3.0": "1.7320508075688773", + "4.0": "2.0", + "4.5": "2.1213203435596426", + "3289854.0": "1813.7954680724064485", + } + + for val, expected := range vals { + v, err := NewFromString(val) + if err != nil { + t.Errorf("error parsing test value into Decimal") + } + + e, err := NewFromString(expected) + if err != nil { + t.Errorf("error parsing test expected value into Decimal") + } + + if sqrt, err := v.Sqrt(); err != nil || !sqrt.Equal(e) { + t.Errorf("Square root of %s should be %s, not %s (error: %s)", v, e, sqrt, err) + } + } +} + func TestRoundBankAnomaly(t *testing.T) { a := New(25, -1) b := New(250, -2)