经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 大数据/云/AI » 人工智能基础 » 查看文章
李宏毅2022机器学习HW3 Image Classification
来源:cnblogs  时间:2024/2/23 8:56:55  对本文有异议

Homework3

数据集下载

在本地环境下进行实验总是令人安心,但是又苦于网上找不到数据集,虽然kaggle上有数据集但是下载存在问题

于是有了一个天才的想法,间接从kaggle上下载(利用output文件夹中的文件是可下载这一机制将数据集从input文件夹拷贝到output文件夹),具体操作如下图



等待数据集拷贝到output后,点击输出的蓝色链接即可下载。
相关代码由下给出

  1. !python -m zipfile -c /kaggle/working/Dataset.zip /kaggle/input/ml2022spring-hw4/Dataset # copy数据集到output文件夹,此过程可能较慢
  2. import os
  3. os.chdir('/kaggle/working')
  4. print(os.getcwd())
  5. print(os.listdir("/kaggle/working"))
  6. from IPython.display import FileLink
  7. FileLink('mycode.zip')

任务要求

Task1 模型选择

这里对改进sample code中的模型以及引用其他模型进行了尝试,但效果不佳。
如果你想对sample code中的模型进行改进,直接在类Classifier中的self.cnn以及self.fc模块中进行修改即可
如果你想引入其他的模型,可以参考下面的代码

  1. import torchvision.models as models
  2. alexNet = models.alexnet(weights=None, num_classes=11)
  3. # model = Classifier().to(device)
  4. model = alexNet.to(device)

PS.至于视频中提到的pretrained问题,这个参数以及被废除,使用参数weights代替,而且这里类别数与原模型的类别数不一致,如果使用原模型参数会报错。

Task2 数据增强

指出原代码问题
__getitem__中

  1. label = int(fname.split("/")[-1].split("_")[0])

应改为

  1. label = int(fname.split("/")[-1].split("_")[0].split("\\")[-1])

训练数据增强

这里我理解助教的意思应该是对每一个训练数据进行多种transform转换最后仍为一个样本,因此我这里的转换代码如下:

  1. transform1 = transforms.RandomHorizontalFlip()
  2. transform2 = transforms.RandomRotation(30)
  3. transform3 = transforms.ColorJitter(brightness=0.5)
  4. transform4 = transforms.RandomAffine(degrees=20, translate=(0.2, 0.2), scale=(0.7, 1.3))
  5. train_tfm = transforms.Compose([
  6. # Resize the image into a fixed shape (height = width = 128)
  7. transforms.Resize((128, 128)),
  8. # You may add some transforms here.
  9. transforms.RandomChoice([transform1, transform2, transform3, transform4]), # 对每个样本随意挑选一种转换
  10. # ToTensor() should be the last one of the transforms.
  11. transforms.ToTensor(),
  12. ])
  13. trans_size = 4

测试数据增强

这里我理解助教的意思是对一个测试样本进行多种不同的变换得到不同的测试样本进行预测,通过投票或其他方式决定这个测试样本的标签
如上面所说,在transforms.Compose中进行多种组合最后得到的仍为一个样本,因此不能在这里进行操作,这里选择在__getitem__中进行修改

  1. except:
  2. label = -1 # test has no label
  3. # multiple prediction for testing samples
  4. trans_im1 = train_tfm(im)
  5. trans_im2 = train_tfm(im)
  6. trans_im3 = train_tfm(im)
  7. trans_im4 = train_tfm(im)
  8. trans_im = torch.stack((trans_im, trans_im1, trans_im2, trans_im3, trans_im4))

这样我们就可以得到原样本以及变换后的四个样本,在测试时,依次取出进行预测并选择合适的机制得到标签即可。

