Skip to content

Commit db52e95

Browse files
committed
plienar beautified
1 parent b65876d commit db52e95

File tree

4 files changed

+594
-503
lines changed

4 files changed

+594
-503
lines changed
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
from datasets import load_dataset
2+
from torch.utils.data import DataLoader
3+
from torchvision import transforms
4+
import torch
5+
import torch.nn as nn
6+
import torch.optim as optim
7+
from tqdm import tqdm
8+
import matplotlib.pyplot as plt
9+
import pandas as pd
10+
11+
exp_path = "CIFAR_100_btViT"
12+
data_path = "uoft-cs/cifar100"
13+
dataset = load_dataset(data_path, split="train", streaming=False)
14+
15+
total_samples = 1281167
16+
batch_size = 64
17+
18+
transform = transforms.Compose([
19+
transforms.ToTensor()
20+
])
21+
22+
def collate_fn(batch):
23+
images, labels = [], []
24+
for item in batch:
25+
try:
26+
image = transform(item["img"])
27+
images.append(image)
28+
labels.append(item["fine_label"])
29+
except Exception as e:
30+
print(f"Error processing image: {e}")
31+
continue
32+
return torch.stack(images), torch.tensor(labels)
33+
34+
train_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
35+
36+
from plinear.models import btViT
37+
38+
dim = 256
39+
depth = 6
40+
41+
config = {'embed_dim' : dim,
42+
'depth' : depth,
43+
'mlp_dim' : 1024,
44+
'img_size' : 32,
45+
'patch_size' : 2,
46+
'channels' : 3,
47+
'num_classes' : 100}
48+
49+
model = btViT(**config)
50+
51+
from torchinfo import summary
52+
summary(model, input_size=(1, 3, 32, 32))
53+
54+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55+
device = torch.device("mps" if torch.mps.is_available() else "cpu")
56+
print(device)
57+
model = model.to(device)
58+
59+
# 손실 함수 및 옵티마이저 설정
60+
criterion = nn.CrossEntropyLoss()
61+
optimizer = optim.Adam(model.parameters(), lr = 1)
62+
63+
# 학습 결과 저장을 위한 리스트
64+
loss_history = []
65+
accuracy_history = []
66+
67+
import time
68+
training_start_time = time.time()
69+
70+
# 학습 루프
71+
num_epochs = 10
72+
73+
for epoch in range(num_epochs):
74+
model.train()
75+
running_loss = 0.0
76+
running_corrects = 0
77+
running_samples = 0
78+
79+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", total=total_samples // batch_size)
80+
for i, (images, labels) in enumerate(progress_bar):
81+
images, labels = images.to(device), labels.to(device)
82+
labels = labels % 1000
83+
84+
optimizer.zero_grad()
85+
outputs = model(images)
86+
loss = criterion(outputs, labels)
87+
loss.backward()
88+
optimizer.step()
89+
90+
preds = torch.argmax(outputs, dim=1)
91+
running_samples += labels.size(0)
92+
running_loss += loss.item()
93+
running_corrects += torch.sum(preds == labels).item()
94+
95+
progress_bar.set_postfix(loss=running_loss / running_samples, accuracy=running_corrects / running_samples)
96+
97+
model.push_to_hub(f'snowian/{exp_path}_{dim}_{depth}_{epoch + 1}')
98+
epoch_loss = running_loss / total_samples
99+
epoch_accuracy = running_corrects / total_samples
100+
loss_history.append(epoch_loss)
101+
accuracy_history.append(epoch_accuracy)
102+
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}")
103+
104+
training_end_time = time.time() # 학습 종료 시간
105+
training_duration = training_end_time - training_start_time # 전체 학습 소요 시간
106+
print(f"Total Training Time: {training_duration:.2f} seconds")
107+
108+
# Epoch-level metrics
109+
plt.figure(figsize=(10, 5))
110+
plt.plot(range(1, num_epochs + 1), loss_history, marker='o', label="Epoch Loss")
111+
plt.xlabel("Epoch")
112+
plt.ylabel("Loss")
113+
plt.title("Loss Over Epochs")
114+
plt.legend()
115+
plt.grid()
116+
plt.savefig(f"{exp_path}/{dim}-{depth} epoch_loss.png")
117+
118+
plt.figure(figsize=(10, 5))
119+
plt.plot(range(1, num_epochs + 1), accuracy_history, marker='o', label="Epoch Accuracy")
120+
plt.xlabel("Epoch")
121+
plt.ylabel("Accuracy")
122+
plt.title("Accuracy Over Epochs")
123+
plt.legend()
124+
plt.grid()
125+
plt.savefig(f"{exp_path}/{dim}-{depth} epoch_acc.png")
126+
127+
# Save metrics to CSV
128+
metrics_data = {
129+
"Epoch": list(range(1, num_epochs + 1)),
130+
"Epoch Loss": loss_history,
131+
"Epoch Accuracy": accuracy_history,
132+
}
133+
134+
# Create DataFrame and save to CSV
135+
epoch_df = pd.DataFrame({"Epoch": metrics_data["Epoch"],
136+
"Loss": metrics_data["Epoch Loss"],
137+
"Accuracy": metrics_data["Epoch Accuracy"]})
138+
139+
epoch_df.to_csv(f"{exp_path}/{dim}-{depth} epoch_metrics.csv", index=False)
140+
141+
print("Metrics saved to CSV files.")
142+
143+
import pandas as pd
144+
145+
def evaluate_model(model, dataloader, criterion, device, save_path=None):
146+
model.eval()
147+
running_corrects = 0
148+
total_samples = 0
149+
150+
with torch.no_grad():
151+
for images, labels in tqdm(dataloader, desc="Evaluating"):
152+
images, labels = images.to(device), labels.to(device)
153+
labels = labels % 1000
154+
outputs = model(images)
155+
preds = torch.argmax(outputs, dim=1)
156+
157+
running_corrects += torch.sum(preds == labels).item()
158+
total_samples += labels.size(0)
159+
160+
accuracy = running_corrects / total_samples
161+
print(f"Test Accuracy: {accuracy:.4f}")
162+
163+
# 검증 결과 저장
164+
if save_path:
165+
results = {"Accuracy": [accuracy]}
166+
results_df = pd.DataFrame(results)
167+
results_df.to_csv(save_path, index=False)
168+
print(f"Test results saved to {save_path}")
169+
170+
return accuracy
171+
172+
# # 검증 데이터로 평가
173+
# validation_dataset = load_dataset(data_path, split="validation", streaming=False)
174+
# print(validation_dataset[:10])
175+
# validation_loader = DataLoader(validation_dataset, batch_size=batch_size, collate_fn=collate_fn)
176+
177+
# val_start_time = time.time()
178+
179+
# test_accuracy = evaluate_model(
180+
# model, validation_loader, criterion, device, save_path=f"{exp_path}/{dim}-{depth} validation_results.csv"
181+
# )
182+
183+
# val_end_time = time.time() # 테스트 종료 시간
184+
# val_duration = val_end_time - val_start_time # 테스트 소요 시간
185+
# print(f"Total Validation Time: {val_duration:.2f} seconds")
186+
187+
188+
test_dataset = load_dataset(data_path, split="test", streaming=False)
189+
print(test_dataset[:10])
190+
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn)
191+
192+
test_start_time = time.time()
193+
194+
test_accuracy = evaluate_model(
195+
model, test_loader, criterion, device, save_path=f"{exp_path}/{dim}-{depth} test_results.csv"
196+
)
197+
198+
test_end_time = time.time() # 테스트 종료 시간
199+
test_duration = test_end_time - test_start_time # 테스트 소요 시간
200+
print(f"Total Test Time: {test_duration:.2f} seconds")

