diff --git a/decimal.go b/decimal.go index a37a230..3c6e90c 100644 --- a/decimal.go +++ b/decimal.go @@ -19,6 +19,7 @@ package decimal import ( "database/sql/driver" "encoding/binary" + "errors" "fmt" "math" "math/big" @@ -2337,3 +2338,59 @@ func (d Decimal) Tan() Decimal { } return y } + +// More math + +// Sqrt returns the square root of d, accurate to DivisionPrecision decimal places. +// Sqrt is only valid for non-negative numbers; it will return an error otherwise. +func (d Decimal) Sqrt() (Decimal, error) { + s, _, err := d.SqrtRound(int32(DivisionPrecision)) + return s, err +} + +// ErrImaginaryResult indicates an operation that would produce an imaginary result. +var ErrImaginaryResult = errors.New("The result of this operation is imaginary.") + +// SqrtMaxIter sets a limit for number of iterations for the Sqrt function +const SqrtMaxIter = 1000000 + +// SqrtRound returns the square root of d, accurate to precision decimal places. +// The bool precise returns whether the precision was achieved. +// SqrtRound is only valid for non-negative numbers; it will return an error otherwise. +func (d Decimal) SqrtRound(precision int32) (Decimal, bool, error) { + var ( + maxError = New(1, -precision) + one = NewFromFloat(1) + lo, hi Decimal + ) + + // Handle cases where d < 0, d = 0, 0 < d < 1, and d > 1 + if d.GreaterThanOrEqual(one) { + lo = Zero + hi = d + } else if d.Equal(one) { + return one, true, nil + } else if d.LessThan(Zero) { + return Zero, false, ErrImaginaryResult + } else if d.Equal(Zero) { + return Zero, true, nil + } else { + // d is between 0 and 1. Therefore, 0 < d < Sqrt(d) < 1. + lo = d + hi = one + } + + var mid Decimal + for i := 0; i < SqrtMaxIter; i++ { + mid = lo.Add(hi).Div(New(2, 0)) //mid = (lo+hi)/2; + if mid.Mul(mid).Sub(d).Abs().LessThan(maxError) { + return mid, true, nil + } + if mid.Mul(mid).GreaterThan(d) { + hi = mid + } else { + lo = mid + } + } + return mid, false, nil +} diff --git a/decimal_test.go b/decimal_test.go index d398f2d..ba98174 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -3434,6 +3434,40 @@ func TestAvg(t *testing.T) { } } +func TestSqrtRound(t *testing.T) { + i := NewFromFloat(-1) + if _, err := i.Sqrt(); err != ErrImaginaryResult { + t.Errorf("Square root of -1 should produce error") + } + + var vals = map[string]string{ + // value : Sqrt(value) + "0.0": "0.0", + "0.002342": "0.0483942145302514", + "1.0": "1.0", + "3.0": "1.7320508075688773", + "4.0": "2.0", + "4.5": "2.1213203435596426", + "3289854.0": "1813.7954680724064485", + } + + for val, expected := range vals { + v, err := NewFromString(val) + if err != nil { + t.Errorf("error parsing test value into Decimal") + } + + e, err := NewFromString(expected) + if err != nil { + t.Errorf("error parsing test expected value into Decimal") + } + + if sqrt, err := v.Sqrt(); err != nil || !sqrt.Equal(e) { + t.Errorf("Square root of %s should be %s, not %s (error: %s)", v, e, sqrt, err) + } + } +} + func TestRoundBankAnomaly(t *testing.T) { a := New(25, -1) b := New(250, -2)