Skip to content

Commit 9fac6c1

Browse files
authored
Merge pull request #1230 from hugy718/dev-postgresql
Add convolution prototype network used in TED project
2 parents c8c455b + 2fb3451 commit 9fac6c1

File tree

3 files changed

+321
-0
lines changed

3 files changed

+321
-0
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Convolutional Prototype Learning
2+
3+
We have successfully applied the idea of prototype loss in various medical image classification task to improve performance, for example detection thyroid eye disease from CT images. Here we provide the implementation of the convolution prototype model in Singa. Due to data privacy, we are not able to release the CT image dataset used. The training scripts `./train.py` demonstrate how to apply this model on cifar-10 dataset.
4+
5+
## run
6+
7+
At Singa project root directory `python examples/healthcare/application/TED_CT_Detection/train.py`
8+
9+
## reference
10+
11+
[Robust Classification with Convolutional Prototype Learning](https://arxiv.org/abs/1805.03438)
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
from singa import layer
20+
from singa import model
21+
import singa.tensor as tensor
22+
from singa import autograd
23+
from singa.tensor import Tensor
24+
25+
26+
class CPLayer(layer.Layer):
27+
def __init__(self, prototype_count=2, temp=10.0):
28+
super(CPLayer, self).__init__()
29+
self.prototype_count = prototype_count
30+
self.temp = temp
31+
32+
def initialize(self, x):
33+
self.feature_dim = x.shape[1]
34+
self.prototype = tensor.random(
35+
(self.feature_dim, self.prototype_count), device=x.device
36+
)
37+
38+
def forward(self, feat):
39+
self.device_check(feat, self.prototype)
40+
self.dtype_check(feat, self.prototype)
41+
42+
feat_sq = autograd.mul(feat, feat)
43+
feat_sq_sum = autograd.reduce_sum(feat_sq, axes=[1], keepdims=1)
44+
feat_sq_sum_tile = autograd.tile(feat_sq_sum, repeats=[1, self.feature_dim])
45+
46+
prototype_sq = autograd.mul(self.prototype, self.prototype)
47+
prototype_sq_sum = autograd.reduce_sum(prototype_sq, axes=[0], keepdims=1)
48+
prototype_sq_sum_tile = autograd.tile(prototype_sq_sum, repeats=feat.shape[0])
49+
50+
cross_term = autograd.matmul(feat, self.prototype)
51+
cross_term_scale = Tensor(
52+
shape=cross_term.shape, device=cross_term.device, requires_grad=False
53+
).set_value(-2)
54+
cross_term_scaled = autograd.mul(cross_term, cross_term_scale)
55+
56+
dist = autograd.add(feat_sq_sum_tile, prototype_sq_sum_tile)
57+
dist = autograd.add(dist, cross_term_scaled)
58+
59+
logits_coeff = (
60+
tensor.ones((feat.shape[0], self.prototype.shape[1]), device=feat.device)
61+
* -1.0
62+
/ self.temp
63+
)
64+
logits_coeff.requires_grad = False
65+
logits = autograd.mul(logits_coeff, dist)
66+
67+
return logits
68+
69+
def get_params(self):
70+
return {self.prototype.name: self.prototype}
71+
72+
def set_params(self, parameters):
73+
self.prototype.copy_from(parameters[self.prototype.name])
74+
75+
76+
class CPL(model.Model):
77+
78+
def __init__(
79+
self,
80+
backbone: model.Model,
81+
prototype_count=2,
82+
lamb=0.5,
83+
temp=10,
84+
label=None,
85+
prototype_weight=None,
86+
):
87+
super(CPL, self).__init__()
88+
# config
89+
self.lamb = lamb
90+
self.prototype_weight = prototype_weight
91+
self.prototype_label = label
92+
93+
# layer
94+
self.backbone = backbone
95+
self.cplayer = CPLayer(prototype_count=prototype_count, temp=temp)
96+
# optimizer
97+
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
98+
99+
def forward(self, x):
100+
feat = self.backbone.forward(x)
101+
logits = self.cplayer(feat)
102+
return logits
103+
104+
def train_one_batch(self, x, y):
105+
out = self.forward(x)
106+
loss = self.softmax_cross_entropy(out, y)
107+
self.optimizer(loss)
108+
return out, loss
109+
110+
def set_optimizer(self, optimizer):
111+
self.optimizer = optimizer
112+
113+
114+
def create_model(backbone, prototype_count=2, lamb=0.5, temp=10.0):
115+
model = CPL(backbone, prototype_count=prototype_count, lamb=lamb, temp=temp)
116+
return model
117+
118+
119+
__all__ = ["CPL", "create_model"]
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
from singa import device
21+
from singa import opt
22+
from singa import tensor
23+
import argparse
24+
import numpy as np
25+
import time
26+
from PIL import Image
27+
28+
import sys
29+
30+
sys.path.append(".")
31+
print(sys.path)
32+
33+
import examples.cnn.model.cnn as cnn
34+
from examples.cnn.data import cifar10
35+
import model as cpl
36+
37+
38+
def accuracy(pred, target):
39+
# y is network output to be compared with ground truth (int)
40+
y = np.argmax(pred, axis=1)
41+
a = y == target
42+
correct = np.array(a, "int").sum()
43+
return correct
44+
45+
46+
def resize_dataset(x, image_size):
47+
num_data = x.shape[0]
48+
dim = x.shape[1]
49+
X = np.zeros(shape=(num_data, dim, image_size, image_size), dtype=np.float32)
50+
for n in range(0, num_data):
51+
for d in range(0, dim):
52+
X[n, d, :, :] = np.array(
53+
Image.fromarray(x[n, d, :, :]).resize(
54+
(image_size, image_size), Image.BILINEAR
55+
),
56+
dtype=np.float32,
57+
)
58+
return X
59+
60+
61+
def run(
62+
local_rank,
63+
max_epoch,
64+
batch_size,
65+
sgd,
66+
graph,
67+
verbosity,
68+
dist_option="plain",
69+
spars=None,
70+
):
71+
dev = device.create_cuda_gpu_on(local_rank)
72+
dev.SetRandSeed(0)
73+
np.random.seed(0)
74+
75+
train_x, train_y, val_x, val_y = cifar10.load()
76+
77+
num_channels = train_x.shape[1]
78+
data_size = np.prod(train_x.shape[1 : train_x.ndim]).item()
79+
num_classes = (np.max(train_y) + 1).item()
80+
81+
backbone = cnn.create_model(num_channels=num_channels, num_classes=num_classes)
82+
model = cpl.create_model(backbone, prototype_count=10, lamb=0.5, temp=10)
83+
84+
if backbone.dimension == 4:
85+
tx = tensor.Tensor(
86+
(batch_size, num_channels, backbone.input_size, backbone.input_size), dev
87+
)
88+
train_x = resize_dataset(train_x, backbone.input_size)
89+
val_x = resize_dataset(val_x, backbone.input_size)
90+
elif backbone.dimension == 2:
91+
tx = tensor.Tensor((batch_size, data_size), dev)
92+
np.reshape(train_x, (train_x.shape[0], -1))
93+
np.reshape(val_x, (val_x.shape[0], -1))
94+
95+
ty = tensor.Tensor((batch_size,), dev, tensor.int32)
96+
num_train_batch = train_x.shape[0] // batch_size
97+
num_val_batch = val_x.shape[0] // batch_size
98+
idx = np.arange(train_x.shape[0], dtype=np.int32)
99+
100+
model.set_optimizer(sgd)
101+
model.compile([tx], is_train=True, use_graph=graph, sequential=True)
102+
dev.SetVerbosity(verbosity)
103+
104+
for epoch in range(max_epoch):
105+
print(f"Epoch {epoch}")
106+
np.random.shuffle(idx)
107+
108+
train_correct = np.zeros(shape=[1], dtype=np.float32)
109+
test_correct = np.zeros(shape=[1], dtype=np.float32)
110+
train_loss = np.zeros(shape=[1], dtype=np.float32)
111+
112+
model.train()
113+
for b in range(num_train_batch):
114+
x = train_x[idx[b * batch_size : (b + 1) * batch_size]]
115+
y = train_y[idx[b * batch_size : (b + 1) * batch_size]]
116+
tx.copy_from_numpy(x)
117+
ty.copy_from_numpy(y)
118+
119+
out, loss = model(tx, ty, dist_option, spars)
120+
train_correct += accuracy(tensor.to_numpy(out), y)
121+
train_loss += tensor.to_numpy(loss)[0]
122+
print(
123+
"Training loss = %f, training accuracy = %f"
124+
% (train_loss, train_correct / (num_train_batch * batch_size)),
125+
flush=True,
126+
)
127+
128+
model.eval()
129+
for b in range(num_val_batch):
130+
x = val_x[b * batch_size : (b + 1) * batch_size]
131+
y = val_y[b * batch_size : (b + 1) * batch_size]
132+
133+
tx.copy_from_numpy(x)
134+
ty.copy_from_numpy(y)
135+
136+
out_test = model(tx, ty, dist_option="fp32", spars=None)
137+
test_correct += accuracy(tensor.to_numpy(out_test), y)
138+
139+
140+
if __name__ == "__main__":
141+
parser = argparse.ArgumentParser(description="Train a CPL model")
142+
parser.add_argument(
143+
"-m",
144+
"--max-epoch",
145+
default=20,
146+
type=int,
147+
help="maximum epochs",
148+
dest="max_epoch",
149+
)
150+
parser.add_argument(
151+
"-b", "--batch-size", default=64, type=int, help="batch size", dest="batch_size"
152+
)
153+
parser.add_argument(
154+
"-l",
155+
"--learning-rate",
156+
default=0.005,
157+
type=float,
158+
help="initial learning rate",
159+
dest="lr",
160+
)
161+
parser.add_argument(
162+
"-i",
163+
"--device-id",
164+
default=0,
165+
type=int,
166+
help="which GPU to use",
167+
dest="device_id",
168+
)
169+
parser.add_argument(
170+
"-g",
171+
"--disable-graph",
172+
default="True",
173+
action="store_false",
174+
help="disable graph",
175+
dest="graph",
176+
)
177+
parser.add_argument(
178+
"-v",
179+
"--log-verbosity",
180+
default=0,
181+
type=int,
182+
help="logging verbosity",
183+
dest="verbosity",
184+
)
185+
args = parser.parse_args()
186+
print(args)
187+
188+
sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5)
189+
run(
190+
args.device_id, args.max_epoch, args.batch_size, sgd, args.graph, args.verbosity
191+
)

0 commit comments

Comments
 (0)