经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 大数据/云/AI » 人工智能基础 » 查看文章
ResNet50的猫狗分类训练及预测
来源:cnblogs  作者:Wchime  时间:2023/4/12 15:42:15  对本文有异议

相比于之前写的ResNet18,下面的ResNet50写得更加工程化一点,这还适用与其他分类。

我的代码文件结构

  

 

1. 数据处理

  首先已经对数据做好了分类

  

 

 

   文件夹结构是这样

  开始划分数据集

  split_data.py

  1. import os
  2. import random
  3. import shutil
  4. def move_file(target_path, save_train_path, save_val_pathm, scale=0.1):
  5. file_list = os.listdir(target_path)
  6. random.shuffle(file_list)
  7. number = int(len(file_list) * scale)
  8. train_list = file_list[number:]
  9. val_list = file_list[:number]
  10. for file in train_list:
  11. target_file_path = os.path.join(target_path, file)
  12. save_file_path = os.path.join(save_train_path, file)
  13. shutil.copyfile(target_file_path, save_file_path)
  14. for file in val_list:
  15. target_file_path = os.path.join(target_path, file)
  16. save_file_path = os.path.join(save_val_pathm, file)
  17. shutil.copyfile(target_file_path, save_file_path)
  18. def split_classify_data(base_path, save_path, scale=0.1):
  19. folder_list = os.listdir(base_path)
  20. for folder in folder_list:
  21. target_path = os.path.join(base_path, folder)
  22. save_train_path = os.path.join(save_path, 'train', folder)
  23. save_val_path = os.path.join(save_path, 'val', folder)
  24. if not os.path.exists(save_train_path):
  25. os.makedirs(save_train_path)
  26. if not os.path.exists(save_val_path):
  27. os.makedirs(save_val_path)
  28. move_file(target_path, save_train_path, save_val_path, scale)
  29. print(folder, 'finish!')
  30. if __name__ == '__main__':
  31. base_path = r'C:\Users\Administrator.DESKTOP-161KJQD\Desktop\save_dir'
  32. save_path = r'C:\Users\Administrator.DESKTOP-161KJQD\Desktop\dog_cat'
  33. # 验证集比例
  34. scale = 0.1
  35. split_classify_data(base_path, save_path, scale)

  运行完以上代码的到的文件夹结构

    

 

 

   一个训练集数据,一个验证集数据

  

