I wanted to get more familiar with PyTorch and also see if I could actually detect AI-generated images. The kind of stuff that's getting harder to spot as these models get better.
The bet: Can I build a binary classifier that separates real from fake? To keep things tractable, I'm scoping this down hard—just cars and car accidents. Classic computer vision dataset vibes but with a twist.
The Data Situation
Data is always the first problem. I got some friends to generate accident images using GPT-4, DALL-E, Nova, and Titan. Combined with real images I scraped, we're sitting at... 578 samples. Yeah. Not great, not terrible for a quick experiment.
Everything lives in S3, so instead of downloading gigs of images locally (annoying), I built a custom dataloader that pulls directly from S3. Built a DataFrame with file paths and labels, did some cleaning—turns out some scraped images had weird extensions that needed filtering with df[~df['file_path'].str.endswith('x')]. The data was quite imbalanced towards AI images, so I had to add stratified splitting logic.
Custom Dataset Class
Here's the PyTorch Dataset implementation. Nothing fancy—just S3 reads with some standard ImageNet normalization:
class CarImageDataset(Dataset):
def __init__(self, annotations_file, transform=None):
self.img_labels = annotations_file
if transform is None:
self.transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)
])
else:
self.transform = transform
def load_from_s3(self, s3_path):
client = boto3.client('s3')
bucket, key = s3_path[5:].split('/', 1)
resp = client.get_object(Bucket=bucket, Key=key)
data = resp['Body'].read()
try:
image = Image.open(BytesIO(data)).convert('RGB')
except Exception as e:
print(f'Error: Load failed for file {s3_path}.')
return image
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = self.img_labels['file_path'].iloc[idx]
image = self.load_from_s3(img_path)
label = self.img_labels['label'].iloc[idx]
if self.transform:
image = self.transform(image)
return image, label
data = CarImageDataset(train)
dataloader = DataLoader(data, batch_size=64, num_workers=0, shuffle=True)
Btw this is clearly not optimal, things to revise would be to add some caching, change the default transform to be even higher quality, and a lot of other small things.
Choosing a Model
Next I had to choose a model. In this case I picked ResNet-18 with IMAGENET1K_V1 weights and changed the output to be 2:
def load_model(weights='IMAGENET1K_V1', n_out=2):
model = models.resnet18(weights=weights)
model.fc = nn.Linear(model.fc.in_features, n_out, bias=True)
return model
I also determined a loss function. At first I started with nn.CrossEntropyLoss(), after trying it out I went to nn.BCEWithLogitsLoss(). I also opted for using AdamW since it's the mathematically correct version—there is a bug in Adam, I don't know why Adam isn't removed.
loss = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
Training Loop
My training loop was somewhat basic:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'
model.train()
for epoch in range(num_epochs):
running_loss = 0.0
correct = 0
total = 0
loop = tqdm(dataloader, leave=True)
for images, labels in loop:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
y_onehot = nn.functional.one_hot(labels, num_classes=2).float()
loss_val = loss(outputs, y_onehot)
running_loss += loss_val.item()
loss_val.backward()
optimizer.step()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
...
Results and Conclusion
With this setup I had some mixed results around 60-80% depending on the shuffling of data.
My conclusion is that I could predict if something was AI generated or not, if I scoped the problem down by a lot. I also should not use ResNet with such a small training set—it has roughly ~12 million parameters so fine tuning off a small amount was not great. Also I need better data, like it should be more balanced with fake and real images. In the next part I will go through other techniques on how to make this approach even better!