Path: blob/master/test/jdk/java/foreign/StdLibTest.java
41144 views
/*1* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.2* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.3*4* This code is free software; you can redistribute it and/or modify it5* under the terms of the GNU General Public License version 2 only, as6* published by the Free Software Foundation.7*8* This code is distributed in the hope that it will be useful, but WITHOUT9* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or10* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License11* version 2 for more details (a copy is included in the LICENSE file that12* accompanied this code).13*14* You should have received a copy of the GNU General Public License version15* 2 along with this work; if not, write to the Free Software Foundation,16* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.17*18* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA19* or visit www.oracle.com if you need additional information or have any20* questions.21*/2223/*24* @test25* @requires ((os.arch == "amd64" | os.arch == "x86_64") & sun.arch.data.model == "64") | os.arch == "aarch64"26* @run testng/othervm --enable-native-access=ALL-UNNAMED StdLibTest27*/2829import java.lang.invoke.MethodHandle;30import java.lang.invoke.MethodHandles;31import java.lang.invoke.MethodType;32import java.time.Instant;33import java.time.LocalDateTime;34import java.time.ZoneOffset;35import java.time.ZonedDateTime;36import java.util.ArrayList;37import java.util.Arrays;38import java.util.Collections;39import java.util.LinkedHashSet;40import java.util.List;41import java.util.Optional;42import java.util.Set;43import java.util.function.BiConsumer;44import java.util.function.Function;45import java.util.stream.Collectors;46import java.util.stream.Stream;4748import jdk.incubator.foreign.*;4950import static jdk.incubator.foreign.MemoryAccess.*;5152import org.testng.annotations.*;5354import static jdk.incubator.foreign.CLinker.*;55import static org.testng.Assert.*;5657@Test58public class StdLibTest {5960final static CLinker abi = CLinker.getInstance();6162private StdLibHelper stdLibHelper = new StdLibHelper();6364@Test(dataProvider = "stringPairs")65void test_strcat(String s1, String s2) throws Throwable {66assertEquals(stdLibHelper.strcat(s1, s2), s1 + s2);67}6869@Test(dataProvider = "stringPairs")70void test_strcmp(String s1, String s2) throws Throwable {71assertEquals(Math.signum(stdLibHelper.strcmp(s1, s2)), Math.signum(s1.compareTo(s2)));72}7374@Test(dataProvider = "strings")75void test_puts(String s) throws Throwable {76assertTrue(stdLibHelper.puts(s) >= 0);77}7879@Test(dataProvider = "strings")80void test_strlen(String s) throws Throwable {81assertEquals(stdLibHelper.strlen(s), s.length());82}8384@Test(dataProvider = "instants")85void test_time(Instant instant) throws Throwable {86StdLibHelper.Tm tm = stdLibHelper.gmtime(instant.getEpochSecond());87LocalDateTime localTime = LocalDateTime.ofInstant(instant, ZoneOffset.UTC);88assertEquals(tm.sec(), localTime.getSecond());89assertEquals(tm.min(), localTime.getMinute());90assertEquals(tm.hour(), localTime.getHour());91//day pf year in Java has 1-offset92assertEquals(tm.yday(), localTime.getDayOfYear() - 1);93assertEquals(tm.mday(), localTime.getDayOfMonth());94//days of week starts from Sunday in C, but on Monday in Java, also account for 1-offset95assertEquals((tm.wday() + 6) % 7, localTime.getDayOfWeek().getValue() - 1);96//month in Java has 1-offset97assertEquals(tm.mon(), localTime.getMonth().getValue() - 1);98assertEquals(tm.isdst(), ZoneOffset.UTC.getRules()99.isDaylightSavings(Instant.ofEpochMilli(instant.getEpochSecond() * 1000)));100}101102@Test(dataProvider = "ints")103void test_qsort(List<Integer> ints) throws Throwable {104if (ints.size() > 0) {105int[] input = ints.stream().mapToInt(i -> i).toArray();106int[] sorted = stdLibHelper.qsort(input);107Arrays.sort(input);108assertEquals(sorted, input);109}110}111112@Test113void test_rand() throws Throwable {114int val = stdLibHelper.rand();115for (int i = 0 ; i < 100 ; i++) {116int newVal = stdLibHelper.rand();117if (newVal != val) {118return; //ok119}120val = newVal;121}122fail("All values are the same! " + val);123}124125@Test(dataProvider = "printfArgs")126void test_printf(List<PrintfArg> args) throws Throwable {127String formatArgs = args.stream()128.map(a -> a.format)129.collect(Collectors.joining(","));130131String formatString = "hello(" + formatArgs + ")\n";132133String expected = String.format(formatString, args.stream()134.map(a -> a.javaValue).toArray());135136int found = stdLibHelper.printf(formatString, args);137assertEquals(found, expected.length());138}139140@Test(dataProvider = "printfArgs")141void test_vprintf(List<PrintfArg> args) throws Throwable {142String formatArgs = args.stream()143.map(a -> a.format)144.collect(Collectors.joining(","));145146String formatString = "hello(" + formatArgs + ")\n";147148String expected = String.format(formatString, args.stream()149.map(a -> a.javaValue).toArray());150151int found = stdLibHelper.vprintf(formatString, args);152assertEquals(found, expected.length());153}154155static class StdLibHelper {156157static final SymbolLookup LOOKUP;158159static {160System.loadLibrary("StdLib");161SymbolLookup stdLibLookup = SymbolLookup.loaderLookup();162MemorySegment funcs = stdLibLookup.lookup("funcs").get()163.asSegment(C_POINTER.byteSize() * 3, ResourceScope.newImplicitScope());164165SymbolLookup fallbackLookup = name -> switch (name) {166case "printf" -> Optional.of(MemoryAccess.getAddressAtIndex(funcs, 0));167case "vprintf" -> Optional.of(MemoryAccess.getAddressAtIndex(funcs, 1));168case "gmtime" -> Optional.of(MemoryAccess.getAddressAtIndex(funcs, 2));169default -> Optional.empty();170};171172LOOKUP = name -> CLinker.systemLookup().lookup(name).or(() -> fallbackLookup.lookup(name));173}174175final static MethodHandle strcat = abi.downcallHandle(LOOKUP.lookup("strcat").get(),176MethodType.methodType(MemoryAddress.class, MemoryAddress.class, MemoryAddress.class),177FunctionDescriptor.of(C_POINTER, C_POINTER, C_POINTER));178179final static MethodHandle strcmp = abi.downcallHandle(LOOKUP.lookup("strcmp").get(),180MethodType.methodType(int.class, MemoryAddress.class, MemoryAddress.class),181FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER));182183final static MethodHandle puts = abi.downcallHandle(LOOKUP.lookup("puts").get(),184MethodType.methodType(int.class, MemoryAddress.class),185FunctionDescriptor.of(C_INT, C_POINTER));186187final static MethodHandle strlen = abi.downcallHandle(LOOKUP.lookup("strlen").get(),188MethodType.methodType(int.class, MemoryAddress.class),189FunctionDescriptor.of(C_INT, C_POINTER));190191final static MethodHandle gmtime = abi.downcallHandle(LOOKUP.lookup("gmtime").get(),192MethodType.methodType(MemoryAddress.class, MemoryAddress.class),193FunctionDescriptor.of(C_POINTER, C_POINTER));194195final static MethodHandle qsort = abi.downcallHandle(LOOKUP.lookup("qsort").get(),196MethodType.methodType(void.class, MemoryAddress.class, long.class, long.class, MemoryAddress.class),197FunctionDescriptor.ofVoid(C_POINTER, C_LONG_LONG, C_LONG_LONG, C_POINTER));198199final static FunctionDescriptor qsortComparFunction = FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER);200201final static MethodHandle qsortCompar;202203final static MethodHandle rand = abi.downcallHandle(LOOKUP.lookup("rand").get(),204MethodType.methodType(int.class),205FunctionDescriptor.of(C_INT));206207final static MethodHandle vprintf = abi.downcallHandle(LOOKUP.lookup("vprintf").get(),208MethodType.methodType(int.class, MemoryAddress.class, VaList.class),209FunctionDescriptor.of(C_INT, C_POINTER, C_VA_LIST));210211final static MemoryAddress printfAddr = LOOKUP.lookup("printf").get();212213final static FunctionDescriptor printfBase = FunctionDescriptor.of(C_INT, C_POINTER);214215static {216try {217//qsort upcall handle218qsortCompar = MethodHandles.lookup().findStatic(StdLibTest.StdLibHelper.class, "qsortCompare",219MethodType.methodType(int.class, MemorySegment.class, MemoryAddress.class, MemoryAddress.class));220} catch (ReflectiveOperationException ex) {221throw new IllegalStateException(ex);222}223}224225String strcat(String s1, String s2) throws Throwable {226try (ResourceScope scope = ResourceScope.newConfinedScope()) {227MemorySegment buf = MemorySegment.allocateNative(s1.length() + s2.length() + 1, scope);228MemorySegment other = toCString(s2, scope);229char[] chars = s1.toCharArray();230for (long i = 0 ; i < chars.length ; i++) {231setByteAtOffset(buf, i, (byte)chars[(int)i]);232}233setByteAtOffset(buf, chars.length, (byte)'\0');234return toJavaString(((MemoryAddress)strcat.invokeExact(buf.address(), other.address())));235}236}237238int strcmp(String s1, String s2) throws Throwable {239try (ResourceScope scope = ResourceScope.newConfinedScope()) {240MemorySegment ns1 = toCString(s1, scope);241MemorySegment ns2 = toCString(s2, scope);242return (int)strcmp.invokeExact(ns1.address(), ns2.address());243}244}245246int puts(String msg) throws Throwable {247try (ResourceScope scope = ResourceScope.newConfinedScope()) {248MemorySegment s = toCString(msg, scope);249return (int)puts.invokeExact(s.address());250}251}252253int strlen(String msg) throws Throwable {254try (ResourceScope scope = ResourceScope.newConfinedScope()) {255MemorySegment s = toCString(msg, scope);256return (int)strlen.invokeExact(s.address());257}258}259260Tm gmtime(long arg) throws Throwable {261try (ResourceScope scope = ResourceScope.newConfinedScope()) {262MemorySegment time = MemorySegment.allocateNative(8, scope);263setLong(time, arg);264return new Tm((MemoryAddress)gmtime.invokeExact(time.address()));265}266}267268static class Tm {269270//Tm pointer should never be freed directly, as it points to shared memory271private final MemorySegment base;272273static final long SIZE = 56;274275Tm(MemoryAddress addr) {276this.base = addr.asSegment(SIZE, ResourceScope.globalScope());277}278279int sec() {280return getIntAtOffset(base, 0);281}282int min() {283return getIntAtOffset(base, 4);284}285int hour() {286return getIntAtOffset(base, 8);287}288int mday() {289return getIntAtOffset(base, 12);290}291int mon() {292return getIntAtOffset(base, 16);293}294int year() {295return getIntAtOffset(base, 20);296}297int wday() {298return getIntAtOffset(base, 24);299}300int yday() {301return getIntAtOffset(base, 28);302}303boolean isdst() {304byte b = getByteAtOffset(base, 32);305return b != 0;306}307}308309int[] qsort(int[] arr) throws Throwable {310//init native array311try (ResourceScope scope = ResourceScope.newConfinedScope()) {312SegmentAllocator allocator = SegmentAllocator.ofScope(scope);313MemorySegment nativeArr = allocator.allocateArray(C_INT, arr);314315//call qsort316MemoryAddress qsortUpcallStub = abi.upcallStub(qsortCompar.bindTo(nativeArr), qsortComparFunction, scope);317318qsort.invokeExact(nativeArr.address(), (long)arr.length, C_INT.byteSize(), qsortUpcallStub);319320//convert back to Java array321return nativeArr.toIntArray();322}323}324325static int qsortCompare(MemorySegment base, MemoryAddress addr1, MemoryAddress addr2) {326return getIntAtOffset(base, addr1.segmentOffset(base)) -327getIntAtOffset(base, addr2.segmentOffset(base));328}329330int rand() throws Throwable {331return (int)rand.invokeExact();332}333334int printf(String format, List<PrintfArg> args) throws Throwable {335try (ResourceScope scope = ResourceScope.newConfinedScope()) {336MemorySegment formatStr = toCString(format, scope);337return (int)specializedPrintf(args).invokeExact(formatStr.address(),338args.stream().map(a -> a.nativeValue(scope)).toArray());339}340}341342int vprintf(String format, List<PrintfArg> args) throws Throwable {343try (ResourceScope scope = ResourceScope.newConfinedScope()) {344MemorySegment formatStr = toCString(format, scope);345VaList vaList = VaList.make(b -> args.forEach(a -> a.accept(b, scope)), scope);346return (int)vprintf.invokeExact(formatStr.address(), vaList);347}348}349350private MethodHandle specializedPrintf(List<PrintfArg> args) {351//method type352MethodType mt = MethodType.methodType(int.class, MemoryAddress.class);353FunctionDescriptor fd = printfBase;354for (PrintfArg arg : args) {355mt = mt.appendParameterTypes(arg.carrier);356fd = fd.withAppendedArgumentLayouts(arg.layout);357}358MethodHandle mh = abi.downcallHandle(printfAddr, mt, fd);359return mh.asSpreader(1, Object[].class, args.size());360}361}362363/*** data providers ***/364365@DataProvider366public static Object[][] ints() {367return perms(0, new Integer[] { 0, 1, 2, 3, 4 }).stream()368.map(l -> new Object[] { l })369.toArray(Object[][]::new);370}371372@DataProvider373public static Object[][] strings() {374return perms(0, new String[] { "a", "b", "c" }).stream()375.map(l -> new Object[] { String.join("", l) })376.toArray(Object[][]::new);377}378379@DataProvider380public static Object[][] stringPairs() {381Object[][] strings = strings();382Object[][] stringPairs = new Object[strings.length * strings.length][];383int pos = 0;384for (Object[] s1 : strings) {385for (Object[] s2 : strings) {386stringPairs[pos++] = new Object[] { s1[0], s2[0] };387}388}389return stringPairs;390}391392@DataProvider393public static Object[][] instants() {394Instant start = ZonedDateTime.of(LocalDateTime.parse("2017-01-01T00:00:00"), ZoneOffset.UTC).toInstant();395Instant end = ZonedDateTime.of(LocalDateTime.parse("2017-12-31T00:00:00"), ZoneOffset.UTC).toInstant();396Object[][] instants = new Object[100][];397for (int i = 0 ; i < instants.length ; i++) {398Instant instant = start.plusSeconds((long)(Math.random() * (end.getEpochSecond() - start.getEpochSecond())));399instants[i] = new Object[] { instant };400}401return instants;402}403404@DataProvider405public static Object[][] printfArgs() {406ArrayList<List<PrintfArg>> res = new ArrayList<>();407List<List<PrintfArg>> perms = new ArrayList<>(perms(0, PrintfArg.values()));408for (int i = 0 ; i < 100 ; i++) {409Collections.shuffle(perms);410res.addAll(perms);411}412return res.stream()413.map(l -> new Object[] { l })414.toArray(Object[][]::new);415}416417enum PrintfArg implements BiConsumer<VaList.Builder, ResourceScope> {418419INTEGRAL(int.class, asVarArg(C_INT), "%d", scope -> 42, 42, VaList.Builder::vargFromInt),420STRING(MemoryAddress.class, asVarArg(C_POINTER), "%s", scope -> toCString("str", scope).address(), "str", VaList.Builder::vargFromAddress),421CHAR(byte.class, asVarArg(C_CHAR), "%c", scope -> (byte) 'h', 'h', (builder, layout, value) -> builder.vargFromInt(C_INT, (int)value)),422DOUBLE(double.class, asVarArg(C_DOUBLE), "%.4f", scope ->1.2345d, 1.2345d, VaList.Builder::vargFromDouble);423424final Class<?> carrier;425final ValueLayout layout;426final String format;427final Function<ResourceScope, ?> nativeValueFactory;428final Object javaValue;429@SuppressWarnings("rawtypes")430final VaListBuilderCall builderCall;431432<Z> PrintfArg(Class<?> carrier, ValueLayout layout, String format, Function<ResourceScope, Z> nativeValueFactory, Object javaValue, VaListBuilderCall<Z> builderCall) {433this.carrier = carrier;434this.layout = layout;435this.format = format;436this.nativeValueFactory = nativeValueFactory;437this.javaValue = javaValue;438this.builderCall = builderCall;439}440441@Override442@SuppressWarnings("unchecked")443public void accept(VaList.Builder builder, ResourceScope scope) {444builderCall.build(builder, layout, nativeValueFactory.apply(scope));445}446447interface VaListBuilderCall<V> {448void build(VaList.Builder builder, ValueLayout layout, V value);449}450451public Object nativeValue(ResourceScope scope) {452return nativeValueFactory.apply(scope);453}454}455456static <Z> Set<List<Z>> perms(int count, Z[] arr) {457if (count == arr.length) {458return Set.of(List.of());459} else {460return Arrays.stream(arr)461.flatMap(num -> {462Set<List<Z>> perms = perms(count + 1, arr);463return Stream.concat(464//take n465perms.stream().map(l -> {466List<Z> li = new ArrayList<>(l);467li.add(num);468return li;469}),470//drop n471perms.stream());472}).collect(Collectors.toCollection(LinkedHashSet::new));473}474}475}476477478