decimal/decimal.go
2015-02-12 12:10:49 -05:00

471 lines
11 KiB
Go

// Package implements an arbitrary precision fixed-point decimal
//
// To use as part of a struct:
//
// type Struct struct {
// Number Decimal
// }
//
// The zero-value of a Decimal is 0, as you would expect.
//
// The best way to create a new Decimal is to use decimal.NewFromString, ex:
//
// n, err := decimal.NewFromString("-123.4567")
// n.String() // output: "-123.4567"
//
// NOTE: this can "only" represent numbers with a maximum of 2^31 digits
// after the decimal point
package decimal
import (
"database/sql/driver"
"fmt"
"math"
"math/big"
"strconv"
"strings"
)
// Number of decimal places in the result when it doesn't divide exactly
//
// Example:
//
// d1 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3)
// d1.String() // output: "0.6666666666666667"
// d2 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(30000)
// d2.String() // output: "0.0000666666666667"
// d3 := decimal.NewFromFloat(20000).Div(decimal.NewFromFloat(3)
// d3.String() // output: "6666.6666666666666667"
// decimal.DivisionPrecision = 3
// d4 := decimal.NewFromFloat(2).Div(decimal.NewFromFloat(3)
// d4.String() // output: "0.667"
//
var DivisionPrecision = 16
// constant, to make computations faster
var Zero = New(0, 1)
var tenInt = big.NewInt(10)
var oneInt = big.NewInt(1)
// Decimal represents a fixed-point decimal. It is immutable.
// number = value * 10 ^ exp
type Decimal struct {
value *big.Int
// NOTE(vadim): this must be an int32, because we cast it to float64 during
// calculations. If exp is 64 bit, we might lose precision.
// If we cared about being able to represent every possible decimal, we
// could make exp a *big.Int but it would hurt performance and numbers
// like that are unrealistic.
exp int32
}
// Returns a new fixed-point decimal, value * 10 ^ exp
func New(value int64, exp int32) Decimal {
return Decimal{
value: big.NewInt(value),
exp: exp,
}
}
// Returns a new Decimal from a string representation
//
// Example:
//
// d, err := NewFromString("-123.45")
// d2, err := NewFromString(".0001")
//
func NewFromString(value string) (Decimal, error) {
var intString string
var exp int32
parts := strings.Split(value, ".")
if len(parts) == 1 {
// There is no decimal point, we can just parse the original string as
// an int
intString = value
exp = 0
} else if len(parts) == 2 {
intString = parts[0] + parts[1]
expInt := -len(parts[1])
if expInt < math.MinInt32 {
// NOTE(vadim): I doubt a string could realistically be this long
return Decimal{}, fmt.Errorf("can't convert %s to decimal: fractional part too long", value)
}
exp = int32(expInt)
} else {
return Decimal{}, fmt.Errorf("can't convert %s to decimal: too many .s", value)
}
dValue := new(big.Int)
_, ok := dValue.SetString(intString, 10)
if !ok {
return Decimal{}, fmt.Errorf("can't convert %s to decimal", value)
}
return Decimal{
value: dValue,
exp: exp,
}, nil
}
// Converts a float64 to Decimal
//
// Example:
//
// NewFromFloat(123.45678901234567).String() // output: "123.4567890123456"
// NewFromFloat(.00000000000000001).String() // output: "0.00000000000000001"
//
// NOTE: this will panic on NaN, +/-inf
func NewFromFloat(value float64) Decimal {
floor := math.Floor(value)
// fast path, where float is an int
if floor == value && !math.IsInf(value, 0) {
return New(int64(value), 0)
}
// slow path: float is a decimal
// HACK(vadim): do this the slow hacky way for now because the logic to
// convert a base-2 float to base-10 properly is not trivial
str := strconv.FormatFloat(value, 'f', -1, 64)
dec, err := NewFromString(str)
if err != nil {
panic(err)
}
return dec
}
// Same as NewFromFloat, except you can choose the number of fractional digits
//
// Example:
//
// NewFromFloatWithExponent(123.456, -2).String() // output: "123.46"
//
func NewFromFloatWithExponent(value float64, exp int32) Decimal {
mul := math.Pow(10, -float64(exp))
floatValue := value * mul
if math.IsNaN(floatValue) || math.IsInf(floatValue, 0) {
panic(fmt.Sprintf("Cannot create a Decimal from %v", floatValue))
}
dValue := big.NewInt(round(floatValue))
return Decimal{
value: dValue,
exp: exp,
}
}
// Rescale returns a rescaled version of the decimal. Returned
// decimal may be less precise if the given exponent is bigger
// than the initial exponent of the Decimal.
// NOTE: this will truncate, NOT round
//
// Example:
//
// d := New(12345, -4)
// d2 := d.rescale(-1)
// d3 := d2.rescale(-4)
// println(d1)
// println(d2)
// println(d3)
//
// Output:
//
// 1.2345
// 1.2
// 1.2000
//
func (d Decimal) rescale(exp int32) Decimal {
d.ensureInitialized()
// NOTE(vadim): must convert exps to float64 before - to prevent overflow
diff := math.Abs(float64(exp) - float64(d.exp))
value := new(big.Int).Set(d.value)
expScale := new(big.Int).Exp(tenInt, big.NewInt(int64(diff)), nil)
if exp > d.exp {
value = value.Quo(value, expScale)
} else if exp < d.exp {
value = value.Mul(value, expScale)
}
return Decimal{
value: value,
exp: exp,
}
}
func (d Decimal) Abs() Decimal {
d.ensureInitialized()
d2Value := new(big.Int).Abs(d.value)
return Decimal{
value: d2Value,
exp: d.exp,
}
}
// Returns d + d2
func (d Decimal) Add(d2 Decimal) Decimal {
baseScale := min(d.exp, d2.exp)
rd := d.rescale(baseScale)
rd2 := d2.rescale(baseScale)
d3Value := new(big.Int).Add(rd.value, rd2.value)
return Decimal{
value: d3Value,
exp: baseScale,
}
}
// Returns d - d2
func (d Decimal) Sub(d2 Decimal) Decimal {
baseScale := min(d.exp, d2.exp)
rd := d.rescale(baseScale)
rd2 := d2.rescale(baseScale)
d3Value := new(big.Int).Sub(rd.value, rd2.value)
return Decimal{
value: d3Value,
exp: baseScale,
}
}
// Returns d * d2
func (d Decimal) Mul(d2 Decimal) Decimal {
d.ensureInitialized()
d2.ensureInitialized()
expInt64 := int64(d.exp) + int64(d2.exp)
if expInt64 > math.MaxInt32 || expInt64 < math.MinInt32 {
// NOTE(vadim): better to panic than give incorrect results, as
// Decimals are usually used for money
panic(fmt.Sprintf("exponent %v overflows an int32!", expInt64))
}
d3Value := new(big.Int).Mul(d.value, d2.value)
return Decimal{
value: d3Value,
exp: int32(expInt64),
}
}
// Returns d / d2. If it doesn't divide exactly,
// the result will have DivisionPrecision digits after the decimal point.
func (d Decimal) Div(d2 Decimal) Decimal {
// NOTE(vadim): division is hard, use Rat to do it
ratNum := d.Rat()
ratDenom := d2.Rat()
quoRat := big.NewRat(0, 1).Quo(ratNum, ratDenom)
// HACK(vadim): converting from Rat to Decimal inefficiently for now
ret, err := NewFromString(quoRat.FloatString(DivisionPrecision))
if err != nil {
panic(err) // this should never happen
}
return ret
}
// Cmp compares x and y and returns -1, 0 or 1
//
// -1 if x < y
// 0 if x == y
// +1 if x > y
//
func (d Decimal) Cmp(d2 Decimal) int {
baseExp := min(d.exp, d2.exp)
rd := d.rescale(baseExp)
rd2 := d2.rescale(baseExp)
return rd.value.Cmp(rd2.value)
}
func (d Decimal) Equals(d2 Decimal) bool {
return d.Cmp(d2) == 0
}
func (d Decimal) Exponent() int32 {
return d.exp
}
func (d Decimal) IntPart() int64 {
scaledD := d.rescale(0)
return scaledD.value.Int64()
}
// Returns a rational number representation of the decimal
func (d Decimal) Rat() *big.Rat {
d.ensureInitialized()
if d.exp <= 0 {
denom := new(big.Int).Exp(tenInt, big.NewInt(int64(-d.exp)), nil)
return new(big.Rat).SetFrac(d.value, denom)
} else {
mul := new(big.Int).Exp(tenInt, big.NewInt(int64(d.exp)), nil)
num := new(big.Int).Mul(d.value, mul)
return new(big.Rat).SetFrac(num, oneInt)
}
}
// Returns the nearest float64 value for d and a bool indicating
// whether f represents d exactly.
// For more details, see the documentation for big.Rat.Float64
func (d Decimal) Float64() (f float64, exact bool) {
return d.Rat().Float64()
}
// String returns the string representation of the decimal
// with the fixed point
//
// Example:
//
// d := New(-12345, -3)
// println(d.String())
//
// Output:
//
// -12.345
//
func (d Decimal) String() string {
if d.exp >= 0 {
return d.rescale(0).value.String()
}
abs := new(big.Int).Abs(d.value)
str := abs.String()
var intPart, fractionalPart string
dExpInt := int(d.exp)
if len(str) > -dExpInt {
intPart = str[:len(str)+dExpInt]
fractionalPart = str[len(str)+dExpInt:]
} else {
intPart = "0"
num0s := -dExpInt - len(str)
fractionalPart = strings.Repeat("0", num0s) + str
}
i := len(fractionalPart) - 1
for ; i >= 0; i-- {
if fractionalPart[i] != '0' {
break
}
}
fractionalPart = fractionalPart[:i+1]
number := intPart
if len(fractionalPart) > 0 {
number += "." + fractionalPart
}
if d.value.Sign() < 0 {
return "-" + number
}
return number
}
// StringScaled first scales the decimal then calls .String() on it.
func (d Decimal) StringScaled(exp int32) string {
return d.rescale(exp).String()
}
func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error {
str, err := unquoteIfQuoted(decimalBytes)
if err != nil {
return fmt.Errorf("Error decoding string '%s': %s", decimalBytes, err)
}
decimal, err := NewFromString(str)
*d = decimal
if err != nil {
return fmt.Errorf("Error decoding string '%s': %s", str, err)
}
return nil
}
func (d Decimal) MarshalJSON() ([]byte, error) {
str := "\"" + d.String() + "\""
return []byte(str), nil
}
// This truncates off digits from the number, without rounding
//
// NOTE: precision is the last digit that will not be truncated (should be >= 0)
//
// decimal.NewFromString("123.456").Truncate(2).String() // "123.45"
//
func (d Decimal) Truncate(precision int32) Decimal {
d.ensureInitialized()
if precision >= 0 && -precision > d.exp {
return d.rescale(-precision)
}
return d
}
// db deserialization
func (d *Decimal) Scan(value interface{}) error {
str, err := unquoteIfQuoted(value)
if err != nil {
return err
}
*d, err = NewFromString(str)
return err
}
// db serialization
func (d Decimal) Value() (driver.Value, error) {
return d.String(), nil
}
// xml deserialization
func (d *Decimal) UnmarshalText(text []byte) error {
str := string(text)
dec, err := NewFromString(str)
*d = dec
if err != nil {
return fmt.Errorf("Error decoding string '%s': %s", str, err)
}
return nil
}
// xml serialization
func (d Decimal) MarshalText() (text []byte, err error) {
return []byte(d.String()), nil
}
func (d *Decimal) ensureInitialized() {
if d.value == nil {
d.value = new(big.Int)
}
}
func min(x, y int32) int32 {
if x >= y {
return y
}
return x
}
func round(n float64) int64 {
if n < 0 {
return int64(n - 0.5)
}
return int64(n + 0.5)
}
func unquoteIfQuoted(value interface{}) (string, error) {
bytes, ok := value.([]byte)
if !ok {
return "", fmt.Errorf("Could not convert value '%+v' to byte array",
value)
}
// If the amount is quoted, strip the quotes
if len(bytes) > 2 && bytes[0] == '"' && bytes[len(bytes)-1] == '"' {
bytes = bytes[1 : len(bytes)-1]
}
return string(bytes), nil
}