From 78289cc844703d5715b0fc512d0795b0310fd194 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Wo=C5=9B?= Date: Wed, 3 Apr 2024 00:16:27 +0200 Subject: [PATCH] Add improved implementation of power operation (#358) * Adjust Pow implementation * Add PowWithPrecision method * Add PowInt32 method * Add PowBigInt method --- decimal.go | 288 ++++++++++++++++++++++++++++++++++++++++-- decimal_bench_test.go | 36 ++++++ decimal_test.go | 244 +++++++++++++++++++++++++++++++++-- 3 files changed, 546 insertions(+), 22 deletions(-) diff --git a/decimal.go b/decimal.go index 0224292..5fb1e4f 100644 --- a/decimal.go +++ b/decimal.go @@ -43,6 +43,20 @@ import ( // d4.String() // output: "0.667" var DivisionPrecision = 16 +// PowPrecisionNegativeExponent specifies the maximum precision of the result (digits after decimal point) +// when calculating decimal power. Only used for cases where the exponent is a negative number. +// This constant applies to Pow, PowInt32 and PowBigInt methods, PowWithPrecision method is not constrained by it. +// +// Example: +// +// d1, err := decimal.NewFromFloat(15.2).PowInt32(-2) +// d1.String() // output: "0.0043282548476454" +// +// decimal.PowPrecisionNegativeExponent = 24 +// d2, err := decimal.NewFromFloat(15.2).PowInt32(-2) +// d2.String() // output: "0.004328254847645429362881" +var PowPrecisionNegativeExponent = 16 + // MarshalJSONWithoutQuotes should be set to true if you want the decimal to // be JSON marshaled as a number, instead of as a string. // WARNING: this is dangerous for decimals with many digits, since many JSON @@ -649,20 +663,274 @@ func (d Decimal) Mod(d2 Decimal) Decimal { return r } -// Pow returns d to the power d2 +// Pow returns d to the power of d2. +// When exponent is negative the returned decimal will have maximum precision of PowPrecisionNegativeExponent places after decimal point. +// +// Pow returns 0 (zero-value of Decimal) instead of error for power operation edge cases, to handle those edge cases use PowWithPrecision +// Edge cases not handled by Pow: +// - 0 ** 0 => undefined value +// - 0 ** y, where y < 0 => infinity +// - x ** y, where x < 0 and y is non-integer decimal => imaginary value +// +// Example: +// +// d1 := decimal.NewFromFloat(4.0) +// d2 := decimal.NewFromFloat(4.0) +// res1 := d1.Pow(d2) +// res1.String() // output: "256" +// +// d3 := decimal.NewFromFloat(5.0) +// d4 := decimal.NewFromFloat(5.73) +// res2 := d3.Pow(d4) +// res2.String() // output: "10118.08037125" func (d Decimal) Pow(d2 Decimal) Decimal { - var temp Decimal - if d2.IntPart() == 0 { - return NewFromFloat(1) + baseSign := d.Sign() + expSign := d2.Sign() + + if baseSign == 0 { + if expSign == 0 { + return Decimal{} + } + if expSign == 1 { + return Decimal{zeroInt, 0} + } + if expSign == -1 { + return Decimal{} + } } - temp = d.Pow(d2.Div(NewFromFloat(2))) - if d2.IntPart()%2 == 0 { - return temp.Mul(temp) + + if expSign == 0 { + return Decimal{oneInt, 0} } - if d2.IntPart() > 0 { - return temp.Mul(temp).Mul(d) + + // TODO: optimize extraction of fractional part + one := Decimal{oneInt, 0} + expIntPart, expFracPart := d2.QuoRem(one, 0) + + if baseSign == -1 && !expFracPart.IsZero() { + return Decimal{} } - return temp.Mul(temp).Div(d) + + intPartPow, _ := d.PowBigInt(expIntPart.value) + + // if exponent is an integer we don't need to calculate d1**frac(d2) + if expFracPart.value.Sign() == 0 { + return intPartPow + } + + // TODO: optimize NumDigits for more performant precision adjustment + digitsBase := d.NumDigits() + digitsExponent := d2.NumDigits() + + precision := digitsBase + + if digitsExponent > precision { + precision += digitsExponent + } + + precision += 6 + + // Calculate x ** frac(y), where + // x ** frac(y) = exp(ln(x ** frac(y)) = exp(ln(x) * frac(y)) + fracPartPow, err := d.Abs().Ln(-d.exp + int32(precision)) + if err != nil { + return Decimal{} + } + + fracPartPow = fracPartPow.Mul(expFracPart) + + fracPartPow, err = fracPartPow.ExpTaylor(-d.exp + int32(precision)) + if err != nil { + return Decimal{} + } + + // Join integer and fractional part, + // base ** (expBase + expFrac) = base ** expBase * base ** expFrac + res := intPartPow.Mul(fracPartPow) + + return res +} + +// PowWithPrecision returns d to the power of d2. +// Precision parameter specifies minimum precision of the result (digits after decimal point). +// Returned decimal is not rounded to 'precision' places after decimal point. +// +// PowWithPrecision returns error when: +// - 0 ** 0 => undefined value +// - 0 ** y, where y < 0 => infinity +// - x ** y, where x < 0 and y is non-integer decimal => imaginary value +// +// Example: +// +// d1 := decimal.NewFromFloat(4.0) +// d2 := decimal.NewFromFloat(4.0) +// res1, err := d1.PowWithPrecision(d2, 2) +// res1.String() // output: "256" +// +// d3 := decimal.NewFromFloat(5.0) +// d4 := decimal.NewFromFloat(5.73) +// res2, err := d3.PowWithPrecision(d4, 5) +// res2.String() // output: "10118.080371595015625" +// +// d5 := decimal.NewFromFloat(-3.0) +// d6 := decimal.NewFromFloat(-6.0) +// res3, err := d5.PowWithPrecision(d6, 10) +// res3.String() // output: "0.0013717421" +func (d Decimal) PowWithPrecision(d2 Decimal, precision int32) (Decimal, error) { + baseSign := d.Sign() + expSign := d2.Sign() + + if baseSign == 0 { + if expSign == 0 { + return Decimal{}, fmt.Errorf("cannot represent undefined value of 0**0") + } + if expSign == 1 { + return Decimal{zeroInt, 0}, nil + } + if expSign == -1 { + return Decimal{}, fmt.Errorf("cannot represent infinity value of 0 ** y, where y < 0") + } + } + + if expSign == 0 { + return Decimal{oneInt, 0}, nil + } + + // TODO: optimize extraction of fractional part + one := Decimal{oneInt, 0} + expIntPart, expFracPart := d2.QuoRem(one, 0) + + if baseSign == -1 && !expFracPart.IsZero() { + return Decimal{}, fmt.Errorf("cannot represent imaginary value of x ** y, where x < 0 and y is non-integer decimal") + } + + intPartPow, _ := d.powBigIntWithPrecision(expIntPart.value, precision) + + // if exponent is an integer we don't need to calculate d1**frac(d2) + if expFracPart.value.Sign() == 0 { + return intPartPow, nil + } + + // TODO: optimize NumDigits for more performant precision adjustment + digitsBase := d.NumDigits() + digitsExponent := d2.NumDigits() + + if int32(digitsBase) > precision { + precision = int32(digitsBase) + } + if int32(digitsExponent) > precision { + precision += int32(digitsExponent) + } + // increase precision by 10 to compensate for errors in further calculations + precision += 10 + + // Calculate x ** frac(y), where + // x ** frac(y) = exp(ln(x ** frac(y)) = exp(ln(x) * frac(y)) + fracPartPow, err := d.Abs().Ln(precision) + if err != nil { + return Decimal{}, err + } + + fracPartPow = fracPartPow.Mul(expFracPart) + + fracPartPow, err = fracPartPow.ExpTaylor(precision) + if err != nil { + return Decimal{}, err + } + + // Join integer and fractional part, + // base ** (expBase + expFrac) = base ** expBase * base ** expFrac + res := intPartPow.Mul(fracPartPow) + + return res, nil +} + +// PowInt32 returns d to the power of exp, where exp is int32. +// Only returns error when d and exp is 0, thus result is undefined. +// +// When exponent is negative the returned decimal will have maximum precision of PowPrecisionNegativeExponent places after decimal point. +// +// Example: +// +// d1, err := decimal.NewFromFloat(4.0).PowInt32(4) +// d1.String() // output: "256" +// +// d2, err := decimal.NewFromFloat(3.13).PowInt32(5) +// d2.String() // output: "300.4150512793" +func (d Decimal) PowInt32(exp int32) (Decimal, error) { + if d.IsZero() && exp == 0 { + return Decimal{}, fmt.Errorf("cannot represent undefined value of 0**0") + } + + isExpNeg := exp < 0 + exp = abs(exp) + + n, result := d, New(1, 0) + + for exp > 0 { + if exp%2 == 1 { + result = result.Mul(n) + } + exp /= 2 + + if exp > 0 { + n = n.Mul(n) + } + } + + if isExpNeg { + return New(1, 0).DivRound(result, int32(PowPrecisionNegativeExponent)), nil + } + + return result, nil +} + +// PowBigInt returns d to the power of exp, where exp is big.Int. +// Only returns error when d and exp is 0, thus result is undefined. +// +// When exponent is negative the returned decimal will have maximum precision of PowPrecisionNegativeExponent places after decimal point. +// +// Example: +// +// d1, err := decimal.NewFromFloat(3.0).PowBigInt(big.NewInt(3)) +// d1.String() // output: "27" +// +// d2, err := decimal.NewFromFloat(629.25).PowBigInt(big.NewInt(5)) +// d2.String() // output: "98654323103449.5673828125" +func (d Decimal) PowBigInt(exp *big.Int) (Decimal, error) { + return d.powBigIntWithPrecision(exp, int32(PowPrecisionNegativeExponent)) +} + +func (d Decimal) powBigIntWithPrecision(exp *big.Int, precision int32) (Decimal, error) { + if d.IsZero() && exp.Sign() == 0 { + return Decimal{}, fmt.Errorf("cannot represent undefined value of 0**0") + } + + tmpExp := new(big.Int).Set(exp) + isExpNeg := exp.Sign() < 0 + + if isExpNeg { + tmpExp.Abs(tmpExp) + } + + n, result := d, New(1, 0) + + for tmpExp.Sign() > 0 { + if tmpExp.Bit(0) == 1 { + result = result.Mul(n) + } + tmpExp.Rsh(tmpExp, 1) + + if tmpExp.Sign() > 0 { + n = n.Mul(n) + } + } + + if isExpNeg { + return New(1, 0).DivRound(result, precision), nil + } + + return result, nil } // ExpHullAbrham calculates the natural exponent of decimal (e to the power of d) using Hull-Abraham algorithm. diff --git a/decimal_bench_test.go b/decimal_bench_test.go index 269a9f6..b1978bc 100644 --- a/decimal_bench_test.go +++ b/decimal_bench_test.go @@ -3,6 +3,7 @@ package decimal import ( "fmt" "math" + "math/big" "math/rand" "sort" "strconv" @@ -185,6 +186,41 @@ func BenchmarkDecimal_IsInteger(b *testing.B) { } } +func BenchmarkDecimal_Pow(b *testing.B) { + d1 := RequireFromString("5.2") + d2 := RequireFromString("6.3") + + for i := 0; i < b.N; i++ { + d1.Pow(d2) + } +} + +func BenchmarkDecimal_PowWithPrecision(b *testing.B) { + d1 := RequireFromString("5.2") + d2 := RequireFromString("6.3") + + for i := 0; i < b.N; i++ { + _, _ = d1.PowWithPrecision(d2, 8) + } +} +func BenchmarkDecimal_PowInt32(b *testing.B) { + d1 := RequireFromString("5.2") + d2 := int32(10) + + for i := 0; i < b.N; i++ { + _, _ = d1.PowInt32(d2) + } +} + +func BenchmarkDecimal_PowBigInt(b *testing.B) { + d1 := RequireFromString("5.2") + d2 := big.NewInt(10) + + for i := 0; i < b.N; i++ { + _, _ = d1.PowBigInt(d2) + } +} + func BenchmarkDecimal_NewFromString(b *testing.B) { count := 72 prices := make([]string, 0, count) diff --git a/decimal_test.go b/decimal_test.go index 841f205..0905ce8 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -2621,21 +2621,241 @@ func TestDecimal_Cmp2(t *testing.T) { } } -func TestPow(t *testing.T) { - a := New(4, 0) - b := New(2, 0) - x := a.Pow(b) - if x.String() != "16" { - t.Errorf("Error, saw %s", x.String()) +func TestDecimal_Pow(t *testing.T) { + for _, testCase := range []struct { + Base string + Exponent string + Expected string + }{ + {"0.0", "1.0", "0.0"}, + {"0.0", "5.7", "0.0"}, + {"0.0", "-3.2", "0.0"}, + {"3.13", "0.0", "1.0"}, + {"-591.5", "0.0", "1.0"}, + {"3.0", "3.0", "27.0"}, + {"3.0", "10.0", "59049.0"}, + {"3.13", "5.0", "300.4150512793"}, + {"4.0", "2.0", "16.0"}, + {"4.0", "-2.0", "0.0625"}, + {"629.25", "5.0", "98654323103449.5673828125"}, + {"5.0", "5.73", "10118.08037159375"}, + {"962.0", "3.2791", "6055212360.0000044205714144"}, + {"5.69169126", "5.18515912", "8242.26344757948412597909547972726268869189399260047793106028930864"}, + {"13.1337", "3.5196719618391835", "8636.856220644773844815693636723928750940666269885"}, + {"67762386.283696923", "4.85917691669163916681738", "112761146905370140621385730157437443321.91755738117317148674362233906499698561022574811238435007575701773212242750262081945556470501"}, + {"-3.0", "6.0", "729"}, + {"-13.757", "5.0", "-492740.983929899460557"}, + {"3.0", "-6.0", "0.0013717421124829"}, + {"13.757", "-5.0", "0.000002029463821"}, + {"66.12", "-7.61313", "0.000000000000013854086588876805036"}, + {"6696871.12", "-2.61313", "0.000000000000000001455988684546983"}, + {"-3.0", "-6.0", "0.0013717421124829"}, + {"-13.757", "-5.0", "-0.000002029463821"}, + } { + base, _ := NewFromString(testCase.Base) + exp, _ := NewFromString(testCase.Exponent) + expected, _ := NewFromString(testCase.Expected) + + result := base.Pow(exp) + + if result.Cmp(expected) != 0 { + t.Errorf("expected %s, got %s, for %s^%s", testCase.Expected, result.String(), testCase.Base, testCase.Exponent) + } } } -func TestNegativePow(t *testing.T) { - a := New(4, 0) - b := New(-2, 0) - x := a.Pow(b) - if x.String() != "0.0625" { - t.Errorf("Error, saw %s", x.String()) +func TestDecimal_PowWithPrecision(t *testing.T) { + for _, testCase := range []struct { + Base string + Exponent string + Precision int32 + Expected string + }{ + {"0.0", "1.0", 2, "0.0"}, + {"0.0", "5.7", 2, "0.0"}, + {"0.0", "-3.2", 2, "0.0"}, + {"3.13", "0.0", 2, "1.0"}, + {"-591.5", "0.0", 2, "1.0"}, + {"3.0", "3.0", 2, "27.0"}, + {"3.0", "10.0", 2, "59049.0"}, + {"3.13", "5.0", 5, "300.4150512793"}, + {"4.0", "2.0", 2, "16.0"}, + {"4.0", "-2.0", 2, "0.06"}, + {"4.0", "-2.0", 4, "0.0625"}, + {"629.25", "5.0", 6, "98654323103449.5673828125"}, + {"5.0", "5.73", 20, "10118.080371595019317118681359884375"}, + {"962.0", "3.2791", 15, "6055212360.000004406551603058195732"}, + {"5.69169126", "5.18515912", 4, "8242.26344757948412587366859330429895955552280978668983459852256"}, + {"13.1337", "3.5196719618391835", 8, "8636.85622064477384481569363672392591908386390769375"}, + {"67762386.283696923", "4.85917691669163916681738", 10, "112761146905370140621385730157437443321.917557381173174638304347353880676293576708009282115993465286373470882947470198597518762"}, + {"-3.0", "6.0", 2, "729"}, + {"-13.757", "5.0", 4, "-492740.983929899460557"}, + {"3.0", "-6.0", 10, "0.0013717421"}, + {"13.757", "-5.0", 20, "0.00000202946382098037"}, + {"66.12", "-7.61313", 20, "0.00000000000001385381563049821591633907104023700216"}, + {"6696871.12", "-2.61313", 24, "0.0000000000000000014558252733872790626400278983397459207418"}, + {"-3.0", "-6.0", 8, "0.00137174"}, + {"-13.757", "-5.0", 16, "-0.000002029463821"}, + } { + base, _ := NewFromString(testCase.Base) + exp, _ := NewFromString(testCase.Exponent) + expected, _ := NewFromString(testCase.Expected) + + result, _ := base.PowWithPrecision(exp, testCase.Precision) + + if result.Cmp(expected) != 0 { + t.Errorf("expected %s, got %s, for %s^%s", testCase.Expected, result.String(), testCase.Base, testCase.Exponent) + } + } +} + +func TestDecimal_PowWithPrecision_Infinity(t *testing.T) { + for _, testCase := range []struct { + Base string + Exponent string + }{ + {"0.0", "0.0"}, + {"0.0", "-2.0"}, + {"0.0", "-4.6"}, + {"-66.12", "7.61313"}, // Imaginary value + {"-5696871.12", "5.61313"}, // Imaginary value + } { + base, _ := NewFromString(testCase.Base) + exp, _ := NewFromString(testCase.Exponent) + + _, err := base.PowWithPrecision(exp, 5) + + if err == nil { + t.Errorf("lool it should be error") + } + } +} + +func TestDecimal_PowWithPrecision_UndefinedResult(t *testing.T) { + base := RequireFromString("0") + exponent := RequireFromString("0") + + _, err := base.PowWithPrecision(exponent, 4) + + if err == nil { + t.Errorf("expected error, cannot be represent undefined value of 0**0") + } +} + +func TestDecimal_PowWithPrecision_InfinityResult(t *testing.T) { + for _, testCase := range []struct { + Base string + Exponent string + }{ + {"0.0", "-2.0"}, + {"0.0", "-4.6"}, + {"0.0", "-9239.671333"}, + } { + base, _ := NewFromString(testCase.Base) + exp, _ := NewFromString(testCase.Exponent) + + _, err := base.PowWithPrecision(exp, 4) + + if err == nil { + t.Errorf("expected error, cannot represent infinity value of 0 ** y, where y < 0") + } + } +} + +func TestDecimal_PowWithPrecision_ImaginaryResult(t *testing.T) { + for _, testCase := range []struct { + Base string + Exponent string + }{ + {"-0.2261", "106.12"}, + {"-66.12", "7.61313"}, + {"-5696871.12", "5.61313"}, + } { + base, _ := NewFromString(testCase.Base) + exp, _ := NewFromString(testCase.Exponent) + + _, err := base.PowWithPrecision(exp, 4) + + if err == nil { + t.Errorf("expected error, cannot represent imaginary value of x ** y, where x < 0 and y is non-integer decimal") + } + } +} + +func TestDecimal_PowInt32(t *testing.T) { + for _, testCase := range []struct { + Decimal string + Exponent int32 + Expected string + }{ + {"0.0", 1, "0.0"}, + {"3.13", 0, "1.0"}, + {"-591.5", 0, "1.0"}, + {"3.0", 3, "27.0"}, + {"3.0", 10, "59049.0"}, + {"3.13", 5, "300.4150512793"}, + {"629.25", 5, "98654323103449.5673828125"}, + {"-3.0", 6, "729"}, + {"-13.757", 5, "-492740.983929899460557"}, + {"3.0", -6, "0.0013717421124829"}, + {"-13.757", -5, "-0.000002029463821"}, + } { + base, _ := NewFromString(testCase.Decimal) + expected, _ := NewFromString(testCase.Expected) + + result, _ := base.PowInt32(testCase.Exponent) + + if result.Cmp(expected) != 0 { + t.Errorf("expected %s, got %s, for %s**%d", testCase.Expected, result.String(), testCase.Decimal, testCase.Exponent) + } + } +} + +func TestDecimal_PowInt32_UndefinedResult(t *testing.T) { + base := RequireFromString("0") + + _, err := base.PowInt32(0) + + if err == nil { + t.Errorf("expected error, cannot be represent undefined value of 0**0") + } +} + +func TestDecimal_PowBigInt(t *testing.T) { + for _, testCase := range []struct { + Decimal string + Exponent *big.Int + Expected string + }{ + {"3.13", big.NewInt(0), "1.0"}, + {"-591.5", big.NewInt(0), "1.0"}, + {"3.0", big.NewInt(3), "27.0"}, + {"3.0", big.NewInt(10), "59049.0"}, + {"3.13", big.NewInt(5), "300.4150512793"}, + {"629.25", big.NewInt(5), "98654323103449.5673828125"}, + {"-3.0", big.NewInt(6), "729"}, + {"-13.757", big.NewInt(5), "-492740.983929899460557"}, + {"3.0", big.NewInt(-6), "0.0013717421124829"}, + {"-13.757", big.NewInt(-5), "-0.000002029463821"}, + } { + base, _ := NewFromString(testCase.Decimal) + expected, _ := NewFromString(testCase.Expected) + + result, _ := base.PowBigInt(testCase.Exponent) + + if result.Cmp(expected) != 0 { + t.Errorf("expected %s, got %s, for %s**%d", testCase.Expected, result.String(), testCase.Decimal, testCase.Exponent) + } + } +} + +func TestDecimal_PowBigInt_UndefinedResult(t *testing.T) { + base := RequireFromString("0") + + _, err := base.PowBigInt(big.NewInt(0)) + + if err == nil { + t.Errorf("expected error, undefined value of 0**0 cannot be represented") } }