Yuwei Sun commited on
Commit
d7774e6
·
1 Parent(s): 8f2dd3f

Upload FedKA-Digit-Five.ipynb

Browse files
Files changed (1) hide show
  1. FedKA-Digit-Five.ipynb +540 -0
FedKA-Digit-Five.ipynb ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "LcO6E1lNBh1P"
7
+ },
8
+ "source": [
9
+ "## 1. Import necessary packages"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "id": "dBaDYH8WBh1U"
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "import numpy as np\n",
21
+ "import tensorflow as tf\n",
22
+ "import scipy.io\n",
23
+ "from torch.utils.data import TensorDataset\n",
24
+ "from torch.utils.data import DataLoader\n",
25
+ "import torch\n",
26
+ "import matplotlib.pyplot as plt\n",
27
+ "import random\n",
28
+ "import os\n",
29
+ "import torch.backends.cudnn as cudnn\n",
30
+ "import torch.optim as optim\n",
31
+ "import torch.utils.data\n",
32
+ "from torchvision import datasets\n",
33
+ "from torchvision import transforms\n",
34
+ "import torch.nn as nn\n",
35
+ "from torch.autograd import Function\n",
36
+ "\n",
37
+ "cudnn.benchmark = False\n",
38
+ "cudnn.deterministic = True\n",
39
+ "cuda = True\n",
40
+ "\n",
41
+ "lr = 3e-4\n",
42
+ "batch_size = 16\n",
43
+ "image_size = 28\n",
44
+ "n_epoch = 200\n",
45
+ "\n",
46
+ "def dataprocess(data, target):\n",
47
+ " data = torch.from_numpy(data).float()\n",
48
+ " target = torch.from_numpy(target).long() \n",
49
+ " dataset = TensorDataset(data, target)\n",
50
+ " trainloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
51
+ "\n",
52
+ " return trainloader"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "metadata": {
58
+ "id": "8n9FMTyTBh1U"
59
+ },
60
+ "source": [
61
+ "## 2. Prepare the datasets for clients (source) and the cloud (target)"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {
68
+ "id": "Zocl_klGBh1V"
69
+ },
70
+ "outputs": [],
71
+ "source": [
72
+ "mat = scipy.io.loadmat('Digit-Five/mnist_data.mat')\n",
73
+ "data = np.transpose((np.array((tf.image.grayscale_to_rgb(tf.convert_to_tensor(mat['train_28'])))).astype('float32')/255.0).reshape(-1,28,28,3), (0,3,1,2))\n",
74
+ "target = np.argmax((mat['label_train']), axis = 1)\n",
75
+ "c1_mt = [data, target]\n",
76
+ "\n",
77
+ "mat = scipy.io.loadmat('Digit-Five/mnistm_with_label.mat')\n",
78
+ "data = np.transpose((np.array((tf.convert_to_tensor(mat['train']))).astype('float32')/255.0).reshape(-1,28,28,3), (0,3,1,2)) \n",
79
+ "target = np.argmax((mat['label_train']), axis = 1)\n",
80
+ "c2_mm =[data, target]\n",
81
+ "\n",
82
+ "mat = scipy.io.loadmat('Digit-Five/usps_28x28.mat')\n",
83
+ "data = np.transpose((np.array((tf.image.grayscale_to_rgb(tf.convert_to_tensor(mat['dataset'][0][0].reshape(-1,28,28,1))))).astype('float32')).reshape(-1,28,28,3), (0,3,1,2))\n",
84
+ "target = mat['dataset'][0][1].flatten()\n",
85
+ "c3_up = [data, target]\n",
86
+ "\n",
87
+ "mat = scipy.io.loadmat('Digit-Five/svhn_train_32x32.mat')\n",
88
+ "data = np.transpose((np.array((tf.image.resize(np.moveaxis(mat['X'], -1, 0), [28,28]) )).astype('float32')/255.0).reshape(-1,28,28,3), (0,3,1,2))\n",
89
+ "target = (mat['y']-1).flatten()\n",
90
+ "c4_sv = [data, target]\n",
91
+ "\n",
92
+ "mat = scipy.io.loadmat('Digit-Five/syn_number.mat')\n",
93
+ "data = np.transpose((np.array((tf.image.resize(mat['train_data'], [28,28]) )).astype('float32')/255.0).reshape(-1,28,28,3), (0,3,1,2)) \n",
94
+ "target = mat['train_label'].flatten()\n",
95
+ "c5_sy = [data, target]"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {
102
+ "id": "6uECrd2eBh1V"
103
+ },
104
+ "outputs": [],
105
+ "source": [
106
+ "c1 = dataprocess(c1_mt[0], c1_mt[1])\n",
107
+ "c2 = dataprocess(c2_mm[0], c2_mm[1])\n",
108
+ "c3 = dataprocess(c3_up[0], c3_up[1])\n",
109
+ "c4 = dataprocess(c4_sv[0], c4_sv[1])\n",
110
+ "c5 = dataprocess(c5_sy[0], c5_sy[1])\n",
111
+ "\n",
112
+ "data_all = [c1_mt[0],c2_mm[0], c3_up[0], c4_sv[0], c5_sy[0]]"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "markdown",
117
+ "metadata": {
118
+ "id": "Th2GBdjoBh1X"
119
+ },
120
+ "source": [
121
+ "## 3. Define the MK-MMD loss (guassian kernel)"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "metadata": {
128
+ "id": "gjz5bZS0Bh1X"
129
+ },
130
+ "outputs": [],
131
+ "source": [
132
+ "class MMD_loss(nn.Module):\n",
133
+ " def __init__(self, kernel_mul = 2.0, kernel_num = 5):\n",
134
+ " super(MMD_loss, self).__init__()\n",
135
+ " self.kernel_num = kernel_num\n",
136
+ " self.kernel_mul = kernel_mul\n",
137
+ " self.fix_sigma = None\n",
138
+ " return\n",
139
+ "\n",
140
+ " def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):\n",
141
+ " n_samples = int(source.size()[0])+int(target.size()[0])\n",
142
+ " total = torch.cat([source, target], dim=0)\n",
143
+ "\n",
144
+ " total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))\n",
145
+ " total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))\n",
146
+ " L2_distance = ((total0-total1)**2).sum(2) \n",
147
+ " if fix_sigma:\n",
148
+ " bandwidth = fix_sigma\n",
149
+ " else:\n",
150
+ " bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)\n",
151
+ " bandwidth /= kernel_mul ** (kernel_num // 2)\n",
152
+ " bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]\n",
153
+ " kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]\n",
154
+ " return sum(kernel_val)\n",
155
+ "\n",
156
+ " def forward(self, source, target):\n",
157
+ " batch_size = int(source.size()[0])\n",
158
+ " kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)\n",
159
+ " XX = kernels[:batch_size, :batch_size]\n",
160
+ " YY = kernels[batch_size:, batch_size:]\n",
161
+ " XY = kernels[:batch_size, batch_size:]\n",
162
+ " YX = kernels[batch_size:, :batch_size]\n",
163
+ " loss = torch.mean(XX + YY - XY -YX)\n",
164
+ " return loss"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "markdown",
169
+ "metadata": {
170
+ "id": "bw4hooRKBh1Y"
171
+ },
172
+ "source": [
173
+ "## 4. Define the model architecture following the Reverse Layer"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "metadata": {
180
+ "id": "_dH5urjaBh1Y"
181
+ },
182
+ "outputs": [],
183
+ "source": [
184
+ "class ReverseLayerF(Function):\n",
185
+ "\n",
186
+ " @staticmethod\n",
187
+ " def forward(ctx, x, alpha):\n",
188
+ " ctx.alpha = alpha\n",
189
+ "\n",
190
+ " return x.view_as(x)\n",
191
+ "\n",
192
+ " @staticmethod\n",
193
+ " def backward(ctx, grad_output):\n",
194
+ " output = grad_output.neg() * ctx.alpha\n",
195
+ "\n",
196
+ " return output, None\n",
197
+ "\n",
198
+ "class CNNModel(nn.Module):\n",
199
+ "\n",
200
+ " def __init__(self):\n",
201
+ " super(CNNModel, self).__init__()\n",
202
+ " self.feature = nn.Sequential()\n",
203
+ " self.feature.add_module('f_conv1', nn.Conv2d(3, 64, kernel_size=5))\n",
204
+ " self.feature.add_module('f_bn1', nn.BatchNorm2d(64))\n",
205
+ " self.feature.add_module('f_pool1', nn.MaxPool2d(2))\n",
206
+ " self.feature.add_module('f_relu1', nn.ReLU(True))\n",
207
+ " self.feature.add_module('f_conv2', nn.Conv2d(64, 50, kernel_size=5))\n",
208
+ " self.feature.add_module('f_bn2', nn.BatchNorm2d(50))\n",
209
+ " self.feature.add_module('f_drop1', nn.Dropout2d())\n",
210
+ " self.feature.add_module('f_pool2', nn.MaxPool2d(2))\n",
211
+ " self.feature.add_module('f_relu2', nn.ReLU(True))\n",
212
+ " \n",
213
+ " self.class_classifier = nn.Sequential()\n",
214
+ " self.class_classifier.add_module('c_fc1', nn.Linear(50 * 4 * 4, 100))\n",
215
+ " self.class_classifier.add_module('c_bn1', nn.BatchNorm1d(100))\n",
216
+ " self.class_classifier.add_module('c_relu1', nn.ReLU(True))\n",
217
+ " self.class_classifier.add_module('c_fc3', nn.Linear(100, 10))\n",
218
+ " self.class_classifier.add_module('c_softmax', nn.LogSoftmax())\n",
219
+ " \n",
220
+ " self.domain_classifier = nn.Sequential()\n",
221
+ " self.domain_classifier.add_module('d_fc1', nn.Linear(50 * 4 * 4, 100))\n",
222
+ " self.domain_classifier.add_module('d_bn1', nn.BatchNorm1d(100))\n",
223
+ " self.domain_classifier.add_module('d_relu1', nn.ReLU(True))\n",
224
+ " self.domain_classifier.add_module('d_fc2', nn.Linear(100, 2))\n",
225
+ " self.domain_classifier.add_module('d_softmax', nn.LogSoftmax(dim=1))\n",
226
+ "\n",
227
+ " def forward(self, input_data, alpha):\n",
228
+ " input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28)\n",
229
+ " feature = self.feature(input_data)\n",
230
+ " feature = feature.view(-1, 50 * 4 * 4)\n",
231
+ " reverse_feature = ReverseLayerF.apply(feature, alpha)\n",
232
+ " class_output = self.class_classifier(feature)\n",
233
+ " domain_output = self.domain_classifier(reverse_feature)\n",
234
+ "\n",
235
+ " return class_output, domain_output"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "markdown",
240
+ "metadata": {
241
+ "id": "YAZ2eOkbBh1Y"
242
+ },
243
+ "source": [
244
+ "## 5. Federated Knowledge Alignment (FedKA) "
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": null,
250
+ "metadata": {
251
+ "id": "Su0klLOiBh1Z",
252
+ "scrolled": true
253
+ },
254
+ "outputs": [],
255
+ "source": [
256
+ "from tqdm import notebook\n",
257
+ "\n",
258
+ "def run(method, voting):\n",
259
+ " learned_models = []\n",
260
+ " global_net = CNNModel()\n",
261
+ " global_optimizer = optim.Adam(global_net.parameters(), lr=lr) \n",
262
+ " clients = []\n",
263
+ " optims = []\n",
264
+ " client_num = 4\n",
265
+ " for n in range(client_num):\n",
266
+ " local_net = CNNModel()\n",
267
+ " local_optimizer = optim.Adam(local_net.parameters(), lr=lr) \n",
268
+ " clients.append(local_net)\n",
269
+ " optims.append(local_optimizer)\n",
270
+ "\n",
271
+ " loss_class = torch.nn.NLLLoss()\n",
272
+ " loss_domain = torch.nn.NLLLoss()\n",
273
+ " loss_class = loss_class.cuda()\n",
274
+ " loss_domain = loss_domain.cuda()\n",
275
+ " loss_mmd = MMD_loss() \n",
276
+ " \n",
277
+ " acc_list = []\n",
278
+ " for epoch in notebook.tqdm(range(n_epoch)):\n",
279
+ " print(f\"===========Round {epoch} ===========\")\n",
280
+ " if cuda:\n",
281
+ " global_net =global_net.cuda()\n",
282
+ "\n",
283
+ " data_target_iter = iter(cloud_dataset[0]) \n",
284
+ " loss_epoch = []\n",
285
+ " \n",
286
+ " acc = []\n",
287
+ " for n in range(client_num):\n",
288
+ " # Enumerating batches from the dataloader provides a random selection of 512 samples every round.\n",
289
+ " for i, (s_img, s_label) in enumerate(clients_datasets[n]): \n",
290
+ " len_dataloader = 32\n",
291
+ " if i > 31: \n",
292
+ " break\n",
293
+ "\n",
294
+ " p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader\n",
295
+ " alpha = 2. / (1. + np.exp(-5 * p)) - 1\n",
296
+ "\n",
297
+ " optims[n].zero_grad()\n",
298
+ " batch_size = len(s_label)\n",
299
+ "\n",
300
+ " input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)\n",
301
+ " class_label = torch.LongTensor(batch_size)\n",
302
+ " domain_label = torch.zeros(batch_size)\n",
303
+ " domain_label = domain_label.long()\n",
304
+ "\n",
305
+ " if cuda:\n",
306
+ " clients[n] = clients[n].cuda()\n",
307
+ " s_img = s_img.cuda()\n",
308
+ " s_label = s_label.cuda()\n",
309
+ " input_img = input_img.cuda()\n",
310
+ " class_label = class_label.cuda()\n",
311
+ " domain_label = domain_label.cuda()\n",
312
+ "\n",
313
+ " input_img.resize_as_(s_img).copy_(s_img)\n",
314
+ " class_label.resize_as_(s_label).copy_(s_label)\n",
315
+ " class_output, domain_output = clients[n](input_data=input_img, alpha=alpha)\n",
316
+ "\n",
317
+ " err_s_label = loss_class(class_output, class_label)\n",
318
+ " err_s_domain = loss_domain(domain_output, domain_label)\n",
319
+ "\n",
320
+ " t_img, _ = data_target_iter.next()\n",
321
+ " batch_size = len(t_img)\n",
322
+ " input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)\n",
323
+ " domain_label = torch.ones(batch_size)\n",
324
+ " domain_label = domain_label.long()\n",
325
+ "\n",
326
+ " if cuda:\n",
327
+ " t_img = t_img.cuda()\n",
328
+ " input_img = input_img.cuda()\n",
329
+ " domain_label = domain_label.cuda()\n",
330
+ "\n",
331
+ " input_img.resize_as_(t_img).copy_(t_img)\n",
332
+ "\n",
333
+ " _, domain_output = clients[n](input_data=input_img, alpha=alpha)\n",
334
+ " err_t_domain = loss_domain(domain_output, domain_label)\n",
335
+ "\n",
336
+ " if method == 0 or method == 3:\n",
337
+ " err = err_s_label\n",
338
+ " else:\n",
339
+ " err = err_s_label + err_s_domain + err_t_domain\n",
340
+ "\n",
341
+ " err.backward()\n",
342
+ " optims[n].step()\n",
343
+ " \n",
344
+ " # mmd loss\n",
345
+ " mmd_loss_total = 0\n",
346
+ " for i, (s_img, s_label) in enumerate(clients_datasets[n]): \n",
347
+ " len_dataloader = 32\n",
348
+ " if i > 31: \n",
349
+ " break\n",
350
+ "\n",
351
+ " p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader\n",
352
+ " alpha = 2. / (1. + np.exp(-5* p)) - 1\n",
353
+ " batch_size = len(s_label)\n",
354
+ " t_img, _ = data_target_iter.next()\n",
355
+ " s_img = s_img.cuda()\n",
356
+ " t_img = t_img.cuda()\n",
357
+ "\n",
358
+ " hidden_c = clients[n].feature(s_img).reshape(batch_size, -1)\n",
359
+ " hidden_avg = global_net.feature(t_img).reshape(batch_size, -1)\n",
360
+ " mmd_loss = loss_mmd(hidden_c, hidden_avg)\n",
361
+ " mmd_loss_total =+mmd_loss*alpha\n",
362
+ " \n",
363
+ " if method == 2 or method == 3:\n",
364
+ " if (i+1) % 8 == 0:\n",
365
+ " optims[n].zero_grad()\n",
366
+ " err_mmd = mmd_loss_total/8\n",
367
+ " err_mmd.backward()\n",
368
+ " optims[n].step()\n",
369
+ " mmd_loss_total = 0 \n",
370
+ " \n",
371
+ " # FedAvg\n",
372
+ " global_sd = global_net.state_dict()\n",
373
+ " for key in global_sd:\n",
374
+ " global_sd[key] = torch.sum(torch.stack([model.state_dict()[key] for m, model in enumerate(clients)]), axis = 0)/client_num\n",
375
+ " # update the global model\n",
376
+ " global_net.load_state_dict(global_sd) \n",
377
+ " \n",
378
+ " \n",
379
+ " if voting:\n",
380
+ " total = 0\n",
381
+ " num_correct = 0\n",
382
+ " for i, (images, labels) in enumerate(cloud_dataset[0]):\n",
383
+ " len_dataloader = 128\n",
384
+ " if i > 127: \n",
385
+ " break\n",
386
+ "\n",
387
+ " p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader\n",
388
+ " alpha = 2. / (1. + np.exp(-5* p)) - 1\n",
389
+ "\n",
390
+ " images =images.cuda()\n",
391
+ " labels = labels.cuda()\n",
392
+ " global_optimizer.zero_grad()\n",
393
+ " class_output, _ = global_net(images, 0)\n",
394
+ "\n",
395
+ " votes = []\n",
396
+ " for n in range(client_num):\n",
397
+ " clients[n] = clients[n].cuda()\n",
398
+ " output,_ = clients[n](images, 0)\n",
399
+ " pred = torch.argmax(output, 1)\n",
400
+ " votes.append(pred)\n",
401
+ "\n",
402
+ " class_label = torch.Tensor([int(max(set(batch), key = batch.count).cpu().data.numpy()) for batch in [list(i) for i in zip(*votes)]]).to(torch.int64)\n",
403
+ " class_label = class_label.cuda()\n",
404
+ "\n",
405
+ " total += labels.size(0)\n",
406
+ " num_correct += (class_label == labels).sum().item()\n",
407
+ "\n",
408
+ " err = loss_class(class_output, class_label)*alpha\n",
409
+ " err.backward()\n",
410
+ " global_optimizer.step()\n",
411
+ " \n",
412
+ " print(f'Voting accuracy: {num_correct * 100 / total}% Adoption rate: {alpha*100}%')\n",
413
+ " \n",
414
+ " \n",
415
+ " # Evaluation every round \n",
416
+ " # Target task \n",
417
+ " with torch.no_grad():\n",
418
+ " num_correct = 0\n",
419
+ " total = 0\n",
420
+ "\n",
421
+ " for i, (images, labels) in enumerate(cloud_dataset[0]):\n",
422
+ " if i > 312:\n",
423
+ " break\n",
424
+ " \n",
425
+ " if cuda:\n",
426
+ " global_net = global_net.cuda()\n",
427
+ " images =images.cuda()\n",
428
+ " labels = labels.cuda()\n",
429
+ " \n",
430
+ " output,_ = global_net(images, 0)\n",
431
+ " pred = torch.argmax(output, 1)\n",
432
+ " total += labels.size(0)\n",
433
+ " num_correct += (pred == labels).sum().item()\n",
434
+ " \n",
435
+ " print(f'Global: Accuracy of the model on {total} test images: {num_correct * 100 / total}% \\n')\n",
436
+ " acc.append(num_correct * 100 / total)\n",
437
+ " \n",
438
+ " acc_list.append(acc)\n",
439
+ "\n",
440
+ " for n in range(client_num):\n",
441
+ " clients[n].load_state_dict(global_sd)\n",
442
+ "\n",
443
+ " return acc_list"
444
+ ]
445
+ },
446
+ {
447
+ "cell_type": "markdown",
448
+ "metadata": {
449
+ "id": "aS8nJ2sdBh1a"
450
+ },
451
+ "source": [
452
+ "## 6. Run experiments\n",
453
+ "\n",
454
+ "### Method 0: Source Only\n",
455
+ "\n",
456
+ "### Method 1: $f$-DANN\n",
457
+ "\n",
458
+ "### Method 2: FedKA"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": null,
464
+ "metadata": {
465
+ "id": "eeqIdMAXBh1b",
466
+ "outputId": "38dd4827-bd80-4f6d-9112-82dda29ab28c"
467
+ },
468
+ "outputs": [],
469
+ "source": [
470
+ "import matplotlib.pyplot as plt\n",
471
+ "\n",
472
+ "methods = [2]\n",
473
+ "voting = True\n",
474
+ "\n",
475
+ "# run 5 tasks\n",
476
+ "for t in range(5):\n",
477
+ " acc = []\n",
478
+ " clients_datasets = [c1, c2, c3, c4, c5]\n",
479
+ " cloud_dataset = [clients_datasets.pop(t)]\n",
480
+ " target = f\"c{t+1}\"\n",
481
+ "\n",
482
+ " # use three seeds\n",
483
+ " for s in range(3):\n",
484
+ " torch.manual_seed(s)\n",
485
+ " random.seed(s)\n",
486
+ " np.random.seed(s)\n",
487
+ " \n",
488
+ " acc_m = []\n",
489
+ " for method in methods:\n",
490
+ "\n",
491
+ " print(f\"Task: c{t+1} Seed: {s} Method: {method}\")\n",
492
+ "\n",
493
+ " evl = run(method, voting)\n",
494
+ " \n",
495
+ " result = np.array((evl)).T\n",
496
+ " plt.plot(result[0], label = \"Global\", color = 'C4')\n",
497
+ " plt.legend()\n",
498
+ " plt.show()\n",
499
+ "\n",
500
+ " acc_m.append(max(np.array((evl))[:,0]))\n",
501
+ " acc.append(acc_m)\n",
502
+ " \n",
503
+ " print(f'Task: c{t+1} Mean: {np.mean((np.array((acc)).T), axis =1)}')\n",
504
+ " print(f'Task: c{t+1} Std: {np.std((np.array((acc)).T), axis =1)}')"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": null,
510
+ "metadata": {},
511
+ "outputs": [],
512
+ "source": []
513
+ }
514
+ ],
515
+ "metadata": {
516
+ "colab": {
517
+ "name": "proposal.ipynb",
518
+ "provenance": []
519
+ },
520
+ "kernelspec": {
521
+ "display_name": "Python 3 (ipykernel)",
522
+ "language": "python",
523
+ "name": "python3"
524
+ },
525
+ "language_info": {
526
+ "codemirror_mode": {
527
+ "name": "ipython",
528
+ "version": 3
529
+ },
530
+ "file_extension": ".py",
531
+ "mimetype": "text/x-python",
532
+ "name": "python",
533
+ "nbconvert_exporter": "python",
534
+ "pygments_lexer": "ipython3",
535
+ "version": "3.7.11"
536
+ }
537
+ },
538
+ "nbformat": 4,
539
+ "nbformat_minor": 1
540
+ }