diff --git a/decimal.go b/decimal.go index 4034815..88a1eac 100644 --- a/decimal.go +++ b/decimal.go @@ -1513,7 +1513,7 @@ func (d Decimal) StringFixedCash(interval uint8) string { return rounded.string(false) } -// Round rounds the decimal to places decimal places. +// Round rounds the decimal to places decimal places (half away from zero). // If places < 0, it will round the integer part to the nearest 10^(-places). // // Example: @@ -1544,6 +1544,37 @@ func (d Decimal) Round(places int32) Decimal { return ret } +// RoundHalfTowardZero rounds the decimal to places decimal places (half toward zero). +// If places < 0, it will round the integer part to the nearest 10^(-places). +// +// Example: +// +// NewFromFloat(5.45).RoundHalfTowardZero(1).String() // output: "5.4" +// NewFromFloat(545).RoundHalfTowardZero(-1).String() // output: "540" +func (d Decimal) RoundHalfTowardZero(places int32) Decimal { + if d.exp == -places { + return d + } + // truncate to places + 1 + ret := d.rescale(-places - 1) + + // add sign(d) * 0.4 + if ret.value.Sign() < 0 { + ret.value.Sub(ret.value, fourInt) + } else { + ret.value.Add(ret.value, fourInt) + } + + // floor for positive numbers, ceil for negative numbers + _, m := ret.value.DivMod(ret.value, tenInt, new(big.Int)) + ret.exp++ + if ret.value.Sign() < 0 && m.Cmp(zeroInt) != 0 { + ret.value.Add(ret.value, oneInt) + } + + return ret +} + // RoundHalfUp rounds the decimal half towards +infinity. // // Example: @@ -1560,7 +1591,7 @@ func (d Decimal) RoundHalfUp(places int32) Decimal { // truncate to places + 1 ret := d.rescale(-places - 1) - // add sign(d) * 0.5 + // add sign(d) * 0.5 if sign(d) >= 0 else sign(d) * 0.4 if ret.value.Sign() < 0 { ret.value.Sub(ret.value, fourInt) } else { @@ -1593,7 +1624,7 @@ func (d Decimal) RoundHalfDown(places int32) Decimal { // truncate to places + 1 ret := d.rescale(-places - 1) - // add sign(d) * 0.5 + // add sign(d) * 0.5 if sign(d) < 0 else sign(d) * 0.4 if ret.value.Sign() < 0 { ret.value.Sub(ret.value, fiveInt) } else { diff --git a/decimal_test.go b/decimal_test.go index 36b4710..593844e 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -3753,3 +3753,32 @@ func TestDecimal_RoundHalfDown(t *testing.T) { }) } } + +func TestDecimal_RoundHalfTowardZero(t *testing.T) { + tests := []struct { + name string + d Decimal + places int32 + want Decimal + }{ + { + name: "5.45", + d: NewFromFloat(5.45), + places: 1, + want: NewFromFloat(5.4), + }, + { + name: "545", + d: NewFromInt(545), + places: -1, + want: NewFromInt(540), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.d.RoundHalfTowardZero(tt.places); !got.Equal(tt.want) { + t.Errorf("RoundHalfDown() = %v, want %v", got, tt.want) + } + }) + } +}