ai.djl.ndarray.NDList Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray;
import ai.djl.Device;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.JsonUtils;
import ai.djl.util.Pair;
import com.google.gson.JsonObject;
import com.google.gson.annotations.SerializedName;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PushbackInputStream;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
/**
* An {@code NDList} represents a sequence of {@link NDArray}s with names.
*
* Each {@link NDArray} in this list can optionally have a name. You can use the name to look up
* an NDArray in the NDList.
*
* @see NDArray
*/
public class NDList extends ArrayList implements NDResource, BytesSupplier {
private static final long serialVersionUID = 1L;
/** Constructs an empty NDList. */
public NDList() {}
/**
* Constructs an empty NDList with the specified initial capacity.
*
* @param initialCapacity the initial capacity of the list
* @throws IllegalArgumentException if the specified initial capacity is negative
*/
public NDList(int initialCapacity) {
super(initialCapacity);
}
/**
* Constructs and initiates an NDList with the specified {@link NDArray}s.
*
* @param arrays the {@link NDArray}s
*/
public NDList(NDArray... arrays) {
super(Arrays.asList(arrays));
}
/**
* Constructs and initiates an NDList with the specified {@link NDArray}s.
*
* @param other the {@link NDArray}s
*/
public NDList(Collection other) {
super(other);
}
/**
* Decodes NDList from byte array.
*
* @param manager manager assigned to {@link NDArray}
* @param byteArray byte array to load from
* @return {@code NDList}
*/
public static NDList decode(NDManager manager, byte[] byteArray) {
if (byteArray.length < 9) {
throw new IllegalArgumentException("Invalid input length: " + byteArray.length);
}
try {
if (byteArray[0] == 'P' && byteArray[1] == 'K') {
return decodeNumpy(manager, new ByteArrayInputStream(byteArray));
} else if (byteArray[0] == (byte) 0x93
&& byteArray[1] == 'N'
&& byteArray[2] == 'U'
&& byteArray[3] == 'M') {
return new NDList(
NDSerializer.decodeNumpy(manager, new ByteArrayInputStream(byteArray)));
} else if (byteArray[8] == '{') {
return decodeSafetensors(manager, new ByteArrayInputStream(byteArray));
}
ByteBuffer bb = ByteBuffer.wrap(byteArray);
int size = bb.getInt();
if (size < 0) {
throw new IllegalArgumentException("Invalid NDList size: " + size);
}
NDList list = new NDList();
for (int i = 0; i < size; i++) {
list.add(i, NDSerializer.decode(manager, bb));
}
return list;
} catch (IOException | BufferUnderflowException e) {
throw new IllegalArgumentException("Invalid NDArray input", e);
}
}
/**
* Decodes NDList from {@link InputStream}.
*
* @param manager manager assigned to {@link NDArray}
* @param is input stream contains the ndlist information
* @return {@code NDList}
*/
public static NDList decode(NDManager manager, InputStream is) {
try {
DataInputStream dis = new DataInputStream(is);
byte[] magic = new byte[9];
dis.readFully(magic);
PushbackInputStream pis = new PushbackInputStream(is, 9);
pis.unread(magic);
if (magic[0] == 'P' && magic[1] == 'K') {
// assume this is npz file
return decodeNumpy(manager, pis);
} else if (magic[0] == (byte) 0x93
&& magic[1] == 'N'
&& magic[2] == 'U'
&& magic[3] == 'M') {
return new NDList(NDSerializer.decodeNumpy(manager, pis));
} else if (magic[8] == '{') {
return decodeSafetensors(manager, pis);
}
dis = new DataInputStream(pis);
int size = dis.readInt();
if (size < 0) {
throw new IllegalArgumentException("Invalid NDList size: " + size);
}
NDList list = new NDList();
for (int i = 0; i < size; i++) {
list.add(i, manager.decode(dis));
}
return list;
} catch (IOException e) {
throw new IllegalArgumentException("Malformed data", e);
}
}
private static NDList decodeSafetensors(NDManager manager, InputStream is) throws IOException {
DataInputStream dis;
if (is instanceof DataInputStream) {
dis = (DataInputStream) is;
} else {
dis = new DataInputStream(is);
}
byte[] buf = new byte[8];
dis.readFully(buf);
int len = Math.toIntExact(ByteBuffer.wrap(buf).order(ByteOrder.LITTLE_ENDIAN).getLong());
buf = new byte[len];
dis.readFully(buf);
String json = new String(buf, StandardCharsets.UTF_8);
// rust implementation sort by name, our implementation preserve the order.
JsonObject jsonObject = JsonUtils.GSON.fromJson(json, JsonObject.class);
List> list = new ArrayList<>();
int max = 0;
for (String key : jsonObject.keySet()) {
if ("__metadata__".equals(key)) {
continue;
}
SafeTensor value = JsonUtils.GSON.fromJson(jsonObject.get(key), SafeTensor.class);
if (value.offsets.length != 2) {
throw new IOException("Malformed safetensors metadata: " + json);
}
max = Math.max(max, value.offsets[1]);
list.add(new Pair<>(key, value));
}
buf = new byte[max];
dis.readFully(buf);
NDList ret = new NDList(list.size());
for (Pair pair : list) {
if ("__metadata__".equals(pair.getKey())) {
continue;
}
SafeTensor st = pair.getValue();
Shape shape = new Shape(st.shape);
ByteBuffer bb = ByteBuffer.wrap(buf, st.offsets[0], st.size());
bb.order(ByteOrder.LITTLE_ENDIAN);
DataType dataType = DataType.fromSafetensors(st.dtype);
NDArray array = manager.create(bb, shape, dataType);
array.setName(pair.getKey());
ret.add(array);
}
return ret;
}
private static NDList decodeNumpy(NDManager manager, InputStream is) throws IOException {
NDList list = new NDList();
ZipInputStream zis = new ZipInputStream(is);
ZipEntry entry;
while ((entry = zis.getNextEntry()) != null) {
String name = entry.getName();
NDArray array = NDSerializer.decodeNumpy(manager, zis);
if (!name.startsWith("arr_") && name.endsWith(".npy")) {
array.setName(name.substring(0, name.length() - 4));
}
list.add(array);
}
return list;
}
/**
* Returns the first occurrence of the specified element from this NDList if it is present.
*
* @param name the name of the NDArray
* @return the first occurrence
*/
public NDArray get(String name) {
for (NDArray array : this) {
if (name.equals(array.getName())) {
return array;
}
}
return null;
}
/**
* Removes the first occurrence of the specified element from this NDList if it is present.
*
* If this list does not contain the element, it is unchanged. More formally, removes the
* element with the lowest index {@code i} such that {@code
* (o==null ? get(i)==null : o.equals(get(i)))} (if such an element exists).
*
* @param name the name of the NDArray to be removed from this NDList, if present
* @return the element that was removed
*/
public NDArray remove(String name) {
int index = 0;
for (NDArray array : this) {
if (name.equals(array.getName())) {
remove(index);
return array;
}
++index;
}
return null;
}
/**
* Returns {@code true} if this NDList contains an NDArray with the specified name.
*
* @param name the name of the NDArray to be removed from this NDList, if present
* @return {@code true} if this list contains the specified element
*/
public boolean contains(String name) {
for (NDArray array : this) {
if (name.equals(array.getName())) {
return true;
}
}
return false;
}
/**
* Returns the head index of the NDList.
*
* @return the head NDArray
* @throws IndexOutOfBoundsException if the index is out of range ({@code index < 0 || index
* >= size()})
*/
public NDArray head() {
return get(0);
}
/**
* Returns the only element if this is a singleton NDList or throws an exception if multiple
* elements.
*
* @return the head NDArray
* @throws IndexOutOfBoundsException if the list does not contain exactly one element
*/
public NDArray singletonOrThrow() {
if (size() != 1) {
throw new IndexOutOfBoundsException(
"Incorrect number of elements in NDList.singletonOrThrow: Expected 1 and was "
+ size());
}
return get(0);
}
/**
* Appends all of the NDArrays in the specified NDList to the end of this NDList, in the order
* that they are returned by the specified NDList's iterator.
*
* @param other the NDList containing NDArray to be added to this list
* @return this NDList after the addition
*/
public NDList addAll(NDList other) {
for (NDArray array : other) {
add(array);
}
return this;
}
/**
* Returns a view of the portion of this NDList between the specified fromIndex, inclusive, and
* to the end.
*
* @param fromIndex the start index (inclusive)
* @return a view of the portion of this NDList
*/
public NDList subNDList(int fromIndex) {
return subNDList(fromIndex, size());
}
/**
* Returns a view of the portion of this NDList between the specified fromIndex, inclusive, and
* toIndex, exclusive.
*
* @param fromIndex the start index (inclusive)
* @param toIndex the end index (exclusive)
* @return a view of the portion of this NDList
*/
public NDList subNDList(int fromIndex, int toIndex) {
return new NDList(subList(fromIndex, toIndex));
}
/**
* Converts all the {@code NDArray} in {@code NDList} to a different {@link Device}.
*
* @param device the {@link Device} to be set
* @param copy set {@code true} if you want to return a copy of the underlying NDArray
* @return a new {@code NDList} with the NDArrays on specified {@link Device}
*/
public NDList toDevice(Device device, boolean copy) {
if (!copy) {
// if all arrays in NDList are already on device, return itself
if (stream().allMatch(array -> array.getDevice() == device)) {
return this;
}
}
NDList newNDList = new NDList(size());
forEach(a -> newNDList.add(a.toDevice(device, copy)));
return newNDList;
}
/** {@inheritDoc} */
@Override
public NDManager getManager() {
return head().getManager();
}
/** {@inheritDoc} */
@Override
public List getResourceNDArrays() {
return this;
}
/** {@inheritDoc} */
@Override
public void attach(NDManager manager) {
forEach(array -> array.attach(manager));
}
/** {@inheritDoc} */
@Override
public void tempAttach(NDManager manager) {
forEach(array -> array.tempAttach(manager));
}
/** {@inheritDoc} */
@Override
public void detach() {
forEach(NDResource::detach);
}
/**
* Encodes the NDList to byte array.
*
* @return the byte array
*/
public byte[] encode() {
return encode(Encoding.ND_LIST);
}
/**
* Encodes the NDList to byte array.
*
* @param encoding encode mode, one of ndlist/npz/safetensor format
* @return the byte array
*/
public byte[] encode(Encoding encoding) {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
encode(baos, encoding);
return baos.toByteArray();
} catch (IOException e) {
throw new AssertionError("NDList is not writable", e);
}
}
/**
* Writes the encoded NDList to {@code OutputStream}.
*
* @param os the {@code OutputStream} to be written to
* @throws IOException if failed on IO operation
*/
public void encode(OutputStream os) throws IOException {
encode(os, Encoding.ND_LIST);
}
/**
* Writes the encoded NDList to {@code OutputStream}.
*
* @param os the {@code OutputStream} to be written to
* @param encoding encode mode, one of ndlist/npz/safetensor format
* @throws IOException if failed on IO operation
*/
public void encode(OutputStream os, Encoding encoding) throws IOException {
if (encoding == Encoding.NPZ) {
ZipOutputStream zos = new ZipOutputStream(os);
int i = 0;
for (NDArray nd : this) {
String name = nd.getName();
if (name == null) {
zos.putNextEntry(new ZipEntry("arr_" + i + ".npy"));
++i;
} else {
zos.putNextEntry(new ZipEntry(name + ".npy"));
}
NDSerializer.encodeAsNumpy(nd, zos);
}
zos.finish();
zos.flush();
return;
} else if (encoding == Encoding.SAFETENSORS) {
Map map = new ConcurrentHashMap<>(size());
int i = 0;
int offset = 0;
for (NDArray nd : this) {
String name = nd.getName();
if (name == null) {
name = "arr_" + i;
++i;
}
SafeTensor st = new SafeTensor();
st.dtype = nd.getDataType().asSafetensors();
st.shape = nd.getShape().getShape();
long size = nd.getDataType().getNumOfBytes() * nd.size();
int limit = offset + Math.toIntExact(size);
st.offsets = new int[] {offset, limit};
map.put(name, st);
offset = limit;
}
byte[] json = JsonUtils.GSON.toJson(map).getBytes(StandardCharsets.UTF_8);
ByteBuffer buf = ByteBuffer.allocate(8);
buf.order(ByteOrder.LITTLE_ENDIAN);
buf.putLong(0, json.length);
os.write(buf.array());
os.write(json);
for (NDArray nd : this) {
os.write(nd.toByteArray());
}
return;
}
DataOutputStream dos = new DataOutputStream(os);
dos.writeInt(size());
for (NDArray nd : this) {
NDSerializer.encode(nd, dos);
}
dos.flush();
}
/** {@inheritDoc} */
@Override
public byte[] getAsBytes() {
return encode();
}
/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
return ByteBuffer.wrap(encode());
}
/**
* Gets all of shapes in the {@code NDList}.
*
* @return shapes in {@code NDList}
*/
public Shape[] getShapes() {
return stream().map(NDArray::getShape).toArray(Shape[]::new);
}
/** {@inheritDoc} */
@Override
public void close() {
forEach(NDArray::close);
clear();
}
/** {@inheritDoc} */
@Override
public String toString() {
StringBuilder builder = new StringBuilder(200);
builder.append("NDList size: ").append(size()).append('\n');
int index = 0;
for (NDArray array : this) {
String name = array.getName();
builder.append(index++).append(' ');
if (name != null) {
builder.append(name);
}
builder.append(": ")
.append(array.getShape())
.append(' ')
.append(array.getDataType())
.append('\n');
}
return builder.toString();
}
/** An enum represents NDList serialization format. */
public enum Encoding {
ND_LIST,
NPZ,
SAFETENSORS
}
private static final class SafeTensor {
String dtype;
long[] shape;
@SerializedName("data_offsets")
int[] offsets;
int size() {
return offsets[1] - offsets[0];
}
}
}