21 lines
653 B
Python
21 lines
653 B
Python
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import torchvision
|
|
from torchvision import datasets, models, transforms
|
|
|
|
|
|
class resnet50(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super(resnet50, self).__init__()
|
|
self.model = models.resnet50(pretrained=False)
|
|
# replace the last FC layer by a FC layer for our model
|
|
num_ftrs = self.model.fc.in_features
|
|
self.model.fc = nn.Linear(num_ftrs, num_classes, bias=True)
|
|
nn.init.xavier_uniform_(self.model.fc.weight)
|
|
self.model.fc.bias.data.fill_(0.01)
|
|
|
|
def forward(self, x):
|
|
f = self.model(x)
|
|
return f
|