Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
PojavLauncherTeam
GitHub Repository: PojavLauncherTeam/mobile
Path: blob/master/test/hotspot/jtreg/vmTestbase/jit/escape/LockElision/MatMul/MatMul.java
41160 views
1
/*
2
* Copyright (c) 2010, 2020, Oracle and/or its affiliates. All rights reserved.
3
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4
*
5
* This code is free software; you can redistribute it and/or modify it
6
* under the terms of the GNU General Public License version 2 only, as
7
* published by the Free Software Foundation.
8
*
9
* This code is distributed in the hope that it will be useful, but WITHOUT
10
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12
* version 2 for more details (a copy is included in the LICENSE file that
13
* accompanied this code).
14
*
15
* You should have received a copy of the GNU General Public License version
16
* 2 along with this work; if not, write to the Free Software Foundation,
17
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18
*
19
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20
* or visit www.oracle.com if you need additional information or have any
21
* questions.
22
*/
23
24
/*
25
* @test
26
* @key randomness
27
*
28
* @summary converted from VM Testbase jit/escape/LockElision/MatMul.
29
* VM Testbase keywords: [jit, quick]
30
* VM Testbase readme:
31
* DESCRIPTION
32
* The test multiplies 2 matrices, first, by directly calculating matrix product
33
* elements, and second, by calculating them parallelly in diffenent threads.
34
* The results are compared then.
35
* The test, in addition to required locks, introduces locks on local variables or
36
* variables not escaping from the executing thread, and nests them manifoldly.
37
* In case of a buggy compiler, during lock elimination some code, required by
38
* the calulation may be eliminated as well, or the code may be overoptimized in
39
* some other way, causing difference in the execution results.
40
* The test has one parameter, -dim, which specifies the dimensions of matrices.
41
*
42
* @library /vmTestbase
43
* /test/lib
44
* @run main/othervm jit.escape.LockElision.MatMul.MatMul -dim 30 -threadCount 10
45
*/
46
47
package jit.escape.LockElision.MatMul;
48
49
import java.util.*;
50
import java.util.concurrent.CountDownLatch;
51
import java.util.concurrent.ExecutorService;
52
import java.util.concurrent.Executors;
53
54
import nsk.share.Consts;
55
import nsk.share.Log;
56
import nsk.share.Pair;
57
import nsk.share.test.StressOptions;
58
import vm.share.options.Option;
59
import vm.share.options.OptionSupport;
60
import vm.share.options.Options;
61
62
import jdk.test.lib.Utils;
63
64
public class MatMul {
65
66
@Option(name = "dim", description = "dimension of matrices")
67
int dim;
68
69
@Option(name = "verbose", default_value = "false",
70
description = "verbose mode")
71
boolean verbose;
72
73
@Option(name = "threadCount", description = "thread count")
74
int threadCount;
75
76
@Options
77
StressOptions stressOptions = new StressOptions();
78
79
private Log log;
80
81
public static void main(String[] args) {
82
MatMul test = new MatMul();
83
OptionSupport.setup(test, args);
84
System.exit(Consts.JCK_STATUS_BASE + test.run());
85
}
86
87
public int run() {
88
log = new Log(System.out, verbose);
89
log.display("Parallel matrix multiplication test");
90
91
Matrix a = Matrix.randomMatrix(dim);
92
Matrix b = Matrix.randomMatrix(dim);
93
long t1, t2;
94
95
t1 = System.currentTimeMillis();
96
Matrix serialResult = serialMul(a, b);
97
t2 = System.currentTimeMillis();
98
log.display("serial time: " + (t2 - t1) + "ms");
99
100
try {
101
t1 = System.currentTimeMillis();
102
Matrix parallelResult = parallelMul(a, b,
103
threadCount * stressOptions.getThreadsFactor());
104
t2 = System.currentTimeMillis();
105
log.display("parallel time: " + (t2 - t1) + "ms");
106
107
if (!serialResult.equals(parallelResult)) {
108
log.complain("a = \n" + a);
109
log.complain("b = \n" + b);
110
111
log.complain("serial: a * b = \n" + serialResult);
112
log.complain("serial: a * b = \n" + parallelResult);
113
return Consts.TEST_FAILED;
114
}
115
return Consts.TEST_PASSED;
116
117
} catch (CounterIncorrectStateException e) {
118
log.complain("incorrect state of counter " + e.counter.name);
119
log.complain("expected = " + e.counter.expected);
120
log.complain("actual " + e.counter.state());
121
return Consts.TEST_FAILED;
122
}
123
}
124
125
public static int convolution(Seq<Integer> one, Seq<Integer> two) {
126
int res = 0;
127
int upperBound = Math.min(one.size(), two.size());
128
for (int i = 0; i < upperBound; i++) {
129
res += one.get(i) * two.get(i);
130
}
131
return res;
132
}
133
134
/**
135
* calculate chunked convolutuion of two sequences
136
* <p/>
137
* This special version of this method:
138
* <pre>{@code
139
* public static int chunkedConvolution(Seq<Integer> one, Seq<Integer> two, int from, int to) {
140
* int res = 0;
141
* int upperBound = Math.min(Math.min(one.size(), two.size()), to + 1);
142
* for (int i = from; i < upperBound; i++) {
143
* res += one.get(i) * two.get(i);
144
* }
145
* return res;
146
* }}</pre>
147
* <p/>
148
* that tries to fool the Lock Elision optimization:
149
* Most lock objects in these lines are really thread local, so related synchronized blocks (dummy blocks) can be removed.
150
* But several synchronized blocks (all that protected by Counter instances) are really necessary, and removing them we obtain
151
* an incorrect result.
152
*
153
* @param one
154
* @param two
155
* @param from - lower bound of sum
156
* @param to - upper bound of sum
157
* @param local - reference ThreadLocal that will be used for calculations
158
* @param bCounter - Counter instance, need to perfom checks
159
*/
160
public static int chunkedConvolutionWithDummy(Seq<Integer> one,
161
Seq<Integer> two, int from, int to, ThreadLocals local,
162
Counter bCounter) {
163
ThreadLocals conv_local1 = new ThreadLocals(local, "conv_local1");
164
ThreadLocals conv_local2 = new ThreadLocals(conv_local1, "conv_local2");
165
ThreadLocals conv_local3 = new ThreadLocals(null, "conv_local3");
166
int res = 0;
167
synchronized (local) {
168
local.updateHash();
169
int upperBound = 0;
170
synchronized (conv_local1) {
171
upperBound = local.min(one.size(), two.size());
172
synchronized (two) {
173
//int upperBound = Math.min(Math.min(one.size(), two.size()), to + 1) :
174
upperBound = conv_local1.min(upperBound, to + 1);
175
synchronized (bCounter) {
176
bCounter.inc();
177
}
178
}
179
for (int i = from; i < upperBound; i++) {
180
synchronized (conv_local2) {
181
conv_local1.updateHash();
182
int prod = 0;
183
synchronized (one) {
184
int t = conv_local2.mult(one.get(i), two.get(i));
185
synchronized (conv_local3) {
186
prod = t;
187
188
}
189
//res += one.get(i) * two.get(i)
190
res = conv_local3.sum(res, prod);
191
}
192
}
193
}
194
}
195
return res;
196
}
197
}
198
199
public boolean productCheck(Matrix a, Matrix b) {
200
if (a == null || b == null) {
201
log.complain("null matrix!");
202
return false;
203
}
204
205
if (a.dim != b.dim) {
206
log.complain("matrices dimension are differs");
207
return false;
208
}
209
return true;
210
}
211
212
public Matrix serialMul(Matrix a, Matrix b) {
213
if (!productCheck(a, b)) {
214
throw new IllegalArgumentException();
215
}
216
217
Matrix result = Matrix.zeroMatrix(a.dim);
218
for (int i = 0; i < a.dim; i++) {
219
for (int j = 0; j < a.dim; j++) {
220
result.set(i, j, convolution(a.row(i), b.column(j)));
221
}
222
}
223
return result;
224
}
225
226
227
/**
228
* Parallel multiplication of matrices.
229
* <p/>
230
* This special version of this method:
231
* <pre>{@code
232
* public Matrix parallelMul1(final Matrix a, final Matrix b, int threadCount) {
233
* if (!productCheck(a, b)) {
234
* throw new IllegalArgumentException();
235
* }
236
* final int dim = a.dim;
237
* final Matrix result = Matrix.zeroMatrix(dim);
238
* <p/>
239
* ExecutorService threadPool = Executors.newFixedThreadPool(threadCount);
240
* final CountDownLatch latch = new CountDownLatch(threadCount);
241
* List<Pair<Integer, Integer>> parts = splitInterval(Pair.of(0, dim - 1), threadCount);
242
* for (final Pair<Integer, Integer> part : parts) {
243
* threadPool.submit(new Runnable() {
244
* @Override
245
* public void run() {
246
* for (int i = 0; i < dim; i++) {
247
* for (int j = 0; j < dim; j++) {
248
* synchronized (result) {
249
* int from = part.first;
250
* int to = part.second;
251
* result.add(i, j, chunkedConvolution(a.row(i), b.column(j), from, to));
252
* }
253
* }
254
* }
255
* latch.countDown();
256
* }
257
* });
258
* }
259
* <p/>
260
* try {
261
* latch.await();
262
* } catch (InterruptedException e) {
263
* e.printStackTrace();
264
* }
265
* threadPool.shutdown();
266
* return result;
267
* }}</pre>
268
* Lines marked with NOP comments need to fool the Lock Elision optimization:
269
* All lock objects in these lines are really thread local, so related synchronized blocks (dummy blocks) can be removed.
270
* But several synchronized blocks (that are nested in dummy blocks) are really necessary, and removing them we obtain
271
* an incorrect result.
272
*
273
* @param a first operand
274
* @param b second operand
275
* @param threadCount number of threads that will be used for calculations
276
* @return product of matrices a and b
277
*/
278
public Matrix parallelMul(final Matrix a, final Matrix b, int threadCount)
279
throws CounterIncorrectStateException {
280
if (!productCheck(a, b)) {
281
throw new IllegalArgumentException();
282
}
283
final int dim = a.dim;
284
final Matrix result = Matrix.zeroMatrix(dim);
285
286
ExecutorService threadPool = Executors.newFixedThreadPool(threadCount);
287
final CountDownLatch latch = new CountDownLatch(threadCount);
288
List<Pair<Integer, Integer>> parts = splitInterval(Pair.of(0, dim - 1),
289
threadCount);
290
291
final Counter lCounter1 = new Counter(threadCount, "lCounter1");
292
final Counter lCounter2 = new Counter(threadCount, "lCounter2");
293
final Counter lCounter3 = new Counter(threadCount, "lCounter3");
294
295
final Counter bCounter1 = new Counter(threadCount * dim * dim,
296
"bCounter1");
297
final Counter bCounter2 = new Counter(threadCount * dim * dim,
298
"bCounter2");
299
final Counter bCounter3 = new Counter(threadCount * dim * dim,
300
"bCounter3");
301
302
final Counter[] counters = {lCounter1, lCounter2, lCounter3,
303
bCounter1, bCounter2, bCounter3};
304
305
final Map<Pair<Integer, Integer>, ThreadLocals> locals1
306
= CollectionsUtils.newHashMap();
307
final Map<Pair<Integer, Integer>, ThreadLocals> locals2
308
= CollectionsUtils.newHashMap();
309
final Map<Pair<Integer, Integer>, ThreadLocals> locals3
310
= CollectionsUtils.newHashMap();
311
312
for (final Pair<Integer, Integer> part : parts) {
313
314
ThreadLocals local1 = new ThreadLocals(null,
315
"locals1[" + part + "]");
316
ThreadLocals local2 = new ThreadLocals(local1,
317
"locals2[" + part + "]");
318
ThreadLocals local3 = new ThreadLocals(local2,
319
"locals3[" + part + "]");
320
321
locals1.put(part, local1);
322
locals2.put(part, local2);
323
locals3.put(part, local3);
324
}
325
326
for (final Pair<Integer, Integer> part : parts) {
327
threadPool.submit(new Runnable() {
328
@Override
329
public void run() {
330
ThreadLocals local1 = locals1.get(part);
331
ThreadLocals local2 = locals2.get(part);
332
ThreadLocals local3 = locals3.get(part);
333
ThreadLocals local4 = locals3.get(part);
334
synchronized (local1) {
335
local1.updateHash();
336
synchronized (lCounter1) {
337
lCounter1.inc();
338
}
339
synchronized (lCounter3) {
340
synchronized (local2) {
341
local2.updateHash();
342
lCounter3.inc();
343
}
344
}
345
synchronized (new Object()) {
346
synchronized (lCounter2) {
347
lCounter2.inc();
348
}
349
for (int i = 0; i < dim; i++) {
350
for (int j = 0; j < dim; j++) {
351
synchronized (bCounter1) {
352
synchronized (new Object()) {
353
bCounter1.inc();
354
}
355
}
356
synchronized (local3) {
357
local3.updateHash();
358
synchronized (bCounter2) {
359
bCounter2.inc();
360
}
361
synchronized (result) {
362
local1.updateHash();
363
synchronized (local2) {
364
local2.updateHash();
365
int from = part.first;
366
int to = part.second;
367
result.add(i, j,
368
chunkedConvolutionWithDummy(
369
a.row(i),
370
b.column(j),
371
from, to,
372
local4,
373
bCounter3));
374
}
375
}
376
}
377
}
378
}
379
}
380
}
381
latch.countDown();
382
}
383
});
384
}
385
386
try {
387
latch.await();
388
} catch (InterruptedException e) {
389
e.printStackTrace();
390
}
391
392
threadPool.shutdown();
393
for (final Pair<Integer, Integer> part : parts) {
394
log.display(
395
"hash for " + part + " = " + locals1.get(part).getHash());
396
}
397
398
399
for (Counter counter : counters) {
400
if (!counter.check()) {
401
throw new CounterIncorrectStateException(counter);
402
}
403
}
404
return result;
405
}
406
407
/**
408
* Split interval into parts
409
*
410
* @param interval - pair than encode bounds of interval
411
* @param partCount - count of parts
412
* @return list of pairs than encode bounds of parts
413
*/
414
public static List<Pair<Integer, Integer>> splitInterval(
415
Pair<Integer, Integer> interval, int partCount) {
416
if (partCount == 0) {
417
throw new IllegalArgumentException();
418
}
419
420
if (partCount == 1) {
421
return CollectionsUtils.asList(interval);
422
}
423
424
int intervalSize = interval.second - interval.first + 1;
425
int partSize = intervalSize / partCount;
426
427
List<Pair<Integer, Integer>> init = splitInterval(
428
Pair.of(interval.first, interval.second - partSize),
429
partCount - 1);
430
Pair<Integer, Integer> lastPart = Pair
431
.of(interval.second - partSize + 1, interval.second);
432
433
return CollectionsUtils.append(init, lastPart);
434
}
435
436
public static class Counter {
437
private int state;
438
439
public final int expected;
440
public final String name;
441
442
public void inc() {
443
state++;
444
}
445
446
public int state() {
447
return state;
448
}
449
450
public boolean check() {
451
return state == expected;
452
}
453
454
public Counter(int expected, String name) {
455
this.expected = expected;
456
this.name = name;
457
}
458
}
459
460
private static class CounterIncorrectStateException extends Exception {
461
public final Counter counter;
462
463
public CounterIncorrectStateException(Counter counter) {
464
this.counter = counter;
465
}
466
}
467
468
private static abstract class Seq<E> implements Iterable<E> {
469
@Override
470
public Iterator<E> iterator() {
471
return new Iterator<E>() {
472
private int p = 0;
473
474
@Override
475
public boolean hasNext() {
476
return p < size();
477
}
478
479
@Override
480
public E next() {
481
return get(p++);
482
}
483
484
@Override
485
public void remove() {
486
}
487
};
488
}
489
490
public abstract E get(int i);
491
492
public abstract int size();
493
}
494
495
private static class CollectionsUtils {
496
497
public static <K, V> Map<K, V> newHashMap() {
498
return new HashMap<K, V>();
499
}
500
501
public static <E> List<E> newArrayList() {
502
return new ArrayList<E>();
503
}
504
505
public static <E> List<E> newArrayList(Collection<E> collection) {
506
return new ArrayList<E>(collection);
507
}
508
509
public static <E> List<E> asList(E e) {
510
List<E> result = newArrayList();
511
result.add(e);
512
return result;
513
}
514
515
public static <E> List<E> append(List<E> init, E last) {
516
List<E> result = newArrayList(init);
517
result.add(last);
518
return result;
519
}
520
}
521
522
private static class Matrix {
523
524
public final int dim;
525
private int[] coeffs;
526
527
private Matrix(int dim) {
528
this.dim = dim;
529
this.coeffs = new int[dim * dim];
530
}
531
532
public void set(int i, int j, int value) {
533
coeffs[i * dim + j] = value;
534
}
535
536
public void add(int i, int j, int value) {
537
coeffs[i * dim + j] += value;
538
}
539
540
public int get(int i, int j) {
541
return coeffs[i * dim + j];
542
}
543
544
public Seq<Integer> row(final int i) {
545
return new Seq<Integer>() {
546
@Override
547
public Integer get(int j) {
548
return Matrix.this.get(i, j);
549
}
550
551
@Override
552
public int size() {
553
return Matrix.this.dim;
554
}
555
};
556
}
557
558
public Seq<Integer> column(final int j) {
559
return new Seq<Integer>() {
560
@Override
561
public Integer get(int i) {
562
return Matrix.this.get(i, j);
563
}
564
565
@Override
566
public int size() {
567
return Matrix.this.dim;
568
}
569
};
570
}
571
572
@Override
573
public String toString() {
574
StringBuilder builder = new StringBuilder();
575
for (int i = 0; i < dim; i++) {
576
for (int j = 0; j < dim; j++) {
577
builder.append((j == 0) ? "" : "\t\t");
578
builder.append(get(i, j));
579
}
580
builder.append("\n");
581
}
582
return builder.toString();
583
}
584
585
@Override
586
public boolean equals(Object other) {
587
if (!(other instanceof Matrix)) {
588
return false;
589
}
590
591
Matrix b = (Matrix) other;
592
if (b.dim != this.dim) {
593
return false;
594
}
595
for (int i = 0; i < dim; i++) {
596
for (int j = 0; j < dim; j++) {
597
if (this.get(i, j) != b.get(i, j)) {
598
return false;
599
}
600
}
601
}
602
return true;
603
}
604
605
private static Random random = Utils.getRandomInstance();
606
607
public static Matrix randomMatrix(int dim) {
608
Matrix result = new Matrix(dim);
609
for (int i = 0; i < dim; i++) {
610
for (int j = 0; j < dim; j++) {
611
result.set(i, j, random.nextInt(50));
612
}
613
}
614
return result;
615
}
616
617
public static Matrix zeroMatrix(int dim) {
618
Matrix result = new Matrix(dim);
619
for (int i = 0; i < dim; i++) {
620
for (int j = 0; j < dim; j++) {
621
result.set(i, j, 0);
622
}
623
}
624
return result;
625
}
626
}
627
628
/**
629
* All instances of this class will be used in thread local context
630
*/
631
private static class ThreadLocals {
632
private static final int HASH_BOUND = 424242;
633
634
private ThreadLocals parent;
635
private int hash = 42;
636
public final String name;
637
638
public ThreadLocals(ThreadLocals parent, String name) {
639
this.parent = parent;
640
this.name = name;
641
}
642
643
public int min(int a, int b) {
644
updateHash(a + b + 1);
645
return Math.min(a, b);
646
}
647
648
public int mult(int a, int b) {
649
updateHash(a + b + 2);
650
return a * b;
651
}
652
653
public int sum(int a, int b) {
654
updateHash(a + b + 3);
655
return a + b;
656
}
657
658
659
public int updateHash() {
660
return updateHash(42);
661
}
662
663
public int updateHash(int data) {
664
hash = (hash + data) % HASH_BOUND;
665
if (parent != null) {
666
hash = parent.updateHash(hash) % HASH_BOUND;
667
}
668
return hash;
669
}
670
671
public int getHash() {
672
return hash;
673
}
674
}
675
}
676
677