// Package implements an arbitrary precision fixed-point decimal // // To use as part of a struct: // // type Struct struct { // Number Decimal // } // // The zero-value of a Decimal is 0, as you would expect. // // The best way to create a new Decimal is to use decimal.NewFromString, ex: // // n, err := decimal.NewFromString("-123.4567") // n.String() // output: "-123.4567" // // NOTE: this can "only" represent numbers with a maximum of 2^31 digits // after the decimal point package decimal import ( "database/sql/driver" "fmt" "math" "math/big" "strconv" "strings" ) // Number of decimal places in the result when it doesn't divide exactly // // Example: // // d1 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3) // d1.String() // output: "0.6666666666666667" // d2 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(30000) // d2.String() // output: "0.0000666666666667" // d3 := decimal.NewFromFloat(20000).Div(decimal.NewFromFloat(3) // d3.String() // output: "6666.6666666666666667" // decimal.DivisionPrecision = 3 // d4 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3) // d4.String() // output: "0.667" // var DivisionPrecision = 16 // constant, to make computations faster var Zero = New(0, 1) var tenInt = big.NewInt(10) var oneInt = big.NewInt(1) // Decimal represents a fixed-point decimal. It is immutable. // number = value * 10 ^ exp type Decimal struct { value *big.Int // NOTE(vadim): this must be an int32, because we cast it to float64 during // calculations. If exp is 64 bit, we might lose precision. // If we cared about being able to represent every possible decimal, we // could make exp a *big.Int but it would hurt performance and numbers // like that are unrealistic. exp int32 } // Returns a new fixed-point decimal, value * 10 ^ exp func New(value int64, exp int32) Decimal { return Decimal{ value: big.NewInt(value), exp: exp, } } // Returns a new Decimal from a string representation // // Example: // // d, err := NewFromString("-123.45") // d2, err := NewFromString(".0001") // func NewFromString(value string) (Decimal, error) { var intString string var exp int32 parts := strings.Split(value, ".") if len(parts) == 1 { // There is no decimal point, we can just parse the original string as // an int intString = value exp = 0 } else if len(parts) == 2 { intString = parts[0] + parts[1] expInt := -len(parts[1]) if expInt < math.MinInt32 { // NOTE(vadim): I doubt a string could realistically be this long return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", value) } exp = int32(expInt) } else { return Decimal{}, fmt.Errorf("can't convert %s to decimal: too many .s", value) } dValue := new(big.Int) _, ok := dValue.SetString(intString, 10) if !ok { return Decimal{}, fmt.Errorf("can't convert %s to decimal", value) } return Decimal{ value: dValue, exp: exp, }, nil } // Converts a float64 to Decimal // // Example: // // NewFromFloat(123.45678901234567).String() // output: "123.4567890123456" // NewFromFloat(.00000000000000001).String() // output: "0.00000000000000001" // // NOTE: this will panic on NaN, +/-inf func NewFromFloat(value float64) Decimal { floor := math.Floor(value) // fast path, where float is an int if floor == value && !math.IsInf(value, 0) { return New(int64(value), 0) } // slow path: float is a decimal // HACK(vadim): do this the slow hacky way for now because the logic to // convert a base-2 float to base-10 properly is not trivial str := strconv.FormatFloat(value, 'f', -1, 64) dec, err := NewFromString(str) if err != nil { panic(err) } return dec } // Same as NewFromFloat, except you can choose the number of fractional digits // // Example: // // NewFromFloatWithExponent(123.456, -2).String() // output: "123.46" // func NewFromFloatWithExponent(value float64, exp int32) Decimal { mul := math.Pow(10, -float64(exp)) floatValue := value * mul if math.IsNaN(floatValue) || math.IsInf(floatValue, 0) { panic(fmt.Sprintf("Cannot create a Decimal from %v", floatValue)) } dValue := big.NewInt(round(floatValue)) return Decimal{ value: dValue, exp: exp, } } // Rescale returns a rescaled version of the decimal. Returned // decimal may be less precise if the given exponent is bigger // than the initial exponent of the Decimal. // NOTE: this will truncate, NOT round // // Example: // // d := New(12345, -4) // d2 := d.rescale(-1) // d3 := d2.rescale(-4) // println(d1) // println(d2) // println(d3) // // Output: // // 1.2345 // 1.2 // 1.2000 // func (d Decimal) rescale(exp int32) Decimal { d.ensureInitialized() // NOTE(vadim): must convert exps to float64 before - to prevent overflow diff := math.Abs(float64(exp) - float64(d.exp)) value := new(big.Int).Set(d.value) expScale := new(big.Int).Exp(tenInt, big.NewInt(int64(diff)), nil) if exp > d.exp { value = value.Quo(value, expScale) } else if exp < d.exp { value = value.Mul(value, expScale) } return Decimal{ value: value, exp: exp, } } func (d Decimal) Abs() Decimal { d.ensureInitialized() d2Value := new(big.Int).Abs(d.value) return Decimal{ value: d2Value, exp: d.exp, } } // Returns d + d2 func (d Decimal) Add(d2 Decimal) Decimal { baseScale := min(d.exp, d2.exp) rd := d.rescale(baseScale) rd2 := d2.rescale(baseScale) d3Value := new(big.Int).Add(rd.value, rd2.value) return Decimal{ value: d3Value, exp: baseScale, } } // Returns d - d2 func (d Decimal) Sub(d2 Decimal) Decimal { baseScale := min(d.exp, d2.exp) rd := d.rescale(baseScale) rd2 := d2.rescale(baseScale) d3Value := new(big.Int).Sub(rd.value, rd2.value) return Decimal{ value: d3Value, exp: baseScale, } } // Returns d * d2 func (d Decimal) Mul(d2 Decimal) Decimal { d.ensureInitialized() d2.ensureInitialized() expInt64 := int64(d.exp) + int64(d2.exp) if expInt64 > math.MaxInt32 || expInt64 < math.MinInt32 { // NOTE(vadim): better to panic than give incorrect results, as // Decimals are usually used for money panic(fmt.Sprintf("exponent %v overflows an int32!", expInt64)) } d3Value := new(big.Int).Mul(d.value, d2.value) return Decimal{ value: d3Value, exp: int32(expInt64), } } // Returns d / d2. If it doesn't divide exactly, // the result will have DivisionPrecision digits after the decimal point. func (d Decimal) Div(d2 Decimal) Decimal { // NOTE(vadim): division is hard, use Rat to do it ratNum := d.Rat() ratDenom := d2.Rat() quoRat := big.NewRat(0, 1).Quo(ratNum, ratDenom) // HACK(vadim): converting from Rat to Decimal inefficiently for now ret, err := NewFromString(quoRat.FloatString(DivisionPrecision)) if err != nil { panic(err) // this should never happen } return ret } // Cmp compares x and y and returns -1, 0 or 1 // // -1 if x < y // 0 if x == y // +1 if x > y // func (d Decimal) Cmp(d2 Decimal) int { baseExp := min(d.exp, d2.exp) rd := d.rescale(baseExp) rd2 := d2.rescale(baseExp) return rd.value.Cmp(rd2.value) } func (d Decimal) Equals(d2 Decimal) bool { return d.Cmp(d2) == 0 } func (d Decimal) Exponent() int32 { return d.exp } func (d Decimal) IntPart() int64 { scaledD := d.rescale(0) return scaledD.value.Int64() } // Returns a rational number representation of the decimal func (d Decimal) Rat() *big.Rat { d.ensureInitialized() if d.exp <= 0 { denom := new(big.Int).Exp(tenInt, big.NewInt(int64(-d.exp)), nil) return new(big.Rat).SetFrac(d.value, denom) } else { mul := new(big.Int).Exp(tenInt, big.NewInt(int64(d.exp)), nil) num := new(big.Int).Mul(d.value, mul) return new(big.Rat).SetFrac(num, oneInt) } } // Returns the nearest float64 value for d and a bool indicating // whether f represents d exactly. // For more details, see the documentation for big.Rat.Float64 func (d Decimal) Float64() (f float64, exact bool) { return d.Rat().Float64() } // String returns the string representation of the decimal // with the fixed point // // Example: // // d := New(-12345, -3) // println(d.String()) // // Output: // // -12.345 // func (d Decimal) String() string { if d.exp >= 0 { return d.rescale(0).value.String() } abs := new(big.Int).Abs(d.value) str := abs.String() var intPart, fractionalPart string dExpInt := int(d.exp) if len(str) > -dExpInt { intPart = str[:len(str)+dExpInt] fractionalPart = str[len(str)+dExpInt:] } else { intPart = "0" num0s := -dExpInt - len(str) fractionalPart = strings.Repeat("0", num0s) + str } i := len(fractionalPart) - 1 for ; i >= 0; i-- { if fractionalPart[i] != '0' { break } } fractionalPart = fractionalPart[:i+1] number := intPart if len(fractionalPart) > 0 { number += "." + fractionalPart } if d.value.Sign() < 0 { return "-" + number } return number } // StringScaled first scales the decimal then calls .String() on it. func (d Decimal) StringScaled(exp int32) string { return d.rescale(exp).String() } func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error { str, err := unquoteIfQuoted(decimalBytes) if err != nil { return fmt.Errorf("Error decoding string '%s': %s", decimalBytes, err) } decimal, err := NewFromString(str) *d = decimal if err != nil { return fmt.Errorf("Error decoding string '%s': %s", str, err) } return nil } func (d Decimal) MarshalJSON() ([]byte, error) { str := "\"" + d.String() + "\"" return []byte(str), nil } // This truncates off digits from the number, without rounding // // NOTE: precision is the last digit that will not be truncated (should be >= 0) // // decimal.NewFromString("123.456").Truncate(2).String() // "123.45" // func (d Decimal) Truncate(precision int32) Decimal { d.ensureInitialized() if precision >= 0 && -precision > d.exp { return d.rescale(-precision) } return d } // db deserialization func (d *Decimal) Scan(value interface{}) error { str, err := unquoteIfQuoted(value) if err != nil { return err } *d, err = NewFromString(str) return err } // db serialization func (d Decimal) Value() (driver.Value, error) { return d.String(), nil } // xml deserialization func (d *Decimal) UnmarshalText(text []byte) error { str := string(text) dec, err := NewFromString(str) *d = dec if err != nil { return fmt.Errorf("Error decoding string '%s': %s", str, err) } return nil } // xml serialization func (d Decimal) MarshalText() (text []byte, err error) { return []byte(d.String()), nil } func (d *Decimal) ensureInitialized() { if d.value == nil { d.value = new(big.Int) } } func min(x, y int32) int32 { if x >= y { return y } return x } func round(n float64) int64 { if n < 0 { return int64(n - 0.5) } return int64(n + 0.5) } func unquoteIfQuoted(value interface{}) (string, error) { bytes, ok := value.([]byte) if !ok { return "", fmt.Errorf("Could not convert value '%+v' to byte array", value) } // If the amount is quoted, strip the quotes if len(bytes) > 2 && bytes[0] == '"' && bytes[len(bytes)-1] == '"' { bytes = bytes[1 : len(bytes)-1] } return string(bytes), nil }