diff --git a/decimal.go b/decimal.go index db4efc6..cf42469 100644 --- a/decimal.go +++ b/decimal.go @@ -1360,6 +1360,52 @@ func RescalePair(d1 Decimal, d2 Decimal) (Decimal, Decimal) { return d1, d2.rescale(baseScale) } +// SqrtMaxIter sets a limit for number of iterations for the Sqrt function +const SqrtMaxIter = 100000 + +// Sqrt returns the square root of d, accurate to DivisionPrecision decimal places. +func Sqrt(d Decimal) Decimal { + s, _ := SqrtRound(d, int32(DivisionPrecision)) + return s +} + +// SqrtRound returns the square root of d, accurate to precision decimal places. +// The bool precise returns whether the precision was reached. +func SqrtRound(d Decimal, precision int32) (Decimal, bool) { + maxError := New(1, -precision) + one := NewFromFloat(1) + var lo Decimal + var hi Decimal + // Handle cases where d < 0, d = 0, 0 < d < 1, and d > 1 + if d.GreaterThanOrEqual(one) { + lo = Zero + hi = d + } else if d.Equal(one) { + return one, true + } else if d.LessThan(Zero) { + return NewFromFloat(-1), false // call this an error , cannot take sqrt of neg w/o imaginaries + } else if d.Equal(Zero) { + return Zero, true + } else { + // d is between 0 and 1. Therefore, 0 < d < Sqrt(d) < 1. + lo = d + hi = one + } + var mid Decimal + for i := 0; i < SqrtMaxIter; i++ { + mid = lo.Add(hi).Div(New(2, 0)) //mid = (lo+hi)/2; + if mid.Mul(mid).Sub(d).Abs().LessThan(maxError) { + return mid, true + } + if mid.Mul(mid).GreaterThan(d) { + hi = mid + } else { + lo = mid + } + } + return mid, false +} + func min(x, y int32) int32 { if x >= y { return y