这里我们选择助教说的第一种方式

  1. model_best = Classifier().to(device)
  2. model_best.load_state_dict(torch.load(f"{_exp_name}_best.ckpt"))
  3. model_best.eval()
  4. prediction = []
  5. with torch.no_grad():
  6. for data,_ in test_loader:
  7. # multiple prediction
  8. data = data.to(device)
  9. # original test batch sample
  10. test_pred = model_best(data[:, 0, :, :, :])
  11. test_pred1 = model_best(data[:, 1, :, :, :])
  12. test_pred2 = model_best(data[:, 2, :, :, :])
  13. test_pred3 = model_best(data[:, 3, :, :, :])
  14. test_pred4 = model_best(data[:, 4, :, :, :])
  15. test_pred = test_pred *0.65 + 0.35*(test_pred1 + test_pred2 + test_pred3 + test_pred4)/4.0
  16. test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)
  17. prediction += test_label.squeeze().tolist()

Mixup数据增强

这里助教所说应该是对训练数据进行Mixup增强,但是效果也是不太行。
下面给出__getitem__完整代码,这里面对train samples进行Mixup增强,对test samples进行了多变换增强。

  1. def __getitem__(self,idx):
  2. # original image procession
  3. fname = self.files[idx]
  4. im = Image.open(fname)
  5. # original image transform
  6. trans_im = self.transform(im)
  7. #
  8. # im = self.data[idx]
  9. try:
  10. # label = int(fname.split("/")[-1].split("_")[0])
  11. label = int(fname.split("/")[-1].split("_")[0].split("\\")[-1])
  12. # mixup augmentation for train samples
  13. fname1 = self.files[random.randint(0, self.__len__() - 1)]
  14. im1 = Image.open(fname1)
  15. trans_mix_im1 = self.transform(im1)
  16. label1 = int(fname1.split("/")[-1].split("_")[0].split("\\")[-1])
  17. trans_im = 0.5*trans_im + 0.5*trans_mix_im1
  18. label = [label, label1]
  19. except:
  20. label = -1 # test has no label
  21. # multiple prediction for testing samples
  22. trans_im1 = train_tfm(im)
  23. trans_im2 = train_tfm(im)
  24. trans_im3 = train_tfm(im)
  25. trans_im4 = train_tfm(im)
  26. trans_im = torch.stack((trans_im, trans_im1, trans_im2, trans_im3, trans_im4))
  27. return trans_im,label

loss修改

  1. loss = criterion(logits, labels[0].to(device))
  2. loss1 = criterion(logits, labels[1].to(device))
  3. loss = loss + loss1
  1. def mixup_accuracy(output, target1, target2):
  2. """
  3. 计算 Mixup 样本的准确率
  4. :param output: 模型的输出,形状为 (batch_size, num_classes)
  5. :param target1: 第一个样本的标签,形状为 (batch_size,)
  6. :param target2: 第二个样本的标签,形状为 (batch_size,)
  7. :return: 准确率
  8. """
  9. # 计算模型对混合样本的预测结果
  10. output = output.to(device)
  11. target1 = target1.to(device)
  12. target2 = target2.to(device)
  13. _, pred_indices = output.topk(2, dim=-1)
  14. # 取出预测结果中最大的两个值对应的索引
  15. pred1_indices = pred_indices[:, 0] # 第一个最大值的索引
  16. pred2_indices = pred_indices[:, 1] # 第二个最大值的索引
  17. # 计算混合样本的准确率
  18. acc1 = (((pred1_indices == target1) | (pred1_indices == target2))).float()
  19. acc2 = (((pred2_indices == target1) | (pred2_indices == target2))).float()
  20. acc = (acc1 + acc2 == 2.0).float().mean()
  21. return acc

Task3 Cross Validation & Ensemble

ChatGpt告诉我这么做

  1. k = 4
  2. kf = KFold(n_splits=k, shuffle=True, random_state=42)
  3. for fold, (train_idx, valid_idx) in enumerate(kf.split(dataset)):
  4. train_set = Subset(dataset, train_idx)
  5. valid_set = Subset(dataset, valid_idx)
  6. # Create data loaders for training and validation sets
  7. train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)
  8. valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, pin_memory=True)

最终结果

这一套组合拳下来,结果糟糕透了,不知道哪里出了问题!

原文链接:https://www.cnblogs.com/hywang1211/p/18028294

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号