https://www.kaggle.com/code/shivanandmn/beginners-guide-to-pytorch-lightning/notebook
This is basic look of pytorch-lighning¶
There are many reserved methods in the lighningmodules called hooks:
- configure_optimizers - this should return optimizer(Adam/SGD)
- training_step - training loop, takes batch and batch_idx as parameters
- validation_step-validation loop, takes batch and batch_idx as parameters
- testing_step- testing loop, takes batch and batch_idx as parameters
In [1]:
import pytorch_lightning as pl
In [2]:
class LightningModel(pl.LightningModule):
def __init__(self):
super().__init__()
def forward(self, x):
pass
def configure_optimizers(self):
pass
def loss_fn(self, output, target):
pass
def training_step(self):
pass
def validation_step(self):
pass
- Pytorch-lightning has data module extension that structures your data preprocessing.
- This structure helps to read and understand code easily for everyone.
- It helps to reuse data across multiple projects even with complex data transform and multiple-GPU handling
Hooks:¶
train_dataloader()
val_dataloader()
test_dataloader()
Above methods in lightning datamodule are dataloaders
prepare_data()
: Download and tokenize or do preprocessing on complete dataset, because this is called on single gpu if your using mulitple gpu, data here is not shared accross gpus.setup()
: splitting or transformations etc. setup takes stage argument None by default or fit or test for training and testing respectively.
In [3]:
class LightningDataset(pl.LightningDataModule):
def __init__(self):
super().__init__()
def prepare_data(self):
pass
def setup(self, stage=None):
pass
def train_dataloader(self):
pass
def val_dataloader(self):
pass
def test_dataloader(self):
pass
- There are mainly two classes required to train model and test it on test data.
- Now, we can use pl.Trainer for instantiating trainer with many parameters like max_epochs, tpu_cores,num_gpus,cpu(default),logger etc
- trainer.fit, fitting the data on model
In [ ]:
dataset = LightningDataset()
model = LightningModel()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model=model, datamodule=dataset)
'끄적끄적' 카테고리의 다른 글
계량경제학은 어떨까?(Feat. Causal inference) (2) | 2022.05.19 |
---|---|
Data-centric AI (0) | 2022.04.20 |
데싸 얘기, 면접질문 등 (0) | 2022.02.18 |
경북대학교 데이터사이언스 대학원 합격 후기 (50) | 2022.02.05 |
엄청난 사이트 정리 (0) | 2021.12.02 |