2.数据集的导入

  我这个文件写了一个数据集的导入和一个学习率更新的函数。数据导入是通用的

  tools.py

  1. import os
  2. import time
  3. import cv2
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.optim as optim
  9. import torchvision
  10. from torch.autograd.variable import Variable
  11. from torch.utils.tensorboard import SummaryWriter
  12. from torchvision import datasets, transforms
  13. from torch.utils.data import Dataset, DataLoader
  14. from torch.optim.lr_scheduler import ExponentialLR, LambdaLR
  15. from torchvision.models import ResNet50_Weights
  16. from tqdm import tqdm
  17. from classify_cfg import *
  18. mean = MEAN
  19. std = STD
  20. def get_dataset(base_dir='', input_size=160):
  21. dateset = dict()
  22. transform_train = transforms.Compose([
  23. # 分辨率重置为input_size
  24. transforms.Resize(input_size),
  25. transforms.RandomRotation(15),
  26. # 对加载的图像作归一化处理, 并裁剪为[input_sizexinput_sizex3]大小的图像(因为这图片像素不一致直接统一)
  27. transforms.CenterCrop(input_size),
  28. transforms.ToTensor(),
  29. transforms.Normalize(mean=mean, std=std)
  30. ])
  31. transform_val = transforms.Compose([
  32. transforms.Resize(input_size),
  33. transforms.RandomRotation(15),
  34. transforms.CenterCrop(input_size),
  35. transforms.ToTensor(),
  36. transforms.Normalize(mean=mean, std=std)
  37. ])
  38. base_dir_train = os.path.join(base_dir, 'train')
  39. train_dataset = datasets.ImageFolder(root=base_dir_train, transform=transform_train)
  40. # print("train_dataset=" + repr(train_dataset[1][0].size()))
  41. # print("train_dataset.class_to_idx=" + repr(train_dataset.class_to_idx))
  42. # print(train_dataset.classes)
  43. classes = train_dataset.classes
  44. # classes = train_dataset.class_to_idx
  45. classes_num = len(train_dataset.classes)
  46. base_dir_val = os.path.join(base_dir, 'val')
  47. val_dataset = datasets.ImageFolder(root=base_dir_val, transform=transform_val)
  48. dateset['train'] = train_dataset
  49. dateset['val'] = val_dataset
  50. return dateset, classes, classes_num
  51. def update_lr(epoch, epochs):
  52. """
  53. 假设开始的学习率lr是0.001,训练次数epochs是100
  54. 当epoch<33时是lr * 1
  55. 当33<=epoch<=66 时是lr * 0.5
  56. 当66<epoch时是lr * 0.1
  57. """
  58. if epoch == 0 or epochs // 3 > epoch:
  59. return 1
  60. elif (epochs // 3 * 2 >= epoch) and (epochs // 3 <= epoch):
  61. return 0.5
  62. else:
  63. return 0.1

 

3.训练模型

  数据集导入好了以后,选择模型,选择优化器等等,然后开始训练。

  mytrain.py

  1. import os
  2. import time
  3. import cv2
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.optim as optim
  8. import torchvision
  9. from torch.autograd.variable import Variable
  10. from torch.utils.tensorboard import SummaryWriter
  11. from torch.utils.data import Dataset, DataLoader
  12. from torch.optim.lr_scheduler import ExponentialLR, LambdaLR
  13. from torchvision.models import ResNet50_Weights
  14. # from tqdm import tqdm
  15. from classify_cfg import *
  16. from tools import get_dataset, update_lr
  17. def train(model, dateset, epochs, batch_size, device, optimizer, scheduler, criterion, save_path):
  18. train_loader = DataLoader(dateset.get('train'), batch_size=batch_size, shuffle=True)
  19. val_loader = DataLoader(dateset.get('val'), batch_size=batch_size, shuffle=True)
  20. # 保存为tensorboard文件
  21. write = SummaryWriter(save_path)
  22. # 训练过程写入txt
  23. f = open(os.path.join(save_path, 'log.txt'), 'w', encoding='utf-8')
  24. best_acc = 0
  25. for epoch in range(epochs):
  26. train_correct = 0.0
  27. model.train()
  28. sum_loss = 0.0
  29. accuracy = -1
  30. total_num = len(train_loader.dataset)
  31. # print(total_num, len(train_loader))
  32. # loop = tqdm(enumerate(train_loader), total=len(train_loader))
  33. batch_count = 0
  34. for batch_idx, (data, target) in enumerate(train_loader):
  35. start_time = time.time()
  36. data, target = Variable(data).to(device), Variable(target).to(device)
  37. output = model(data)
  38. loss = criterion(output, target)
  39. optimizer.zero_grad()
  40. loss.backward()
  41. optimizer.step()
  42. print_loss = loss.data.item()
  43. sum_loss += print_loss
  44. train_predict = torch.max(output.data, 1)[1]
  45. if torch.cuda.is_available():
  46. train_correct += (train_predict.cuda() == target.cuda()).sum()
  47. else:
  48. train_correct += (train_predict == target).sum()
  49. accuracy = (train_correct / total_num) * 100
  50. # loop.set_description(f'Epoch [{epoch+1}/{epochs}]')
  51. # loop.set_postfix(loss=loss.item(), acc='{:.3f}'.format(accuracy))
  52. batch_count += len(data)
  53. end_time = time.time()
  54. s = f'Epoch:[{epoch+1}/{epochs}] Batch:[{batch_count}/{total_num}] train_acc: {"{:.2f}".format(accuracy)} ' f'train_loss: {"{:.3f}".format(loss.item())} time: {int((end_time-start_time)*1000)} ms'
  55. # print(f'Epoch:[{epoch+1}/{epochs}]', f'Batch:[{batch_count}/{total_num}]',
  56. # 'train_acc:', '{:.2f}'.format(accuracy), 'train_loss:', '{:.3f}'.format(loss.item()),
  57. # 'time:', f'{int((end_time-start_time)*1000)} ms')
  58. print(s)
  59. f.write(s+'\n')
  60. write.add_scalar('train_acc', accuracy, epoch)
  61. write.add_scalar('train_loss', loss.item(), epoch)
  62. # print(optimizer.param_groups[0]['lr'])
  63. scheduler.step()
  64. if best_acc < accuracy:
  65. best_acc = accuracy
  66. torch.save(model, os.path.join(save_path, 'best.pt'))
  67. if epoch+1 == epochs:
  68. torch.save(model, os.path.join(save_path, 'last.pt'))
  69. # 预测验证集
  70. # if (epoch+1) % 5 == 0 or epoch+1 == epochs:
  71. model.eval()
  72. test_loss = 0.0
  73. correct = 0.0
  74. total_num = len(val_loader.dataset)
  75. # print(total_num, len(val_loader))
  76. with torch.no_grad():
  77. for data, target in val_loader:
  78. data, target = Variable(data).to(device), Variable(target).to(device)
  79. output = model(data)
  80. loss = criterion(output, target)
  81. _, pred = torch.max(output.data, 1)
  82. if torch.cuda.is_available():
  83. correct += torch.sum(pred.cuda() == target.cuda())
  84. else:
  85. correct += torch.sum(pred == target)
  86. print_loss = loss.data.item()
  87. test_loss += print_loss
  88. acc = correct / total_num * 100
  89. avg_loss = test_loss / len(val_loader)
  90. s = f"val acc: {'{:.2f}'.format(acc)} val loss: {'{:.3f}'.format(avg_loss)}"
  91. # print('val acc: ', '{:.2f}'.format(acc), 'val loss: ', '{:.3f}'.format(avg_loss))
  92. print(s)
  93. f.write(s+'\n')
  94. write.add_scalar('val_acc', acc, epoch)
  95. write.add_scalar('val_loss', avg_loss, epoch)
  96. # loop.set_postfix(val_loss='{:.3f}'.format(avg_loss), val_acc='{:.3f}'.format(acc))
  97. f.close()
  98. if __name__ == '__main__':
  99. device = DEVICE
  100. epochs = EPOCHS
  101. batch_size = BATCH_SIZE
  102. input_size = INPUT_SIZE
  103. lr = LR
  104. # ---------------------------训练-------------------------------------
  105. # 图片的路径
  106. base_dir = r'C:\Users\Administrator.DESKTOP-161KJQD\Desktop\dog_cat'
  107. # 保存的路径
  108. save_path = r'C:\Users\Administrator.DESKTOP-161KJQD\Desktop\dog_cat_save'
  109. dateset, classes, classes_num = get_dataset(base_dir, input_size=input_size)
  110. # model = torchvision.models.resnet50(pretrained=True)
  111. model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
  112. num_ftrs = model.fc.in_features
  113. model.fc = nn.Linear(num_ftrs, classes_num)
  114. model.to(DEVICE)
  115. # # 损失函数,交叉熵损失函数
  116. criteon = nn.CrossEntropyLoss()
  117. # 选择优化器
  118. optimizer = optim.SGD(model.parameters(), lr=lr)
  119. # 学习率更新
  120. # scheduler = ExponentialLR(optimizer, gamma=0.9)
  121. scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: update_lr(epoch, epochs))
  122. # 开始训练
  123. train(model, dateset, epochs, batch_size, device, optimizer, scheduler, criteon, save_path)
  124. # 将label保存起来
  125. with open(os.path.join(save_path, 'labels.txt'), 'w', encoding='utf-8') as f:
  126. f.write(f'{classes_num} {classes}')

  训练结束以后,在保存路径下会得到下面的文件

  

 

  最好的模型,最后一次的模型,标签的列表,训练的记录和tensorboard记录

  在该路径下执行 tensorboard --logdir=.  

  

 

  然后在浏览器打开给出的地址,即可看到数据训练过程的绘图

 4.对图片进行预测

  考虑对于用户来说,用户是在网页或者手机上上传一张图片进行预测,所以这边是采用二进制数据。

  mypredict.py

  

  1. import cv2
  2. import numpy as np
  3. import torch
  4. from classify_cfg import *
  5.  
  6.  
  7.  
  8. def img_process(img_betys, img_size, device):
  9. img_arry = np.asarray(bytearray(img_betys), dtype='uint8')
  10. # im0 = cv2.imread(img_betys)
  11. im0 = cv2.imdecode(img_arry, cv2.IMREAD_COLOR)
  12. image = cv2.resize(im0, (img_size, img_size))
  13. image = np.float32(image) / 255.0
  14. image[:, :, ] -= np.float32(mean)
  15. image[:, :, ] /= np.float32(std)
  16. image = image.transpose((2, 0, 1))
  17. im = torch.from_numpy(image).unsqueeze(0)
  18. im = im.to(device)
  19. return im
  20. def predict(model_path, img, device):
  21. model = torch.load(model_path)
  22. model.to(device)
  23. model.eval()
  24. predicts = model(img)
  25. # print(predicts)
  26. _, preds = torch.max(predicts, 1)
  27. pred = torch.squeeze(preds)
  28. # print(pred)
  29. return pred
  30. if __name__ == '__main__':
  31. mean = MEAN
  32. std = STD
  33. device = DEVICE
  34. classes = ['', '']
  35. # # 预测
  36. model_path = r'C:\Users\Administrator.DESKTOP-161KJQD\Desktop\dog_cat_save\best.pt'
  37. img_path = r'C:\Users\Administrator.DESKTOP-161KJQD\Desktop\save_dir\狗\000000.jpg'
  38. with open(img_path, 'rb') as f:
  39. img_betys = f.read()
  40. img =img_process(img_betys, 160, device)
  41. # print(img.shape)
  42. # print(img)
  43. pred = predict(model_path, img, device)
  44. print(classes[int(pred)])

 

原文链接:https://www.cnblogs.com/moon3496694/p/17310038.html

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号