Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 4d81b36

Browse files
committedDec 21, 2016
A tiny fix in PyDataProvider2
* hidden decorator kwargs in DataProvider.__init__ * also add unit test for this.
1 parent 2965df5 commit 4d81b36

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed
 

‎paddle/gserver/tests/test_PyDataProvider2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from paddle.trainer.PyDataProvider2 import *
1818

1919

20-
@provider(input_types=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)])
20+
@provider(slots=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)])
2121
def test_dense_no_seq(setting, filename):
2222
for i in xrange(200):
2323
yield [(float(j - 100) * float(i + 1)) / 200.0 for j in xrange(200)]

‎python/paddle/trainer/PyDataProvider2.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def provider(input_types=None,
232232
check=False,
233233
check_fail_continue=False,
234234
init_hook=None,
235-
**kwargs):
235+
**outter_kwargs):
236236
"""
237237
Provider decorator. Use it to make a function into PyDataProvider2 object.
238238
In this function, user only need to get each sample for some train/test
@@ -318,11 +318,6 @@ def __init__(self, file_list, **kwargs):
318318
self.logger = logging.getLogger("")
319319
self.logger.setLevel(logging.INFO)
320320
self.input_types = None
321-
if 'slots' in kwargs:
322-
self.logger.warning('setting slots value is deprecated, '
323-
'please use input_types instead.')
324-
self.slots = kwargs['slots']
325-
self.slots = input_types
326321
self.should_shuffle = should_shuffle
327322

328323
true_table = [1, 't', 'true', 'on']
@@ -358,9 +353,19 @@ def __init__(self, file_list, **kwargs):
358353
self.check = check
359354
if init_hook is not None:
360355
init_hook(self, file_list=file_list, **kwargs)
356+
357+
if 'slots' in outter_kwargs:
358+
self.logger.warning('setting slots value is deprecated, '
359+
'please use input_types instead.')
360+
self.slots = outter_kwargs['slots']
361+
if input_types is not None:
362+
self.slots = input_types
363+
361364
if self.input_types is not None:
362365
self.slots = self.input_types
363-
assert self.slots is not None
366+
367+
assert self.slots is not None, \
368+
"Data Provider's input_types must be set"
364369
assert self.generator is not None
365370

366371
use_dynamic_order = False

0 commit comments

Comments
 (0)
Please sign in to comment.