Track images for PyTorch#

!lamin init --storage "mnist-100"
ℹ️ Loading schema modules: core==0.30rc5 
✅ Created & loaded instance: testuser1/mnist-100
import lamindb as ln
import pandas as pd

ℹ️ Instance: testuser1/mnist-100
ℹ️ User: testuser1
ℹ️ Added notebook: Transform(id='GfLCSYv8BjvJ', v='0', name='mnist-local', type=notebook, title='Track images for PyTorch', created_by='DzTjkKse', created_at=datetime.datetime(2023, 3, 31, 19, 11, 40))
ℹ️ Added run: Run(id='WfxNzP5SUvPD4IAjWHJZ', transform_id='GfLCSYv8BjvJ', transform_v='0', created_by='DzTjkKse', created_at=datetime.datetime(2023, 3, 31, 19, 11, 40))
Hide code cell content
# prepare local data
import boto3
from pathlib import Path

s3 = boto3.resource("s3")
bucket = s3.Bucket("bernardo-test-bucket-1")
for obj in bucket.objects.filter(Prefix="mnist-100/"):
    if not obj.key.endswith("/"):
        bucket.download_file(obj.key, obj.key)

Assume we have a local directory of files that we’d like to ingest:

!ls mnist-100/images*
mnist_0.png   mnist_27.png  mnist_45.png  mnist_63.png	mnist_81.png
mnist_1.png   mnist_28.png  mnist_46.png  mnist_64.png	mnist_82.png
mnist_10.png  mnist_29.png  mnist_47.png  mnist_65.png	mnist_83.png
mnist_11.png  mnist_3.png   mnist_48.png  mnist_66.png	mnist_84.png
mnist_12.png  mnist_30.png  mnist_49.png  mnist_67.png	mnist_85.png
mnist_13.png  mnist_31.png  mnist_5.png   mnist_68.png	mnist_86.png
mnist_14.png  mnist_32.png  mnist_50.png  mnist_69.png	mnist_87.png
mnist_15.png  mnist_33.png  mnist_51.png  mnist_7.png	mnist_88.png
mnist_16.png  mnist_34.png  mnist_52.png  mnist_70.png	mnist_89.png
mnist_17.png  mnist_35.png  mnist_53.png  mnist_71.png	mnist_9.png
mnist_18.png  mnist_36.png  mnist_54.png  mnist_72.png	mnist_90.png
mnist_19.png  mnist_37.png  mnist_55.png  mnist_73.png	mnist_91.png
mnist_2.png   mnist_38.png  mnist_56.png  mnist_74.png	mnist_92.png
mnist_20.png  mnist_39.png  mnist_57.png  mnist_75.png	mnist_93.png
mnist_21.png  mnist_4.png   mnist_58.png  mnist_76.png	mnist_94.png
mnist_22.png  mnist_40.png  mnist_59.png  mnist_77.png	mnist_95.png
mnist_23.png  mnist_41.png  mnist_6.png   mnist_78.png	mnist_96.png
mnist_24.png  mnist_42.png  mnist_60.png  mnist_79.png	mnist_97.png
mnist_25.png  mnist_43.png  mnist_61.png  mnist_8.png	mnist_98.png
mnist_26.png  mnist_44.png  mnist_62.png  mnist_80.png	mnist_99.png

And a .csv file containing the labels for each of the images.

labels_df = pd.read_csv("mnist-100/labels.csv")
filename label
0 mnist_0.png 5
1 mnist_1.png 0
2 mnist_2.png 4
3 mnist_3.png 1
4 mnist_4.png 9

Ingest images and labels#

Let’s ingest each image in the folder as a File record by leveraging the Folder entity.

img_folder = ln.Folder(folder="mnist-100/images")

Let’s also ingest the labels file as a single data object.

labels = ln.File("mnist-100/labels.csv")


We can equally well pass cloud locations!

Create the PyTorch Dataset#

Let’s query the relevant data objects to instantiate a canonical PyTorch custom image dataset.

Hide code cell content
# define the custom dataset class, as seen in the PyTorch guide

import os
import torch
from import read_image
from import Dataset

class CustomImageDataset(Dataset):
    def __init__(
        self, annotations_file, img_dir, transform=None, target_transform=None
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path).to(torch.float32)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

The canonical PyTorch dataset takes as input the path to a folder with images.

img_folder =
img_folderpath = img_folder.path()

As well as the path to a csv file with labels.

labels =, suffix=".csv").one()
labels_filepath = labels.path()

Let’s now instantiate the canonical PyTorch custom image dataset.

dataset = CustomImageDataset(labels_filepath, img_folderpath)

Create the PyTorch DataLoaders#

from import random_split, DataLoader

# define train and test splits
train_subset, test_subset = random_split(dataset, [80, 20])

# create train an test Dataloaders based on splits
train_loader = DataLoader(train_subset.dataset)
test_loader = DataLoader(test_subset.dataset)

Train a simple autoencoder#

We can now train a canonical PyTorch Lightning autoencoder.

Hide code cell content
import torch
from torch import optim, nn
from torchmetrics import Accuracy
import pytorch_lightning as pl

# torch.set_default_dtype(torch.uint8)

encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        self.encoder = encoder
        self.decoder = decoder
        self.test_accuracy = Accuracy(task="multiclass", num_classes=9)

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

autoencoder = LitAutoEncoder(encoder, decoder)
2023-03-31 19:12:11,880:INFO - Created a temporary directory at /tmp/tmpwf1fbj75
2023-03-31 19:12:11,881:INFO - Writing /tmp/tmpwf1fbj75/
trainer = pl.Trainer(limit_train_batches=100, max_epochs=5), train_dataloaders=train_loader)
2023-03-31 19:12:12,121:INFO - GPU available: False, used: False
2023-03-31 19:12:12,122:INFO - TPU available: False, using: 0 TPU cores
2023-03-31 19:12:12,123:INFO - IPU available: False, using: 0 IPUs
2023-03-31 19:12:12,124:INFO - HPU available: False, using: 0 HPUs
/home/runner/work/pytorch-lamin-mnist/pytorch-lamin-mnist/.nox/build-3-9/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/ UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
2023-03-31 19:12:12,451:WARNING - Missing logger folder: /home/runner/work/pytorch-lamin-mnist/pytorch-lamin-mnist/docs/guide/lightning_logs
2023-03-31 19:12:12,456:INFO - 
  | Name          | Type               | Params
0 | encoder       | Sequential         | 50.4 K
1 | decoder       | Sequential         | 51.2 K
2 | test_accuracy | MulticlassAccuracy | 0     
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
2023-03-31 19:12:16,166:INFO - `` stopped: `max_epochs=5` reached.