plinear/btnn/linear.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,23 @@
77
class Linear(nn.Module):
88
def __init__(self, x, y):
99
super(Linear, self).__init__()
10-
self.real_pos = nn.Linear(x, y)
11-
self.real_neg = nn.Linear(x, y)
10+
self.pr = nn.Linear(x, y)
11+
self.nr = nn.Linear(x, y)
1212

1313
torch.nn.init.uniform_(self.real_pos.weight, -1, 1)
1414
torch.nn.init.uniform_(self.real_neg.weight, -1, 1)
1515

1616
def forward(self, x):
17-
w_pos = self.real_pos.weight
18-
w_neg = self.real_neg.weight
19-
tern_pos = posNet(w_pos)
20-
tern_neg = posNet(w_neg)
17+
pr = self.pr.weight
18+
nr = self.nr.weight
19+
qpr = posNet(pr)
20+
qnr = posNet(nr)
2121

2222
# Apply quantization using posNet with detach
23-
tern_pos = tern_pos - w_pos.detach() + w_pos
24-
tern_neg = tern_neg - w_neg.detach() + w_neg
23+
qpr = qpr - pr.detach() + pr
24+
qnr = qnr - nr.detach() + nr
2525

2626
# Compute linear transformations
27-
y_pos = F.linear(x, tern_pos)
28-
y_neg = F.linear(x, tern_neg)
27+
yr = F.linear(x, qpr) - F.linear(x, qnr)
2928

30-
# Combine positive and negative parts
31-
y = y_pos - y_neg
32-
33-
return y
29+
return yr

0 commit comments

Comments
 (0)