diff --git a/decimal.go b/decimal.go index a37a230..4034815 100644 --- a/decimal.go +++ b/decimal.go @@ -1544,6 +1544,72 @@ func (d Decimal) Round(places int32) Decimal { return ret } +// RoundHalfUp rounds the decimal half towards +infinity. +// +// Example: +// +// NewFromFloat(545).RoundHalfUp(-2).String() // output: "500" +// NewFromFloat(500).RoundHalfUp(-2).String() // output: "500" +// NewFromFloat(1.1001).RoundHalfUp(2).String() // output: "1.10" +// NewFromFloat(-1.454).RoundHalfUp(1).String() // output: "-1.4" +// NewFromFloat(-1.464).RoundHalfUp(1).String() // output: "-1.5" +func (d Decimal) RoundHalfUp(places int32) Decimal { + if d.exp == -places { + return d + } + // truncate to places + 1 + ret := d.rescale(-places - 1) + + // add sign(d) * 0.5 + if ret.value.Sign() < 0 { + ret.value.Sub(ret.value, fourInt) + } else { + ret.value.Add(ret.value, fiveInt) + } + + // floor for positive numbers, ceil for negative numbers + _, m := ret.value.DivMod(ret.value, tenInt, new(big.Int)) + ret.exp++ + if ret.value.Sign() < 0 && m.Cmp(zeroInt) != 0 { + ret.value.Add(ret.value, oneInt) + } + + return ret +} + +// RoundHalfDown rounds the decimal half towards -infinity. +// +// Example: +// +// NewFromFloat(550).RoundHalfDown(-2).String() // output: "500" +// NewFromFloat(560).RoundHalfDown(-2).String() // output: "600" +// NewFromFloat(1.1001).RoundHalfDown(2).String() // output: "1.11" +// NewFromFloat(-1.454).RoundHalfDown(1).String() // output: "-1.5" +// NewFromFloat(-1.444).RoundHalfDown(1).String() // output: "-1.4" +func (d Decimal) RoundHalfDown(places int32) Decimal { + if d.exp == -places { + return d + } + // truncate to places + 1 + ret := d.rescale(-places - 1) + + // add sign(d) * 0.5 + if ret.value.Sign() < 0 { + ret.value.Sub(ret.value, fiveInt) + } else { + ret.value.Add(ret.value, fourInt) + } + + // floor for positive numbers, ceil for negative numbers + _, m := ret.value.DivMod(ret.value, tenInt, new(big.Int)) + ret.exp++ + if ret.value.Sign() < 0 && m.Cmp(zeroInt) != 0 { + ret.value.Add(ret.value, oneInt) + } + + return ret +} + // RoundCeil rounds the decimal towards +infinity. // // Example: diff --git a/decimal_test.go b/decimal_test.go index d398f2d..36b4710 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -3647,3 +3647,109 @@ func ExampleNewFromFloat() { //0.123123123123123 //-10000000000000 } + +func TestDecimal_RoundHalfUp(t *testing.T) { + tests := []struct { + name string + d Decimal + places int32 + want Decimal + }{ + { + name: "550", + d: NewFromInt(550), + places: -2, + want: NewFromInt(600), + }, + { + name: "545", + d: NewFromInt(545), + places: -2, + want: NewFromInt(500), + }, + { + name: "500", + d: NewFromInt(500), + places: -2, + want: NewFromInt(500), + }, + { + name: "1.1001", + d: NewFromFloat(1.1001), + places: 2, + want: NewFromFloat(1.10), + }, + { + name: "-1.454", + d: NewFromFloat(-1.454), + places: 1, + want: NewFromFloat(-1.4), + }, + { + name: "-1.464", + d: NewFromFloat(-1.464), + places: 1, + want: NewFromFloat(-1.5), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.d.RoundHalfUp(tt.places); !got.Equal(tt.want) { + t.Errorf("RoundHalfUp() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDecimal_RoundHalfDown(t *testing.T) { + tests := []struct { + name string + d Decimal + places int32 + want Decimal + }{ + { + name: "550", + d: NewFromInt(550), + places: -2, + want: NewFromInt(500), + }, + { + name: "560", + d: NewFromInt(560), + places: -2, + want: NewFromInt(600), + }, + { + name: "500", + d: NewFromInt(500), + places: -2, + want: NewFromInt(500), + }, + { + name: "1.1001", + d: NewFromFloat(1.1001), + places: 2, + want: NewFromFloat(1.10), + }, + { + name: "-1.454", + d: NewFromFloat(-1.454), + places: 1, + want: NewFromFloat(-1.5), + }, + { + name: "-1.444", + d: NewFromFloat(-1.444), + places: 1, + want: NewFromFloat(-1.4), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.d.RoundHalfDown(tt.places); !got.Equal(tt.want) { + t.Errorf("RoundHalfDown() = %v, want %v", got, tt.want) + } + }) + } +}