package org.graalvm.compiler.hotspot.test;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.AlgorithmParameters;
import java.security.SecureRandom;
import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;
import org.graalvm.compiler.code.CompilationResult;
import org.graalvm.compiler.debug.DebugContext;
import org.graalvm.compiler.hotspot.meta.HotSpotGraphBuilderPlugins;
import org.junit.Assert;
import org.junit.Test;
import jdk.vm.ci.code.InstalledCode;
import jdk.vm.ci.meta.ResolvedJavaMethod;
public class HotSpotCryptoSubstitutionTest extends HotSpotGraalCompilerTest {
@Override
protected InstalledCode addMethod(DebugContext debug, ResolvedJavaMethod method, CompilationResult compResult) {
return getBackend().createDefaultInstalledCode(debug, method, compResult);
}
SecretKey aesKey;
SecretKey desKey;
byte[] input;
ByteArrayOutputStream aesExpected = new ByteArrayOutputStream();
ByteArrayOutputStream desExpected = new ByteArrayOutputStream();
public HotSpotCryptoSubstitutionTest() throws Exception {
byte[] seed = {0x4, 0x7, 0x1, 0x1};
SecureRandom random = new SecureRandom(seed);
KeyGenerator aesKeyGen = KeyGenerator.getInstance("AES");
KeyGenerator desKeyGen = KeyGenerator.getInstance("DESede");
aesKeyGen.init(128, random);
desKeyGen.init(168, random);
aesKey = aesKeyGen.generateKey();
desKey = desKeyGen.generateKey();
input = readClassfile16(getClass());
aesExpected.write(runEncryptDecrypt(aesKey, "AES/CBC/NoPadding"));
aesExpected.write(runEncryptDecrypt(aesKey, "AES/CBC/PKCS5Padding"));
desExpected.write(runEncryptDecrypt(desKey, "DESede/CBC/NoPadding"));
desExpected.write(runEncryptDecrypt(desKey, "DESede/CBC/PKCS5Padding"));
}
@Test
public void testAESCryptIntrinsics() throws Exception {
String aesEncryptName = HotSpotGraphBuilderPlugins.lookupIntrinsicName(runtime().getVMConfig(), "com/sun/crypto/provider/AESCrypt", "implEncryptBlock", "encryptBlock");
String aesDecryptName = HotSpotGraphBuilderPlugins.lookupIntrinsicName(runtime().getVMConfig(), "com/sun/crypto/provider/AESCrypt", "implDecryptBlock", "decryptBlock");
if (compileAndInstall("com.sun.crypto.provider.AESCrypt", aesEncryptName, aesDecryptName)) {
ByteArrayOutputStream actual = new ByteArrayOutputStream();
actual.write(runEncryptDecrypt(aesKey, "AES/CBC/NoPadding"));
actual.write(runEncryptDecrypt(aesKey, "AES/CBC/PKCS5Padding"));
Assert.assertArrayEquals(aesExpected.toByteArray(), actual.toByteArray());
}
}
@Test
public void testCipherBlockChainingIntrinsics() throws Exception {
String cbcEncryptName = HotSpotGraphBuilderPlugins.lookupIntrinsicName(runtime().getVMConfig(), "com/sun/crypto/provider/CipherBlockChaining", "implEncrypt", "encrypt");
String cbcDecryptName = HotSpotGraphBuilderPlugins.lookupIntrinsicName(runtime().getVMConfig(), "com/sun/crypto/provider/CipherBlockChaining", "implDecrypt", "decrypt");
if (compileAndInstall("com.sun.crypto.provider.CipherBlockChaining", cbcEncryptName, cbcDecryptName)) {
ByteArrayOutputStream actual = new ByteArrayOutputStream();
actual.write(runEncryptDecrypt(aesKey, "AES/CBC/NoPadding"));
actual.write(runEncryptDecrypt(aesKey, "AES/CBC/PKCS5Padding"));
Assert.assertArrayEquals(aesExpected.toByteArray(), actual.toByteArray());
actual.reset();
actual.write(runEncryptDecrypt(desKey, "DESede/CBC/NoPadding"));
actual.write(runEncryptDecrypt(desKey, "DESede/CBC/PKCS5Padding"));
Assert.assertArrayEquals(desExpected.toByteArray(), actual.toByteArray());
}
}
private boolean compileAndInstall(String className, String... methodNames) {
if (!runtime().getVMConfig().useAESIntrinsics) {
return false;
}
Class<?> c;
try {
c = Class.forName(className);
} catch (ClassNotFoundException e) {
return false;
}
boolean atLeastOneCompiled = false;
for (String methodName : methodNames) {
if (compileAndInstallSubstitution(c, methodName) != null) {
atLeastOneCompiled = true;
}
}
return atLeastOneCompiled;
}
AlgorithmParameters algorithmParameters;
private byte[] encrypt(byte[] indata, SecretKey key, String algorithm) throws Exception {
byte[] result = indata;
Cipher c = Cipher.getInstance(algorithm);
c.init(Cipher.ENCRYPT_MODE, key);
algorithmParameters = c.getParameters();
byte[] r1 = c.update(result);
byte[] r2 = c.doFinal();
result = new byte[r1.length + r2.length];
System.arraycopy(r1, 0, result, 0, r1.length);
System.arraycopy(r2, 0, result, r1.length, r2.length);
return result;
}
private byte[] decrypt(byte[] indata, SecretKey key, String algorithm) throws Exception {
byte[] result = indata;
Cipher c = Cipher.getInstance(algorithm);
c.init(Cipher.DECRYPT_MODE, key, algorithmParameters);
byte[] r1 = c.update(result);
byte[] r2 = c.doFinal();
result = new byte[r1.length + r2.length];
System.arraycopy(r1, 0, result, 0, r1.length);
System.arraycopy(r2, 0, result, r1.length, r2.length);
return result;