From 6aeb23c548ff903ea020c188801ed73e0c6d5421 Mon Sep 17 00:00:00 2001 From: Caj Larsson Date: Sat, 12 Aug 2023 18:42:03 +0200 Subject: [PATCH] Jit works when you put all the inputs in the docrated fn --- jit-train-makemore.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/jit-train-makemore.py b/jit-train-makemore.py index 72f1172..3020d6c 100644 --- a/jit-train-makemore.py +++ b/jit-train-makemore.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 - from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes from tinygrad.nn import Linear, Embedding @@ -62,30 +61,25 @@ class Model: def parameters(self): return [self.l1.weight, self.l1.bias, self.l2.weight, self.l2.bias, self.emb.weight] -@TinyJit -def jitloss(x,y): - return sparse_categorical_crossentropy(x, y).realize() - block_size=3 random.shuffle(words) X,Y = dataset(words, block_size) m = Model() opt = SGD(m.parameters(), lr=0.1) -for step in range(1000): - batch_ix = npr.randint(0, X.shape[0], (32,)) - x_batch, y_batch = Tensor(X[batch_ix], requires_grad=False), Y[batch_ix] - - logits = m(x_batch) - loss = jitloss(logits, y_batch) - +@TinyJit +def stepf(m, opt, x, y): + logits = m(x) + loss = sparse_categorical_crossentropy(logits, y) opt.zero_grad() - loss.backward() - if m.parameters()[0].grad is None: - break - opt.step() + return loss.numpy() + +for step in range(100000): + batch_ix = npr.randint(0, X.shape[0], (32,)) + x_batch, y_batch = Tensor(X[batch_ix], requires_grad=False), Y[batch_ix] + loss = stepf(m, opt, x_batch, y_batch) if step % 100 == 0: - print(f"Step {step+1} | Loss: {loss.numpy()}") + print(f"Step {step+1} | Loss: {loss}")