pytorch/ignite

enable automatic mixed precision for xla

Open

#1,931 opened on Apr 12, 2021

View on GitHub
 (8 comments) (1 reaction) (0 assignees)Python (602 forks)batch import
enhancementhelp wanted

Repository metrics

Stars
 (4,313 stars)
PR merge metrics
 (Avg merge 15d 11h) (17 merged PRs in 30d)

Description

Feature

Automatic mixed precision for xla has landed in pytorch 1.8.1 and torch/xla nightly. We should enable it in create_supervised_* helper functions.

Suggested solution

Remove xla and amp checks in _check_arg().

  • For create_supervised_trainer, update supervised_training_step_tpu() function to accept scaler argument just like supervised_training_step_amp().
  • For create_supervised_evaluator, just removing xla and amp checks in _check_arg() should work.
  • For tests, we could remove xla checks and only run with pytorch 1.8.1.

Additional context

This feature should not be included in ignite release until the next torch and xla comes out.

Contributor guide