diff --git a/egs/musdb18/X-UMX/train.py b/egs/musdb18/X-UMX/train.py index fb1f4a9bc..0dec7bed3 100755 --- a/egs/musdb18/X-UMX/train.py +++ b/egs/musdb18/X-UMX/train.py @@ -15,7 +15,7 @@ from asteroid.engine.system import System from asteroid.engine.optimizers import make_optimizer from asteroid.models import XUMX -from asteroid.models.x_umx import _STFT, _Spectrogram +from asteroid.models.x_umx import _STFT, _Spectrogram, _ISTFT from asteroid.losses import singlesrc_mse from torch.nn.modules.loss import _Loss from torch import nn @@ -79,14 +79,14 @@ def freq_domain_loss(s_hat, gt_spec, combination=True): calculated frequency-domain loss """ - n_src = len(s_hat) + n_src, _, _, n_channel, _ = s_hat.shape idx_list = [i for i in range(n_src)] inferences = [] refrences = [] for i, s in enumerate(s_hat): inferences.append(s) - refrences.append(gt_spec[..., 2 * i : 2 * i + 2, :]) + refrences.append(gt_spec[..., n_channel * i : n_channel * (i + 1), :]) assert inferences[0].shape == refrences[0].shape _loss_mse = 0.0 @@ -143,7 +143,7 @@ def time_domain_loss(mix, time_hat, gt_time, combination=True): # Prepare Data and Fix Shape mix_ref = [mix] - mix_ref.extend([gt_time[..., 2 * i : 2 * i + 2, :] for i in range(n_src)]) + mix_ref.extend([gt_time[..., n_channel * i : n_channel * (i + 1), :] for i in range(n_src)]) mix_ref = torch.stack(mix_ref) mix_ref = mix_ref.view(-1, time_length) time_hat = time_hat.view(n_batch * n_channel * time_hat.shape[0], time_hat.shape[-1]) @@ -250,6 +250,8 @@ def __init__( _STFT(window_length=window_length, n_fft=in_chan, n_hop=n_hop), _Spectrogram(spec_power=spec_power, mono=(nb_channels == 1)), ) + self.istft = _ISTFT(window=self.transform[0].window, n_fft=in_chan, hop_length=n_hop) + self.nb_channels = nb_channels self._combi = loss_combine_sources self._multi = loss_use_multidomain self.coef = mix_coef @@ -277,12 +279,33 @@ def forward(self, est_targets, targets, return_est=False, **kwargs): # Fix shape and apply transformation of targets n_batch, n_src, n_channel, time_length = targets.shape - targets = targets.view(n_batch, n_src * n_channel, time_length) - Y = self.transform(targets)[0] + + # downmix in the frequency domain + if n_channel == 2 and self.nb_channels == 1: + Y = [] + signals = [] + for i in range(n_src): + spec, ang = self.transform(targets[:, i, ...]) + Y.append(spec.clone()) + spec = spec.permute(1, 2, 3, 0) + sig_downmix = self.istft(spec.unsqueeze(0), ang.unsqueeze(0)) + signals.append(sig_downmix.permute(1, 0, 2, 3)) + targets = torch.cat(signals, 1) + mixture_t = torch.sum(targets, 1) + targets = targets.squeeze(2) + Y = torch.cat(Y, dim=2) + else: + targets = targets.view(n_batch, n_src * self.nb_channels, time_length) + Y = self.transform(targets)[0] if self._multi: n_src = spec_hat.shape[0] - mixture_t = sum([targets[:, 2 * i : 2 * i + 2, ...] for i in range(n_src)]) + mixture_t = sum( + [ + targets[:, self.nb_channels * i : self.nb_channels * (i + 1), ...] + for i in range(n_src) + ] + ) loss_f = freq_domain_loss(spec_hat, Y, combination=self._combi) loss_t = time_domain_loss(mixture_t, time_hat, targets, combination=self._combi) loss = float(self.coef) * loss_t + loss_f