Skip to content

Commit

Permalink
Update model_lib.py
Browse files Browse the repository at this point in the history
Add some print statements to show progress without showing all the logging info.
  • Loading branch information
crazydonkey200 authored Mar 1, 2025
1 parent fcfa082 commit 806bc73
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions hero/model_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,6 +1697,7 @@ def train_one_step_fn(state, batch, lr, add_log_info=False):
while steps < config.num_train_steps:
with jax.profiler.StepTraceAnnotation('train', step_num=steps):
logging.info('steps: %s', steps)
print(f'steps: {steps}')
t1 = time.time()
batch = next(train_set_iter)
batch = jax.tree_util.tree_map(build_global_array_fn, batch)
Expand Down Expand Up @@ -1734,8 +1735,10 @@ def train_one_step_fn(state, batch, lr, add_log_info=False):
train_loss = loss.addressable_data(0)
train_loss = np.array(train_loss)
logging.info('train_loss: %s', train_loss)
print(f'train_loss: {train_loss}')
step_time = time.time() - prev_step_timestamp
logging.info('%s secs per step.', step_time)
print(f'{step_time} secs per step')
prev_step_timestamp = time.time()
metrics_aggregator.add('train_loss', train_loss)

Expand Down Expand Up @@ -1802,6 +1805,8 @@ def train_one_step_fn(state, batch, lr, add_log_info=False):
validation_tokens=total_num_tokens,
validation_eval_time=validation_eval_time))
writer.flush()
print(f'validation_loss: {mean_eval_loss}')
print(f'validation_eval_time: {validation_eval_time}')
steps += 1
# Ensure all the checkpoints are saved.
mngr.close()
Expand Down

0 comments on commit 806bc73

Please sign in to comment.