-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcollab_run.py
More file actions
executable file
·194 lines (181 loc) · 9.75 KB
/
collab_run.py
File metadata and controls
executable file
·194 lines (181 loc) · 9.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from tqdm import tqdm
import argparse
from collab_utils.clients import GeneralClient
from collab_utils.server import Server
from models.model import GPT
import random
import numpy as np
import torch
import os
import ast
from collab_utils.collaboration_strategies import to_collaboration_strategy
from collab_utils.aggregation_strategies import to_aggregation_strategy
import wandb
from models.lora import get_ft_model
def parse_list(value):
return ast.literal_eval(value)
parser = argparse.ArgumentParser()
parser.add_argument("-gr", "--num_global_rounds", default = 20, type=int)
parser.add_argument("-num_steps", "--num_local_steps", default = 25, type=int)
parser.add_argument("-model_path", "--model_path", type = str)
parser.add_argument('-lr',"--learning_rate",default=2e-3,type=float)
parser.add_argument('-wd',"--weight_decay",default=1e-2,type=float)
parser.add_argument('-ds','--dataset',default='agnews',type=str)
parser.add_argument('-data_path','--data_path',type=str)
parser.add_argument('-nc','--num_clients',default=4,type=int)
parser.add_argument('-device','--device',default="cuda",type=str)
parser.add_argument('-el', '--expert_lora_ranks', default='[8,8,8,8]', type=parse_list, help='Comma-separated list of LoRA ranks')
parser.add_argument('-en', '--expert_numbers', default='[2,2,2,2]', type=parse_list, help='Comma-separated list of number of experiments')
parser.add_argument('-k', '--topk', default=2,type=int)
parser.add_argument('-as','--collaboration_strategy',default="all", type=str)
parser.add_argument('-aggregation_strategy','--aggregation_strategy',default="default", type=str)
parser.add_argument('-bs','--batch_size', default=64,type=int)
parser.add_argument('-micro_bs','--micro_batch_size',default=64,type=int)
parser.add_argument('-wandb','--wandb_log',action='store_true')
parser.add_argument('-wandb_proj','--wandb_project',default="CoMoLE", type=str)
parser.add_argument('-wandb_run_name','--wandb_run_name',default="test", type=str)
parser.add_argument('-out_dir','--output_dir',default="../out", type=str)
parser.add_argument('-log_every','--num_log_steps',default=1, type=int)
parser.add_argument('-eval_every','--num_eval_steps',default=1, type=int)
parser.add_argument('-update_router_every','--num_router_update_steps',default=1, type=int)
parser.add_argument('-seed','--seed',default=1, type=int)
parser.add_argument('-scheduler','--scheduler', default="cosine", type=str)
parser.add_argument('-lb_lam','--lb_lambda', default=0.01, type=float)
parser.add_argument('-p_lam','--p_lambda', default=0.01, type=float)
parser.add_argument('-p_strength','--pruning_strength', default=0.99, type=float)
parser.add_argument('-is_pruning', '--is_pruning', action='store_true', help='Enable pruning if set')
parser.add_argument('-exp0_importance','--expert0_importance', default=0.9, type=float)
parser.add_argument('-gating_update_iters','--gating_update_iters', default=1, type=int)
parser.add_argument('-save_model','--save_model', action='store_true')
parser.add_argument('-lora_do','--lora_dropout', default=0.0, type=float)
parser.add_argument('-alter_on_train','--alter_gate_update_on_train', action='store_true')
parser.add_argument('-bm','--base_model', default="gpt2", type=str)
parser.add_argument('-is_alter','--is_alternating', action='store_true')
parser.add_argument('-is_no_router','--is_no_router', action='store_true')
parser.add_argument('-learning_rate_scale','--learning_rate_scale', default=1.0, type=float)
args = parser.parse_args()
num_gpus = 1
assert len(args.expert_lora_ranks) == args.num_clients, f"Please specify lora rank for each client {args.expert_lora_ranks}."
assert len(args.expert_numbers) == args.num_clients, f"Please specify number of expersts for each client {args.expert_numbers}."
assert (len(set(args.expert_numbers)) == 1 and args.collaboration_strategy == "all") or \
args.collaboration_strategy != "all", f"Different number of experts is not supported for `all` strategy: {args.expert_numbers}"
assert (all(value == 1 for value in args.expert_numbers) and len(set(args.expert_lora_ranks)) > 1) and (args.collaboration_strategy == "all" or args.collaboration_strategy == "ffalora") or \
len(set(args.expert_lora_ranks)) == 1, \
f"Different number of lora ranks is only supported for `all` strategy and 1 expert for each cliet: {args.collaboration_strategy}, {args.expert_numbers}, {args.expert_lora_ranks}"
# assert all(args.topk <= en for en in args.expert_numbers), "Each value in topk must be less than or equal to the corresponding value in expert_numbers"
collaboration_strategy = to_collaboration_strategy(args.collaboration_strategy)
aggregation_strategy = to_aggregation_strategy(args.aggregation_strategy)
print("is alternating updates:", args.is_alternating)
print("alter_on_train:", args.alter_gate_update_on_train)
def init_client_model(override_args):
if args.base_model.startswith("gpt"):
model = GPT.from_pretrained(args.base_model, override_args)
model = get_ft_model(model, collaboration_strategy)
elif "llama" in args.base_model:
from models.modeling_llama_moe_hf import LlamaMoEForCausalLM
from models.configuration_llama_moe import LlamaMoEConfig
model = LlamaMoEForCausalLM.from_pretrained(args.base_model, LlamaMoEConfig(**override_args))
model = get_ft_model(model, collaboration_strategy)
elif "SmolLM" in args.base_model:
from models.modeling_llama_moe_hf import LlamaMoEForCausalLM
from models.configuration_llama_moe import LlamaMoEConfig
smollm_args = {
"bos_token_id": 0,
"eos_token_id": 0,
"hidden_size": 960,
"intermediate_size": 2560,
"max_position_embeddings": 2048,
"num_attention_heads": 15,
"num_hidden_layers": 32,
"num_key_value_heads": 5,
"rope_theta": 10000.0,
"vocab_size": 49152
}
merged_args = {**smollm_args, **override_args}
model = LlamaMoEForCausalLM.from_pretrained(args.base_model, LlamaMoEConfig(**merged_args))
model = get_ft_model(model, collaboration_strategy)
else:
raise ValueError("Unknown model type")
return model
def init_server_model(override_args):
if args.base_model.startswith("gpt"):
server = Server(args, GPT, config = override_args)
elif "llama" in args.base_model:
from models.modeling_llama_moe_hf import LlamaMoEForCausalLM
from models.configuration_llama_moe import LlamaMoEConfig
server = Server(args, LlamaMoEForCausalLM, LlamaMoEConfig(**override_args))
elif "SmolLM" in args.base_model:
from models.modeling_llama_moe_hf import LlamaMoEForCausalLM
from models.configuration_llama_moe import LlamaMoEConfig
smollm_args = {
"bos_token_id": 0,
"eos_token_id": 0,
"hidden_size": 960,
"intermediate_size": 2560,
"max_position_embeddings": 2048,
"num_attention_heads": 15,
"num_hidden_layers": 32,
"num_key_value_heads": 5,
"rope_theta": 10000.0,
"vocab_size": 49152
}
merged_args = {**smollm_args, **override_args}
server = Server(args, LlamaMoEForCausalLM, LlamaMoEConfig(**merged_args))
else:
raise ValueError("Unknown model type")
return server
def set_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
set_seed(args.seed)
if args.wandb_log:
import wandb
wandb.init(project=args.wandb_project, entity='ec-llm', name=args.wandb_run_name, config=vars(args))
print('=============== initializing clients and server')
acc_steps = args.batch_size // args.micro_batch_size
clients = {}
for client_id in range(args.num_clients):
override_args = dict(
expert_num = args.expert_numbers[client_id],
lora_rank = args.expert_lora_ranks[client_id],
lora_dropout = args.lora_dropout,
topk_exp = min(args.topk,args.expert_numbers[client_id]),
load_balancing_lambda = args.lb_lambda,
pruning_lambda = args.p_lambda,
pruning_strength = args.pruning_strength,
pruning = args.is_pruning,
expert0_importance = args.expert0_importance,
is_no_router = args.is_no_router,
device = f'cuda:{client_id % num_gpus}' if num_gpus > 1 else 'cuda')
clients[client_id] = GeneralClient(
args=args,
client_id=client_id,
model=init_client_model,
data_path = os.path.join(args.data_path,str(args.num_clients)),
output_dir = args.output_dir,
override_args = override_args,
is_shifted=args.base_model.startswith("gpt"),
dtype=np.uint16 if args.base_model.startswith("gpt") else np.uint32)
server_override_args = dict(
expert_num = min(args.expert_numbers),
lora_rank = max(args.expert_lora_ranks),
topk_exp = args.topk,
is_no_router = args.is_no_router,
device = 'cpu')
server = init_server_model(server_override_args)
print('=============== collaborative finetuning')
for epoch in tqdm(range(args.num_global_rounds)):
print(f"Starting training of epoch: {epoch}")
for id in range(args.num_clients):
clients[id].synchronize(server.server_model, collaboration_strategy, aggregation_strategy, id)
clients[id].train(acc_steps = acc_steps, local_num_steps = args.num_local_steps)
print(f"Locally trained client: {id}")
with torch.no_grad():
server.aggregate_parameters([clients[i].model for i in range(args.num_clients)], collaboration_strategy, aggregation_strategy, [clients[i].num_train_samples for i in range(args.num_clients)])
if args.save_model == True:
for id in range(args.num_clients):
out_dir = os.path.join(args.output_dir, f'client_{id}')
if not os.path.exists(out_dir):
os.makedirs(out_dir)
clients[id].save_model(out_dir)