-
Notifications
You must be signed in to change notification settings - Fork 99
Open
Description
🐞 Bug
The same batch size, different micro batches, the algorithm effects are inconsistent.
I have fixed the random seed.
I set chunks equal to 2 or 4
Code that reproduces
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.nn import functional as F
from torchgpipe import GPipe
import random, os
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SimpleDNN(nn.Module):
def __init__(self):
super(SimpleDNN, self).__init__()
self.fc1 = nn.Linear(28*28, 512) # assuming input images are 28x28
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 10)
def forward(self, x):
x = x.view(-1, 28*28) # flatten the image
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# set random seed
seed = 0
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
model = SimpleDNN().to(device)
model = nn.Sequential(
model,
nn.ReLU(),
nn.Linear(10, 10)
)
chunks = 2 # Assume you want to divide the model into chunks
model = GPipe(model, balance=[1, 1, 1], chunks=chunks, devices=[device] * 3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
total = 0
correct = 0
for epoch in range(1):
for batch_idx, (data, target) in enumerate(train_loader):
if data.size(0) % chunks != 0:
continue # Skip batches that do not have the correct size
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
_, predicted = torch.max(output, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f'batch {batch_idx}, Accuracy: {100 * correct / total}%')
Metadata
Metadata
Assignees
Labels
No labels