decimal/decimal.go
2018-04-22 16:31:26 -04:00

1147 lines
29 KiB
Go

// Package decimal implements an arbitrary precision fixed-point decimal.
//
// To use as part of a struct:
//
// type Struct struct {
// Number Decimal
// }
//
// The zero-value of a Decimal is 0, as you would expect.
//
// The best way to create a new Decimal is to use decimal.NewFromString, ex:
//
// n, err := decimal.NewFromString("-123.4567")
// n.String() // output: "-123.4567"
//
// NOTE: This can "only" represent numbers with a maximum of 2^31 digits
// after the decimal point.
package decimal
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"math"
"math/big"
"strconv"
"strings"
"gopkg.in/mgo.v2/bson"
)
// DivisionPrecision is the number of decimal places in the result when it
// doesn't divide exactly.
//
// Example:
//
// d1 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3)
// d1.String() // output: "0.6666666666666667"
// d2 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(30000)
// d2.String() // output: "0.0000666666666667"
// d3 := decimal.NewFromFloat(20000).Div(decimal.NewFromFloat(3)
// d3.String() // output: "6666.6666666666666667"
// decimal.DivisionPrecision = 3
// d4 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3)
// d4.String() // output: "0.667"
//
var DivisionPrecision = 16
// MarshalJSONWithoutQuotes should be set to true if you want the decimal to
// be JSON marshaled as a number, instead of as a string.
// WARNING: this is dangerous for decimals with many digits, since many JSON
// unmarshallers (ex: Javascript's) will unmarshal JSON numbers to IEEE 754
// double-precision floating point numbers, which means you can potentially
// silently lose precision.
var MarshalJSONWithoutQuotes = false
// Zero constant, to make computations faster.
var Zero = New(0, 1)
// fiveDec used in Cash Rounding
var fiveDec = New(5, 0)
var zeroInt = big.NewInt(0)
var oneInt = big.NewInt(1)
var twoInt = big.NewInt(2)
var fourInt = big.NewInt(4)
var fiveInt = big.NewInt(5)
var tenInt = big.NewInt(10)
var twentyInt = big.NewInt(20)
// Decimal represents a fixed-point decimal. It is immutable.
// number = value * 10 ^ exp
type Decimal struct {
value *big.Int
// NOTE(vadim): this must be an int32, because we cast it to float64 during
// calculations. If exp is 64 bit, we might lose precision.
// If we cared about being able to represent every possible decimal, we
// could make exp a *big.Int but it would hurt performance and numbers
// like that are unrealistic.
exp int32
}
// New returns a new fixed-point decimal, value * 10 ^ exp.
func New(value int64, exp int32) Decimal {
return Decimal{
value: big.NewInt(value),
exp: exp,
}
}
// NewFromBigInt returns a new Decimal from a big.Int, value * 10 ^ exp
func NewFromBigInt(value *big.Int, exp int32) Decimal {
return Decimal{
value: big.NewInt(0).Set(value),
exp: exp,
}
}
// NewFromString returns a new Decimal from a string representation.
//
// Example:
//
// d, err := NewFromString("-123.45")
// d2, err := NewFromString(".0001")
//
func NewFromString(value string) (Decimal, error) {
originalInput := value
var intString string
var exp int64
// Check if number is using scientific notation
eIndex := strings.IndexAny(value, "Ee")
if eIndex != -1 {
expInt, err := strconv.ParseInt(value[eIndex+1:], 10, 32)
if err != nil {
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange {
return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", value)
}
return Decimal{}, fmt.Errorf("can't convert %s to decimal: exponent is not numeric", value)
}
value = value[:eIndex]
exp = expInt
}
parts := strings.Split(value, ".")
if len(parts) == 1 {
// There is no decimal point, we can just parse the original string as
// an int
intString = value
} else if len(parts) == 2 {
// strip the insignificant digits for more accurate comparisons.
decimalPart := strings.TrimRight(parts[1], "0")
intString = parts[0] + decimalPart
expInt := -len(decimalPart)
exp += int64(expInt)
} else {
return Decimal{}, fmt.Errorf("can't convert %s to decimal: too many .s", value)
}
dValue := new(big.Int)
_, ok := dValue.SetString(intString, 10)
if !ok {
return Decimal{}, fmt.Errorf("can't convert %s to decimal", value)
}
if exp < math.MinInt32 || exp > math.MaxInt32 {
// NOTE(vadim): I doubt a string could realistically be this long
return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", originalInput)
}
return Decimal{
value: dValue,
exp: int32(exp),
}, nil
}
// RequireFromString returns a new Decimal from a string representation
// or panics if NewFromString would have returned an error.
//
// Example:
//
// d := RequireFromString("-123.45")
// d2 := RequireFromString(".0001")
//
func RequireFromString(value string) Decimal {
dec, err := NewFromString(value)
if err != nil {
panic(err)
}
return dec
}
// NewFromFloat converts a float64 to Decimal.
//
// Example:
//
// NewFromFloat(123.45678901234567).String() // output: "123.4567890123456"
// NewFromFloat(.00000000000000001).String() // output: "0.00000000000000001"
//
// NOTE: some float64 numbers can take up about 300 bytes of memory in decimal representation.
// Consider using NewFromFloatWithExponent if space is more important than precision.
//
// NOTE: this will panic on NaN, +/-inf
func NewFromFloat(value float64) Decimal {
return NewFromFloatWithExponent(value, math.MinInt32)
}
// NewFromFloatWithExponent converts a float64 to Decimal, with an arbitrary
// number of fractional digits.
//
// Example:
//
// NewFromFloatWithExponent(123.456, -2).String() // output: "123.46"
//
func NewFromFloatWithExponent(value float64, exp int32) Decimal {
if math.IsNaN(value) || math.IsInf(value, 0) {
panic(fmt.Sprintf("Cannot create a Decimal from %v", value))
}
bits := math.Float64bits(value)
mant := bits & (1<<52 - 1)
exp2 := int32((bits >> 52) & (1<<11 - 1))
sign := bits >> 63
if exp2 == 0 {
// specials
if mant == 0 {
return Decimal{}
} else {
// subnormal
exp2++
}
} else {
// normal
mant |= 1 << 52
}
exp2 -= 1023 + 52
// normalizing base-2 values
for mant&1 == 0 {
mant = mant >> 1
exp2++
}
// maximum number of fractional base-10 digits to represent 2^N exactly cannot be more than -N if N<0
if exp < 0 && exp < exp2 {
if exp2 < 0 {
exp = exp2
} else {
exp = 0
}
}
// representing 10^M * 2^N as 5^M * 2^(M+N)
exp2 -= exp
temp := big.NewInt(1)
dMant := big.NewInt(int64(mant))
// applying 5^M
if exp > 0 {
temp = temp.SetInt64(int64(exp))
temp = temp.Exp(fiveInt, temp, nil)
} else if exp < 0 {
temp = temp.SetInt64(-int64(exp))
temp = temp.Exp(fiveInt, temp, nil)
dMant = dMant.Mul(dMant, temp)
temp = temp.SetUint64(1)
}
// applying 2^(M+N)
if exp2 > 0 {
dMant = dMant.Lsh(dMant, uint(exp2))
} else if exp2 < 0 {
temp = temp.Lsh(temp, uint(-exp2))
}
// rounding and downscaling
if exp > 0 || exp2 < 0 {
halfDown := new(big.Int).Rsh(temp, 1)
dMant = dMant.Add(dMant, halfDown)
dMant = dMant.Quo(dMant, temp)
}
if sign == 1 {
dMant = dMant.Neg(dMant)
}
return Decimal{
value: dMant,
exp: exp,
}
}
// rescale returns a rescaled version of the decimal. Returned
// decimal may be less precise if the given exponent is bigger
// than the initial exponent of the Decimal.
// NOTE: this will truncate, NOT round
//
// Example:
//
// d := New(12345, -4)
// d2 := d.rescale(-1)
// d3 := d2.rescale(-4)
// println(d1)
// println(d2)
// println(d3)
//
// Output:
//
// 1.2345
// 1.2
// 1.2000
//
func (d Decimal) rescale(exp int32) Decimal {
d.ensureInitialized()
// NOTE(vadim): must convert exps to float64 before - to prevent overflow
diff := math.Abs(float64(exp) - float64(d.exp))
value := new(big.Int).Set(d.value)
expScale := new(big.Int).Exp(tenInt, big.NewInt(int64(diff)), nil)
if exp > d.exp {
value = value.Quo(value, expScale)
} else if exp < d.exp {
value = value.Mul(value, expScale)
}
return Decimal{
value: value,
exp: exp,
}
}
// Abs returns the absolute value of the decimal.
func (d Decimal) Abs() Decimal {
d.ensureInitialized()
d2Value := new(big.Int).Abs(d.value)
return Decimal{
value: d2Value,
exp: d.exp,
}
}
// Add returns d + d2.
func (d Decimal) Add(d2 Decimal) Decimal {
baseScale := min(d.exp, d2.exp)
rd := d.rescale(baseScale)
rd2 := d2.rescale(baseScale)
d3Value := new(big.Int).Add(rd.value, rd2.value)
return Decimal{
value: d3Value,
exp: baseScale,
}
}
// Sub returns d - d2.
func (d Decimal) Sub(d2 Decimal) Decimal {
baseScale := min(d.exp, d2.exp)
rd := d.rescale(baseScale)
rd2 := d2.rescale(baseScale)
d3Value := new(big.Int).Sub(rd.value, rd2.value)
return Decimal{
value: d3Value,
exp: baseScale,
}
}
// Neg returns -d.
func (d Decimal) Neg() Decimal {
d.ensureInitialized()
val := new(big.Int).Neg(d.value)
return Decimal{
value: val,
exp: d.exp,
}
}
// Mul returns d * d2.
func (d Decimal) Mul(d2 Decimal) Decimal {
d.ensureInitialized()
d2.ensureInitialized()
expInt64 := int64(d.exp) + int64(d2.exp)
if expInt64 > math.MaxInt32 || expInt64 < math.MinInt32 {
// NOTE(vadim): better to panic than give incorrect results, as
// Decimals are usually used for money
panic(fmt.Sprintf("exponent %v overflows an int32!", expInt64))
}
d3Value := new(big.Int).Mul(d.value, d2.value)
return Decimal{
value: d3Value,
exp: int32(expInt64),
}
}
// Shift shifts the decimal in base 10.
// It shifts left when shift is positive and right if shift is negative.
// In simpler terms, the given value for shift is added to the exponent
// of the decimal.
func (d Decimal) Shift(shift int32) Decimal {
d.ensureInitialized()
return Decimal{
value: new(big.Int).Set(d.value),
exp: d.exp + shift,
}
}
// Div returns d / d2. If it doesn't divide exactly, the result will have
// DivisionPrecision digits after the decimal point.
func (d Decimal) Div(d2 Decimal) Decimal {
return d.DivRound(d2, int32(DivisionPrecision))
}
// QuoRem does divsion with remainder
// d.QuoRem(d2,precision) returns quotient q and remainder r such that
// d = d2 * q + r, q an integer multiple of 10^(-precision)
// 0 <= r < abs(d2) * 10 ^(-precision) if d>=0
// 0 >= r > -abs(d2) * 10 ^(-precision) if d<0
// Note that precision<0 is allowed as input.
func (d Decimal) QuoRem(d2 Decimal, precision int32) (Decimal, Decimal) {
d.ensureInitialized()
d2.ensureInitialized()
if d2.value.Sign() == 0 {
panic("decimal division by 0")
}
scale := -precision
e := int64(d.exp - d2.exp - scale)
if e > math.MaxInt32 || e < math.MinInt32 {
panic("overflow in decimal QuoRem")
}
var aa, bb, expo big.Int
var scalerest int32
// d = a 10^ea
// d2 = b 10^eb
if e < 0 {
aa = *d.value
expo.SetInt64(-e)
bb.Exp(tenInt, &expo, nil)
bb.Mul(d2.value, &bb)
scalerest = d.exp
// now aa = a
// bb = b 10^(scale + eb - ea)
} else {
expo.SetInt64(e)
aa.Exp(tenInt, &expo, nil)
aa.Mul(d.value, &aa)
bb = *d2.value
scalerest = scale + d2.exp
// now aa = a ^ (ea - eb - scale)
// bb = b
}
var q, r big.Int
q.QuoRem(&aa, &bb, &r)
dq := Decimal{value: &q, exp: scale}
dr := Decimal{value: &r, exp: scalerest}
return dq, dr
}
// DivRound divides and rounds to a given precision
// i.e. to an integer multiple of 10^(-precision)
// for a positive quotient digit 5 is rounded up, away from 0
// if the quotient is negative then digit 5 is rounded down, away from 0
// Note that precision<0 is allowed as input.
func (d Decimal) DivRound(d2 Decimal, precision int32) Decimal {
// QuoRem already checks initialization
q, r := d.QuoRem(d2, precision)
// the actual rounding decision is based on comparing r*10^precision and d2/2
// instead compare 2 r 10 ^precision and d2
var rv2 big.Int
rv2.Abs(r.value)
rv2.Lsh(&rv2, 1)
// now rv2 = abs(r.value) * 2
r2 := Decimal{value: &rv2, exp: r.exp + precision}
// r2 is now 2 * r * 10 ^ precision
var c = r2.Cmp(d2.Abs())
if c < 0 {
return q
}
if d.value.Sign()*d2.value.Sign() < 0 {
return q.Sub(New(1, -precision))
}
return q.Add(New(1, -precision))
}
// Mod returns d % d2.
func (d Decimal) Mod(d2 Decimal) Decimal {
quo := d.Div(d2).Truncate(0)
return d.Sub(d2.Mul(quo))
}
// Pow returns d to the power d2
func (d Decimal) Pow(d2 Decimal) Decimal {
var temp Decimal
if d2.IntPart() == 0 {
return NewFromFloat(1)
}
temp = d.Pow(d2.Div(NewFromFloat(2)))
if d2.IntPart()%2 == 0 {
return temp.Mul(temp)
}
if d2.IntPart() > 0 {
return temp.Mul(temp).Mul(d)
}
return temp.Mul(temp).Div(d)
}
// Cmp compares the numbers represented by d and d2 and returns:
//
// -1 if d < d2
// 0 if d == d2
// +1 if d > d2
//
func (d Decimal) Cmp(d2 Decimal) int {
d.ensureInitialized()
d2.ensureInitialized()
if d.exp == d2.exp {
return d.value.Cmp(d2.value)
}
baseExp := min(d.exp, d2.exp)
rd := d.rescale(baseExp)
rd2 := d2.rescale(baseExp)
return rd.value.Cmp(rd2.value)
}
// Equal returns whether the numbers represented by d and d2 are equal.
func (d Decimal) Equal(d2 Decimal) bool {
return d.Cmp(d2) == 0
}
// Equals is deprecated, please use Equal method instead
func (d Decimal) Equals(d2 Decimal) bool {
return d.Equal(d2)
}
// GreaterThan (GT) returns true when d is greater than d2.
func (d Decimal) GreaterThan(d2 Decimal) bool {
return d.Cmp(d2) == 1
}
// GreaterThanOrEqual (GTE) returns true when d is greater than or equal to d2.
func (d Decimal) GreaterThanOrEqual(d2 Decimal) bool {
cmp := d.Cmp(d2)
return cmp == 1 || cmp == 0
}
// LessThan (LT) returns true when d is less than d2.
func (d Decimal) LessThan(d2 Decimal) bool {
return d.Cmp(d2) == -1
}
// LessThanOrEqual (LTE) returns true when d is less than or equal to d2.
func (d Decimal) LessThanOrEqual(d2 Decimal) bool {
cmp := d.Cmp(d2)
return cmp == -1 || cmp == 0
}
// Sign returns:
//
// -1 if d < 0
// 0 if d == 0
// +1 if d > 0
//
func (d Decimal) Sign() int {
if d.value == nil {
return 0
}
return d.value.Sign()
}
// Exponent returns the exponent, or scale component of the decimal.
func (d Decimal) Exponent() int32 {
return d.exp
}
// Coefficient returns the coefficient of the decimal. It is scaled by 10^Exponent()
func (d Decimal) Coefficient() *big.Int {
// we copy the coefficient so that mutating the result does not mutate the
// Decimal.
return big.NewInt(0).Set(d.value)
}
// IntPart returns the integer component of the decimal.
func (d Decimal) IntPart() int64 {
scaledD := d.rescale(0)
return scaledD.value.Int64()
}
// Rat returns a rational number representation of the decimal.
func (d Decimal) Rat() *big.Rat {
d.ensureInitialized()
if d.exp <= 0 {
// NOTE(vadim): must negate after casting to prevent int32 overflow
denom := new(big.Int).Exp(tenInt, big.NewInt(-int64(d.exp)), nil)
return new(big.Rat).SetFrac(d.value, denom)
}
mul := new(big.Int).Exp(tenInt, big.NewInt(int64(d.exp)), nil)
num := new(big.Int).Mul(d.value, mul)
return new(big.Rat).SetFrac(num, oneInt)
}
// Float64 returns the nearest float64 value for d and a bool indicating
// whether f represents d exactly.
// For more details, see the documentation for big.Rat.Float64
func (d Decimal) Float64() (f float64, exact bool) {
return d.Rat().Float64()
}
// String returns the string representation of the decimal
// with the fixed point.
//
// Example:
//
// d := New(-12345, -3)
// println(d.String())
//
// Output:
//
// -12.345
//
func (d Decimal) String() string {
return d.string(true)
}
// StringFixed returns a rounded fixed-point string with places digits after
// the decimal point.
//
// Example:
//
// NewFromFloat(0).StringFixed(2) // output: "0.00"
// NewFromFloat(0).StringFixed(0) // output: "0"
// NewFromFloat(5.45).StringFixed(0) // output: "5"
// NewFromFloat(5.45).StringFixed(1) // output: "5.5"
// NewFromFloat(5.45).StringFixed(2) // output: "5.45"
// NewFromFloat(5.45).StringFixed(3) // output: "5.450"
// NewFromFloat(545).StringFixed(-1) // output: "550"
//
func (d Decimal) StringFixed(places int32) string {
rounded := d.Round(places)
return rounded.string(false)
}
// StringFixedBank returns a banker rounded fixed-point string with places digits
// after the decimal point.
//
// Example:
//
// NewFromFloat(0).StringFixed(2) // output: "0.00"
// NewFromFloat(0).StringFixed(0) // output: "0"
// NewFromFloat(5.45).StringFixed(0) // output: "5"
// NewFromFloat(5.45).StringFixed(1) // output: "5.4"
// NewFromFloat(5.45).StringFixed(2) // output: "5.45"
// NewFromFloat(5.45).StringFixed(3) // output: "5.450"
// NewFromFloat(545).StringFixed(-1) // output: "550"
//
func (d Decimal) StringFixedBank(places int32) string {
rounded := d.RoundBank(places)
return rounded.string(false)
}
// StringFixedCash returns a Swedish/Cash rounded fixed-point string. For
// more details see the documentation at function RoundCash.
func (d Decimal) StringFixedCash(interval uint8) string {
rounded := d.RoundCash(interval)
return rounded.string(false)
}
// Round rounds the decimal to places decimal places.
// If places < 0, it will round the integer part to the nearest 10^(-places).
//
// Example:
//
// NewFromFloat(5.45).Round(1).String() // output: "5.5"
// NewFromFloat(545).Round(-1).String() // output: "550"
//
func (d Decimal) Round(places int32) Decimal {
// truncate to places + 1
ret := d.rescale(-places - 1)
// add sign(d) * 0.5
if ret.value.Sign() < 0 {
ret.value.Sub(ret.value, fiveInt)
} else {
ret.value.Add(ret.value, fiveInt)
}
// floor for positive numbers, ceil for negative numbers
_, m := ret.value.DivMod(ret.value, tenInt, new(big.Int))
ret.exp++
if ret.value.Sign() < 0 && m.Cmp(zeroInt) != 0 {
ret.value.Add(ret.value, oneInt)
}
return ret
}
// RoundBank rounds the decimal to places decimal places.
// If the final digit to round is equidistant from the nearest two integers the
// rounded value is taken as the even number
//
// If places < 0, it will round the integer part to the nearest 10^(-places).
//
// Examples:
//
// NewFromFloat(5.45).Round(1).String() // output: "5.4"
// NewFromFloat(545).Round(-1).String() // output: "540"
// NewFromFloat(5.46).Round(1).String() // output: "5.5"
// NewFromFloat(546).Round(-1).String() // output: "550"
// NewFromFloat(5.55).Round(1).String() // output: "5.6"
// NewFromFloat(555).Round(-1).String() // output: "560"
//
func (d Decimal) RoundBank(places int32) Decimal {
round := d.Round(places)
remainder := d.Sub(round).Abs()
half := New(5, -places-1)
if remainder.Cmp(half) == 0 && round.value.Bit(0) != 0 {
if round.value.Sign() < 0 {
round.value.Add(round.value, oneInt)
} else {
round.value.Sub(round.value, oneInt)
}
}
return round
}
// RoundCash aka Cash/Penny/öre rounding rounds decimal to a specific
// interval. The amount payable for a cash transaction is rounded to the nearest
// multiple of the minimum currency unit available. The following intervals are
// available: 5, 10, 15, 25, 50 and 100; any other number throws a panic.
// 5: 5 cent rounding 3.43 => 3.45
// 10: 10 cent rounding 3.45 => 3.50 (5 gets rounded up)
// 15: 10 cent rounding 3.45 => 3.40 (5 gets rounded down)
// 25: 25 cent rounding 3.41 => 3.50
// 50: 50 cent rounding 3.75 => 4.00
// 100: 100 cent rounding 3.50 => 4.00
// For more details: https://en.wikipedia.org/wiki/Cash_rounding
func (d Decimal) RoundCash(interval uint8) Decimal {
var iVal *big.Int
switch interval {
case 5:
iVal = twentyInt
case 10:
iVal = tenInt
case 15:
if d.exp < 0 {
// TODO: optimize and reduce allocations
orgExp := d.exp
dOne := New(10^-int64(orgExp), orgExp)
d2 := d
d2.exp = 0
if d2.Mod(fiveDec).Equal(Zero) {
d2.exp = orgExp
d2 = d2.Sub(dOne)
d = d2
}
}
iVal = tenInt
case 25:
iVal = fourInt
case 50:
iVal = twoInt
case 100:
iVal = oneInt
default:
panic(fmt.Sprintf("Decimal does not support this Cash rounding interval `%d`. Supported: 5, 10, 15, 25, 50, 100", interval))
}
dVal := Decimal{
value: iVal,
}
// TODO: optimize those calculations to reduce the high allocations (~29 allocs).
return d.Mul(dVal).Round(0).Div(dVal).Truncate(2)
}
// Floor returns the nearest integer value less than or equal to d.
func (d Decimal) Floor() Decimal {
d.ensureInitialized()
if d.exp >= 0 {
return d
}
exp := big.NewInt(10)
// NOTE(vadim): must negate after casting to prevent int32 overflow
exp.Exp(exp, big.NewInt(-int64(d.exp)), nil)
z := new(big.Int).Div(d.value, exp)
return Decimal{value: z, exp: 0}
}
// Ceil returns the nearest integer value greater than or equal to d.
func (d Decimal) Ceil() Decimal {
d.ensureInitialized()
if d.exp >= 0 {
return d
}
exp := big.NewInt(10)
// NOTE(vadim): must negate after casting to prevent int32 overflow
exp.Exp(exp, big.NewInt(-int64(d.exp)), nil)
z, m := new(big.Int).DivMod(d.value, exp, new(big.Int))
if m.Cmp(zeroInt) != 0 {
z.Add(z, oneInt)
}
return Decimal{value: z, exp: 0}
}
// Truncate truncates off digits from the number, without rounding.
//
// NOTE: precision is the last digit that will not be truncated (must be >= 0).
//
// Example:
//
// decimal.NewFromString("123.456").Truncate(2).String() // "123.45"
//
func (d Decimal) Truncate(precision int32) Decimal {
d.ensureInitialized()
if precision >= 0 && -precision > d.exp {
return d.rescale(-precision)
}
return d
}
// UnmarshalJSON implements the json.Unmarshaler interface.
func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error {
if string(decimalBytes) == "null" {
return nil
}
str, err := unquoteIfQuoted(decimalBytes)
if err != nil {
return fmt.Errorf("Error decoding string '%s': %s", decimalBytes, err)
}
decimal, err := NewFromString(str)
*d = decimal
if err != nil {
return fmt.Errorf("Error decoding string '%s': %s", str, err)
}
return nil
}
// MarshalJSON implements the json.Marshaler interface.
func (d Decimal) MarshalJSON() ([]byte, error) {
var str string
if MarshalJSONWithoutQuotes {
str = d.String()
} else {
str = "\"" + d.String() + "\""
}
return []byte(str), nil
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. As a string representation
// is already used when encoding to text, this method stores that string as []byte
func (d *Decimal) UnmarshalBinary(data []byte) error {
// Extract the exponent
d.exp = int32(binary.BigEndian.Uint32(data[:4]))
// Extract the value
d.value = new(big.Int)
return d.value.GobDecode(data[4:])
}
// MarshalBinary implements the encoding.BinaryMarshaler interface.
func (d Decimal) MarshalBinary() (data []byte, err error) {
// Write the exponent first since it's a fixed size
v1 := make([]byte, 4)
binary.BigEndian.PutUint32(v1, uint32(d.exp))
// Add the value
var v2 []byte
if v2, err = d.value.GobEncode(); err != nil {
return
}
// Return the byte array
data = append(v1, v2...)
return
}
// GetBSON implements the bson.Getter interface
func (d Decimal) GetBSON() (interface{}, error) {
// Pass through string to create Mongo Decimal128 type
dec128, err := bson.ParseDecimal128(d.String())
if err != nil {
return nil, err
}
return dec128, nil
}
// SetBSON implements the bson.Setter interface
func (d *Decimal) SetBSON(raw bson.Raw) error {
// Unmarshal as Mongo Decimal128 first then pass through string to obtain Decimal
var dec128 bson.Decimal128
berr := raw.Unmarshal(&dec128)
if berr != nil {
return berr
}
dec, derr := NewFromString(dec128.String())
if derr != nil {
return derr
}
*d = dec
return nil
}
// Scan implements the sql.Scanner interface for database deserialization.
func (d *Decimal) Scan(value interface{}) error {
// first try to see if the data is stored in database as a Numeric datatype
switch v := value.(type) {
case float32:
*d = NewFromFloat(float64(v))
return nil
case float64:
// numeric in sqlite3 sends us float64
*d = NewFromFloat(v)
return nil
case int64:
// at least in sqlite3 when the value is 0 in db, the data is sent
// to us as an int64 instead of a float64 ...
*d = New(v, 0)
return nil
default:
// default is trying to interpret value stored as string
str, err := unquoteIfQuoted(v)
if err != nil {
return err
}
*d, err = NewFromString(str)
return err
}
}
// Value implements the driver.Valuer interface for database serialization.
func (d Decimal) Value() (driver.Value, error) {
return d.String(), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface for XML
// deserialization.
func (d *Decimal) UnmarshalText(text []byte) error {
str := string(text)
dec, err := NewFromString(str)
*d = dec
if err != nil {
return fmt.Errorf("Error decoding string '%s': %s", str, err)
}
return nil
}
// MarshalText implements the encoding.TextMarshaler interface for XML
// serialization.
func (d Decimal) MarshalText() (text []byte, err error) {
return []byte(d.String()), nil
}
// GobEncode implements the gob.GobEncoder interface for gob serialization.
func (d Decimal) GobEncode() ([]byte, error) {
return d.MarshalBinary()
}
// GobDecode implements the gob.GobDecoder interface for gob serialization.
func (d *Decimal) GobDecode(data []byte) error {
return d.UnmarshalBinary(data)
}
// StringScaled first scales the decimal then calls .String() on it.
// NOTE: buggy, unintuitive, and DEPRECATED! Use StringFixed instead.
func (d Decimal) StringScaled(exp int32) string {
return d.rescale(exp).String()
}
func (d Decimal) string(trimTrailingZeros bool) string {
if d.exp >= 0 {
return d.rescale(0).value.String()
}
abs := new(big.Int).Abs(d.value)
str := abs.String()
var intPart, fractionalPart string
// NOTE(vadim): this cast to int will cause bugs if d.exp == INT_MIN
// and you are on a 32-bit machine. Won't fix this super-edge case.
dExpInt := int(d.exp)
if len(str) > -dExpInt {
intPart = str[:len(str)+dExpInt]
fractionalPart = str[len(str)+dExpInt:]
} else {
intPart = "0"
num0s := -dExpInt - len(str)
fractionalPart = strings.Repeat("0", num0s) + str
}
if trimTrailingZeros {
i := len(fractionalPart) - 1
for ; i >= 0; i-- {
if fractionalPart[i] != '0' {
break
}
}
fractionalPart = fractionalPart[:i+1]
}
number := intPart
if len(fractionalPart) > 0 {
number += "." + fractionalPart
}
if d.value.Sign() < 0 {
return "-" + number
}
return number
}
func (d *Decimal) ensureInitialized() {
if d.value == nil {
d.value = new(big.Int)
}
}
// Min returns the smallest Decimal that was passed in the arguments.
//
// To call this function with an array, you must do:
//
// Min(arr[0], arr[1:]...)
//
// This makes it harder to accidentally call Min with 0 arguments.
func Min(first Decimal, rest ...Decimal) Decimal {
ans := first
for _, item := range rest {
if item.Cmp(ans) < 0 {
ans = item
}
}
return ans
}
// Max returns the largest Decimal that was passed in the arguments.
//
// To call this function with an array, you must do:
//
// Max(arr[0], arr[1:]...)
//
// This makes it harder to accidentally call Max with 0 arguments.
func Max(first Decimal, rest ...Decimal) Decimal {
ans := first
for _, item := range rest {
if item.Cmp(ans) > 0 {
ans = item
}
}
return ans
}
// Sum returns the combined total of the provided first and rest Decimals
func Sum(first Decimal, rest ...Decimal) Decimal {
total := first
for _, item := range rest {
total = total.Add(item)
}
return total
}
// Avg returns the average value of the provided first and rest Decimals
func Avg(first Decimal, rest ...Decimal) Decimal {
count := New(int64(len(rest)+1), 0)
sum := Sum(first, rest...)
return sum.Div(count)
}
func min(x, y int32) int32 {
if x >= y {
return y
}
return x
}
func unquoteIfQuoted(value interface{}) (string, error) {
var bytes []byte
switch v := value.(type) {
case string:
bytes = []byte(v)
case []byte:
bytes = v
default:
return "", fmt.Errorf("Could not convert value '%+v' to byte array of type '%T'",
value, value)
}
// If the amount is quoted, strip the quotes
if len(bytes) > 2 && bytes[0] == '"' && bytes[len(bytes)-1] == '"' {
bytes = bytes[1 : len(bytes)-1]
}
return string(bytes), nil
}
// NullDecimal represents a nullable decimal with compatibility for
// scanning null values from the database.
type NullDecimal struct {
Decimal Decimal
Valid bool
}
// Scan implements the sql.Scanner interface for database deserialization.
func (d *NullDecimal) Scan(value interface{}) error {
if value == nil {
d.Valid = false
return nil
}
d.Valid = true
return d.Decimal.Scan(value)
}
// Value implements the driver.Valuer interface for database serialization.
func (d NullDecimal) Value() (driver.Value, error) {
if !d.Valid {
return nil, nil
}
return d.Decimal.Value()
}
// UnmarshalJSON implements the json.Unmarshaler interface.
func (d *NullDecimal) UnmarshalJSON(decimalBytes []byte) error {
if string(decimalBytes) == "null" {
d.Valid = false
return nil
}
d.Valid = true
return d.Decimal.UnmarshalJSON(decimalBytes)
}
// MarshalJSON implements the json.Marshaler interface.
func (d NullDecimal) MarshalJSON() ([]byte, error) {
if !d.Valid {
return []byte("null"), nil
}
return d.Decimal.MarshalJSON()
}