Track images for PyTorch#
!lamin init --storage "mnist-100"
ℹ️ Loading schema modules: core==0.30rc5
✅ Created & loaded instance: testuser1/mnist-100
import lamindb as ln
import pandas as pd
ln.track()
ℹ️ Instance: testuser1/mnist-100
ℹ️ User: testuser1
ℹ️ Added notebook: Transform(id='GfLCSYv8BjvJ', v='0', name='mnist-local', type=notebook, title='Track images for PyTorch', created_by='DzTjkKse', created_at=datetime.datetime(2023, 3, 31, 19, 11, 40))
ℹ️ Added run: Run(id='WfxNzP5SUvPD4IAjWHJZ', transform_id='GfLCSYv8BjvJ', transform_v='0', created_by='DzTjkKse', created_at=datetime.datetime(2023, 3, 31, 19, 11, 40))
Show code cell content
# prepare local data
import boto3
from pathlib import Path
s3 = boto3.resource("s3")
bucket = s3.Bucket("bernardo-test-bucket-1")
for obj in bucket.objects.filter(Prefix="mnist-100/"):
if not obj.key.endswith("/"):
Path(obj.key).parent.mkdir(exist_ok=True)
bucket.download_file(obj.key, obj.key)
Assume we have a local directory of files that we’d like to ingest:
!ls mnist-100/images*
mnist_0.png mnist_27.png mnist_45.png mnist_63.png mnist_81.png
mnist_1.png mnist_28.png mnist_46.png mnist_64.png mnist_82.png
mnist_10.png mnist_29.png mnist_47.png mnist_65.png mnist_83.png
mnist_11.png mnist_3.png mnist_48.png mnist_66.png mnist_84.png
mnist_12.png mnist_30.png mnist_49.png mnist_67.png mnist_85.png
mnist_13.png mnist_31.png mnist_5.png mnist_68.png mnist_86.png
mnist_14.png mnist_32.png mnist_50.png mnist_69.png mnist_87.png
mnist_15.png mnist_33.png mnist_51.png mnist_7.png mnist_88.png
mnist_16.png mnist_34.png mnist_52.png mnist_70.png mnist_89.png
mnist_17.png mnist_35.png mnist_53.png mnist_71.png mnist_9.png
mnist_18.png mnist_36.png mnist_54.png mnist_72.png mnist_90.png
mnist_19.png mnist_37.png mnist_55.png mnist_73.png mnist_91.png
mnist_2.png mnist_38.png mnist_56.png mnist_74.png mnist_92.png
mnist_20.png mnist_39.png mnist_57.png mnist_75.png mnist_93.png
mnist_21.png mnist_4.png mnist_58.png mnist_76.png mnist_94.png
mnist_22.png mnist_40.png mnist_59.png mnist_77.png mnist_95.png
mnist_23.png mnist_41.png mnist_6.png mnist_78.png mnist_96.png
mnist_24.png mnist_42.png mnist_60.png mnist_79.png mnist_97.png
mnist_25.png mnist_43.png mnist_61.png mnist_8.png mnist_98.png
mnist_26.png mnist_44.png mnist_62.png mnist_80.png mnist_99.png
And a .csv
file containing the labels for each of the images.
labels_df = pd.read_csv("mnist-100/labels.csv")
labels_df.head()
filename | label | |
---|---|---|
0 | mnist_0.png | 5 |
1 | mnist_1.png | 0 |
2 | mnist_2.png | 4 |
3 | mnist_3.png | 1 |
4 | mnist_4.png | 9 |
Ingest images and labels#
Let’s ingest each image in the folder as a File
record by leveraging the Folder
entity.
img_folder = ln.Folder(folder="mnist-100/images")
ln.add(img_folder);
Let’s also ingest the labels file as a single data object.
labels = ln.File("mnist-100/labels.csv")
ln.add(labels);
Important
We can equally well pass cloud locations!
Create the PyTorch Dataset#
Let’s query the relevant data objects to instantiate a canonical PyTorch custom image dataset.
Show code cell content
# define the custom dataset class, as seen in the PyTorch guide
import os
import torch
from torchvision.io import read_image
from torch.utils.data import Dataset
class CustomImageDataset(Dataset):
def __init__(
self, annotations_file, img_dir, transform=None, target_transform=None
):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path).to(torch.float32)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
The canonical PyTorch dataset takes as input the path to a folder with images.
img_folder = ln.select(ln.Folder).one()
img_folderpath = img_folder.path()
As well as the path to a csv file with labels.
labels = ln.select(ln.File, suffix=".csv").one()
labels_filepath = labels.path()
Let’s now instantiate the canonical PyTorch custom image dataset.
dataset = CustomImageDataset(labels_filepath, img_folderpath)
Create the PyTorch DataLoaders#
from torch.utils.data import random_split, DataLoader
# define train and test splits
train_subset, test_subset = random_split(dataset, [80, 20])
# create train an test Dataloaders based on splits
train_loader = DataLoader(train_subset.dataset)
test_loader = DataLoader(test_subset.dataset)
Train a simple autoencoder#
We can now train a canonical PyTorch Lightning autoencoder.
Show code cell content
import torch
from torch import optim, nn
from torchmetrics import Accuracy
import pytorch_lightning as pl
# torch.set_default_dtype(torch.uint8)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
class LitAutoEncoder(pl.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.test_accuracy = Accuracy(task="multiclass", num_classes=9)
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
autoencoder = LitAutoEncoder(encoder, decoder)
2023-03-31 19:12:11,880:INFO - Created a temporary directory at /tmp/tmpwf1fbj75
2023-03-31 19:12:11,881:INFO - Writing /tmp/tmpwf1fbj75/_remote_module_non_scriptable.py
trainer = pl.Trainer(limit_train_batches=100, max_epochs=5)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
2023-03-31 19:12:12,121:INFO - GPU available: False, used: False
2023-03-31 19:12:12,122:INFO - TPU available: False, using: 0 TPU cores
2023-03-31 19:12:12,123:INFO - IPU available: False, using: 0 IPUs
2023-03-31 19:12:12,124:INFO - HPU available: False, using: 0 HPUs
/home/runner/work/pytorch-lamin-mnist/pytorch-lamin-mnist/.nox/build-3-9/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
warning_cache.warn(
2023-03-31 19:12:12,451:WARNING - Missing logger folder: /home/runner/work/pytorch-lamin-mnist/pytorch-lamin-mnist/docs/guide/lightning_logs
2023-03-31 19:12:12,456:INFO -
| Name | Type | Params
-----------------------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
2 | test_accuracy | MulticlassAccuracy | 0
-----------------------------------------------------
101 K Trainable params
0 Non-trainable params
101 K Total params
0.407 Total estimated model params size (MB)
2023-03-31 19:12:16,166:INFO - `Trainer.fit` stopped: `max_epochs=5` reached.