Path: blob/master/test/jdk/java/foreign/TestSpliterator.java
41144 views
/*1* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.2* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.3*4* This code is free software; you can redistribute it and/or modify it5* under the terms of the GNU General Public License version 2 only, as6* published by the Free Software Foundation.7*8* This code is distributed in the hope that it will be useful, but WITHOUT9* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or10* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License11* version 2 for more details (a copy is included in the LICENSE file that12* accompanied this code).13*14* You should have received a copy of the GNU General Public License version15* 2 along with this work; if not, write to the Free Software Foundation,16* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.17*18* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA19* or visit www.oracle.com if you need additional information or have any20* questions.21*/2223/*24* @test25* @run testng TestSpliterator26*/2728import jdk.incubator.foreign.MemoryLayout;29import jdk.incubator.foreign.MemoryLayouts;30import jdk.incubator.foreign.MemorySegment;31import jdk.incubator.foreign.ResourceScope;32import jdk.incubator.foreign.SequenceLayout;3334import java.lang.invoke.VarHandle;35import java.util.LinkedList;36import java.util.List;37import java.util.Spliterator;38import java.util.concurrent.CountedCompleter;39import java.util.concurrent.RecursiveTask;40import java.util.concurrent.atomic.AtomicLong;41import java.util.stream.LongStream;42import java.util.stream.StreamSupport;4344import org.testng.annotations.*;4546import static org.testng.Assert.*;4748public class TestSpliterator {4950static final VarHandle INT_HANDLE = MemoryLayout.sequenceLayout(MemoryLayouts.JAVA_INT)51.varHandle(int.class, MemoryLayout.PathElement.sequenceElement());5253final static int CARRIER_SIZE = 4;5455@Test(dataProvider = "splits")56public void testSum(int size, int threshold) {57SequenceLayout layout = MemoryLayout.sequenceLayout(size, MemoryLayouts.JAVA_INT);5859//setup60try (ResourceScope scope = ResourceScope.newSharedScope()) {61MemorySegment segment = MemorySegment.allocateNative(layout, scope);62for (int i = 0; i < layout.elementCount().getAsLong(); i++) {63INT_HANDLE.set(segment, (long) i, i);64}65long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum();66//serial67long serial = sum(0, segment);68assertEquals(serial, expected);69//parallel counted completer70long parallelCounted = new SumSegmentCounted(null, segment.spliterator(layout.elementLayout()), threshold).invoke();71assertEquals(parallelCounted, expected);72//parallel recursive action73long parallelRecursive = new SumSegmentRecursive(segment.spliterator(layout.elementLayout()), threshold).invoke();74assertEquals(parallelRecursive, expected);75//parallel stream76long streamParallel = segment.elements(layout.elementLayout()).parallel()77.reduce(0L, TestSpliterator::sumSingle, Long::sum);78assertEquals(streamParallel, expected);79}80}8182@Test83public void testSumSameThread() {84SequenceLayout layout = MemoryLayout.sequenceLayout(1024, MemoryLayouts.JAVA_INT);8586//setup87MemorySegment segment = MemorySegment.allocateNative(layout, ResourceScope.newImplicitScope());88for (int i = 0; i < layout.elementCount().getAsLong(); i++) {89INT_HANDLE.set(segment, (long) i, i);90}91long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum();9293//check that a segment w/o ACQUIRE access mode can still be used from same thread94AtomicLong spliteratorSum = new AtomicLong();95segment.spliterator(layout.elementLayout())96.forEachRemaining(s -> spliteratorSum.addAndGet(sumSingle(0L, s)));97assertEquals(spliteratorSum.get(), expected);98}99100@Test(expectedExceptions = IllegalArgumentException.class)101public void testBadSpliteratorElementSizeTooBig() {102MemorySegment.ofArray(new byte[2]).spliterator(MemoryLayouts.JAVA_INT);103}104105@Test(expectedExceptions = IllegalArgumentException.class)106public void testBadStreamElementSizeTooBig() {107MemorySegment.ofArray(new byte[2]).elements(MemoryLayouts.JAVA_INT);108}109110@Test(expectedExceptions = IllegalArgumentException.class)111public void testBadSpliteratorElementSizeNotMultiple() {112MemorySegment.ofArray(new byte[7]).spliterator(MemoryLayouts.JAVA_INT);113}114115@Test(expectedExceptions = IllegalArgumentException.class)116public void testBadStreamElementSizeNotMultiple() {117MemorySegment.ofArray(new byte[7]).elements(MemoryLayouts.JAVA_INT);118}119120@Test(expectedExceptions = IllegalArgumentException.class)121public void testBadSpliteratorElementSizeZero() {122MemorySegment.ofArray(new byte[7]).spliterator(MemoryLayout.sequenceLayout(0, MemoryLayouts.JAVA_INT));123}124125@Test(expectedExceptions = IllegalArgumentException.class)126public void testBadStreamElementSizeZero() {127MemorySegment.ofArray(new byte[7]).elements(MemoryLayout.sequenceLayout(0, MemoryLayouts.JAVA_INT));128}129130static long sumSingle(long acc, MemorySegment segment) {131return acc + (int)INT_HANDLE.get(segment, 0L);132}133134static long sum(long start, MemorySegment segment) {135long sum = start;136int length = (int)segment.byteSize();137for (int i = 0 ; i < length / CARRIER_SIZE ; i++) {138sum += (int)INT_HANDLE.get(segment, (long)i);139}140return sum;141}142143static class SumSegmentCounted extends CountedCompleter<Long> {144145final long threshold;146long localSum = 0;147List<SumSegmentCounted> children = new LinkedList<>();148149private Spliterator<MemorySegment> segmentSplitter;150151SumSegmentCounted(SumSegmentCounted parent, Spliterator<MemorySegment> segmentSplitter, long threshold) {152super(parent);153this.segmentSplitter = segmentSplitter;154this.threshold = threshold;155}156157@Override158public void compute() {159Spliterator<MemorySegment> sub;160while (segmentSplitter.estimateSize() > threshold &&161(sub = segmentSplitter.trySplit()) != null) {162addToPendingCount(1);163SumSegmentCounted child = new SumSegmentCounted(this, sub, threshold);164children.add(child);165child.fork();166}167segmentSplitter.forEachRemaining(slice -> {168localSum += sumSingle(0, slice);169});170tryComplete();171}172173@Override174public Long getRawResult() {175long sum = localSum;176for (SumSegmentCounted c : children) {177sum += c.getRawResult();178}179return sum;180}181}182183static class SumSegmentRecursive extends RecursiveTask<Long> {184185final long threshold;186private final Spliterator<MemorySegment> splitter;187private long result;188189SumSegmentRecursive(Spliterator<MemorySegment> splitter, long threshold) {190this.splitter = splitter;191this.threshold = threshold;192}193194@Override195protected Long compute() {196if (splitter.estimateSize() > threshold) {197SumSegmentRecursive sub = new SumSegmentRecursive(splitter.trySplit(), threshold);198sub.fork();199return compute() + sub.join();200} else {201splitter.forEachRemaining(slice -> {202result += sumSingle(0, slice);203});204return result;205}206}207}208209@DataProvider(name = "splits")210public Object[][] splits() {211return new Object[][] {212{ 10, 1 },213{ 100, 1 },214{ 1000, 1 },215{ 10000, 1 },216{ 10, 10 },217{ 100, 10 },218{ 1000, 10 },219{ 10000, 10 },220{ 10, 100 },221{ 100, 100 },222{ 1000, 100 },223{ 10000, 100 },224{ 10, 1000 },225{ 100, 1000 },226{ 1000, 1000 },227{ 10000, 1000 },228{ 10, 10000 },229{ 100, 10000 },230{ 1000, 10000 },231{ 10000, 10000 },232};233}234}235236237