arithmetic.gno

12.12 Kb ยท 472 lines
  1// arithmetic provides arithmetic operations for Uint objects.
  2// This includes basic binary operations such as addition, subtraction, multiplication, division, and modulo operations
  3// as well as overflow checks, and negation. These functions are essential for numeric
  4// calculations using 256-bit unsigned integers.
  5package uint256
  6
  7import (
  8	"math/bits"
  9)
 10
 11// Add sets z to the sum x+y
 12func (z *Uint) Add(x, y *Uint) *Uint {
 13	var carry uint64
 14	z.arr[0], carry = bits.Add64(x.arr[0], y.arr[0], 0)
 15	z.arr[1], carry = bits.Add64(x.arr[1], y.arr[1], carry)
 16	z.arr[2], carry = bits.Add64(x.arr[2], y.arr[2], carry)
 17	z.arr[3], _ = bits.Add64(x.arr[3], y.arr[3], carry)
 18	return z
 19}
 20
 21// AddOverflow sets z to the sum x+y, and returns z and whether overflow occurred
 22func (z *Uint) AddOverflow(x, y *Uint) (*Uint, bool) {
 23	var carry uint64
 24	z.arr[0], carry = bits.Add64(x.arr[0], y.arr[0], 0)
 25	z.arr[1], carry = bits.Add64(x.arr[1], y.arr[1], carry)
 26	z.arr[2], carry = bits.Add64(x.arr[2], y.arr[2], carry)
 27	z.arr[3], carry = bits.Add64(x.arr[3], y.arr[3], carry)
 28	return z, carry != 0
 29}
 30
 31// Sub sets z to the difference x-y
 32func (z *Uint) Sub(x, y *Uint) *Uint {
 33	var carry uint64
 34	z.arr[0], carry = bits.Sub64(x.arr[0], y.arr[0], 0)
 35	z.arr[1], carry = bits.Sub64(x.arr[1], y.arr[1], carry)
 36	z.arr[2], carry = bits.Sub64(x.arr[2], y.arr[2], carry)
 37	z.arr[3], _ = bits.Sub64(x.arr[3], y.arr[3], carry)
 38	return z
 39}
 40
 41// SubOverflow sets z to the difference x-y and returns z and true if the operation underflowed
 42func (z *Uint) SubOverflow(x, y *Uint) (*Uint, bool) {
 43	var carry uint64
 44	z.arr[0], carry = bits.Sub64(x.arr[0], y.arr[0], 0)
 45	z.arr[1], carry = bits.Sub64(x.arr[1], y.arr[1], carry)
 46	z.arr[2], carry = bits.Sub64(x.arr[2], y.arr[2], carry)
 47	z.arr[3], carry = bits.Sub64(x.arr[3], y.arr[3], carry)
 48	return z, carry != 0
 49}
 50
 51// Neg returns -x mod 2^256.
 52func (z *Uint) Neg(x *Uint) *Uint {
 53	return z.Sub(new(Uint), x)
 54}
 55
 56// commented out for possible overflow
 57// Mul sets z to the product x*y
 58func (z *Uint) Mul(x, y *Uint) *Uint {
 59	var (
 60		res              Uint
 61		carry            uint64
 62		res1, res2, res3 uint64
 63	)
 64
 65	carry, res.arr[0] = bits.Mul64(x.arr[0], y.arr[0])
 66	carry, res1 = umulHop(carry, x.arr[1], y.arr[0])
 67	carry, res2 = umulHop(carry, x.arr[2], y.arr[0])
 68	res3 = x.arr[3]*y.arr[0] + carry
 69
 70	carry, res.arr[1] = umulHop(res1, x.arr[0], y.arr[1])
 71	carry, res2 = umulStep(res2, x.arr[1], y.arr[1], carry)
 72	res3 = res3 + x.arr[2]*y.arr[1] + carry
 73
 74	carry, res.arr[2] = umulHop(res2, x.arr[0], y.arr[2])
 75	res3 = res3 + x.arr[1]*y.arr[2] + carry
 76
 77	res.arr[3] = res3 + x.arr[0]*y.arr[3]
 78
 79	return z.Set(&res)
 80}
 81
 82// MulOverflow sets z to the product x*y, and returns z and  whether overflow occurred
 83func (z *Uint) MulOverflow(x, y *Uint) (*Uint, bool) {
 84	p := umul(x, y)
 85	copy(z.arr[:], p[:4])
 86	return z, (p[4] | p[5] | p[6] | p[7]) != 0
 87}
 88
 89// commented out for possible overflow
 90// Div sets z to the quotient x/y for returns z.
 91// If y == 0, z is set to 0
 92func (z *Uint) Div(x, y *Uint) *Uint {
 93	if y.IsZero() || y.Gt(x) {
 94		return z.Clear()
 95	}
 96	if x.Eq(y) {
 97		return z.SetOne()
 98	}
 99	// Shortcut some cases
100	if x.IsUint64() {
101		return z.SetUint64(x.Uint64() / y.Uint64())
102	}
103
104	// At this point, we know
105	// x/y ; x > y > 0
106
107	var quot Uint
108	udivrem(quot.arr[:], x.arr[:], y)
109	return z.Set(&quot)
110}
111
112// MulMod calculates the modulo-m multiplication of x and y and
113// returns z.
114// If m == 0, z is set to 0 (OBS: differs from the big.Int)
115func (z *Uint) MulMod(x, y, m *Uint) *Uint {
116	if x.IsZero() || y.IsZero() || m.IsZero() {
117		return z.Clear()
118	}
119	p := umul(x, y)
120
121	if m.arr[3] != 0 {
122		mu := Reciprocal(m)
123		r := reduce4(p, m, mu)
124		return z.Set(&r)
125	}
126
127	var (
128		pl Uint
129		ph Uint
130	)
131
132	pl = Uint{arr: [4]uint64{p[0], p[1], p[2], p[3]}}
133	ph = Uint{arr: [4]uint64{p[4], p[5], p[6], p[7]}}
134
135	// If the multiplication is within 256 bits use Mod().
136	if ph.IsZero() {
137		return z.Mod(&pl, m)
138	}
139
140	var quot [8]uint64
141	rem := udivrem(quot[:], p[:], m)
142	return z.Set(&rem)
143}
144
145// Mod sets z to the modulus x%y for y != 0 and returns z.
146// If y == 0, z is set to 0 (OBS: differs from the big.Uint)
147func (z *Uint) Mod(x, y *Uint) *Uint {
148	if x.IsZero() || y.IsZero() {
149		return z.Clear()
150	}
151	switch x.Cmp(y) {
152	case -1:
153		// x < y
154		copy(z.arr[:], x.arr[:])
155		return z
156	case 0:
157		// x == y
158		return z.Clear() // They are equal
159	}
160
161	// At this point:
162	// x != 0
163	// y != 0
164	// x > y
165
166	// Shortcut trivial case
167	if x.IsUint64() {
168		return z.SetUint64(x.Uint64() % y.Uint64())
169	}
170
171	var quot Uint
172	*z = udivrem(quot.arr[:], x.arr[:], y)
173	return z
174}
175
176// DivMod sets z to the quotient x div y and m to the modulus x mod y and returns the pair (z, m) for y != 0.
177// If y == 0, both z and m are set to 0 (OBS: differs from the big.Int)
178func (z *Uint) DivMod(x, y, m *Uint) (*Uint, *Uint) {
179	if y.IsZero() {
180		return z.Clear(), m.Clear()
181	}
182	var quot Uint
183	*m = udivrem(quot.arr[:], x.arr[:], y)
184	*z = quot
185	return z, m
186}
187
188// Exp sets z = base**exponent mod 2**256, and returns z.
189func (z *Uint) Exp(base, exponent *Uint) *Uint {
190	res := Uint{arr: [4]uint64{1, 0, 0, 0}}
191	multiplier := *base
192	expBitLen := exponent.BitLen()
193
194	curBit := 0
195	word := exponent.arr[0]
196	for ; curBit < expBitLen && curBit < 64; curBit++ {
197		if word&1 == 1 {
198			res.Mul(&res, &multiplier)
199		}
200		multiplier.squared()
201		word >>= 1
202	}
203
204	word = exponent.arr[1]
205	for ; curBit < expBitLen && curBit < 128; curBit++ {
206		if word&1 == 1 {
207			res.Mul(&res, &multiplier)
208		}
209		multiplier.squared()
210		word >>= 1
211	}
212
213	word = exponent.arr[2]
214	for ; curBit < expBitLen && curBit < 192; curBit++ {
215		if word&1 == 1 {
216			res.Mul(&res, &multiplier)
217		}
218		multiplier.squared()
219		word >>= 1
220	}
221
222	word = exponent.arr[3]
223	for ; curBit < expBitLen && curBit < 256; curBit++ {
224		if word&1 == 1 {
225			res.Mul(&res, &multiplier)
226		}
227		multiplier.squared()
228		word >>= 1
229	}
230	return z.Set(&res)
231}
232
233func (z *Uint) squared() {
234	var (
235		res                    Uint
236		carry0, carry1, carry2 uint64
237		res1, res2             uint64
238	)
239
240	carry0, res.arr[0] = bits.Mul64(z.arr[0], z.arr[0])
241	carry0, res1 = umulHop(carry0, z.arr[0], z.arr[1])
242	carry0, res2 = umulHop(carry0, z.arr[0], z.arr[2])
243
244	carry1, res.arr[1] = umulHop(res1, z.arr[0], z.arr[1])
245	carry1, res2 = umulStep(res2, z.arr[1], z.arr[1], carry1)
246
247	carry2, res.arr[2] = umulHop(res2, z.arr[0], z.arr[2])
248
249	res.arr[3] = 2*(z.arr[0]*z.arr[3]+z.arr[1]*z.arr[2]) + carry0 + carry1 + carry2
250
251	z.Set(&res)
252}
253
254// udivrem divides u by d and produces both quotient and remainder.
255// The quotient is stored in provided quot - len(u)-len(d)+1 words.
256// It loosely follows the Knuth's division algorithm (sometimes referenced as "schoolbook" division) using 64-bit words.
257// See Knuth, Volume 2, section 4.3.1, Algorithm D.
258func udivrem(quot, u []uint64, d *Uint) (rem Uint) {
259	var dLen int
260	for i := len(d.arr) - 1; i >= 0; i-- {
261		if d.arr[i] != 0 {
262			dLen = i + 1
263			break
264		}
265	}
266
267	shift := uint(bits.LeadingZeros64(d.arr[dLen-1]))
268
269	var dnStorage Uint
270	dn := dnStorage.arr[:dLen]
271	for i := dLen - 1; i > 0; i-- {
272		dn[i] = (d.arr[i] << shift) | (d.arr[i-1] >> (64 - shift))
273	}
274	dn[0] = d.arr[0] << shift
275
276	var uLen int
277	for i := len(u) - 1; i >= 0; i-- {
278		if u[i] != 0 {
279			uLen = i + 1
280			break
281		}
282	}
283
284	if uLen < dLen {
285		copy(rem.arr[:], u)
286		return rem
287	}
288
289	var unStorage [9]uint64
290	un := unStorage[:uLen+1]
291	un[uLen] = u[uLen-1] >> (64 - shift)
292	for i := uLen - 1; i > 0; i-- {
293		un[i] = (u[i] << shift) | (u[i-1] >> (64 - shift))
294	}
295	un[0] = u[0] << shift
296
297	// TODO: Skip the highest word of numerator if not significant.
298
299	if dLen == 1 {
300		r := udivremBy1(quot, un, dn[0])
301		rem.SetUint64(r >> shift)
302		return rem
303	}
304
305	udivremKnuth(quot, un, dn)
306
307	for i := 0; i < dLen-1; i++ {
308		rem.arr[i] = (un[i] >> shift) | (un[i+1] << (64 - shift))
309	}
310	rem.arr[dLen-1] = un[dLen-1] >> shift
311
312	return rem
313}
314
315// umul computes full 256 x 256 -> 512 multiplication.
316func umul(x, y *Uint) [8]uint64 {
317	var (
318		res                           [8]uint64
319		carry, carry4, carry5, carry6 uint64
320		res1, res2, res3, res4, res5  uint64
321	)
322
323	carry, res[0] = bits.Mul64(x.arr[0], y.arr[0])
324	carry, res1 = umulHop(carry, x.arr[1], y.arr[0])
325	carry, res2 = umulHop(carry, x.arr[2], y.arr[0])
326	carry4, res3 = umulHop(carry, x.arr[3], y.arr[0])
327
328	carry, res[1] = umulHop(res1, x.arr[0], y.arr[1])
329	carry, res2 = umulStep(res2, x.arr[1], y.arr[1], carry)
330	carry, res3 = umulStep(res3, x.arr[2], y.arr[1], carry)
331	carry5, res4 = umulStep(carry4, x.arr[3], y.arr[1], carry)
332
333	carry, res[2] = umulHop(res2, x.arr[0], y.arr[2])
334	carry, res3 = umulStep(res3, x.arr[1], y.arr[2], carry)
335	carry, res4 = umulStep(res4, x.arr[2], y.arr[2], carry)
336	carry6, res5 = umulStep(carry5, x.arr[3], y.arr[2], carry)
337
338	carry, res[3] = umulHop(res3, x.arr[0], y.arr[3])
339	carry, res[4] = umulStep(res4, x.arr[1], y.arr[3], carry)
340	carry, res[5] = umulStep(res5, x.arr[2], y.arr[3], carry)
341	res[7], res[6] = umulStep(carry6, x.arr[3], y.arr[3], carry)
342
343	return res
344}
345
346// umulStep computes (hi * 2^64 + lo) = z + (x * y) + carry.
347func umulStep(z, x, y, carry uint64) (hi, lo uint64) {
348	hi, lo = bits.Mul64(x, y)
349	lo, carry = bits.Add64(lo, carry, 0)
350	hi, _ = bits.Add64(hi, 0, carry)
351	lo, carry = bits.Add64(lo, z, 0)
352	hi, _ = bits.Add64(hi, 0, carry)
353	return hi, lo
354}
355
356// umulHop computes (hi * 2^64 + lo) = z + (x * y)
357func umulHop(z, x, y uint64) (hi, lo uint64) {
358	hi, lo = bits.Mul64(x, y)
359	lo, carry := bits.Add64(lo, z, 0)
360	hi, _ = bits.Add64(hi, 0, carry)
361	return hi, lo
362}
363
364// udivremBy1 divides u by single normalized word d and produces both quotient and remainder.
365// The quotient is stored in provided quot.
366func udivremBy1(quot, u []uint64, d uint64) (rem uint64) {
367	reciprocal := reciprocal2by1(d)
368	rem = u[len(u)-1] // Set the top word as remainder.
369	for j := len(u) - 2; j >= 0; j-- {
370		quot[j], rem = udivrem2by1(rem, u[j], d, reciprocal)
371	}
372	return rem
373}
374
375// udivremKnuth implements the division of u by normalized multiple word d from the Knuth's division algorithm.
376// The quotient is stored in provided quot - len(u)-len(d) words.
377// Updates u to contain the remainder - len(d) words.
378func udivremKnuth(quot, u, d []uint64) {
379	dh := d[len(d)-1]
380	dl := d[len(d)-2]
381	reciprocal := reciprocal2by1(dh)
382
383	for j := len(u) - len(d) - 1; j >= 0; j-- {
384		u2 := u[j+len(d)]
385		u1 := u[j+len(d)-1]
386		u0 := u[j+len(d)-2]
387
388		var qhat, rhat uint64
389		if u2 >= dh { // Division overflows.
390			qhat = ^uint64(0)
391			// TODO: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case).
392		} else {
393			qhat, rhat = udivrem2by1(u2, u1, dh, reciprocal)
394			ph, pl := bits.Mul64(qhat, dl)
395			if ph > rhat || (ph == rhat && pl > u0) {
396				qhat--
397				// TODO: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case).
398			}
399		}
400
401		// Multiply and subtract.
402		borrow := subMulTo(u[j:], d, qhat)
403		u[j+len(d)] = u2 - borrow
404		if u2 < borrow { // Too much subtracted, add back.
405			qhat--
406			u[j+len(d)] += addTo(u[j:], d)
407		}
408
409		quot[j] = qhat // Store quotient digit.
410	}
411}
412
413// isBitSet returns true if bit n-th is set, where n = 0 is LSB.
414// The n must be <= 255.
415func (z *Uint) isBitSet(n uint) bool {
416	return (z.arr[n/64] & (1 << (n % 64))) != 0
417}
418
419// addTo computes x += y.
420// Requires len(x) >= len(y).
421func addTo(x, y []uint64) uint64 {
422	var carry uint64
423	for i := 0; i < len(y); i++ {
424		x[i], carry = bits.Add64(x[i], y[i], carry)
425	}
426	return carry
427}
428
429// subMulTo computes x -= y * multiplier.
430// Requires len(x) >= len(y).
431func subMulTo(x, y []uint64, multiplier uint64) uint64 {
432	var borrow uint64
433	for i := 0; i < len(y); i++ {
434		s, carry1 := bits.Sub64(x[i], borrow, 0)
435		ph, pl := bits.Mul64(y[i], multiplier)
436		t, carry2 := bits.Sub64(s, pl, 0)
437		x[i] = t
438		borrow = ph + carry1 + carry2
439	}
440	return borrow
441}
442
443// reciprocal2by1 computes <^d, ^0> / d.
444func reciprocal2by1(d uint64) uint64 {
445	reciprocal, _ := bits.Div64(^d, ^uint64(0), d)
446	return reciprocal
447}
448
449// udivrem2by1 divides <uh, ul> / d and produces both quotient and remainder.
450// It uses the provided d's reciprocal.
451// Implementation ported from https://github.com/chfast/intx and is based on
452// "Improved division by invariant integers", Algorithm 4.
453func udivrem2by1(uh, ul, d, reciprocal uint64) (quot, rem uint64) {
454	qh, ql := bits.Mul64(reciprocal, uh)
455	ql, carry := bits.Add64(ql, ul, 0)
456	qh, _ = bits.Add64(qh, uh, carry)
457	qh++
458
459	r := ul - qh*d
460
461	if r > ql {
462		qh--
463		r += d
464	}
465
466	if r >= d {
467		qh++
468		r -= d
469	}
470
471	return qh, r
472}