Introduction
๋ชจ๋ธ ํ์ต ์ ํ์ต ํํฉ monitoring, hyperparameter tuning ๋ฑ ๋ค์ํ๊ฒ ํ์ฉ๋๋ tool์ธ wandb์ ๋ํด ์ ๋ฆฌํ ๋ด์ฉ์
๋๋ค.
Example code github
Why wandb?
1.
Train/Validation/Test logging
2.
Visualization (log, gradient, โฆ)
3.
Hyperparameter tuning
Installation
$ conda install -c conda-forge wandb
# or
$ pip install wandb
Bash
๋ณต์ฌ
Synchronization
1.
Sign up in Weights and Biases
2.
login in your local env
$ wandb login
Bash
๋ณต์ฌ
Paste your API key from wandb website.
Initiation
import wandb
wandb.init(project='project_name',
entity='user_name or team_name',
name='display_name_for_this_run')
# or
wandb.init(project='project_name')
wandb.run.name = 'disply_name_for_this_run'
wandb.run.save()
Python
๋ณต์ฌ
Experiment tracking
Track metrics
Major use case of wandb.
wandb.log
for epoch in range(1, args.epochs+1):
print(f'=====Epoch {epoch}')
print("Training...")
train_mae = train_utils.train_model(model, device, train_loader, criterion, optimizer)
print("Evaluating...")
valid_mae = train_utils.evaluate_model(model, device, valid_loader, metric='mse')
print({'Train': train_mae, 'Valid': valid_mae})
wandb.log({'epoch': epoch,
'train_loss': train_mae,
'valid_loss': valid_mae})
Python
๋ณต์ฌ
Can track & visualize various metrics per epoch.
wandb.watch
model = Net()
wandb.watch(model, criterion, log='all')
Python
๋ณต์ฌ
Can track & visualize modelโs gradients and topology per epoch.
Also can view system dashboard including..
โข
CPU
โข
GPU
โข
Memory uilization
โข
Power usage
โข
Temperatures
โข
โฆ
Can explore all different runs in a single dashboard.
Track hyperparameters
wandb.config
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0)
...
args = parser.parse_args()
wandb.config.update(args)
# or
config = wandb.config
config.learning_rate = 0.01
config.device = 0
Python
๋ณต์ฌ
Can view all runs with their hyperparameters.
Visualize model
With the watch command, we can inspect the modelโs topology (network type, number of parameters, output shape, โฆ)
Inspect logs
We can also see actual raw logs printed in the console.
Data and model versioning
W&B has a built-in versioning system.
Artifact is a versioned folder of data and used for dataset versioning, model versioning and dependencies tracking.
artifact = wandb.Artifact('cifar10', type='dataset')
file_path = './data/cifar-10-batches-py'
artifact.add_dir(file_path) # or artifact.add_reference('reference_url')
# to save artifact
run.log_artifact(artifact)
# to load saved artifact
artifact = run.use_artifact('artifact')
artifact_dir = artifact.download()
Python
๋ณต์ฌ
Hyperparameter tuning with Sweeps
W&B Sweeps is a tool to automate hyperparameter optimization and exploration.
โ Reduce boilerplate code + Cool visualizations.
First create a configuration file
sweep_config = {
'method': 'random',
'metric': {'goal': 'minimize', 'name': 'loss'},
'parameters': {
'batch_size': {
'distribution': 'q_log_uniform',
'max': math.log(256),
'min': math.log(32),
'q': 1
},
'epochs': {'value': 5},
'fc_layer_size': {'values': [128, 256, 512]},
'learning_rate': {
'distribution': 'uniform',
'max': 0.1,
'min': 0
},
'optimizer': {'values': ['adam', 'sgd']}
}
Python
๋ณต์ฌ
Define the tuning method, metric, and parameters
โข
tuning method (3 options): random, grid search, bayes search
โข
metric: whether to minimize or maximize
โข
parameters: hyperparameters to be searched by Sweeps.
Sweeps will try all different combinations and compute loss for each one.
def train(config):
with wandb.init(project='project_name', entity='username or teamname', config=config):
config = wandb.config
net = Net(config.fc_layer_size)
net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=config.learning_rate)
Python
๋ณต์ฌ