arithmetic_test.gno

11.51 Kb ยท 392 lines
  1package uint256
  2
  3import (
  4	"testing"
  5)
  6
  7type binOp2Test struct {
  8	x, y, want string
  9}
 10
 11func TestAdd(t *testing.T) {
 12	tests := []binOp2Test{
 13		{"0", "1", "1"},
 14		{"1", "0", "1"},
 15		{"1", "1", "2"},
 16		{"1", "3", "4"},
 17		{"10", "10", "20"},
 18		{"18446744073709551615", "18446744073709551615", "36893488147419103230"}, // uint64 overflow
 19	}
 20
 21	for _, tt := range tests {
 22		x := MustFromDecimal(tt.x)
 23		y := MustFromDecimal(tt.y)
 24
 25		want := MustFromDecimal(tt.want)
 26		got := new(Uint).Add(x, y)
 27
 28		if got.Neq(want) {
 29			t.Errorf("Add(%s, %s) = %v, want %v", tt.x, tt.y, got.String(), want.String())
 30		}
 31	}
 32}
 33
 34func TestAddOverflow(t *testing.T) {
 35	tests := []struct {
 36		x, y     string
 37		want     string
 38		overflow bool
 39	}{
 40		{"0", "1", "1", false},
 41		{"1", "0", "1", false},
 42		{"1", "1", "2", false},
 43		{"10", "10", "20", false},
 44		{"18446744073709551615", "18446744073709551615", "36893488147419103230", false},                    // uint64 overflow, but not Uint256 overflow
 45		{"115792089237316195423570985008687907853269984665640564039457584007913129639935", "1", "0", true}, // 2^256 - 1 + 1, should overflow
 46		{"57896044618658097711785492504343953926634992332820282019728792003956564819967", "57896044618658097711785492504343953926634992332820282019728792003956564819968", "115792089237316195423570985008687907853269984665640564039457584007913129639935", false}, // (2^255 - 1) + 2^255, no overflow
 47		{"57896044618658097711785492504343953926634992332820282019728792003956564819967", "57896044618658097711785492504343953926634992332820282019728792003956564819969", "0", true},                                                                               // (2^255 - 1) + (2^255 + 1), should overflow
 48	}
 49
 50	for _, tt := range tests {
 51		x := MustFromDecimal(tt.x)
 52		y := MustFromDecimal(tt.y)
 53		want, _ := FromDecimal(tt.want)
 54
 55		got, overflow := new(Uint).AddOverflow(x, y)
 56
 57		if got.Cmp(want) != 0 || overflow != tt.overflow {
 58			t.Errorf("AddOverflow(%s, %s) = (%s, %v), want (%s, %v)",
 59				tt.x, tt.y, got.String(), overflow, tt.want, tt.overflow)
 60		}
 61	}
 62}
 63
 64func TestSub(t *testing.T) {
 65	tests := []binOp2Test{
 66		{"1", "0", "1"},
 67		{"1", "1", "0"},
 68		{"10", "10", "0"},
 69		{"31337", "1337", "30000"},
 70		{"2", "3", twoPow256Sub1}, // underflow
 71	}
 72
 73	for _, tc := range tests {
 74		x := MustFromDecimal(tc.x)
 75		y := MustFromDecimal(tc.y)
 76
 77		want := MustFromDecimal(tc.want)
 78
 79		got := new(Uint).Sub(x, y)
 80
 81		if got.Neq(want) {
 82			t.Errorf(
 83				"Sub(%s, %s) = %v, want %v",
 84				tc.x, tc.y, got.String(), want.String(),
 85			)
 86		}
 87	}
 88}
 89
 90func TestSubOverflow(t *testing.T) {
 91	tests := []struct {
 92		x, y     string
 93		want     string
 94		overflow bool
 95	}{
 96		{"1", "0", "1", false},
 97		{"1", "1", "0", false},
 98		{"10", "10", "0", false},
 99		{"31337", "1337", "30000", false},
100		{"0", "1", "115792089237316195423570985008687907853269984665640564039457584007913129639935", true},                                                                                                                                                         // 0 - 1, should underflow
101		{"57896044618658097711785492504343953926634992332820282019728792003956564819968", "1", "57896044618658097711785492504343953926634992332820282019728792003956564819967", false},                                                                             // 2^255 - 1, no underflow
102		{"57896044618658097711785492504343953926634992332820282019728792003956564819968", "57896044618658097711785492504343953926634992332820282019728792003956564819969", "115792089237316195423570985008687907853269984665640564039457584007913129639935", true}, // 2^255 - (2^255 + 1), should underflow
103	}
104
105	for _, tc := range tests {
106		x := MustFromDecimal(tc.x)
107		y := MustFromDecimal(tc.y)
108		want := MustFromDecimal(tc.want)
109
110		got, overflow := new(Uint).SubOverflow(x, y)
111
112		if got.Cmp(want) != 0 || overflow != tc.overflow {
113			t.Errorf(
114				"SubOverflow(%s, %s) = (%s, %v), want (%s, %v)",
115				tc.x, tc.y, got.String(), overflow, tc.want, tc.overflow,
116			)
117		}
118	}
119}
120
121func TestMul(t *testing.T) {
122	tests := []binOp2Test{
123		{"1", "0", "0"},
124		{"1", "1", "1"},
125		{"10", "10", "100"},
126		{"18446744073709551615", "2", "36893488147419103230"}, // uint64 overflow
127	}
128
129	for _, tt := range tests {
130		x := MustFromDecimal(tt.x)
131		y := MustFromDecimal(tt.y)
132		want := MustFromDecimal(tt.want)
133		got := new(Uint).Mul(x, y)
134
135		if got.Neq(want) {
136			t.Errorf("Mul(%s, %s) = %v, want %v", tt.x, tt.y, got.String(), want.String())
137		}
138	}
139}
140
141func TestMulOverflow(t *testing.T) {
142	tests := []struct {
143		x        string
144		y        string
145		wantZ    string
146		wantOver bool
147	}{
148		{"0x1", "0x1", "0x1", false},
149		{"0x0", "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0x0", false},
150		{"0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0x2", "0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe", true},
151		{"0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0x1", true},
152		{"0x8000000000000000000000000000000000000000000000000000000000000000", "0x2", "0x0", true},
153		{"0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0x2", "0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe", false},
154		{"0x100000000000000000", "0x100000000000000000", "0x10000000000000000000000000000000000", false},
155		{"0x10000000000000000000000000000000", "0x10000000000000000000000000000000", "0x100000000000000000000000000000000000000000000000000000000000000", false},
156	}
157
158	for _, tt := range tests {
159		x := MustFromHex(tt.x)
160		y := MustFromHex(tt.y)
161		wantZ := MustFromHex(tt.wantZ)
162
163		gotZ, gotOver := new(Uint).MulOverflow(x, y)
164
165		if gotZ.Neq(wantZ) {
166			t.Errorf(
167				"MulOverflow(%s, %s) = %s, want %s",
168				tt.x, tt.y, gotZ.String(), wantZ.String(),
169			)
170		}
171		if gotOver != tt.wantOver {
172			t.Errorf("MulOverflow(%s, %s) = %v, want %v", tt.x, tt.y, gotOver, tt.wantOver)
173		}
174	}
175}
176
177func TestDiv(t *testing.T) {
178	tests := []binOp2Test{
179		{"31337", "3", "10445"},
180		{"31337", "0", "0"},
181		{"0", "31337", "0"},
182		{"1", "1", "1"},
183		{"1000000000000000000", "3", "333333333333333333"},
184		{twoPow256Sub1, "2", "57896044618658097711785492504343953926634992332820282019728792003956564819967"},
185	}
186
187	for _, tt := range tests {
188		x := MustFromDecimal(tt.x)
189		y := MustFromDecimal(tt.y)
190		want := MustFromDecimal(tt.want)
191
192		got := new(Uint).Div(x, y)
193
194		if got.Neq(want) {
195			t.Errorf("Div(%s, %s) = %v, want %v", tt.x, tt.y, got.String(), want.String())
196		}
197	}
198}
199
200func TestMod(t *testing.T) {
201	tests := []binOp2Test{
202		{"31337", "3", "2"},
203		{"31337", "0", "0"},
204		{"0", "31337", "0"},
205		{"2", "31337", "2"},
206		{"1", "1", "0"},
207		{"115792089237316195423570985008687907853269984665640564039457584007913129639935", "2", "1"}, // 2^256 - 1 mod 2
208		{"115792089237316195423570985008687907853269984665640564039457584007913129639935", "3", "0"}, // 2^256 - 1 mod 3
209		{"115792089237316195423570985008687907853269984665640564039457584007913129639935", "57896044618658097711785492504343953926634992332820282019728792003956564819968", "57896044618658097711785492504343953926634992332820282019728792003956564819967"}, // 2^256 - 1 mod 2^255
210	}
211
212	for _, tt := range tests {
213		x := MustFromDecimal(tt.x)
214		y := MustFromDecimal(tt.y)
215		want := MustFromDecimal(tt.want)
216
217		got := new(Uint).Mod(x, y)
218
219		if got.Neq(want) {
220			t.Errorf("Mod(%s, %s) = %v, want %v", tt.x, tt.y, got.String(), want.String())
221		}
222	}
223}
224
225func TestMulMod(t *testing.T) {
226	tests := []struct {
227		x    string
228		y    string
229		m    string
230		want string
231	}{
232		{"0x1", "0x1", "0x2", "0x1"},
233		{"0x10", "0x10", "0x7", "0x4"},
234		{"0x100", "0x100", "0x17", "0x9"},
235		{"0x31337", "0x31337", "0x31338", "0x1"},
236		{"0x0", "0x31337", "0x31338", "0x0"},
237		{"0x31337", "0x0", "0x31338", "0x0"},
238		{"0x2", "0x3", "0x5", "0x1"},
239		{"0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0x0"},
240		{"0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe", "0x1"},
241		{"0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0xffffffffffffffffffffffffffffffff", "0x0"},
242	}
243
244	for _, tt := range tests {
245		x := MustFromHex(tt.x)
246		y := MustFromHex(tt.y)
247		m := MustFromHex(tt.m)
248		want := MustFromHex(tt.want)
249
250		got := new(Uint).MulMod(x, y, m)
251
252		if got.Neq(want) {
253			t.Errorf(
254				"MulMod(%s, %s, %s) = %s, want %s",
255				tt.x, tt.y, tt.m, got.String(), want.String(),
256			)
257		}
258	}
259}
260
261func TestDivMod(t *testing.T) {
262	tests := []struct {
263		x       string
264		y       string
265		wantDiv string
266		wantMod string
267	}{
268		{"1", "1", "1", "0"},
269		{"10", "10", "1", "0"},
270		{"100", "10", "10", "0"},
271		{"31337", "3", "10445", "2"},
272		{"31337", "0", "0", "0"},
273		{"0", "31337", "0", "0"},
274		{"2", "31337", "0", "2"},
275	}
276
277	for _, tt := range tests {
278		x := MustFromDecimal(tt.x)
279		y := MustFromDecimal(tt.y)
280		wantDiv := MustFromDecimal(tt.wantDiv)
281		wantMod := MustFromDecimal(tt.wantMod)
282
283		gotDiv := new(Uint)
284		gotMod := new(Uint)
285		gotDiv.DivMod(x, y, gotMod)
286
287		for i := range gotDiv.arr {
288			if gotDiv.arr[i] != wantDiv.arr[i] {
289				t.Errorf("DivMod(%s, %s) got Div %v, want Div %v", tt.x, tt.y, gotDiv, wantDiv)
290				break
291			}
292		}
293		for i := range gotMod.arr {
294			if gotMod.arr[i] != wantMod.arr[i] {
295				t.Errorf("DivMod(%s, %s) got Mod %v, want Mod %v", tt.x, tt.y, gotMod, wantMod)
296				break
297			}
298		}
299	}
300}
301
302func TestNeg(t *testing.T) {
303	tests := []struct {
304		x    string
305		want string
306	}{
307		{"31337", "115792089237316195423570985008687907853269984665640564039457584007913129608599"},
308		{"115792089237316195423570985008687907853269984665640564039457584007913129608599", "31337"},
309		{"0", "0"},
310		{"2", "115792089237316195423570985008687907853269984665640564039457584007913129639934"},
311		{"1", twoPow256Sub1},
312	}
313
314	for _, tt := range tests {
315		x := MustFromDecimal(tt.x)
316		want := MustFromDecimal(tt.want)
317
318		got := new(Uint).Neg(x)
319
320		if got.Neq(want) {
321			t.Errorf("Neg(%s) = %v, want %v", tt.x, got.String(), want.String())
322		}
323	}
324}
325
326func TestExp(t *testing.T) {
327	tests := []binOp2Test{
328		{"31337", "3", "30773171189753"},
329		{"31337", "0", "1"},
330		{"0", "31337", "0"},
331		{"1", "1", "1"},
332		{"2", "3", "8"},
333		{"2", "64", "18446744073709551616"},
334		{"2", "128", "340282366920938463463374607431768211456"},
335		{"2", "255", "57896044618658097711785492504343953926634992332820282019728792003956564819968"},
336		{"2", "256", "0"}, // overflow
337	}
338
339	for _, tt := range tests {
340		x := MustFromDecimal(tt.x)
341		y := MustFromDecimal(tt.y)
342		want := MustFromDecimal(tt.want)
343
344		got := new(Uint).Exp(x, y)
345
346		if got.Neq(want) {
347			t.Errorf(
348				"Exp(%s, %s) = %v, want %v",
349				tt.x, tt.y, got.String(), want.String(),
350			)
351		}
352	}
353}
354
355func TestExp_LargeExponent(t *testing.T) {
356	tests := []struct {
357		name     string
358		base     string
359		exponent string
360		expected string
361	}{
362		{
363			name:     "2^129",
364			base:     "2",
365			exponent: "680564733841876926926749214863536422912",
366			expected: "0",
367		},
368		{
369			name:     "2^193",
370			base:     "2",
371			exponent: "12379400392853802746563808384000000000000000000",
372			expected: "0",
373		},
374	}
375
376	for _, tt := range tests {
377		t.Run(tt.name, func(t *testing.T) {
378			base := MustFromDecimal(tt.base)
379			exponent := MustFromDecimal(tt.exponent)
380			expected := MustFromDecimal(tt.expected)
381
382			result := new(Uint).Exp(base, exponent)
383
384			if result.Neq(expected) {
385				t.Errorf(
386					"Test %s failed. Expected %s, got %s",
387					tt.name, expected.String(), result.String(),
388				)
389			}
390		})
391	}
392}