溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

怎樣使用Keras和Tensorflow學習圖形數(shù)據(jù)

發(fā)布時間:2021-12-23 16:25:51 來源:億速云 閱讀:166 作者:柒染 欄目:互聯(lián)網(wǎng)科技

這篇文章給大家介紹怎樣使用Keras和Tensorflow學習圖形數(shù)據(jù),內(nèi)容非常詳細,感興趣的小伙伴們可以參考借鑒,希望對大家能有所幫助。

動機:

有很多數(shù)據(jù)可以在實際應用中以圖表的形式表示,如引文網(wǎng)絡,社交網(wǎng)絡(追隨者圖,朋友網(wǎng)絡......),生物網(wǎng)絡或電信。

使用Graph提取的特征可以通過依賴相鄰節(jié)點之間的信息流來提高預測模型的性能。但是,表示圖形數(shù)據(jù)并不簡單,特別是如果不打算實現(xiàn)手工制作的特征,因為大多數(shù)ML模型都期望固定大小或線性輸入,而圖形數(shù)據(jù)不是這種情況。

在這篇文章中,將探討一些處理通用圖的方法,以便根據(jù)直接從數(shù)據(jù)中學習的圖表表示進行節(jié)點分類。

數(shù)據(jù)集:

在Keras引用網(wǎng)絡數(shù)據(jù)集將作為基地,在整個這個職位的實現(xiàn)和實驗。每個節(jié)點代表科學論文,節(jié)點之間的邊緣代表兩篇論文之間的引用關系。

每個節(jié)點由一組二進制特征(字袋)以及將其鏈接到其他節(jié)點的一組邊表示。

該數(shù)據(jù)集有2708個節(jié)點,分為七個類別之一。該網(wǎng)絡有5429個鏈接。每個節(jié)點也由二進制字特征表示,指示相應字的存在。總體而言,每個節(jié)點有1433個二進制(稀疏)功能。以下我們只使用140 樣品用于培訓,其余用于驗證/測試。

問題設定:

怎樣使用Keras和Tensorflow學習圖形數(shù)據(jù)

問題:在沒有訓練樣本的情況下為圖中的節(jié)點分配類標簽。

直覺 / 假設:圖中接近的節(jié)點更可能具有相似的標簽。

解決方案:找到一種從圖中提取特征的方法,以幫助對新節(jié)點進行分類。

擬議方法:

基線模型:

怎樣使用Keras和Tensorflow學習圖形數(shù)據(jù)

簡單的基線模型

首先嘗試使用最簡單的模型,該模型學習僅使用二進制特征預測節(jié)點類并丟棄所有圖形信息。該模型是一個完全連接的神經(jīng)網(wǎng)絡,它將二進制特征作為輸入,并輸出每個節(jié)點的類概率。

def get_features_only_model(n_features, n_classes):
    in_ = Input((n_features,))
    x = Dense(10, activation="relu", kernel_regularizer=l1(0.001))(in_)
    x = Dropout(0.5)(x)
    x = Dense(n_classes, activation="softmax")(x)
    model = Model(in_, x)
    model.compile(loss="sparse_categorical_crossentropy", metrics=['acc'], optimizer="adam")
    model.summary()
    return model

基線模型準確度:53.28%

這是將通過添加基于圖形的功能來嘗試改進的初始準確度。

添加圖表功能:

通過在預測兩個輸入節(jié)點之間的最短路徑長度的倒數(shù)的輔助任務上訓練網(wǎng)絡,通過將每個節(jié)點嵌入到矢量中來自動學習圖形特征的一種方法,如下圖和下面的代碼片段所示:

怎樣使用Keras和Tensorflow學習圖形數(shù)據(jù)

學習每個節(jié)點的嵌入向量

def get_graph_embedding_model(n_nodes):

    in_1 = Input((1,))

    in_2 = Input((1,))

    emb = Embedding(n_nodes, 100, name="node1")

    x1 = emb(in_1)

    x2 = emb(in_2)

    x1 = Flatten()(x1)

    x1 = Dropout(0.1)(x1)

    x2 = Flatten()(x2)

    x2 = Dropout(0.1)(x2)

    x = Multiply()([x1, x2])

    x = Dropout(0.1)(x)

    x = Dense(1, activation="linear", name="spl")(x)

    model = Model([in_1, in_2], x)

    model.compile(loss="mae", optimizer="adam")

    model.summary()

    return model

下一步是使用預先訓練的節(jié)點嵌入作為分類模型的輸入。還使用學習的嵌入向量的距離添加附加輸入,該輸入是相鄰節(jié)點的平均二進制特征。

生成的分類網(wǎng)絡如下圖所示:

怎樣使用Keras和Tensorflow學習圖形數(shù)據(jù)

使用預訓練嵌入來進行節(jié)點分類

圖嵌入分類模型準確度:73.06%

可以看到,添加學習圖形特征作為分類模型的輸入有助于顯著提高分類準確性,與基線模型相比,從53.28%到73.06%。

改進圖形功能學習:

可以通過進一步推進預訓練并使用節(jié)點嵌入網(wǎng)絡中的二進制特征,然后除了節(jié)點嵌入向量之外重新使用來自二進制特征的預訓練權重,來進一步改進先前的模型。這導致模型依賴于從圖結構中學習的二進制特征的更有用的表示。

怎樣使用Keras和Tensorflow學習圖形數(shù)據(jù)

 

改進的圖嵌入分類模型準確度:76.35%

與以前的方法相比,這種額外的改進增加了幾個百分點。

關于怎樣使用Keras和Tensorflow學習圖形數(shù)據(jù)就分享到這里了,希望以上內(nèi)容可以對大家有一定的幫助,可以學到更多知識。如果覺得文章不錯,可以把它分享出去讓更多的人看到。

向AI問一下細節(jié)

免責聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權內(nèi)容。

AI