Jit works when you put all the inputs in the docrated fn

master
Caj Larsson 1 year ago
parent c0c2da78c5
commit 6aeb23c548

@ -1,5 +1,4 @@
#!/usr/bin/env python3
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
from tinygrad.nn import Linear, Embedding
@ -62,30 +61,25 @@ class Model:
def parameters(self):
return [self.l1.weight, self.l1.bias, self.l2.weight, self.l2.bias, self.emb.weight]
@TinyJit
def jitloss(x,y):
return sparse_categorical_crossentropy(x, y).realize()
block_size=3
random.shuffle(words)
X,Y = dataset(words, block_size)
m = Model()
opt = SGD(m.parameters(), lr=0.1)
for step in range(1000):
batch_ix = npr.randint(0, X.shape[0], (32,))
x_batch, y_batch = Tensor(X[batch_ix], requires_grad=False), Y[batch_ix]
logits = m(x_batch)
loss = jitloss(logits, y_batch)
@TinyJit
def stepf(m, opt, x, y):
logits = m(x)
loss = sparse_categorical_crossentropy(logits, y)
opt.zero_grad()
loss.backward()
if m.parameters()[0].grad is None:
break
opt.step()
return loss.numpy()
for step in range(100000):
batch_ix = npr.randint(0, X.shape[0], (32,))
x_batch, y_batch = Tensor(X[batch_ix], requires_grad=False), Y[batch_ix]
loss = stepf(m, opt, x_batch, y_batch)
if step % 100 == 0:
print(f"Step {step+1} | Loss: {loss.numpy()}")
print(f"Step {step+1} | Loss: {loss}")

Loading…
Cancel
Save