代码拉取完成,页面将自动刷新
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch
from torch.nn import functional as F
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file("20B_tokenizer.json")
args = types.SimpleNamespace()
args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-430m/RWKV-4-Pile-430M-20220808-8066'
args.n_layer = 24
args.n_embd = 1024
context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
TEMPERATURE = 1.0
TOP_P = 0.85
########################################################################################################
class RWKV_RNN(torch.jit.ScriptModule):
def __init__(self, args):
super().__init__()
self.args = args
self.eval() # set torch to inference mode
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
for k in w.keys():
if '.time_' in k: w[k] = w[k].squeeze()
if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
else: w[k] = w[k].float() # convert to f32 type
self.w = types.SimpleNamespace() # set self.w from w
self.w.blocks = {}
for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here: here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
def layer_norm(self, x, w):
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
@torch.jit.script_method
def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
state[5*i+0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
return r * (vw @ k)
@torch.jit.script_method
def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
state[5*i+1] = x
r = torch.sigmoid(rw @ xr)
k = kw @ xk
v = vw @ xv
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = time_first + k
qq = torch.maximum(pp, ww)
e1 = torch.exp(pp - qq)
e2 = torch.exp(ww - qq)
a = e1 * aa + e2 * v
b = e1 * bb + e2
wkv = a / b
ww = pp + time_decay
qq = torch.maximum(ww, k)
e1 = torch.exp(ww - qq)
e2 = torch.exp(k - qq)
state[5*i+2] = e1 * aa + e2 * v
state[5*i+3] = e1 * bb + e2
state[5*i+4] = qq
return ow @ (r * wkv)
def forward(self, token, state):
with torch.no_grad():
if state == None:
state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)
for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity
x = self.w.emb.weight[token]
x = self.layer_norm(x, self.w.blocks[0].ln0)
for i in range(self.args.n_layer):
att = self.w.blocks[i].att
x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay,
att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)
ffn = self.w.blocks[i].ffn
x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
ffn.time_mix_k, ffn.time_mix_r,
ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
return x.float(), state
##########################################################################################################
def sample_logits(out, temperature=1.0, top_p=0.8):
probs = F.softmax(out, dim=-1).numpy()
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return out
########################################################################################################
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)
print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None
for token in tokenizer.encode(context).ids:
init_out, init_state = model.forward(token, init_state)
for TRIAL in range(NUM_TRIALS):
print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
all_tokens = []
out_last = 0
out, state = init_out.clone(), init_state.clone()
for i in range(LENGTH_PER_TRIAL):
token = sample_logits(out, TEMPERATURE, TOP_P)
all_tokens += [token]
tmp = tokenizer.decode(all_tokens[out_last:])
if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
print(tmp, end="", flush=True)
out_last = i + 1
out, state = model.forward(token, state)
print('\n')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。