Path: blob/master/test/jdk/sun/security/util/math/TestIntegerModuloP.java
41152 views
/*1* Copyright (c) 2018, 2021, Oracle and/or its affiliates. All rights reserved.2* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.3*4* This code is free software; you can redistribute it and/or modify it5* under the terms of the GNU General Public License version 2 only, as6* published by the Free Software Foundation.7*8* This code is distributed in the hope that it will be useful, but WITHOUT9* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or10* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License11* version 2 for more details (a copy is included in the LICENSE file that12* accompanied this code).13*14* You should have received a copy of the GNU General Public License version15* 2 along with this work; if not, write to the Free Software Foundation,16* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.17*18* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA19* or visit www.oracle.com if you need additional information or have any20* questions.21*/2223/*24* @test25* @bug 8181594 820864826* @summary Test proper operation of integer field arithmetic27* @modules java.base/sun.security.util java.base/sun.security.util.math java.base/sun.security.util.math.intpoly28* @build BigIntegerModuloP29* @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial25519 32 030* @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial448 56 131* @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial1305 16 232* @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP256 32 533* @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP384 48 634* @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP521 66 735* @run main TestIntegerModuloP sun.security.util.math.intpoly.P256OrderField 32 836* @run main TestIntegerModuloP sun.security.util.math.intpoly.P384OrderField 48 937* @run main TestIntegerModuloP sun.security.util.math.intpoly.P521OrderField 66 1038* @run main TestIntegerModuloP sun.security.util.math.intpoly.Curve25519OrderField 32 1139* @run main TestIntegerModuloP sun.security.util.math.intpoly.Curve448OrderField 56 1240*/4142import sun.security.util.math.*;43import sun.security.util.math.intpoly.*;44import java.util.function.*;4546import java.util.*;47import java.math.*;48import java.nio.*;4950public class TestIntegerModuloP {5152static BigInteger TWO = BigInteger.valueOf(2);5354// The test has a list of functions, and it selects randomly from that list5556// The function types57interface ElemFunction extends BiFunction58<MutableIntegerModuloP, IntegerModuloP, IntegerModuloP> { }59interface ElemArrayFunction extends BiFunction60<MutableIntegerModuloP, IntegerModuloP, byte[]> { }61interface TriConsumer <T, U, V> {62void accept(T t, U u, V v);63}64interface ElemSetFunction extends TriConsumer65<MutableIntegerModuloP, IntegerModuloP, byte[]> { }6667// The lists of functions. Multiple lists are needed because the test68// respects the limitations of the arithmetic implementations.69static final List<ElemFunction> ADD_FUNCTIONS = new ArrayList<>();70static final List<ElemFunction> MULT_FUNCTIONS = new ArrayList<>();71static final List<ElemArrayFunction> ARRAY_FUNCTIONS = new ArrayList<>();72static final List<ElemSetFunction> SET_FUNCTIONS = new ArrayList<>();7374static void setUpFunctions(IntegerFieldModuloP field, int length) {7576ADD_FUNCTIONS.clear();77MULT_FUNCTIONS.clear();78SET_FUNCTIONS.clear();79ARRAY_FUNCTIONS.clear();8081byte highByte = (byte)82(field.getSize().bitLength() > length * 8 ? 1 : 0);8384// add functions are (im)mutable add/subtract85ADD_FUNCTIONS.add(IntegerModuloP::add);86ADD_FUNCTIONS.add(IntegerModuloP::subtract);87ADD_FUNCTIONS.add(MutableIntegerModuloP::setSum);88ADD_FUNCTIONS.add(MutableIntegerModuloP::setDifference);89// also include functions that return the first/second argument90ADD_FUNCTIONS.add((a, b) -> a);91ADD_FUNCTIONS.add((a, b) -> b);9293// mult functions are (im)mutable multiply and square94MULT_FUNCTIONS.add(IntegerModuloP::multiply);95MULT_FUNCTIONS.add((a, b) -> a.square());96MULT_FUNCTIONS.add((a, b) -> b.square());97MULT_FUNCTIONS.add(MutableIntegerModuloP::setProduct);98MULT_FUNCTIONS.add((a, b) -> a.setSquare());99// also test multiplication by a small value100MULT_FUNCTIONS.add((a, b) -> a.setProduct(b.getField().getSmallValue(101b.asBigInteger().mod(BigInteger.valueOf(262144)).intValue())));102103// set functions are setValue with various argument types104SET_FUNCTIONS.add((a, b, c) -> a.setValue(b));105SET_FUNCTIONS.add((a, b, c) ->106a.setValue(c, 0, c.length, (byte) 0));107SET_FUNCTIONS.add((a, b, c) ->108a.setValue(c, 0, c.length / 2, (byte) 0));109SET_FUNCTIONS.add((a, b, c) ->110a.setValue(ByteBuffer.wrap(c, 0, c.length / 2).order(ByteOrder.LITTLE_ENDIAN),111c.length / 2, highByte));112113// array functions return the (possibly modified) value as byte array114ARRAY_FUNCTIONS.add((a, b ) -> a.asByteArray(length));115ARRAY_FUNCTIONS.add((a, b) -> a.addModPowerTwo(b, length));116}117118public static void main(String[] args) {119120String className = args[0];121final int length = Integer.parseInt(args[1]);122int seed = Integer.parseInt(args[2]);123124Class<IntegerFieldModuloP> fieldBaseClass = IntegerFieldModuloP.class;125try {126Class<? extends IntegerFieldModuloP> clazz =127Class.forName(className).asSubclass(fieldBaseClass);128IntegerFieldModuloP field =129clazz.getDeclaredConstructor().newInstance();130131setUpFunctions(field, length);132133runFieldTest(field, length, seed);134} catch (Exception ex) {135throw new RuntimeException(ex);136}137System.out.println("All tests passed");138}139140static void assertEqual(IntegerModuloP e1, IntegerModuloP e2) {141142if (!e1.asBigInteger().equals(e2.asBigInteger())) {143throw new RuntimeException("values not equal: "144+ e1.asBigInteger() + " != " + e2.asBigInteger());145}146}147148// A class that holds pairs of actual/expected values, and allows149// computation on these pairs.150static class TestPair<T extends IntegerModuloP> {151private final T test;152private final T baseline;153154public TestPair(T test, T baseline) {155this.test = test;156this.baseline = baseline;157}158159public T getTest() {160return test;161}162public T getBaseline() {163return baseline;164}165166private void assertEqual() {167TestIntegerModuloP.assertEqual(test, baseline);168}169170public TestPair<MutableIntegerModuloP> mutable() {171return new TestPair<>(test.mutable(), baseline.mutable());172}173174public175<R extends IntegerModuloP, X extends IntegerModuloP>176TestPair<X> apply(BiFunction<T, R, X> func, TestPair<R> right) {177X testResult = func.apply(test, right.test);178X baselineResult = func.apply(baseline, right.baseline);179return new TestPair(testResult, baselineResult);180}181182public183<U extends IntegerModuloP, V>184void apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV) {185func.accept(test, right.test, argV);186func.accept(baseline, right.baseline, argV);187}188189public190<R extends IntegerModuloP>191void applyAndCheckArray(BiFunction<T, R, byte[]> func,192TestPair<R> right) {193byte[] testResult = func.apply(test, right.test);194byte[] baselineResult = func.apply(baseline, right.baseline);195if (!Arrays.equals(testResult, baselineResult)) {196throw new RuntimeException("Array values do not match: "197+ HexFormat.of().withUpperCase().formatHex(testResult) + " != "198+ HexFormat.of().withUpperCase().formatHex(baselineResult));199}200}201202}203204static TestPair<IntegerModuloP>205applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left,206TestPair<IntegerModuloP> right) {207208TestPair<IntegerModuloP> result = left.apply(func, right);209result.assertEqual();210left.assertEqual();211right.assertEqual();212213return result;214}215216static void217setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left,218TestPair<IntegerModuloP> right, byte[] argV) {219220left.apply(func, right, argV);221left.assertEqual();222right.assertEqual();223}224225static TestPair<MutableIntegerModuloP>226applyAndCheckMutable(ElemFunction func,227TestPair<MutableIntegerModuloP> left,228TestPair<IntegerModuloP> right) {229230TestPair<IntegerModuloP> result = applyAndCheck(func, left, right);231232TestPair<MutableIntegerModuloP> mutableResult = result.mutable();233mutableResult.assertEqual();234result.assertEqual();235left.assertEqual();236right.assertEqual();237238return mutableResult;239}240241static void242cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left,243TestPair<MutableIntegerModuloP> right) {244245left.getTest().conditionalSwapWith(right.getTest(), swap);246left.getBaseline().conditionalSwapWith(right.getBaseline(), swap);247248left.assertEqual();249right.assertEqual();250251}252253// Request arithmetic that should overflow, and ensure that overflow is254// detected.255static void runOverflowTest(TestPair<IntegerModuloP> elem) {256257TestPair<MutableIntegerModuloP> mutableElem = elem.mutable();258259try {260for (int i = 0; i < 1000; i++) {261applyAndCheck(MutableIntegerModuloP::setSum, mutableElem, elem);262}263applyAndCheck(MutableIntegerModuloP::setProduct, mutableElem, elem);264} catch (ArithmeticException ex) {265// this is expected266}267268mutableElem = elem.mutable();269try {270for (int i = 0; i < 1000; i++) {271elem = applyAndCheck(IntegerModuloP::add,272mutableElem, elem);273}274applyAndCheck(IntegerModuloP::multiply, mutableElem, elem);275} catch (ArithmeticException ex) {276// this is expected277}278}279280// Run a large number of random operations and ensure that281// results are correct282static void runOperationsTest(Random random, int length,283TestPair<IntegerModuloP> elem,284TestPair<IntegerModuloP> right) {285286TestPair<MutableIntegerModuloP> left = elem.mutable();287288for (int i = 0; i < 10000; i++) {289290ElemFunction addFunc1 =291ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size()));292TestPair<MutableIntegerModuloP> result1 =293applyAndCheckMutable(addFunc1, left, right);294295// left could have been modified, so turn it back into a summand296applyAndCheckMutable((a, b) -> a.setSquare(), left, right);297298ElemFunction addFunc2 =299ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size()));300TestPair<IntegerModuloP> result2 =301applyAndCheck(addFunc2, left, right);302303if (elem.test.getField() instanceof IntegerPolynomial) {304IntegerPolynomial field =305(IntegerPolynomial) elem.test.getField();306int numAdds = field.getMaxAdds();307for (int j = 1; j < numAdds; j++) {308ElemFunction addFunc3 = ADD_FUNCTIONS.309get(random.nextInt(ADD_FUNCTIONS.size()));310result2 = applyAndCheck(addFunc3, left, right);311}312}313314ElemFunction multFunc2 =315MULT_FUNCTIONS.get(random.nextInt(MULT_FUNCTIONS.size()));316TestPair<MutableIntegerModuloP> multResult =317applyAndCheckMutable(multFunc2, result1, result2);318319int swap = random.nextInt(2);320cswapAndCheck(swap, left, multResult);321322ElemSetFunction setFunc =323SET_FUNCTIONS.get(random.nextInt(SET_FUNCTIONS.size()));324byte[] valueArr = new byte[2 * length];325random.nextBytes(valueArr);326setAndCheck(setFunc, result1, result2, valueArr);327328// left could have been modified, so to turn it back into a summand329applyAndCheckMutable((a, b) -> a.setSquare(), left, right);330331ElemArrayFunction arrayFunc =332ARRAY_FUNCTIONS.get(random.nextInt(ARRAY_FUNCTIONS.size()));333left.applyAndCheckArray(arrayFunc, right);334}335}336337// Run all the tests for a given field338static void runFieldTest(IntegerFieldModuloP testField,339int length, int seed) {340System.out.println("Testing: " + testField.getClass().getSimpleName());341342Random random = new Random(seed);343344IntegerFieldModuloP baselineField =345new BigIntegerModuloP(testField.getSize());346347int numBits = testField.getSize().bitLength();348BigInteger r =349new BigInteger(numBits, random).mod(testField.getSize());350TestPair<IntegerModuloP> rand =351new TestPair(testField.getElement(r), baselineField.getElement(r));352353runOverflowTest(rand);354355// check combinations of operations for different kinds of elements356List<TestPair<IntegerModuloP>> testElements = new ArrayList<>();357testElements.add(rand);358testElements.add(new TestPair(testField.get0(), baselineField.get0()));359testElements.add(new TestPair(testField.get1(), baselineField.get1()));360byte[] testArr = {121, 37, -100, -5, 76, 33};361testElements.add(new TestPair(testField.getElement(testArr),362baselineField.getElement(testArr)));363364testArr = new byte[length];365random.nextBytes(testArr);366testElements.add(new TestPair(testField.getElement(testArr),367baselineField.getElement(testArr)));368369random.nextBytes(testArr);370byte highByte = (byte) (numBits > length * 8 ? 1 : 0);371testElements.add(372new TestPair(373testField.getElement(testArr, 0, testArr.length, highByte),374baselineField.getElement(testArr, 0, testArr.length, highByte)375)376);377378for (int i = 0; i < testElements.size(); i++) {379for (int j = 0; j < testElements.size(); j++) {380runOperationsTest(random, length, testElements.get(i),381testElements.get(j));382}383}384}385}386387388389