Benchmark MappedCollection streaming¶
In [1]:
!lamin load laminlabs/arrayloader-benchmarks
💡 loaded instance: laminlabs/arrayloader-benchmarks
In [2]:
import lamindb as ln
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import gc
import time
💡 lamindb instance: laminlabs/arrayloader-benchmarks
In [3]:
ln.track()
💡 notebook imports: lamindb==0.67.3 psutil==5.9.5 torch==2.0.1 tqdm==4.66.1 💡 loaded: Transform(uid='oXJWvVPX89PZ5zKv', name='Benchmark MappedCollection streaming', short_name='mapped-streaming-benchmark', version='1', type='notebook', updated_at=2024-02-11 11:30:32 UTC, created_by_id=1) 💡 loaded: Run(uid='alQulQ2TwAmL0EjEiu2n', run_at=2024-02-11 12:15:06 UTC, transform_id=20, created_by_id=1)
In [4]:
collection_h5ads = ln.Collection.filter(uid="VwxM0HNDtEcNjJEKwYqO").one()
Benchmark¶
In [5]:
BATCH_SIZE = 1024
In [6]:
def benchmark(loader, n_samples):
num_iter = n_samples // BATCH_SIZE
loader_iter = loader.__iter__()
# exclude first batch from benchmark as this includes the setup time
batch = next(loader_iter)
start_time = time.time()
for i, batch in tqdm(enumerate(loader_iter), total=num_iter):
X = batch[0]
# for pytorch DataLoader
# Merlin sends to cuda by default
if hasattr(X, "is_cuda") and not X.is_cuda:
X = X.cuda()
if i == num_iter:
break
if i % 10 == 0:
gc.collect()
execution_time = time.time() - start_time
gc.collect()
time_per_sample = (1e6 * execution_time) / (num_iter * BATCH_SIZE)
print(f'time per sample: {time_per_sample:.2f} μs')
samples_per_sec = num_iter * BATCH_SIZE / execution_time
print(f'samples per sec: {samples_per_sec:.2f} samples/sec')
return samples_per_sec, time_per_sample
In [7]:
# parallel=True doesn't work with stream=True now
dataset = collection_h5ads.mapped(join=None, stream=True)
In [8]:
# cuda init
torch.ones(2).cuda()
Out[8]:
tensor([1., 1.], device='cuda:0')
In [9]:
loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True
)
In [10]:
samples_per_sec, time_per_sample = benchmark(loader, n_samples=10000)
0%| | 0/9 [00:00<?, ?it/s]
time per sample: 435526.92 μs samples per sec: 2.30 samples/sec
In [11]:
dataset.close()
In [ ]: