mirror of
https://github.com/shopspring/decimal.git
synced 2024-11-22 20:40:48 +01:00
Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
b77bff862c
3 changed files with 149 additions and 8 deletions
|
@ -1,4 +1,4 @@
|
||||||
# decimal [![Build Status](https://travis-ci.org/shopspring/decimal.png?branch=master)](https://travis-ci.org/shopspring/decimal)
|
# decimal [![Build Status](https://travis-ci.org/shopspring/decimal.png?branch=master)](https://travis-ci.org/shopspring/decimal) [![BADGINATOR](https://badginator.herokuapp.com/shopspring/decimal.svg?image_analysis=1)](https://github.com/defunctzombie/badginator)
|
||||||
|
|
||||||
Arbitrary-precision fixed-point decimal numbers in go.
|
Arbitrary-precision fixed-point decimal numbers in go.
|
||||||
|
|
||||||
|
|
41
decimal.go
41
decimal.go
|
@ -140,7 +140,7 @@ func NewFromFloat(value float64) Decimal {
|
||||||
floor := math.Floor(value)
|
floor := math.Floor(value)
|
||||||
|
|
||||||
// fast path, where float is an int
|
// fast path, where float is an int
|
||||||
if floor == value && !math.IsInf(value, 0) {
|
if floor == value && value <= math.MaxInt64 && value >= math.MinInt64 {
|
||||||
return New(int64(value), 0)
|
return New(int64(value), 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -350,6 +350,12 @@ func (d Decimal) DivRound(d2 Decimal, precision int32) Decimal {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mod returns d % d2.
|
||||||
|
func (d Decimal) Mod(d2 Decimal) Decimal {
|
||||||
|
quo := d.Div(d2).Truncate(0)
|
||||||
|
return d.Sub(d2.Mul(quo))
|
||||||
|
}
|
||||||
|
|
||||||
// Cmp compares the numbers represented by d and d2 and returns:
|
// Cmp compares the numbers represented by d and d2 and returns:
|
||||||
//
|
//
|
||||||
// -1 if d < d2
|
// -1 if d < d2
|
||||||
|
@ -357,6 +363,13 @@ func (d Decimal) DivRound(d2 Decimal, precision int32) Decimal {
|
||||||
// +1 if d > d2
|
// +1 if d > d2
|
||||||
//
|
//
|
||||||
func (d Decimal) Cmp(d2 Decimal) int {
|
func (d Decimal) Cmp(d2 Decimal) int {
|
||||||
|
d.ensureInitialized()
|
||||||
|
d2.ensureInitialized()
|
||||||
|
|
||||||
|
if d.exp == d2.exp {
|
||||||
|
return d.value.Cmp(d2.value)
|
||||||
|
}
|
||||||
|
|
||||||
baseExp := min(d.exp, d2.exp)
|
baseExp := min(d.exp, d2.exp)
|
||||||
rd := d.rescale(baseExp)
|
rd := d.rescale(baseExp)
|
||||||
rd2 := d2.rescale(baseExp)
|
rd2 := d2.rescale(baseExp)
|
||||||
|
@ -532,13 +545,29 @@ func (d Decimal) MarshalJSON() ([]byte, error) {
|
||||||
|
|
||||||
// Scan implements the sql.Scanner interface for database deserialization.
|
// Scan implements the sql.Scanner interface for database deserialization.
|
||||||
func (d *Decimal) Scan(value interface{}) error {
|
func (d *Decimal) Scan(value interface{}) error {
|
||||||
str, err := unquoteIfQuoted(value)
|
// first try to see if the data is stored in database as a Numeric datatype
|
||||||
if err != nil {
|
switch v := value.(type) {
|
||||||
|
|
||||||
|
case float64:
|
||||||
|
// numeric in sqlite3 sends us float64
|
||||||
|
*d = NewFromFloat(v)
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case int64:
|
||||||
|
// at least in sqlite3 when the value is 0 in db, the data is sent
|
||||||
|
// to us as an int64 instead of a float64 ...
|
||||||
|
*d = New(v, 0)
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// default is trying to interpret value stored as string
|
||||||
|
str, err := unquoteIfQuoted(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*d, err = NewFromString(str)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
*d, err = NewFromString(str)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the driver.Valuer interface for database serialization.
|
// Value implements the driver.Valuer interface for database serialization.
|
||||||
|
|
114
decimal_test.go
114
decimal_test.go
|
@ -4,7 +4,8 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"math"
|
"math"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -32,6 +33,7 @@ var testTable = map[float64]string{
|
||||||
.1000000000000003: "0.1000000000000003",
|
.1000000000000003: "0.1000000000000003",
|
||||||
.1000000000000005: "0.1000000000000005",
|
.1000000000000005: "0.1000000000000005",
|
||||||
.1000000000000008: "0.1000000000000008",
|
.1000000000000008: "0.1000000000000008",
|
||||||
|
1e25: "10000000000000000000000000",
|
||||||
}
|
}
|
||||||
|
|
||||||
var testTableScientificNotation = map[string]string{
|
var testTableScientificNotation = map[string]string{
|
||||||
|
@ -941,6 +943,39 @@ func TestDecimal_DivRound2(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDecimal_Mod(t *testing.T) {
|
||||||
|
type Inp struct {
|
||||||
|
a string
|
||||||
|
b string
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs := map[Inp]string{
|
||||||
|
Inp{"3", "2"}: "1",
|
||||||
|
Inp{"3451204593", "2454495034"}: "996709559",
|
||||||
|
Inp{"24544.95034", ".3451204593"}: "0.3283950433",
|
||||||
|
Inp{".1", ".1"}: "0",
|
||||||
|
Inp{"0", "1.001"}: "0",
|
||||||
|
Inp{"-7.5", "2"}: "-1.5",
|
||||||
|
Inp{"7.5", "-2"}: "1.5",
|
||||||
|
Inp{"-7.5", "-2"}: "-1.5",
|
||||||
|
}
|
||||||
|
|
||||||
|
for inp, res := range inputs {
|
||||||
|
a, err := NewFromString(inp.a)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
b, err := NewFromString(inp.b)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
c := a.Mod(b)
|
||||||
|
if c.String() != res {
|
||||||
|
t.Errorf("expected %s, got %s", res, c.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDecimal_Overflow(t *testing.T) {
|
func TestDecimal_Overflow(t *testing.T) {
|
||||||
if !didPanic(func() { New(1, math.MinInt32).Mul(New(1, math.MinInt32)) }) {
|
if !didPanic(func() { New(1, math.MinInt32).Mul(New(1, math.MinInt32)) }) {
|
||||||
t.Fatalf("should have gotten an overflow panic")
|
t.Fatalf("should have gotten an overflow panic")
|
||||||
|
@ -1071,6 +1106,67 @@ func TestDecimal_Max(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDecimal_Scan(t *testing.T) {
|
||||||
|
// test the Scan method the implements the
|
||||||
|
// sql.Scanner interface
|
||||||
|
// check for the for different type of values
|
||||||
|
// that are possible to be received from the database
|
||||||
|
// drivers
|
||||||
|
|
||||||
|
// in normal operations the db driver (sqlite at least)
|
||||||
|
// will return an int64 if you specified a numeric format
|
||||||
|
a := Decimal{}
|
||||||
|
dbvalue := float64(54.33)
|
||||||
|
expected := NewFromFloat(dbvalue)
|
||||||
|
|
||||||
|
err := a.Scan(dbvalue)
|
||||||
|
if err != nil {
|
||||||
|
// Scan failed... no need to test result value
|
||||||
|
t.Errorf("a.Scan(54.33) failed with message: %s", err)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Scan suceeded... test resulting values
|
||||||
|
if !a.Equals(expected) {
|
||||||
|
t.Errorf("%s does not equal to %s", a, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// at least SQLite returns an int64 when 0 is stored in the db
|
||||||
|
// and you specified a numeric format on the schema
|
||||||
|
dbvalue_int := int64(0)
|
||||||
|
expected = New(dbvalue_int, 0)
|
||||||
|
|
||||||
|
err = a.Scan(dbvalue_int)
|
||||||
|
if err != nil {
|
||||||
|
// Scan failed... no need to test result value
|
||||||
|
t.Errorf("a.Scan(0) failed with message: %s", err)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Scan suceeded... test resulting values
|
||||||
|
if !a.Equals(expected) {
|
||||||
|
t.Errorf("%s does not equal to %s", a, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// in case you specified a varchar in your SQL schema,
|
||||||
|
// the database driver will return byte slice []byte
|
||||||
|
value_str := "535.666"
|
||||||
|
dbvalue_str := []byte(value_str)
|
||||||
|
expected, err = NewFromString(value_str)
|
||||||
|
|
||||||
|
err = a.Scan(dbvalue_str)
|
||||||
|
if err != nil {
|
||||||
|
// Scan failed... no need to test result value
|
||||||
|
t.Errorf("a.Scan('535.666') failed with message: %s", err)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Scan suceeded... test resulting values
|
||||||
|
if !a.Equals(expected) {
|
||||||
|
t.Errorf("%s does not equal to %s", a, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// old tests after this line
|
// old tests after this line
|
||||||
|
|
||||||
func TestDecimal_Scale(t *testing.T) {
|
func TestDecimal_Scale(t *testing.T) {
|
||||||
|
@ -1153,3 +1249,19 @@ func didPanic(f func()) bool {
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DecimalSlice []Decimal
|
||||||
|
|
||||||
|
func (p DecimalSlice) Len() int { return len(p) }
|
||||||
|
func (p DecimalSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||||
|
func (p DecimalSlice) Less(i, j int) bool { return p[i].Cmp(p[j]) < 0 }
|
||||||
|
func Benchmark_Cmp(b *testing.B) {
|
||||||
|
decimals := DecimalSlice([]Decimal{})
|
||||||
|
for i := 0; i < 1000000; i++ {
|
||||||
|
decimals = append(decimals, New(int64(i), 0))
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
sort.Sort(decimals)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue