diff --git a/src/bigint.rs b/src/bigint.rs index 551ecc6..c47bf88 100644 --- a/src/bigint.rs +++ b/src/bigint.rs @@ -77,6 +77,11 @@ impl BigInt { trailing_zeros += 1; } + // Always keep at least one digit + if trailing_zeros == current_len { + trailing_zeros -= 1; + } + let new_len = current_len - trailing_zeros; self.inner.truncate(new_len); @@ -94,6 +99,23 @@ impl BigInt { pub fn from_str_radix(s: T, radix: usize) -> BigInt { todo!(); } + + fn get_digit_count(a: &Self, b: &Self) -> usize { + let a_digits = a.inner.len(); + let b_digits = b.inner.len(); + + let digits = if b_digits > a_digits { + b_digits + } else { + a_digits + }; + + if digits > 0 { + digits + } else { + 1 + } + } } impl Add for BigInt { @@ -107,21 +129,20 @@ impl Add for BigInt { let mut out = BigInt::new_empty(); - let u_digits = self.inner.len(); - let v_digits = rhs.inner.len(); - - let digits = if v_digits > u_digits { - v_digits - } else { - u_digits - }; + let digits = Self::get_digit_count(&self, &rhs) + 1; let mut carry = 0usize; - for i in 0..=digits { + for i in 0..digits { let a = *self.inner.get(i).unwrap_or(&0usize); let b = *rhs.inner.get(i).unwrap_or(&0usize); let (res, overflowed) = a.overflowing_add(b); + if res == 0 && !overflowed { + out.inner.push(res + carry); + carry = 0; + continue; + } + if overflowed { out.inner.push(res + carry); carry = 1; @@ -134,6 +155,8 @@ impl Add for BigInt { } } + out.trim_zeros(); + out } } @@ -145,14 +168,7 @@ impl Sub for BigInt { // @TODO: handle signs let mut out = BigInt::new_empty(); - let u_digits = self.inner.len(); - let v_digits = rhs.inner.len(); - - let digits = if v_digits > u_digits { - v_digits - } else { - u_digits - }; + let digits = Self::get_digit_count(&self, &rhs); let mut borrow = 0usize; for i in 0..digits { @@ -423,19 +439,38 @@ mod tests { diff.inner[0], core::usize::MAX - 1 ); + } - let a = BigInt { + #[test] + fn test_sub_assign() { + let mut a = BigInt { inner: vec![1,0,1], sign: Positive }; let b = BigInt::from(2); - let diff = a - b; + + a -= b; + assert_eq!( - diff.inner, + a.inner, vec![core::usize::MAX, core::usize::MAX] ); } + #[test] + fn test_zeros() { + let a = BigInt::new(); + let b = BigInt::new(); + + // let c = a.clone() - b.clone(); + // assert_eq!(a.clone(), b.clone()); + // assert_eq!(c, a.clone()); + + let c = a.clone() + b.clone(); + assert_eq!(a.clone(), b.clone()); + assert_eq!(c, a.clone()); + } + #[test] fn test_not() { let a = BigInt::from(0u8);