Path: blob/master/test/jdk/java/foreign/CallGeneratorHelper.java
41145 views
/*1* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.2* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.3*4* This code is free software; you can redistribute it and/or modify it5* under the terms of the GNU General Public License version 2 only, as6* published by the Free Software Foundation.7*8* This code is distributed in the hope that it will be useful, but WITHOUT9* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or10* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License11* version 2 for more details (a copy is included in the LICENSE file that12* accompanied this code).13*14* You should have received a copy of the GNU General Public License version15* 2 along with this work; if not, write to the Free Software Foundation,16* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.17*18* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA19* or visit www.oracle.com if you need additional information or have any20* questions.21*22*/2324import jdk.incubator.foreign.GroupLayout;25import jdk.incubator.foreign.MemoryAddress;26import jdk.incubator.foreign.MemoryLayout;27import jdk.incubator.foreign.MemorySegment;28import jdk.incubator.foreign.ResourceScope;29import jdk.incubator.foreign.SegmentAllocator;30import jdk.incubator.foreign.ValueLayout;3132import java.lang.invoke.VarHandle;33import java.util.ArrayList;34import java.util.List;35import java.util.Stack;36import java.util.function.Consumer;37import java.util.stream.Collectors;38import java.util.stream.IntStream;3940import org.testng.annotations.*;4142import static jdk.incubator.foreign.CLinker.*;43import static org.testng.Assert.*;4445public class CallGeneratorHelper extends NativeTestHelper {4647static SegmentAllocator IMPLICIT_ALLOCATOR = (size, align) -> MemorySegment.allocateNative(size, align, ResourceScope.newImplicitScope());4849static final int MAX_FIELDS = 3;50static final int MAX_PARAMS = 3;51static final int CHUNK_SIZE = 600;5253public static void assertStructEquals(MemorySegment actual, MemorySegment expected, MemoryLayout layout) {54assertEquals(actual.byteSize(), expected.byteSize());55GroupLayout g = (GroupLayout) layout;56for (MemoryLayout field : g.memberLayouts()) {57if (field instanceof ValueLayout) {58VarHandle vh = g.varHandle(vhCarrier(field), MemoryLayout.PathElement.groupElement(field.name().orElseThrow()));59assertEquals(vh.get(actual), vh.get(expected));60}61}62}6364private static Class<?> vhCarrier(MemoryLayout layout) {65if (layout instanceof ValueLayout) {66if (isIntegral(layout)) {67if (layout.bitSize() == 64) {68return long.class;69}70return int.class;71} else if (layout.bitSize() == 32) {72return float.class;73}74return double.class;75} else {76throw new IllegalStateException("Unexpected layout: " + layout);77}78}7980enum Ret {81VOID,82NON_VOID83}8485enum StructFieldType {86INT("int", C_INT),87FLOAT("float", C_FLOAT),88DOUBLE("double", C_DOUBLE),89POINTER("void*", C_POINTER);9091final String typeStr;92final MemoryLayout layout;9394StructFieldType(String typeStr, MemoryLayout layout) {95this.typeStr = typeStr;96this.layout = layout;97}9899MemoryLayout layout() {100return layout;101}102103@SuppressWarnings("unchecked")104static List<List<StructFieldType>>[] perms = new List[10];105106static List<List<StructFieldType>> perms(int i) {107if (perms[i] == null) {108perms[i] = generateTest(i, values());109}110return perms[i];111}112}113114enum ParamType {115INT("int", C_INT),116FLOAT("float", C_FLOAT),117DOUBLE("double", C_DOUBLE),118POINTER("void*", C_POINTER),119STRUCT("struct S", null);120121private final String typeStr;122private final MemoryLayout layout;123124ParamType(String typeStr, MemoryLayout layout) {125this.typeStr = typeStr;126this.layout = layout;127}128129String type(List<StructFieldType> fields) {130return this == STRUCT ?131typeStr + "_" + sigCode(fields) :132typeStr;133}134135MemoryLayout layout(List<StructFieldType> fields) {136if (this == STRUCT) {137long offset = 0L;138List<MemoryLayout> layouts = new ArrayList<>();139for (StructFieldType field : fields) {140MemoryLayout l = field.layout();141long padding = offset % l.bitSize();142if (padding != 0) {143layouts.add(MemoryLayout.paddingLayout(padding));144offset += padding;145}146layouts.add(l.withName("field" + offset));147offset += l.bitSize();148}149return MemoryLayout.structLayout(layouts.toArray(new MemoryLayout[0]));150} else {151return layout;152}153}154155@SuppressWarnings("unchecked")156static List<List<ParamType>>[] perms = new List[10];157158static List<List<ParamType>> perms(int i) {159if (perms[i] == null) {160perms[i] = generateTest(i, values());161}162return perms[i];163}164}165166static <Z> List<List<Z>> generateTest(int i, Z[] elems) {167List<List<Z>> res = new ArrayList<>();168generateTest(i, new Stack<>(), elems, res);169return res;170}171172static <Z> void generateTest(int i, Stack<Z> combo, Z[] elems, List<List<Z>> results) {173if (i == 0) {174results.add(new ArrayList<>(combo));175} else {176for (Z z : elems) {177combo.push(z);178generateTest(i - 1, combo, elems, results);179combo.pop();180}181}182}183184@DataProvider(name = "functions")185public static Object[][] functions() {186int functions = 0;187List<Object[]> downcalls = new ArrayList<>();188for (Ret r : Ret.values()) {189for (int i = 0; i <= MAX_PARAMS; i++) {190if (r != Ret.VOID && i == 0) continue;191for (List<ParamType> ptypes : ParamType.perms(i)) {192String retCode = r == Ret.VOID ? "V" : ptypes.get(0).name().charAt(0) + "";193String sigCode = sigCode(ptypes);194if (ptypes.contains(ParamType.STRUCT)) {195for (int j = 1; j <= MAX_FIELDS; j++) {196for (List<StructFieldType> fields : StructFieldType.perms(j)) {197String structCode = sigCode(fields);198int count = functions;199int fCode = functions++ / CHUNK_SIZE;200String fName = String.format("f%d_%s_%s_%s", fCode, retCode, sigCode, structCode);201downcalls.add(new Object[] { count, fName, r, ptypes, fields });202}203}204} else {205String structCode = sigCode(List.<StructFieldType>of());206int count = functions;207int fCode = functions++ / CHUNK_SIZE;208String fName = String.format("f%d_%s_%s_%s", fCode, retCode, sigCode, structCode);209downcalls.add(new Object[] { count, fName, r, ptypes, List.of() });210}211}212}213}214return downcalls.toArray(new Object[0][]);215}216217static <Z extends Enum<Z>> String sigCode(List<Z> elems) {218return elems.stream().map(p -> p.name().charAt(0) + "").collect(Collectors.joining());219}220221static void generateStructDecl(List<StructFieldType> fields) {222String structCode = sigCode(fields);223List<String> fieldDecls = new ArrayList<>();224for (int i = 0 ; i < fields.size() ; i++) {225fieldDecls.add(String.format("%s p%d;", fields.get(i).typeStr, i));226}227String res = String.format("struct S_%s { %s };", structCode,228fieldDecls.stream().collect(Collectors.joining(" ")));229System.out.println(res);230}231232/* this can be used to generate the test header/implementation */233public static void main(String[] args) {234boolean header = args.length > 0 && args[0].equals("header");235boolean upcall = args.length > 1 && args[1].equals("upcall");236if (upcall) {237generateUpcalls(header);238} else {239generateDowncalls(header);240}241}242243static void generateDowncalls(boolean header) {244if (header) {245System.out.println(246"#ifdef _WIN64\n" +247"#define EXPORT __declspec(dllexport)\n" +248"#else\n" +249"#define EXPORT\n" +250"#endif\n"251);252253for (int j = 1; j <= MAX_FIELDS; j++) {254for (List<StructFieldType> fields : StructFieldType.perms(j)) {255generateStructDecl(fields);256}257}258} else {259System.out.println(260"#include \"libh\"\n" +261"#ifdef __clang__\n" +262"#pragma clang optimize off\n" +263"#elif defined __GNUC__\n" +264"#pragma GCC optimize (\"O0\")\n" +265"#elif defined _MSC_BUILD\n" +266"#pragma optimize( \"\", off )\n" +267"#endif\n"268);269}270271for (Object[] downcall : functions()) {272String fName = (String)downcall[0];273Ret r = (Ret)downcall[1];274@SuppressWarnings("unchecked")275List<ParamType> ptypes = (List<ParamType>)downcall[2];276@SuppressWarnings("unchecked")277List<StructFieldType> fields = (List<StructFieldType>)downcall[3];278generateDowncallFunction(fName, r, ptypes, fields, header);279}280}281282static void generateDowncallFunction(String fName, Ret ret, List<ParamType> params, List<StructFieldType> fields, boolean declOnly) {283String retType = ret == Ret.VOID ? "void" : params.get(0).type(fields);284List<String> paramDecls = new ArrayList<>();285for (int i = 0 ; i < params.size() ; i++) {286paramDecls.add(String.format("%s p%d", params.get(i).type(fields), i));287}288String sig = paramDecls.isEmpty() ?289"void" :290paramDecls.stream().collect(Collectors.joining(", "));291String body = ret == Ret.VOID ? "{ }" : "{ return p0; }";292String res = String.format("EXPORT %s f%s(%s) %s", retType, fName,293sig, declOnly ? ";" : body);294System.out.println(res);295}296297static void generateUpcalls(boolean header) {298if (header) {299System.out.println(300"#ifdef _WIN64\n" +301"#define EXPORT __declspec(dllexport)\n" +302"#else\n" +303"#define EXPORT\n" +304"#endif\n"305);306307for (int j = 1; j <= MAX_FIELDS; j++) {308for (List<StructFieldType> fields : StructFieldType.perms(j)) {309generateStructDecl(fields);310}311}312} else {313System.out.println(314"#include \"libh\"\n" +315"#ifdef __clang__\n" +316"#pragma clang optimize off\n" +317"#elif defined __GNUC__\n" +318"#pragma GCC optimize (\"O0\")\n" +319"#elif defined _MSC_BUILD\n" +320"#pragma optimize( \"\", off )\n" +321"#endif\n"322);323}324325for (Object[] downcall : functions()) {326String fName = (String)downcall[0];327Ret r = (Ret)downcall[1];328@SuppressWarnings("unchecked")329List<ParamType> ptypes = (List<ParamType>)downcall[2];330@SuppressWarnings("unchecked")331List<StructFieldType> fields = (List<StructFieldType>)downcall[3];332generateUpcallFunction(fName, r, ptypes, fields, header);333}334}335336static void generateUpcallFunction(String fName, Ret ret, List<ParamType> params, List<StructFieldType> fields, boolean declOnly) {337String retType = ret == Ret.VOID ? "void" : params.get(0).type(fields);338List<String> paramDecls = new ArrayList<>();339for (int i = 0 ; i < params.size() ; i++) {340paramDecls.add(String.format("%s p%d", params.get(i).type(fields), i));341}342String paramNames = IntStream.range(0, params.size())343.mapToObj(i -> "p" + i)344.collect(Collectors.joining(","));345String sig = paramDecls.isEmpty() ?346"" :347paramDecls.stream().collect(Collectors.joining(", ")) + ", ";348String body = String.format(ret == Ret.VOID ? "{ cb(%s); }" : "{ return cb(%s); }", paramNames);349List<String> paramTypes = params.stream().map(p -> p.type(fields)).collect(Collectors.toList());350String cbSig = paramTypes.isEmpty() ?351"void" :352paramTypes.stream().collect(Collectors.joining(", "));353String cbParam = String.format("%s (*cb)(%s)",354retType, cbSig);355356String res = String.format("EXPORT %s %s(%s %s) %s", retType, fName,357sig, cbParam, declOnly ? ";" : body);358System.out.println(res);359}360361//helper methods362363@SuppressWarnings("unchecked")364static Object makeArg(MemoryLayout layout, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException {365if (layout instanceof GroupLayout) {366MemorySegment segment = MemorySegment.allocateNative(layout, ResourceScope.newImplicitScope());367initStruct(segment, (GroupLayout)layout, checks, check);368return segment;369} else if (isPointer(layout)) {370MemorySegment segment = MemorySegment.allocateNative(1, ResourceScope.newImplicitScope());371if (check) {372checks.add(o -> {373try {374assertEquals(o, segment.address());375} catch (Throwable ex) {376throw new IllegalStateException(ex);377}378});379}380return segment.address();381} else if (layout instanceof ValueLayout) {382if (isIntegral(layout)) {383if (check) {384checks.add(o -> assertEquals(o, 42));385}386return 42;387} else if (layout.bitSize() == 32) {388if (check) {389checks.add(o -> assertEquals(o, 12f));390}391return 12f;392} else {393if (check) {394checks.add(o -> assertEquals(o, 24d));395}396return 24d;397}398} else {399throw new IllegalStateException("Unexpected layout: " + layout);400}401}402403static void initStruct(MemorySegment str, GroupLayout g, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException {404for (MemoryLayout l : g.memberLayouts()) {405if (l.isPadding()) continue;406VarHandle accessor = g.varHandle(structFieldCarrier(l), MemoryLayout.PathElement.groupElement(l.name().get()));407List<Consumer<Object>> fieldsCheck = new ArrayList<>();408Object value = makeArg(l, fieldsCheck, check);409if (isPointer(l)) {410value = ((MemoryAddress)value).toRawLongValue();411}412//set value413accessor.set(str, value);414//add check415if (check) {416assertTrue(fieldsCheck.size() == 1);417checks.add(o -> {418MemorySegment actual = (MemorySegment)o;419try {420if (isPointer(l)) {421fieldsCheck.get(0).accept(MemoryAddress.ofLong((long)accessor.get(actual)));422} else {423fieldsCheck.get(0).accept(accessor.get(actual));424}425} catch (Throwable ex) {426throw new IllegalStateException(ex);427}428});429}430}431}432433static Class<?> structFieldCarrier(MemoryLayout layout) {434if (isPointer(layout)) {435return long.class;436} else if (layout instanceof ValueLayout) {437if (isIntegral(layout)) {438return int.class;439} else if (layout.bitSize() == 32) {440return float.class;441} else {442return double.class;443}444} else {445throw new IllegalStateException("Unexpected layout: " + layout);446}447}448449static Class<?> paramCarrier(MemoryLayout layout) {450if (layout instanceof GroupLayout) {451return MemorySegment.class;452} if (isPointer(layout)) {453return MemoryAddress.class;454} else if (layout instanceof ValueLayout) {455if (isIntegral(layout)) {456return int.class;457} else if (layout.bitSize() == 32) {458return float.class;459} else {460return double.class;461}462} else {463throw new IllegalStateException("Unexpected layout: " + layout);464}465}466}467468469