图卷积网络GCN代码

整理文档很辛苦,赏杯茶钱您下走!

免费阅读已结束,点击下载阅读编辑剩下 ...

阅读已结束,您可以下载文档离线阅读编辑

资源描述

密级公开图卷积⽹络Python代码编写:July2,2018摘要GCN代码主要基于Kipf的博客和论⽂后Karateclub数据实现的半监督学习。代码是Python+Tensorflow。⽬录1RNN1I第1章Graphconvolutionalnetworks1#!/usr/bin/envpython3#*coding:utf8*3’semisupervisedlearning:asimpleGCNimplementationinkarateclubdataset’57__author__=’Zhangyijing’9importnumpyasnpimporttensorflowastf11importmatplotlib.pyplotasplt13lr=1e2#learningrateepochs=300#trainingepochs15nodes_num=34#34membersinthekarateclubclass_num=4#4classes17hidden_dim=[4,4,2,class_num]#thethirddigitisthelocationdimension,thelastistheclassnumberselected_labels_indices=[19,16,9,23]#onlyfournodesareselectedandeachclasspossessesanodewithknownlabels19#getthenormalizationadjacentmatrix21classdata_processing(object):def__init__(self,nodes_number):23self.__nodes_number=nodes_number25defget_adjMatrix(self):data=np.int32(np.loadtxt(’./karate_edges_77.txt’))27adjMatrix=np.zeros((self.__nodes_number,self.__nodes_number),dtype=np.float32)forindexinnp.arange(len(data)):129adjMatrix[data[index,0]1,data[index,1]1]=1adjMatrix+=np.eye(self.__nodes_number)31degreeMatrix_sqrt=np.diag(np.power(np.sum(adjMatrix,axis=0),1/2))adjMatrix_norm=np.dot(np.dot(degreeMatrix_sqrt,adjMatrix),degreeMatrix_sqrt)33returnadjMatrix_norm35defbuild_graph(adjMatrix_norm):37x=tf.placeholder(tf.float32,[nodes_num,None])y=tf.placeholder(tf.float32,[class_num,class_num])39#initializeweightsandbiases41W1=tf.Variable(tf.truncated_normal([nodes_num,hidden_dim[0]],stddev=0.01),dtype=tf.float32)bias1=tf.Variable(tf.truncated_normal([nodes_num,hidden_dim[0]],stddev=0.01),dtype=tf.float32)43W2=tf.Variable(tf.truncated_normal([hidden_dim[0],hidden_dim[1]],stddev=0.01),dtype=tf.float32)bias2=tf.Variable(tf.truncated_normal([nodes_num,hidden_dim[1]],stddev=0.01),dtype=tf.float32)45W3=tf.Variable(tf.truncated_normal([hidden_dim[1],hidden_dim[2]],stddev=0.01),dtype=tf.float32)bias3=tf.Variable(tf.truncated_normal([nodes_num,hidden_dim[2]],stddev=0.01),dtype=tf.float32)47W4=tf.Variable(tf.truncated_normal([hidden_dim[2],hidden_dim[3]],stddev=0.01),dtype=tf.float32)bias4=tf.Variable(tf.truncated_normal([nodes_num,hidden_dim[3]],stddev=0.01),dtype=tf.float32)49#calculatetheoutputsofeachlayer51o1=tf.nn.tanh(tf.add(tf.matmul(tf.matmul(adjMatrix_norm,x),W1),bias1))o2=tf.nn.tanh(tf.add(tf.matmul(tf.matmul(adjMatrix_norm,o1),W2),bias2))53o3=tf.nn.tanh(tf.add(tf.matmul(tf.matmul(adjMatrix_norm,o2),W3),bias3))logits=tf.add(tf.matmul(tf.matmul(adjMatrix_norm,o3),W4),bias4)55#theprediction257y_predict=tf.nn.softmax(logits)59#calculatethecrossentropyloss_op=tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2(logits=tf.gather(logits,selected_labels_indices),labels=y))61train_op=tf.train.AdamOptimizer(lr).minimize(loss_op)63pred_labels=tf.argmax(y_predict,axis=1)65return{’o3’:o3,’train_op’:train_op,67’loss_op’:loss_op,’y_predict’:y_predict,69’x’:x,’y’:y,71’pred_labels’:pred_labels}7375#plotthefiguredefplot_figure(location,labels):77data=np.int32(np.loadtxt(’./karate_edges_77.txt’))idx0=np.where(labels==0)[0]79idx1=np.where(labels==1)[0]idx2=np.where(labels==2)[0]81idx3=np.where(labels==3)[0]83#plotthescatterofdifferentlabelsplt.scatter(location[idx0,0],location[idx0,1],marker=’o’,color=’b’,label=’0’,s=20)85plt.scatter(location[idx1,0],location[idx1,1],marker=’o’,color=’g’,label=’1’,s=20)plt.scatter(location[idx2,0],location[idx2,1],marker=’o’,color=’k’,label=’2’,s=20)87plt.scatter(location[idx3,0],location[idx3,1],marker=’o’,color=’m’,label=’3’,s=20)89#theselectednodesaredenotedby’*’plt.scatter(location[selected_labels_indices[0],0],location[selected_labels_indices[0],1],marker=’x’,color=’b’,label=’0’,s=50)91plt.scatter(location[selected_labels_indices[1],0],location[3selected_labels_indices[1],1],marker=’x’,color=’g’,label=’1’,s=50)plt.scatter(location[selected_labels_indices[2],0],location[selected_labels_indices[2],1],marker=’x’,color=’k’,label=’2’,s=50)93plt.scatter(location[selected_labels_indices[3],0],location[selected_labels_indices[3],1],marker=’x’,color=’m’,label=’3’,s=50)95#plottheedgesforiinnp.arange(len(data)):97plt.plot([location[data[i,0]1,0],location[data[i,1]1,0]],[location[data[i,0]1,1],location[data[i,1]1,1]],color=’r’,linewidth=0.25)99plt.show()101x=np.eye(nodes_num)#thefeaturedescriptionofeachnode103y=np.eye(class_num)#theonehotcodinglabelsofthe4samplesdata=data_processing(nodes_num)105adjMatrix_norm=data.get_adjMatrix()#normalizedadjacentmatrixgraph=build_graph(adjMatrix_norm)#thebuildgraph107print(’Begintrain’)109withtf.Session()assess:sess.run(tf.global_variables_initializer())111forepochinnp.arange(epochs):print(epoch)113sess.run(graph[’train_op’],feed_dict={graph[’x’]:x,graph[’y’]:y})#trainthemodellocation=graph[’o3’].eval(feed_dict={graph[’x’]:x,graph[’y’]:y})#evaluatethelocations115labels=graph[’pred_labels’].eval(feed_dict={graph[’x’]:x,graph[’y’]:y})#evaluatethepredictedlabels117plot_figure(location,labels)GCN半监督学习代码4

1 / 7
下载文档,编辑使用

©2015-2020 m.777doc.com 三七文档.

备案号:鲁ICP备2024069028号-1 客服联系 QQ:2149211541

×
保存成功