package org.graalvm.compiler.code;
import static jdk.vm.ci.meta.MetaUtil.identityHashCodeString;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Objects;
import java.util.function.BiConsumer;
import org.graalvm.compiler.code.DataSection.Data;
import org.graalvm.compiler.serviceprovider.BufferUtil;
import jdk.vm.ci.code.site.DataSectionReference;
import jdk.vm.ci.meta.SerializableConstant;
import jdk.vm.ci.meta.VMConstant;
public final class DataSection implements Iterable<Data> {
public interface Patches {
void registerPatch(int position, VMConstant c);
}
public abstract static class Data {
private int alignment;
private final int size;
private DataSectionReference ref;
protected Data(int alignment, int size) {
this.alignment = alignment;
this.size = size;
ref = null;
}
protected abstract void emit(ByteBuffer buffer, Patches patches);
public void updateAlignment(int newAlignment) {
if (newAlignment == alignment) {
return;
}
alignment = lcm(alignment, newAlignment);
}
public int getAlignment() {
return alignment;
}
public int getSize() {
return size;
}
@Override
public int hashCode() {
throw new UnsupportedOperationException("hashCode");
}
@Override
public String toString() {
return identityHashCodeString(this);
}
@Override
public boolean equals(Object obj) {
assert ref != null;
if (obj == this) {
return true;
}
if (obj instanceof Data) {
Data that = (Data) obj;
if (this.alignment == that.alignment && this.size == that.size && this.ref.equals(that.ref)) {
return true;
}
}
return false;
}
}
public static final class RawData extends Data {
private final byte[] data;
public RawData(byte[] data, int alignment) {
super(alignment, data.length);
this.data = data;
}
@Override
protected void emit(ByteBuffer buffer, Patches patches) {
buffer.put(data);
}
}
public static final class SerializableData extends Data {
private final SerializableConstant constant;
public SerializableData(SerializableConstant constant) {
this(constant, 1);
}
public SerializableData(SerializableConstant constant, int alignment) {
super(alignment, constant.getSerializedSize());
this.constant = constant;
}
@Override
protected void emit(ByteBuffer buffer, Patches patches) {
int position = buffer.position();
constant.serialize(buffer);
assert buffer.position() - position == constant.getSerializedSize() : "wrong number of bytes written";
}
}
public static class ZeroData extends Data {
protected ZeroData(int alignment, int size) {
super(alignment, size);
}
public static ZeroData create(int alignment, int size) {
switch (size) {
case 1:
return new ZeroData(alignment, size) {
@Override
protected void emit(ByteBuffer buffer, Patches patches) {
buffer.put((byte) 0);
}
};
case 2:
return new ZeroData(alignment, size) {
@Override
protected void emit(ByteBuffer buffer, Patches patches) {
buffer.putShort((short) 0);
}
};
case 4:
return new ZeroData(alignment, size) {
@Override
protected void emit(ByteBuffer buffer, Patches patches) {
buffer.putInt(0);
}
};
case 8:
return new ZeroData(alignment, size) {
@Override
protected void emit(ByteBuffer buffer, Patches patches) {
buffer.putLong(0);
}
};
default:
return new ZeroData(alignment, size);
}
}
@Override
protected void emit(ByteBuffer buffer, Patches patches) {
int rest = getSize();
while (rest > 8) {
buffer.putLong(0L);
rest -= 8;
}
while (rest > 0) {
buffer.put((byte) 0);
rest--;
}
}
}
public static final class PackedData extends Data {
private final Data[] nested;
private PackedData(int alignment, int size, Data[] nested) {
super(alignment, size);
this.nested = nested;
}
public static PackedData create(Data[] nested) {
int size = 0;
int alignment = 1;
for (int i = 0; i < nested.length; i++) {
assert size % nested[i].getAlignment() == 0 : "invalid alignment in packed constants";
alignment = DataSection.lcm(alignment, nested[i].getAlignment());
size += nested[i].getSize();
}
return new PackedData(alignment, size, nested);
}
@Override
protected void emit(ByteBuffer buffer, Patches patches) {
for (Data data : nested) {
data.emit(buffer, patches);
}
}
}
private final ArrayList<Data> dataItems = new ArrayList<>();
private boolean closed;
private int sectionAlignment;
private int sectionSize;
@Override
public int hashCode() {
throw new UnsupportedOperationException("hashCode");
}
@Override
public String toString() {
return identityHashCodeString(this);
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj instanceof DataSection) {
DataSection that = (DataSection) obj;
if (this.closed == that.closed && this.sectionAlignment == that.sectionAlignment && this.sectionSize == that.sectionSize && Objects.equals(this.dataItems, that.dataItems)) {
return true;
}
}
return false;
}
public DataSectionReference insertData(Data data) {
checkOpen();
synchronized (data) {
if (data.ref == null) {
data.ref = new DataSectionReference();
dataItems.add(data);
}
return data.ref;
}
}
public void addAll(DataSection other) {
checkOpen();
other.checkOpen();
for (Data data : other.dataItems) {
assert data.ref != null;
dataItems.add(data);
}
other.dataItems.clear();
}
public boolean closed() {
return closed;
}
public void close() {
checkOpen();
closed = true;
dataItems.sort((a, b) -> a.alignment - b.alignment);
int position = 0;
int alignment = 1;
for (Data d : dataItems) {
alignment = lcm(alignment, d.alignment);
position = align(position, d.alignment);
d.ref.setOffset(position);
position += d.size;
}
sectionAlignment = alignment;
sectionSize = position;
}
public int getSectionSize() {
checkClosed();
return sectionSize;
}
public int getSectionAlignment() {
checkClosed();
return sectionAlignment;
}
public void buildDataSection(ByteBuffer buffer, Patches patch) {
buildDataSection(buffer, patch, (r, s) -> {
});
}
public void buildDataSection(ByteBuffer buffer, Patches patch, BiConsumer<DataSectionReference, Integer> onEmit) {
checkClosed();
assert buffer.remaining() >= sectionSize;
int start = buffer.position();
for (Data d : dataItems) {
BufferUtil.asBaseBuffer(buffer).position(start + d.ref.getOffset());
onEmit.accept(d.ref, d.getSize());
d.emit(buffer, patch);
}