1 Star 0 Fork 0

Season/GHMOE

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
hetero_data_utils.py 15.57 KB
一键复制 编辑 原始数据 按行查看 历史
Season 提交于 2025-01-04 05:46 . original code
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import numpy as np
from sklearn.preprocessing import label_binarize
import scipy.io
from ogb.nodeproppred import NodePropPredDataset
import gdown
from os import path
from torch_geometric.data import Data
import os
from torch_geometric.utils import to_undirected
import torch_geometric.transforms as T
import pandas as pd
class NCDataset(object):
def __init__(self, name):
"""
based off of ogb NodePropPredDataset
https://github.com/snap-stanford/ogb/blob/master/ogb/nodeproppred/dataset.py
Gives torch tensors instead of numpy arrays
- name (str): name of the dataset
- root (str): root directory to store the dataset folder
- meta_dict: dictionary that stores all the meta-information about data. Default is None,
but when something is passed, it uses its information. Useful for debugging for external contributers.
Usage after construction:
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
graph, label = dataset[0]
Where the graph is a dictionary of the following form:
dataset.graph = {'edge_index': edge_index,
'edge_feat': None,
'node_feat': node_feat,
'num_nodes': num_nodes}
For additional documentation, see OGB Library-Agnostic Loader https://ogb.stanford.edu/docs/nodeprop/
"""
self.name = name # original name, e.g., ogbn-proteins
self.graph = {}
self.label = None
def get_idx_split(self, split_type='random', train_prop=.5, valid_prop=.25):
"""
train_prop: The proportion of dataset for train split. Between 0 and 1.
valid_prop: The proportion of dataset for validation split. Between 0 and 1.
"""
if split_type == 'random':
ignore_negative = False if self.name == 'ogbn-proteins' else True
train_idx, valid_idx, test_idx = rand_train_test_idx(
self.label, train_prop=train_prop, valid_prop=valid_prop, ignore_negative=ignore_negative)
split_idx = {'train': train_idx,
'valid': valid_idx,
'test': test_idx}
return split_idx
def __getitem__(self, idx):
assert idx == 0, 'This dataset has only one graph'
return self.graph, self.label
def __len__(self):
return 1
def __repr__(self):
return '{}({})'.format(self.__class__.__name__, len(self))
def load_nc_dataset(dataname):
""" Loader for NCDataset, returns NCDataset. """
if dataname == 'penn94':
dataset = load_penn94_dataset()
elif dataname == 'arxiv-year':
dataset = load_arxiv_year_dataset()
elif dataname == 'snap-patents':
dataset = load_snap_patents_mat()
elif dataname == 'pokec':
dataset = load_pokec_mat()
elif dataname == "genius":
dataset = load_genius()
elif dataname == "twitch-gamer":
dataset = load_twitch_gamer_dataset()
elif dataname == 'ogbn-proteins':
dataset = load_proteins_dataset()
else:
raise ValueError('Invalid dataname')
return dataset
def load_penn94_dataset():
filename = 'penn94'
A, metadata = load_penn94()
dataset = NCDataset(filename)
edge_index = torch.tensor(A.nonzero(), dtype=torch.long)
metadata = metadata.astype(int)
label = metadata[:, 1] - 1 # gender label, -1 means unlabeled
# make features into one-hot encodings
feature_vals = np.hstack(
(np.expand_dims(metadata[:, 0], 1), metadata[:, 2:]))
features = np.empty((A.shape[0], 0))
for col in range(feature_vals.shape[1]):
feat_col = feature_vals[:, col]
feat_onehot = label_binarize(feat_col, classes=np.unique(feat_col))
features = np.hstack((features, feat_onehot))
node_feat = torch.tensor(features, dtype=torch.float)
num_nodes = metadata.shape[0]
dataset.graph = {'edge_index': edge_index,
'edge_feat': None,
'node_feat': node_feat,
'num_nodes': num_nodes}
dataset.label = torch.tensor(label)
return dataset
def load_penn94():
# e.g. filename = Rutgers89 or Cornell5 or Wisconsin87 or Amherst41
# columns are: student/faculty, gender, major,
# second major/minor, dorm/house, year/ high school
# 0 denotes missing entry
mat = scipy.io.loadmat('data/' + 'facebook100/' + 'penn94' + '.mat')
A = mat['A']
metadata = mat['local_info']
return A, metadata
def load_arxiv_year_dataset(nclass=5):
filename = 'arxiv-year'
dataset = NCDataset(filename)
ogb_dataset = NodePropPredDataset(name='ogbn-arxiv')
dataset.graph = ogb_dataset.graph
dataset.graph['edge_index'] = torch.as_tensor(dataset.graph['edge_index'])
dataset.graph['node_feat'] = torch.as_tensor(dataset.graph['node_feat'])
label = even_quantile_labels(
dataset.graph['node_year'].flatten(), nclass, verbose=False)
dataset.label = torch.as_tensor(label).reshape(-1, 1)
return dataset
def even_quantile_labels(vals, nclasses, verbose=True):
""" partitions vals into nclasses by a quantile based split,
where the first class is less than the 1/nclasses quantile,
second class is less than the 2/nclasses quantile, and so on
vals is np array
returns an np array of int class labels
"""
label = -1 * np.ones(vals.shape[0], dtype=int)
interval_lst = []
lower = -np.inf
for k in range(nclasses - 1):
upper = np.nanquantile(vals, (k + 1) / nclasses)
interval_lst.append((lower, upper))
inds = (vals >= lower) * (vals < upper)
label[inds] = k
lower = upper
label[vals >= lower] = nclasses - 1
interval_lst.append((lower, np.inf))
if verbose:
print('Class Label Intervals:')
for class_idx, interval in enumerate(interval_lst):
print(f'Class {class_idx}: [{interval[0]}, {interval[1]})]')
return label
def load_snap_patents_mat(nclass=5):
if not path.exists(f'data/snap_patents.mat'):
p = dataset_drive_url['snap-patents']
print(f"Snap patents url: {p}")
gdown.download(id=dataset_drive_url['snap-patents'], \
output=f'data/snap_patents.mat', quiet=False)
fulldata = scipy.io.loadmat(f'data/snap_patents.mat')
dataset = NCDataset('snap_patents')
edge_index = torch.tensor(fulldata['edge_index'], dtype=torch.long)
node_feat = torch.tensor(
fulldata['node_feat'].todense(), dtype=torch.float)
num_nodes = int(fulldata['num_nodes'])
dataset.graph = {'edge_index': edge_index,
'edge_feat': None,
'node_feat': node_feat,
'num_nodes': num_nodes}
years = fulldata['years'].flatten()
label = even_quantile_labels(years, nclass, verbose=False)
dataset.label = torch.tensor(label, dtype=torch.long)
return dataset
def load_pokec_mat():
""" requires pokec.mat
"""
if not path.exists(f'data/pokec.mat'):
gdown.download(id=dataset_drive_url['pokec'], \
output=f'data/pokec.mat', quiet=False)
fulldata = scipy.io.loadmat(f'data/pokec.mat')
dataset = NCDataset('pokec')
edge_index = torch.tensor(fulldata['edge_index'], dtype=torch.long)
node_feat = torch.tensor(fulldata['node_feat']).float()
num_nodes = int(fulldata['num_nodes'])
dataset.graph = {'edge_index': edge_index,
'edge_feat': None,
'node_feat': node_feat,
'num_nodes': num_nodes}
label = fulldata['label'].flatten()
dataset.label = torch.tensor(label, dtype=torch.long)
return dataset
def load_genius():
filename = 'genius'
dataset = NCDataset(filename)
fulldata = scipy.io.loadmat(f'data/genius.mat')
edge_index = torch.tensor(fulldata['edge_index'], dtype=torch.long)
node_feat = torch.tensor(fulldata['node_feat'], dtype=torch.float)
label = torch.tensor(fulldata['label'], dtype=torch.long).squeeze()
num_nodes = label.shape[0]
dataset.graph = {'edge_index': edge_index,
'edge_feat': None,
'node_feat': node_feat,
'num_nodes': num_nodes}
dataset.label = label
return dataset
def load_twitch_gamer_dataset(task="mature", normalize=True):
if not path.exists(f'data/twitch-gamer_feat.csv'):
gdown.download(id=dataset_drive_url['twitch-gamer_feat'],
output=f'data/twitch-gamer_feat.csv', quiet=False)
if not path.exists(f'data/twitch-gamer_edges.csv'):
gdown.download(id=dataset_drive_url['twitch-gamer_edges'],
output=f'data/twitch-gamer_edges.csv', quiet=False)
edges = pd.read_csv(f'data/twitch-gamer_edges.csv')
nodes = pd.read_csv(f'data/twitch-gamer_feat.csv')
edge_index = torch.tensor(edges.to_numpy()).t().type(torch.LongTensor)
num_nodes = len(nodes)
label, features = load_twitch_gamer(nodes, task)
node_feat = torch.tensor(features, dtype=torch.float)
if normalize:
node_feat = node_feat - node_feat.mean(dim=0, keepdim=True)
node_feat = node_feat / node_feat.std(dim=0, keepdim=True)
dataset = NCDataset("twitch-gamer")
dataset.graph = {'edge_index': edge_index,
'node_feat': node_feat,
'edge_feat': None,
'num_nodes': num_nodes}
dataset.label = torch.tensor(label)
return dataset
def load_twitch_gamer(nodes, task="dead_account"):
nodes = nodes.drop('numeric_id', axis=1)
nodes['created_at'] = nodes.created_at.replace('-', '', regex=True).astype(int)
nodes['updated_at'] = nodes.updated_at.replace('-', '', regex=True).astype(int)
one_hot = {k: v for v, k in enumerate(nodes['language'].unique())}
lang_encoding = [one_hot[lang] for lang in nodes['language']]
nodes['language'] = lang_encoding
if task is not None:
label = nodes[task].to_numpy()
features = nodes.drop(task, axis=1).to_numpy()
return label, features
def load_proteins_dataset():
ogb_dataset = NodePropPredDataset(name='ogbn-proteins')
dataset = NCDataset('ogbn-proteins')
def protein_orig_split(**kwargs):
split_idx = ogb_dataset.get_idx_split()
return {'train': torch.as_tensor(split_idx['train']),
'valid': torch.as_tensor(split_idx['valid']),
'test': torch.as_tensor(split_idx['test'])}
dataset.get_idx_split = protein_orig_split
dataset.graph, dataset.label = ogb_dataset.graph, ogb_dataset.labels
dataset.graph['edge_index'] = torch.as_tensor(dataset.graph['edge_index'])
dataset.graph['edge_feat'] = torch.as_tensor(dataset.graph['edge_feat'])
dataset.label = torch.as_tensor(dataset.label)
return dataset
def rand_train_test_idx(label, train_prop=.5, valid_prop=.25, ignore_negative=True):
""" randomly splits label into train/valid/test splits """
if ignore_negative:
labeled_nodes = torch.where(label != -1)[0]
else:
labeled_nodes = label
n = labeled_nodes.shape[0]
train_num = int(n * train_prop)
valid_num = int(n * valid_prop)
perm = torch.as_tensor(np.random.permutation(n))
train_indices = perm[:train_num]
val_indices = perm[train_num:train_num + valid_num]
test_indices = perm[train_num + valid_num:]
if not ignore_negative:
return train_indices, val_indices, test_indices
train_idx = labeled_nodes[train_indices]
valid_idx = labeled_nodes[val_indices]
test_idx = labeled_nodes[test_indices]
return train_idx, valid_idx, test_idx
def load_fixed_splits(dataset):
""" loads saved fixed splits for dataset
"""
name = dataset
if not os.path.exists(f'./data/splits/{name}-splits.npy'):
assert dataset in splits_drive_url.keys()
gdown.download(id=splits_drive_url[dataset], output=f'./data/splits/{name}-splits.npy', quiet=False)
splits_lst = np.load(f'./data/splits/{name}-splits.npy', allow_pickle=True)
for i in range(len(splits_lst)):
for key in splits_lst[i]:
if not torch.is_tensor(splits_lst[i][key]):
splits_lst[i][key] = torch.as_tensor(splits_lst[i][key])
return splits_lst
splits_drive_url = {
'snap-patents' : '12xbBRqd8mtG_XkNLH8dRRNZJvVM4Pw-N',
'pokec' : '1ZhpAiyTNc0cE_hhgyiqxnkKREHK7MK-_',
}
dataset_drive_url = {
'twitch-gamer_feat' : '1fA9VIIEI8N0L27MSQfcBzJgRQLvSbrvR',
'twitch-gamer_edges' : '1XLETC6dG3lVl7kDmytEJ52hvDMVdxnZ0',
'snap-patents' : '1ldh23TSY1PwXia6dU0MYcpyEgX-w3Hia',
'pokec' : '1dNs5E7BrWJbgcHeQ_zuy5Ozp2tRCWG0y',
'yelp-chi': '1fAXtTVQS4CfEk4asqrFw9EPmlUPGbGtJ',
'wiki_views': '1p5DlVHrnFgYm3VsNIzahSsvCD424AyvP', # Wiki 1.9M
'wiki_edges': '14X7FlkjrlUgmnsYtPwdh-gGuFla4yb5u', # Wiki 1.9M
'wiki_features': '1ySNspxbK-snNoAZM7oxiWGvOnTRdSyEK' # Wiki 1.9M
}
def to_sparse_tensor(edge_index, edge_feat, num_nodes):
""" converts the edge_index into SparseTensor
"""
num_edges = edge_index.size(1)
(row, col), N, E = edge_index, num_nodes, num_edges
perm = (col * N + row).argsort()
row, col = row[perm], col[perm]
value = edge_feat[perm]
adj_t = SparseTensor(row=col, col=row, value=value,
sparse_sizes=(N, N), is_sorted=True)
# Pre-process some important attributes.
adj_t.storage.rowptr()
adj_t.storage.csr2csc()
return adj_t
def load_hetero_data(args):
dataset_name = args.dataset
raw_dataset = load_nc_dataset(dataset_name)
if len(raw_dataset.label.shape) == 1:
raw_dataset.label = raw_dataset.label.unsqueeze(1)
if args.rand_split or args.dataset in ['ogbn-proteins', 'wiki']:
split_idx_lst = [raw_dataset.get_idx_split(train_prop=args.train_prop, valid_prop=args.valid_prop)
for _ in range(args.run)]
else:
split_idx_lst = load_fixed_splits(dataset_name)
if dataset_name == 'ogbn-proteins':
raw_dataset.graph['edge_index'] = to_sparse_tensor(raw_dataset.graph['edge_index'],
raw_dataset.graph['edge_feat'], raw_dataset.graph['num_nodes'])
raw_dataset.graph['node_feat'] = raw_dataset.graph['edge_index'].mean(dim=1)
raw_dataset.graph['edge_index'].set_value_(None)
raw_dataset.graph['edge_feat'] = None
n = raw_dataset.graph['num_nodes']
# infer the number of classes for non one-hot and one-hot labels
c = max(raw_dataset.label.max().item() + 1, raw_dataset.label.shape[1])
d = raw_dataset.graph['node_feat'].shape[1]
data = Data(x=raw_dataset.graph['node_feat'], edge_index=raw_dataset.graph['edge_index'])
data = T.ToSparseTensor()(data)
data.y = raw_dataset.label
data.num_features = d
data.num_classes = c
data.num_nodes = n
split_idx = split_idx_lst[args.split_run_idx]
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
data = process_hetero(data, train_idx, valid_idx, test_idx)
return None, data, split_idx
def process_hetero(data, train_idx, valid_idx, test_idx):
n = data.num_nodes
train_mask = create_mask(n, train_idx)
val_mask = create_mask(n, valid_idx)
test_mask = create_mask(n, test_idx)
data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask
return data
def create_mask(n, pos):
res = torch.zeros(n, dtype=torch.bool)
res[pos] = True
return res
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/HelloSeason/ghmoe.git
[email protected]:HelloSeason/ghmoe.git
HelloSeason
ghmoe
GHMOE
master

搜索帮助