Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions Lib/test/test_wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from test.support.os_helper import FakePath, unlink
import io
import os
import re
import struct
import tempfile
import sys
Expand Down Expand Up @@ -323,14 +324,14 @@ def test_read_wrong_number_of_channels(self):
b = b'RIFF' + struct.pack('<L', 36) + b'WAVE'
b += b'fmt ' + struct.pack('<LHHLLHH', 16, 1, 0, 11025, 11025, 1, 8)
b += b'data' + struct.pack('<L', 0)
with self.assertRaisesRegex(wave.Error, 'bad # of channels'):
with self.assertRaisesRegex(wave.Error, 'bad # of channels: 0'):
wave.open(io.BytesIO(b))

def test_read_wrong_sample_width(self):
b = b'RIFF' + struct.pack('<L', 36) + b'WAVE'
b += b'fmt ' + struct.pack('<LHHLLHH', 16, 1, 1, 11025, 11025, 1, 0)
b += b'data' + struct.pack('<L', 0)
with self.assertRaisesRegex(wave.Error, 'bad sample width'):
with self.assertRaisesRegex(wave.Error, 'bad sample width: 0'):
wave.open(io.BytesIO(b))

def test_open_in_write_raises(self):
Expand Down Expand Up @@ -430,6 +431,36 @@ def test_setframerate_rounds(self, arg, expected):
f.setframerate(arg)
self.assertEqual(f.getframerate(), expected)

@support.subTests('nchannels', (0, -1))
def test_setnchannels_error_includes_value(self, nchannels):
msg = re.escape(f'bad # of channels: {nchannels!r}')
with wave.open(io.BytesIO(), 'wb') as f:
with self.assertRaisesRegex(wave.Error, msg):
f.setnchannels(nchannels)
with self.assertRaises(wave.Error):
f.close()

@support.subTests('sampwidth', (0, 5))
def test_setsampwidth_error_includes_value(self, sampwidth):
msg = re.escape(f'bad sample width: {sampwidth!r}')
with wave.open(io.BytesIO(), 'wb') as f:
f.setnchannels(1)
with self.assertRaisesRegex(wave.Error, msg):
f.setsampwidth(sampwidth)
with self.assertRaises(wave.Error):
f.close()

@support.subTests('arg', (-1, 0, 0.4))
def test_setframerate_error_includes_value(self, arg):
msg = re.escape(f'bad frame rate: {arg!r}')
with wave.open(io.BytesIO(), 'wb') as f:
f.setnchannels(1)
f.setsampwidth(2)
with self.assertRaisesRegex(wave.Error, msg):
f.setframerate(arg)
with self.assertRaises(wave.Error):
f.close()

def test_write_odd_data_chunk_pads_and_updates_riff_size(self):
# gh-117716: odd-sized data chunks must be padded with one zero byte.
with io.BytesIO() as output:
Expand Down
16 changes: 8 additions & 8 deletions Lib/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,9 @@ def _read_fmt_chunk(self, chunk):
raise Error(subformat_msg)
self._sampwidth = (sampwidth + 7) // 8
if not self._sampwidth:
raise Error('bad sample width')
raise Error(f'bad sample width: {sampwidth!r}')
if not self._nchannels:
raise Error('bad # of channels')
raise Error(f'bad # of channels: {self._nchannels!r}')
self._framesize = self._nchannels * self._sampwidth
self._comptype = 'NONE'
self._compname = 'not compressed'
Expand Down Expand Up @@ -495,7 +495,7 @@ def setnchannels(self, nchannels):
if self._datawritten:
raise Error('cannot change parameters after starting to write')
if nchannels < 1:
raise Error('bad # of channels')
raise Error(f'bad # of channels: {nchannels!r}')
self._nchannels = nchannels

def getnchannels(self):
Expand All @@ -510,7 +510,7 @@ def setsampwidth(self, sampwidth):
if sampwidth not in (4, 8):
raise Error('unsupported sample width for IEEE float format')
elif sampwidth < 1 or sampwidth > 4:
raise Error('bad sample width')
raise Error(f'bad sample width: {sampwidth!r}')
self._sampwidth = sampwidth

def getsampwidth(self):
Expand All @@ -521,10 +521,10 @@ def getsampwidth(self):
def setframerate(self, framerate):
if self._datawritten:
raise Error('cannot change parameters after starting to write')
framerate = int(round(framerate))
if framerate <= 0:
raise Error('bad frame rate')
self._framerate = framerate
rounded_framerate = int(round(framerate))
if rounded_framerate <= 0:
raise Error(f'bad frame rate: {framerate!r}')
self._framerate = rounded_framerate

def getframerate(self):
if not self._framerate:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Error messages in :mod:`wave` for invalid channel count, sample width, and
frame rate now include the offending value.
Loading