风格迁移0-05:stylegan-源码⽆死⾓解读(1)-框架总览
代码总览
根据源码中的README.md 我们可以知道,训练的开始,是要再源码的根⽬录下执⾏:train.py ,配置好config⽂件即可运⾏。我们再改⽂件中看到⼀个:
dnnlib.submit_run(**kwargs)
函数,其参数简单明了,⼀个kwargs,作者是简单明了了,但是害苦了我们啊,其主要参数如下:
其中submit_run的实现如下:
def submit_run(submit_config: SubmitConfig, run_func_name:str,**run_func_kwargs)->None:
"""Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place."""
submit_config = py(submit_config)
if submit_config.user_name is None:
submit_config.user_name = get_user_name()
submit_config.run_func_name = run_func_name
submit_config.run_func_kwargs = run_func_kwargs
assert submit_config.submit_target == SubmitTarget.LOCAL
if submit_config.submit_target in{SubmitTarget.LOCAL}:
run_dir = _create_run_dir_local(submit_config)
submit_config.task_name ="{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
submit_config.run_dir = run_dir
_populate_run_dir(run_dir, submit_config)
if submit_config.print_info:
print("\nSubmit config:\n")
pprint.pprint(submit_config, indent=4, width=200, compact=False)
print()
if submit_config.ask_confirmation:
if not util.ask_yes_no("Continue submitting the job?"):
return
run_wrapper(submit_config)
简单的来说,再run_wrapper函数之前的内容我们都不需要去理会,其主要的功能就是在results⽬录(根据配置⽂件)下⽣成⼀个⼦项⽬,该⼦项⽬保存了你当前的训练配置信息,如果在训练过程中发⽣了中断,我们可以运⾏⼦项⽬,继续训练。后续如果有时间,会为⼤家分析
其核⼼要点函数在于run_wrapper函数,我们看看其参数配置把:
submit_config
#运⾏的根⽬录,
'run_dir_root'={str}'results'
# 其⽣在'run_dir_root'⽬录下⽣成⼦项⽬的名称
'run_desc'={str}'sgan-result-1gpu'
# 应该是忽略拷贝的⽂件
'run_dir_ignore'={list}<class'list'>:['__pycache__','*.pyproj','*.sln','*.suo','.cache','.idea','.vs','.vscode','results','datasets','cache']
'run_dir_extra_files'={NoneType}None
'submit_target'={SubmitTarget} SubmitTarget.LOCAL
# 使⽤GPU的数⽬
'num_gpus'={int}1
# 是否打印信息
'print_info'={bool}False
'ask_confirmation'={bool}False
# 其⽣成⼦项⽬名称的前缀,如00002-sgan-result-1gpu前⾯的00002
'run_id'={int}11
# ⼦项⽬的名称
'run_name'={str}'00011-sgan-result-1gpu'
# ⼦项⽬运⾏,的⽬录
'run_dir'={str}'results\\00011-sgan-result-1gpu'
# 训练运⾏的函数,在aining_loop
'run_func_name'={str}'aining_loop'
# 运⾏函数的参数,后续着重分析,就是aining_loop函数的参数
'run_func_kwargs'={dict}<class'dict'>:{'mirror_augment':True,'total_kimg':25000,'G_args':{'func_name':'trainingworks_stylegan.G_style'},'D_ar gs':{'func_name':'trainingworks_stylegan.D_basic'},'G_opt_args':{'beta1':0.0,'beta2':0.99,'epsilon':1e-08},'D_opt_args':{'beta1':0.0,'beta2':0.99, 'epsilon':1e-08},'G_loss_args':{'func_name':'training.loss.G_logistic_nonsaturating'},'D_loss_args':{'func_name':'training.loss.D_logistic_simplegp','r1_ gamma':10.0},'dataset_args':{'tfrecord_dir':'result'},'sched_args':{'minibatch_base':1,'minibatch_dict':{4:32,8:32,16:32,32:16,64:8,128:4,256:2 ,512:1},'lod_initial_resolution':8,'G_lrate_dict':{128:0.0015,256:0.002,512:0.003,1024:0.003},'D_lrate_dict':{128:0.0015,256:0.002,512:0.003,1 024:0.
003}},'grid_args':{'size':'4k','layout':'random'},'metric_arg_list':[{'func_name':'metrics.frechet_inception_distance.FID','name':'fid50k','num_ima ges':50000,'minibatch_per_gpu':8}],'tf_config':...
# 电脑⽤户名
'user_name'={str}'zwh'
# 该次程序运⾏吗,名称(胡乱理解就⾏,⽆关紧要)
'task_name')={str}'zwh-00011-sgan-result-1gpu'
# 电脑主机名
'host_name'={str}'localhost'
看了上⾯的参数之后,我相信⼤家可以注意到,其中⼤部分配置应该都是与⼦项⽬相关的,但是要注意的⼀个参数是’run_func_kwargs’,该参数是再后⾯的调⽤中传递给aining_loop函数的,我们先进⼊run_wrapper函数(⼀些⽆关要紧的代码我就不注释了):
def run_wrapper(submit_config: SubmitConfig)->None:
......
#这⼀段都是都是log信息的收集
......
import dnnlib
dnnlib.submit_config = submit_config
try:
print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
start_time = time.time()
print('=1'*50)
print(submit_config.run_func_name)
# 通过字符串aining_loop,调⽤该函数
util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config,**submit_config.run_func_kwargs) print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time()- start_time)))
except:
......
# ⼀些打印的处理
......
可以明确的知道,其主要核⼼点在于函数util.call_func_by_name,其通过字符串加载模块,并调⽤函数,training_loop,现在我们来看看在前⾯没有注释的:
'run_func_kwargs'={dict}<class'dict'>:{
'mirror_augment':True,
'total_kimg':25000,
'G_args':{'func_name':'trainingworks_stylegan.G_style'},
'D_args':{'func_name':'trainingworks_stylegan.D_basic'},
'G_opt_args':{'beta1':0.0,'beta2':0.99,'epsilon':1e-08},
'D_opt_args':{'beta1':0.0,'beta2':0.99,'epsilon':1e-08},
'G_loss_args':{'func_name':'training.loss.G_logistic_nonsaturating'},
'D_loss_args':{'func_name':'training.loss.D_logistic_simplegp','r1_gamma':10.0},
'dataset_args':{'tfrecord_dir':'result'},
'sched_args':{'minibatch_base':1,'minibatch_dict':{4:32,8:32,16:32,32:16,64:8,128:4,256:2,512:1},
'lod_initial_resolution':8,
'G_lrate_dict':{128:0.0015,256:0.002,512:0.003,1024:0.003},
'D_lrate_dict':{128:0.0015,256:0.002,512:0.003,1024:0.003}},
'grid_args':{'size':'4k','layout':'random'},
'metric_arg_list':[{
'func_name':'metrics.frechet_inception_distance.FID',
'name':'fid50k',
'num_images':50000,
'minibatch_per_gpu':8}],
'tf_config':...
total_kimg :2500
这样为⼤家展开,应该还是很清晰明了的,注意,上⾯的配置,是我本⼈的配置,并不代表的配置会和我的⼀样。我们还是来看看training_loop函数,在这之前我们看看其函数参数的意,当作结合上⾯⼀起注释了
# discriminators ⽹络框架
# discriminators ⽹络框架
D_args ={'func_name':'trainingworks_stylegan.D_basic'}
# discriminators ⽹络的损失函数
D_loss_args ={'func_name':'training.loss.D_logistic_simplegp','r1_gamma':10.0}
# 应该是 discriminators⽹络求损失是相关的参数,后续详细了解
D_opt_args ={'beta1':0.0,'beta2':0.99,'epsilon':1e-08}
# 重复次数,可能是进⾏多次鉴别
D_repeats =1
# ⽣成⽹络框架
G_args ={'func_name':'trainingworks_stylegan.G_style'}
# ⽣成⽹络损失
G_loss_args ={'func_name':'training.loss.G_logistic_nonsaturating'}
# ⽣成⽹络相关的超参数
G_opt_args ={'beta1':0.0,'beta2':0.99,'epsilon':1e-08}
# 平滑,不是很理解,后续在了解
G_smoothing_kimg =10.0
# ⽣成训练数据的⽬录
dataset_args ={'tfrecord_dir':'result'}
# 这个也是不是了解
drange_net =[-1,1]
# ⽹格,输出图⽚,4K可能代表4*1024
grid_args ={'size':'4k','layout':'random'}
# 图⽚快照,也不是很理解
image_snapshot_ticks =1
# 对模型进⾏指标衡量的参数
metric_arg_list =[{
'func_name':'metrics.frechet_inception_distance.FID',
'name':'fid50k',
'num_images':50000,
'minibatch_per_gpu':8}]
# minibatch重复次数
minibatch_repeats =4
# 镜像翻转
mirror_augment =True
# ⼤概是训练10ticks打印⼀次图⽚
network_snapshot_ticks =10
# 不是很理解
reset_opt_for_new_lod =True
# 这个参数⽐较重要,⼤家因为某种原因断了,训练⼦项⽬的时候,可以设置为上次断了的时候,训练的图⽚张数,如4000 resume_kimg =0.0
# 加载预模型的id,00002-sgan-result-1gpu的前缀,如这⾥的00002
resume_run_id =}None
resume_snapshot =None
# 继续训练的时间点
resume_time =0.0
# 保存模型
save_tf_graph =False
save_tf_graph =False
save_weight_histograms =False
# 最⼩的minibatch的基数,以及各个分辨率在⽣成和鉴别⽹络的学习率
sched_args ={'minibatch_base':1,'minibatch_dict':{4:32,8:32,16:32,32:16,64:8,128:4,256:2,512:1},'lod_initial_resolution':8,'G_lrate_dict':{12 8:0.0015,256:0.002,512:0.003,1024:0.003},'D_lrate_dict':{128:0.0015,256:0.002,512:0.003,1024:0.003}}
# ⼦项⽬的配置
submit_config ={'run_dir_root':'results','run_desc':'sgan-result-1gpu','run_dir_ignore':['__pycache__','*.pyproj','*.sln','*.suo','.cache','.idea','.vs','.vsc ode','results','datasets','cache'],'run_dir_extra_files':None,'submit_target':<SubmitTarget.LOCAL:1>,'nu
m_gpus':1,'print_info':False,'ask_confirmati on':False,'run_id':6,'run_name':'00006-sgan-result-1gpu','run_dir':'results\\00006-sgan-result-1gpu','run_func_name':'aining_l oop','run_func_kwargs':{'mirror_augment':True,'total_kimg':25000,'G_args':{'func_name':'trainingworks_stylegan.G_style'},'D_args':{'func_name' :'trainingworks_stylegan.D_basic'},'G_opt_args':{'beta1':0.0,'beta2':0.99,'epsilon':1e-08},'D_opt_args':{'beta1':0.0,'beta2':0.99,'epsilon':1e-08}, 'G_loss_args':{'func_name':'training.loss.G_logistic_nonsaturating'},'D_loss_args':{'func_name':'training.loss.D_logistic_simplegp','r1_gamma':10.0},' dataset_args':{'
# 随机种⼦
tf_config ={'rnd.np_random_seed':1000}
# 训练数据总的图⽚张数
total_kimg =25000
在源码中,也有对函数参数的详细注解,⼤家可以结合⼀起参考,有错误的地⽅欢迎提出,本⼈好进⾏修改。下⾯是函数总体框架的注解:
def training_loop(
#print('sched_args: ',sched_args)
# Initialize dnnlib and TensorFlow.
# 根据⼦项⽬对初始化⼀些基本配置
ctx = dnnlib.RunContext(submit_config, train)
tflib.init_tf(tf_config)
# Load training set.
# 加载训练数据,其会把所有分辨率的数据都加载进来
4k电影源代码training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True,**dataset_args)
# Construct networks.,
# 如果指定了resume_run_id,则加其中的预训练模型,如果没有则从零开始训练。该处为核⼼重点,后续仔细分析
with tf.device('/gpu:0'):
if resume_run_id is not None:
network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
print('Loading networks from "%s"...'% network_pkl)
G, D, Gs = misc.load_pkl(network_pkl)
else:
print('')
G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size,**G_args)
D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size,**D_args)
# 如果有多个GPU存在,其会其多个GPU权重的平均值。可以理解为,专门⽤来保存权重的
Gs = G.clone('Gs')
G.print_layers(); D.print_layers()
print('Building ')
with tf.name_scope('Inputs'), tf.device('/cpu:0'):
# 图⽚分辨的,以2的多少次⽅进⾏输⼊,就是我们训练数据的2,3,4,5,6,7,......
lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
# 学习率
lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
# 输⼊minibatch数⽬
minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
# 每个GPU训练的批次⼤⼩
minibatch_split = minibatch_in // submit_config.num_gpus
# 这个参数也⽐较奇怪,后续分析内部代码时讲解
Gs_beta =0.5** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg *1000.0)if G_smoothing_kimg >0.0else0.0
# 对⽹络进⾏优化,应该包含了损失函数在⾥⾯
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论